From c79e7030894f25ee6cf0e5f490e7ca97097a76d1 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Mon, 11 Sep 2023 10:58:08 -0700 Subject: [PATCH] Add deployment endpoints Document replicate.deployments.predictions.create in README --- README.md | 17 +++++++++++++++++ index.d.ts | 17 ++++++++++++++++- index.js | 7 +++++++ index.test.ts | 37 +++++++++++++++++++++++++++++++++++++ lib/deployments.js | 37 +++++++++++++++++++++++++++++++++++++ 5 files changed, 114 insertions(+), 1 deletion(-) create mode 100644 lib/deployments.js diff --git a/README.md b/README.md index 20e11ba..13968ca 100644 --- a/README.md +++ b/README.md @@ -552,6 +552,23 @@ const response = await replicate.trainings.list(); } ``` +### `replicate.deployments.predictions.create` + +```js +const response = await replicate.deployments.predictions.create(deployment_owner, deployment_name, options); +``` + +| name | type | description | +| ------------------------------- | -------- | -------------------------------------------------------------------------------------------------------------------------------- | +| `deployment_owner` | string | **Required**. The name of the user or organization that owns the deployment | +| `deployment_name` | string | **Required**. The name of the deployment | +| `options.input` | object | **Required**. An object with the model's inputs | +| `options.webhook` | string | An HTTPS URL for receiving a webhook when the prediction has new output | +| `options.webhook_events_filter` | string[] | You can change which events trigger webhook requests by specifying webhook events (`start` \| `output` \| `logs` \| `completed`) | + +Use `replicate.wait` to wait for a prediction to finish, +or `replicate.predictions.cancel` to cancel a prediction before it finishes. + ### `replicate.paginate` Pass another method as an argument to iterate over results diff --git a/index.d.ts b/index.d.ts index 5f94ce9..7f4461c 100644 --- a/index.d.ts +++ b/index.d.ts @@ -47,7 +47,7 @@ declare module 'replicate' { logs?: string; metrics?: { predict_time?: number; - } + }; webhook?: string; webhook_events_filter?: WebhookEventType[]; created_at: string; @@ -156,5 +156,20 @@ declare module 'replicate' { cancel(training_id: string): Promise; list(): Promise>; }; + + deployments: { + predictions: { + create( + deployment_name: string, + deployment_owner: string, + options: { + input: object; + stream?: boolean; + webhook?: string; + webhook_events_filter?: WebhookEventType[]; + } + ): Promise; + }; + }; } } diff --git a/index.js b/index.js index 293f447..9662235 100644 --- a/index.js +++ b/index.js @@ -2,6 +2,7 @@ const ApiError = require('./lib/error'); const { withAutomaticRetries } = require('./lib/util'); const collections = require('./lib/collections'); +const deployments = require('./lib/deployments'); const models = require('./lib/models'); const predictions = require('./lib/predictions'); const trainings = require('./lib/trainings'); @@ -69,6 +70,12 @@ class Replicate { cancel: trainings.cancel.bind(this), list: trainings.list.bind(this), }; + + this.deployments = { + predictions: { + create: deployments.predictions.create.bind(this), + } + }; } /** diff --git a/index.test.ts b/index.test.ts index 442e36e..fb65f29 100644 --- a/index.test.ts +++ b/index.test.ts @@ -582,6 +582,43 @@ describe('Replicate client', () => { }); }); + describe('deployments.predictions.create', () => { + test('Calls the correct API route with the correct payload', async () => { + nock(BASE_URL) + .post('/deployments/replicate/greeter/predictions') + .reply(200, { + id: 'mfrgcyzzme2wkmbwgzrgmntcg', + version: + '5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa', + urls: { + get: 'https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq', + cancel: + 'https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel', + }, + created_at: '2022-09-10T09:44:22.165836Z', + started_at: null, + completed_at: null, + status: 'starting', + input: { + text: 'Alice', + }, + output: null, + error: null, + logs: null, + metrics: {}, + }); + const prediction = await client.deployments.predictions.create("replicate", "greeter", { + input: { + text: 'Alice', + }, + webhook: 'http://test.host/webhook', + webhook_events_filter: [ 'output', 'completed' ], + }); + expect(prediction.id).toBe('mfrgcyzzme2wkmbwgzrgmntcg'); + }); + // Add more tests for error handling, edge cases, etc. + }); + describe('run', () => { test('Calls the correct API routes', async () => { let firstPollingRequest = true; diff --git a/lib/deployments.js b/lib/deployments.js new file mode 100644 index 0000000..4682c9b --- /dev/null +++ b/lib/deployments.js @@ -0,0 +1,37 @@ +/** + * Create a new prediction with a deployment + * + * @param {string} deployment_owner - Required. The username of the user or organization who owns the deployment + * @param {string} deployment_name - Required. The name of the deployment + * @param {object} options + * @param {object} options.input - Required. An object with the model inputs + * @param {boolean} [options.stream] - Whether to stream the prediction output. Defaults to false + * @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output + * @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`) + * @returns {Promise} Resolves with the created prediction data + */ +async function createPrediction(deployment_owner, deployment_name, options) { + const { stream, ...data } = options; + + if (data.webhook) { + try { + // eslint-disable-next-line no-new + new URL(data.webhook); + } catch (err) { + throw new Error('Invalid webhook URL'); + } + } + + const response = await this.request(`/deployments/${deployment_owner}/${deployment_name}/predictions`, { + method: 'POST', + data: { ...data, stream }, + }); + + return response.json(); +} + +module.exports = { + predictions: { + create: createPrediction, + } +};