diff --git a/README.md b/README.md index 06b7436..0e17557 100644 --- a/README.md +++ b/README.md @@ -251,6 +251,7 @@ const response = await replicate.predictions.create(options); | ------------------------------- | -------- | -------------------------------------------------------------------------------------------------------------------------------- | | `options.version` | string | **Required**. The model version | | `options.input` | object | **Required**. An object with the model's inputs | +| `options.stream` | boolean | Requests a URL for streaming output output | | `options.webhook` | string | An HTTPS URL for receiving a webhook when the prediction has new output | | `options.webhook_events_filter` | string[] | You can change which events trigger webhook requests by specifying webhook events (`start` \| `output` \| `logs` \| `completed`) | @@ -258,10 +259,6 @@ const response = await replicate.predictions.create(options); { "id": "ufawqhfynnddngldkgtslldrkq", "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", - "urls": { - "get": "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq", - "cancel": "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel" - }, "status": "succeeded", "input": { "text": "Alice" @@ -272,10 +269,56 @@ const response = await replicate.predictions.create(options); "metrics": {}, "created_at": "2022-04-26T22:13:06.224088Z", "started_at": null, - "completed_at": null + "completed_at": null, + "urls": { + "get": "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq", + "cancel": "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel", + "stream": "https://streaming.api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq" // Present only if `options.stream` is `true` + } } ``` +#### Streaming + +Specify the `stream` option when creating a prediction +to request a URL to receive streaming output using +[server-sent events (SSE)](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events). + +If the requested model version supports streaming, +then the returned prediction will have a `stream` entry in its `urls` property +with a URL that you can use to construct an +[`EventSource`](https://developer.mozilla.org/en-US/docs/Web/API/EventSource). + +```js +if (prediction && prediction.urls && prediction.urls.stream) { + const source = new EventSource(prediction.urls.stream, { withCredentials: true }); + + source.addEventListener("output", (e) => { + console.log("output", e.data); + }); + + source.addEventListener("error"), (e) => { + console.error("error", JSON.parse(e.data)); + }); + + source.addEventListener("done"), (e) => { + source.close(); + console.log("done", JSON.parse(e.data)); + }); +} +``` + +A prediction's event stream consists of the following event types: + +| event | format | description | +| -------- | ---------- | ---------------------------------------------- | +| `output` | plain text | Emitted when the prediction returns new output | +| `error` | JSON | Emitted when the prediction returns an error | +| `done` | JSON | Emitted when the prediction finishes | + +A `done` event is emitted when a prediction finishes successfully, +is cancelled, or produces an error. + ### `replicate.predictions.get` ```js diff --git a/index.d.ts b/index.d.ts index 03efee0..4f6e49f 100644 --- a/index.d.ts +++ b/index.d.ts @@ -53,6 +53,11 @@ declare module 'replicate' { created_at: string; updated_at: string; completed_at?: string; + urls: { + get: string; + cancel: string; + stream?: string; + }; } export type Training = Prediction; @@ -80,18 +85,26 @@ declare module 'replicate' { identifier: `${string}/${string}:${string}`, options: { input: object; - wait?: boolean | { interval?: number; maxAttempts?: number }; + wait?: { interval?: number; max_attempts?: number }; webhook?: string; webhook_events_filter?: WebhookEventType[]; } ): Promise; - request(route: string, parameters: any): Promise; + + request(route: string | URL, options: { + method?: string; + headers?: object | Headers; + params?: object; + data?: object; + }): Promise; + paginate(endpoint: () => Promise>): AsyncGenerator<[ T ]>; + wait( prediction: Prediction, options: { interval?: number; - maxAttempts?: number; + max_attempts?: number; } ): Promise; @@ -116,6 +129,7 @@ declare module 'replicate' { create(options: { version: string; input: object; + stream?: boolean; webhook?: string; webhook_events_filter?: WebhookEventType[]; }): Promise; diff --git a/index.js b/index.js index de75431..e6d50d3 100644 --- a/index.js +++ b/index.js @@ -80,15 +80,17 @@ class Replicate { * @param {string} identifier - Required. The model version identifier in the format "{owner}/{name}:{version}" * @param {object} options * @param {object} options.input - Required. An object with the model inputs - * @param {object} [options.wait] - Whether to wait for the prediction to finish. Defaults to false + * @param {object} [options.wait] - Options for waiting for the prediction to finish * @param {number} [options.wait.interval] - Polling interval in milliseconds. Defaults to 250 - * @param {number} [options.wait.maxAttempts] - Maximum number of polling attempts. Defaults to no limit + * @param {number} [options.wait.max_attempts] - Maximum number of polling attempts. Defaults to no limit * @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output * @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`) * @throws {Error} If the prediction failed * @returns {Promise} - Resolves with the output of running the model */ async run(identifier, options) { + const { wait, ...data } = options; + // Define a pattern for owner and model names that allows // letters, digits, and certain special characters. // Example: "user123", "abc__123", "user.name" @@ -108,12 +110,14 @@ class Replicate { } const { version } = match.groups; - const prediction = await this.predictions.create({ - wait: true, - ...options, + + let prediction = await this.predictions.create({ + ...data, version, }); + prediction = await this.wait(prediction, wait || {}); + if (prediction.status === 'failed') { throw new Error(`Prediction failed: ${prediction.error}`); } @@ -125,43 +129,53 @@ class Replicate { * Make a request to the Replicate API. * * @param {string} route - REST API endpoint path - * @param {object} parameters - Request parameters - * @param {string} [parameters.method] - HTTP method. Defaults to GET - * @param {object} [parameters.params] - Query parameters - * @param {object} [parameters.data] - Body parameters - * @returns {Promise} - Resolves with the API response data + * @param {object} options - Request parameters + * @param {string} [options.method] - HTTP method. Defaults to GET + * @param {object} [options.params] - Query parameters + * @param {object|Headers} [options.headers] - HTTP headers + * @param {object} [options.data] - Body parameters + * @returns {Promise} - Resolves with the response object * @throws {ApiError} If the request failed */ - async request(route, parameters) { + async request(route, options) { const { auth, baseUrl, userAgent } = this; - const url = new URL( - route.startsWith('/') ? route.slice(1) : route, - baseUrl.endsWith('/') ? baseUrl : `${baseUrl}/` - ); + let url; + if (route instanceof URL) { + url = route; + } else { + url = new URL( + route.startsWith('/') ? route.slice(1) : route, + baseUrl.endsWith('/') ? baseUrl : `${baseUrl}/` + ); + } - const { method = 'GET', params = {}, data } = parameters; + const { method = 'GET', params = {}, data } = options; Object.entries(params).forEach(([key, value]) => { url.searchParams.append(key, value); }); - const headers = { - Authorization: `Token ${auth}`, - 'Content-Type': 'application/json', - 'User-Agent': userAgent, - }; + const headers = new Headers(); + headers.append('Authorization', `Token ${auth}`); + headers.append('Content-Type', 'application/json'); + headers.append('User-Agent', userAgent); + if (options.headers) { + options.headers.forEach((value, key) => { + headers.append(key, value); + }); + } - const options = { + const init = { method, headers, body: data ? JSON.stringify(data) : undefined, }; - const response = await this.fetch(url, options); + const response = await this.fetch(url, init); if (!response.ok) { - const request = new Request(url, options); + const request = new Request(url, init); const responseText = await response.text(); throw new ApiError( `Request to ${url} failed with status ${response.status} ${response.statusText}: ${responseText}.`, @@ -170,7 +184,7 @@ class Replicate { ); } - return response.json(); + return response; } /** @@ -188,7 +202,7 @@ class Replicate { const response = await endpoint(); yield response.results; if (response.next) { - const nextPage = () => this.request(response.next, { method: 'GET' }); + const nextPage = () => this.request(response.next, { method: 'GET' }).then((r) => r.json()); yield* this.paginate(nextPage); } } @@ -204,7 +218,7 @@ class Replicate { * @param {object} prediction - Prediction object * @param {object} options - Options * @param {number} [options.interval] - Polling interval in milliseconds. Defaults to 250 - * @param {number} [options.maxAttempts] - Maximum number of polling attempts. Defaults to no limit + * @param {number} [options.max_attempts] - Maximum number of polling attempts. Defaults to no limit * @throws {Error} If the prediction doesn't complete within the maximum number of attempts * @throws {Error} If the prediction failed * @returns {Promise} Resolves with the completed prediction object @@ -230,7 +244,7 @@ class Replicate { let attempts = 0; const interval = options.interval || 250; - const maxAttempts = options.maxAttempts || null; + const max_attempts = options.max_attempts || null; while ( updatedPrediction.status !== 'succeeded' && @@ -238,9 +252,9 @@ class Replicate { updatedPrediction.status !== 'canceled' ) { attempts += 1; - if (maxAttempts && attempts > maxAttempts) { + if (max_attempts && attempts > max_attempts) { throw new Error( - `Prediction ${id} did not finish after ${maxAttempts} attempts` + `Prediction ${id} did not finish after ${max_attempts} attempts` ); } diff --git a/index.test.ts b/index.test.ts index 0538afd..b726137 100644 --- a/index.test.ts +++ b/index.test.ts @@ -152,6 +152,24 @@ describe('Replicate client', () => { expect(prediction.id).toBe('ufawqhfynnddngldkgtslldrkq'); }); + test('Passes stream parameter to API endpoint', async () => { + nock(BASE_URL) + .post('/predictions') + .reply(201, (_uri, body) => { + expect(body[ 'stream' ]).toBe(true); + return body + }) + + await client.predictions.create({ + version: + '5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa', + input: { + prompt: 'Tell me a story', + }, + stream: true + }); + }); + test('Throws an error if webhook URL is invalid', async () => { await expect(async () => { await client.predictions.create({ @@ -506,7 +524,7 @@ describe('Replicate client', () => { status: 'processing', }) .get('/predictions/ufawqhfynnddngldkgtslldrkq') - .reply(200, { + .reply(201, { id: 'ufawqhfynnddngldkgtslldrkq', status: 'succeeded', output: 'foobar', diff --git a/lib/collections.js b/lib/collections.js index de6f814..b40d091 100644 --- a/lib/collections.js +++ b/lib/collections.js @@ -5,9 +5,11 @@ * @returns {Promise} - Resolves with the collection data */ async function getCollection(collection_slug) { - return this.request(`/collections/${collection_slug}`, { + const response = await this.request(`/collections/${collection_slug}`, { method: 'GET', }); + + return response.json(); } /** @@ -16,9 +18,11 @@ async function getCollection(collection_slug) { * @returns {Promise} - Resolves with the collections data */ async function listCollections() { - return this.request('/collections', { + const response = await this.request('/collections', { method: 'GET', }); + + return response.json(); } module.exports = { get: getCollection, list: listCollections }; diff --git a/lib/error.js b/lib/error.js index fc3b184..81dbe47 100644 --- a/lib/error.js +++ b/lib/error.js @@ -1,3 +1,6 @@ +/** + * A representation of an API error. + */ class ApiError extends Error { /** * Creates a representation of an API error. diff --git a/lib/models.js b/lib/models.js index 55c9b43..373ed23 100644 --- a/lib/models.js +++ b/lib/models.js @@ -6,9 +6,11 @@ * @returns {Promise} Resolves with the model data */ async function getModel(model_owner, model_name) { - return this.request(`/models/${model_owner}/${model_name}`, { + const response = await this.request(`/models/${model_owner}/${model_name}`, { method: 'GET', }); + + return response.json(); } /** @@ -19,9 +21,11 @@ async function getModel(model_owner, model_name) { * @returns {Promise} Resolves with the list of model versions */ async function listModelVersions(model_owner, model_name) { - return this.request(`/models/${model_owner}/${model_name}/versions`, { + const response = await this.request(`/models/${model_owner}/${model_name}/versions`, { method: 'GET', }); + + return response.json(); } /** @@ -33,12 +37,14 @@ async function listModelVersions(model_owner, model_name) { * @returns {Promise} Resolves with the model version data */ async function getModelVersion(model_owner, model_name, version_id) { - return this.request( + const response = await this.request( `/models/${model_owner}/${model_name}/versions/${version_id}`, { method: 'GET', } ); + + return response.json(); } module.exports = { diff --git a/lib/predictions.js b/lib/predictions.js index ba46ef8..32b18ad 100644 --- a/lib/predictions.js +++ b/lib/predictions.js @@ -4,15 +4,13 @@ * @param {object} options * @param {string} options.version - Required. The model version * @param {object} options.input - Required. An object with the model inputs - * @param {boolean|object} [options.wait] - Whether to wait for the prediction to finish. Defaults to false - * @param {number} [options.wait.interval] - Polling interval in milliseconds. Defaults to 250 - * @param {number} [options.wait.maxAttempts] - Maximum number of polling attempts. Defaults to no limit * @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output * @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`) - * @returns {Promise} Resolves with the created prediction data + * @param {boolean} [options.stream] - Whether to stream the prediction output. Defaults to false + * @returns {Promise} Resolves with the created prediction */ async function createPrediction(options) { - const { wait, ...data } = options; + const { stream, ...data } = options; if (data.webhook) { try { @@ -23,22 +21,12 @@ async function createPrediction(options) { } } - const prediction = this.request('/predictions', { + const response = await this.request('/predictions', { method: 'POST', - data, + data: { ...data, stream }, }); - if (wait) { - const { maxAttempts, interval } = wait; - - // eslint-disable-next-line no-promise-executor-return - const sleep = (ms) => new Promise((resolve) => setTimeout(resolve, ms)); - await sleep(interval || 250); - - return this.wait(await prediction, { maxAttempts, interval }); - } - - return prediction; + return response.json(); } /** @@ -48,9 +36,11 @@ async function createPrediction(options) { * @returns {Promise} Resolves with the prediction data */ async function getPrediction(prediction_id) { - return this.request(`/predictions/${prediction_id}`, { + const response = await this.request(`/predictions/${prediction_id}`, { method: 'GET', }); + + return response.json(); } /** @@ -60,9 +50,11 @@ async function getPrediction(prediction_id) { * @returns {Promise} Resolves with the data for the training */ async function cancelPrediction(prediction_id) { - return this.request(`/predictions/${prediction_id}/cancel`, { + const response = await this.request(`/predictions/${prediction_id}/cancel`, { method: 'POST', }); + + return response.json(); } /** @@ -71,9 +63,11 @@ async function cancelPrediction(prediction_id) { * @returns {Promise} - Resolves with a page of predictions */ async function listPredictions() { - return this.request('/predictions', { + const response = await this.request('/predictions', { method: 'GET', }); + + return response.json(); } module.exports = { diff --git a/lib/trainings.js b/lib/trainings.js index 688e919..edb037c 100644 --- a/lib/trainings.js +++ b/lib/trainings.js @@ -23,12 +23,12 @@ async function createTraining(model_owner, model_name, version_id, options) { } } - const training = this.request(`/models/${model_owner}/${model_name}/versions/${version_id}/trainings`, { + const response = await this.request(`/models/${model_owner}/${model_name}/versions/${version_id}/trainings`, { method: 'POST', data, }); - return training; + return response.json(); } /** @@ -38,9 +38,11 @@ async function createTraining(model_owner, model_name, version_id, options) { * @returns {Promise} Resolves with the data for the training */ async function getTraining(training_id) { - return this.request(`/trainings/${training_id}`, { + const response = await this.request(`/trainings/${training_id}`, { method: 'GET', }); + + return response.json(); } /** @@ -50,9 +52,11 @@ async function getTraining(training_id) { * @returns {Promise} Resolves with the data for the training */ async function cancelTraining(training_id) { - return this.request(`/trainings/${training_id}/cancel`, { + const response = await this.request(`/trainings/${training_id}/cancel`, { method: 'POST', }); + + return response.json(); } /** @@ -61,9 +65,11 @@ async function cancelTraining(training_id) { * @returns {Promise} - Resolves with a page of trainings */ async function listTrainings() { - return this.request('/trainings', { + const response = await this.request('/trainings', { method: 'GET', }); + + return response.json(); } module.exports = {