diff --git a/index.d.ts b/index.d.ts index 4f6e49f..6b56beb 100644 --- a/index.d.ts +++ b/index.d.ts @@ -88,6 +88,7 @@ declare module 'replicate' { wait?: { interval?: number; max_attempts?: number }; webhook?: string; webhook_events_filter?: WebhookEventType[]; + signal?: AbortSignal; } ): Promise; @@ -105,7 +106,8 @@ declare module 'replicate' { options: { interval?: number; max_attempts?: number; - } + }, + stop?: (Prediction) => Promise ): Promise; collections: { diff --git a/index.js b/index.js index e6d50d3..b1803a1 100644 --- a/index.js +++ b/index.js @@ -85,6 +85,7 @@ class Replicate { * @param {number} [options.wait.max_attempts] - Maximum number of polling attempts. Defaults to no limit * @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output * @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`) + * @param {AbortSignal} [options.signal] - AbortSignal to cancel the prediction * @throws {Error} If the prediction failed * @returns {Promise} - Resolves with the output of running the model */ @@ -116,7 +117,16 @@ class Replicate { version, }); - prediction = await this.wait(prediction, wait || {}); + const { signal } = options; + + prediction = await this.wait(prediction, wait || {}, async ({ id }) => { + if (signal && signal.aborted) { + await this.predictions.cancel(id); + return true; // stop polling + } + + return false; // continue polling + }); if (prediction.status === 'failed') { throw new Error(`Prediction failed: ${prediction.error}`); @@ -150,7 +160,11 @@ class Replicate { ); } - const { method = 'GET', params = {}, data } = options; + const { + method = 'GET', + params = {}, + data, + } = options; Object.entries(params).forEach(([key, value]) => { url.searchParams.append(key, value); @@ -219,11 +233,12 @@ class Replicate { * @param {object} options - Options * @param {number} [options.interval] - Polling interval in milliseconds. Defaults to 250 * @param {number} [options.max_attempts] - Maximum number of polling attempts. Defaults to no limit + * @param {Function} [stop] - Async callback function that is called after each polling attempt. Receives the prediction object as an argument. Return false to cancel polling. * @throws {Error} If the prediction doesn't complete within the maximum number of attempts * @throws {Error} If the prediction failed * @returns {Promise} Resolves with the completed prediction object */ - async wait(prediction, options) { + async wait(prediction, options, stop) { const { id } = prediction; if (!id) { throw new Error('Invalid prediction'); @@ -261,6 +276,9 @@ class Replicate { /* eslint-disable no-await-in-loop */ await sleep(interval); updatedPrediction = await this.predictions.get(prediction.id); + if (stop && await stop(updatedPrediction) === true) { + break; + } /* eslint-enable no-await-in-loop */ } diff --git a/index.test.ts b/index.test.ts index b726137..9147b81 100644 --- a/index.test.ts +++ b/index.test.ts @@ -581,6 +581,44 @@ describe('Replicate client', () => { }); }).rejects.toThrow('Invalid webhook URL'); }); + + test('Aborts the operation when abort signal is invoked', async () => { + const controller = new AbortController(); + const { signal } = controller; + + const scope = nock(BASE_URL) + .post('/predictions', (body) => { + controller.abort(); + return body; + }) + .reply(201, { + id: 'ufawqhfynnddngldkgtslldrkq', + status: 'processing', + }) + .persist() + .get('/predictions/ufawqhfynnddngldkgtslldrkq') + .reply(200, { + id: 'ufawqhfynnddngldkgtslldrkq', + status: 'processing', + }) + .post('/predictions/ufawqhfynnddngldkgtslldrkq/cancel') + .reply(200, { + id: 'ufawqhfynnddngldkgtslldrkq', + status: 'canceled', + });; + + await client.run( + 'owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa', + { + input: { text: 'Hello, world!' }, + signal, + } + ) + + expect(signal.aborted).toBe(true); + + scope.done(); + }); }); // Continue with tests for other methods