From 0975aa2e735f0f1100828ce8c7829fceb2448525 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Mon, 7 Aug 2023 05:44:52 -0700 Subject: [PATCH] Add automatic retry policy --- index.js | 6 ++++- index.test.ts | 74 +++++++++++++++++++++++++++++++++++++++++++++++++-- lib/util.js | 69 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 146 insertions(+), 3 deletions(-) create mode 100644 lib/util.js diff --git a/index.js b/index.js index 84806a0..6a23eff 100644 --- a/index.js +++ b/index.js @@ -1,4 +1,5 @@ const ApiError = require('./lib/error'); +const { withAutomaticRetries } = require('./lib/util'); const collections = require('./lib/collections'); const models = require('./lib/models'); @@ -201,7 +202,10 @@ class Replicate { body: data ? JSON.stringify(data) : undefined, }; - const response = await this.fetch(url, init); + const shouldRetry = method === 'GET' ? + (response) => (response.status === 429 || response.status >= 500) : + (response) => (response.status === 429); + const response = await withAutomaticRetries(async () => this.fetch(url, init), { shouldRetry }); if (!response.ok) { const request = new Request(url, init); diff --git a/index.test.ts b/index.test.ts index 8a139b8..818f874 100644 --- a/index.test.ts +++ b/index.test.ts @@ -196,7 +196,44 @@ describe('Replicate client', () => { expect((error as ApiError).message).toContain("Invalid input") } }) - // Add more tests for error handling, edge cases, etc. + + test('Automatically retries on 429', async () => { + nock(BASE_URL) + .post('/predictions') + .reply(429, { + detail: "Too many requests", + }, { "Content-Type": "application/json", "Retry-After": "1" }) + .post('/predictions') + .reply(201, { + id: 'ufawqhfynnddngldkgtslldrkq', + }); + const prediction = await client.predictions.create({ + version: + '5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa', + input: { + text: 'Alice', + }, + }); + expect(prediction.id).toBe('ufawqhfynnddngldkgtslldrkq'); + }); + + test('Does not automatically retry on 500', async () => { + nock(BASE_URL) + .post('/predictions') + .reply(500, { + detail: "Internal server error", + }, { "Content-Type": "application/json" }); + + await expect( + client.predictions.create({ + version: + '5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa', + input: { + text: 'Alice', + }, + }) + ).rejects.toThrow(`Request to https://api.replicate.com/v1/predictions failed with status 500 Internal Server Error: {"detail":"Internal server error"}.`) + }); }); describe('predictions.get', () => { @@ -234,7 +271,40 @@ describe('Replicate client', () => { ); expect(prediction.id).toBe('rrr4z55ocneqzikepnug6xezpe'); }); - // Add more tests for error handling, edge cases, etc. + + test('Automatically retries on 429', async () => { + nock(BASE_URL) + .get('/predictions/rrr4z55ocneqzikepnug6xezpe') + .reply(429, { + detail: "Too many requests", + }, { "Content-Type": "application/json", "Retry-After": "1" }) + .get('/predictions/rrr4z55ocneqzikepnug6xezpe') + .reply(200, { + id: 'rrr4z55ocneqzikepnug6xezpe', + }); + + const prediction = await client.predictions.get( + 'rrr4z55ocneqzikepnug6xezpe' + ); + expect(prediction.id).toBe('rrr4z55ocneqzikepnug6xezpe'); + }); + + test('Automatically retries on 500', async () => { + nock(BASE_URL) + .get('/predictions/rrr4z55ocneqzikepnug6xezpe') + .reply(500, { + detail: "Internal server error", + }, { "Content-Type": "application/json" }) + .get('/predictions/rrr4z55ocneqzikepnug6xezpe') + .reply(200, { + id: 'rrr4z55ocneqzikepnug6xezpe', + }); + + const prediction = await client.predictions.get( + 'rrr4z55ocneqzikepnug6xezpe' + ); + expect(prediction.id).toBe('rrr4z55ocneqzikepnug6xezpe'); + }); }); describe('predictions.cancel', () => { diff --git a/lib/util.js b/lib/util.js new file mode 100644 index 0000000..1b9cf9d --- /dev/null +++ b/lib/util.js @@ -0,0 +1,69 @@ +const ApiError = require('./error'); + +/** + * Automatically retry a request if it fails with an appropriate status code. + * + * A GET request is retried if it fails with a 429 or 5xx status code. + * A non-GET request is retried only if it fails with a 429 status code. + * + * If the response sets a Retry-After header, + * the request is retried after the number of seconds specified in the header. + * Otherwise, the request is retried after the specified interval, + * with exponential backoff and jitter. + * + * @param {Function} request - A function that returns a Promise that resolves with a Response object + * @param {object} options + * @param {Function} [options.shouldRetry] - A function that returns true if the request should be retried + * @param {number} [options.maxRetries] - Maximum number of retries. Defaults to 5 + * @param {number} [options.interval] - Interval between retries in milliseconds. Defaults to 500 + * @returns {Promise} - Resolves with the response object + * @throws {ApiError} If the request failed + */ +async function withAutomaticRetries(request, options = {}) { + const shouldRetry = options.shouldRetry || (() => (false)); + const maxRetries = options.maxRetries || 5; + const interval = options.interval || 500; + const jitter = options.jitter || 100; + + // eslint-disable-next-line no-promise-executor-return + const sleep = (ms) => new Promise((resolve) => setTimeout(resolve, ms)); + + let attempts = 0; + do { + let delay = (interval * (2 ** attempts)) + (Math.random() * jitter); + + /* eslint-disable no-await-in-loop */ + try { + const response = await request(); + if (response.ok || !shouldRetry(response)) { + return response; + } + } catch (error) { + if (error instanceof ApiError) { + const retryAfter = error.response.headers.get('Retry-After'); + if (retryAfter) { + if (!Number.isInteger(retryAfter)) { // Retry-After is a date + const date = new Date(retryAfter); + if (!Number.isNaN(date.getTime())) { + delay = date.getTime() - new Date().getTime(); + } + } else { // Retry-After is a number of seconds + delay = retryAfter * 1000; + } + } + } + } + + if (Number.isInteger(maxRetries) && maxRetries > 0) { + if (Number.isInteger(delay) && delay > 0) { + await sleep(interval * 2 ** (options.maxRetries - maxRetries)); + } + attempts += 1; + } + /* eslint-enable no-await-in-loop */ + } while (attempts < maxRetries); + + return request(); +} + +module.exports = { withAutomaticRetries };