From dc0e67af21457ba062cf1e26191da1b4bd8ceb47 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Thu, 14 Mar 2024 11:40:50 +0000 Subject: [PATCH 1/3] Add support for new deployment endpoints --- index.d.ts | 19 ++++++++++++ index.js | 3 ++ lib/deployments.js | 73 ++++++++++++++++++++++++++++++++++++++++++++++ tsconfig.json | 6 +--- 4 files changed, 96 insertions(+), 5 deletions(-) diff --git a/index.d.ts b/index.d.ts index abf68dce..0092f786 100644 --- a/index.d.ts +++ b/index.d.ts @@ -194,6 +194,25 @@ declare module "replicate" { deployment_owner: string, deployment_name: string ): Promise; + create(deployment_config: { + name: string; + model: string; + version: string; + hardware: string; + min_instances: number; + max_instances: number; + }): Promise; + update( + deployment_owner: string, + deployment_name: string, + deployment_config: { + version: string; + hardware: string; + min_instances: number; + max_instances: number; + } + ): Promise; + list(): Promise>; }; hardware: { diff --git a/index.js b/index.js index cd299f4b..6a3a41aa 100644 --- a/index.js +++ b/index.js @@ -66,6 +66,9 @@ class Replicate { this.deployments = { get: deployments.get.bind(this), + create: deployments.create.bind(this), + update: deployments.update.bind(this), + list: deployments.list.bind(this), predictions: { create: deployments.predictions.create.bind(this), }, diff --git a/lib/deployments.js b/lib/deployments.js index 3e1ceebd..4f6f3c6b 100644 --- a/lib/deployments.js +++ b/lib/deployments.js @@ -57,9 +57,82 @@ async function getDeployment(deployment_owner, deployment_name) { return response.json(); } +/** + * @typedef {Object} DeploymentCreateRequest - Request body for `deployments.create` + * @property {string} name - the name of the deployment + * @property {string} model - the full name of the model that you want to deploy e.g. stability-ai/sdxl + * @property {string} version - the 64-character string ID of the model version that you want to deploy + * @property {string} hardware - the SKU for the hardware used to run the model, via `replicate.hardware.list()` + * @property {number} min_instances - the minimum number of instances for scaling + * @property {number} max_instances - the maximum number of instances for scaling + */ + +/** + * Create a deployment + * + * @param {DeploymentCreateRequest} config - Required. The deployment config. + * @returns {Promise} Resolves with the deployment data + */ +async function createDeployment(deployment_config) { + const response = await this.request("/deployments", { + method: "POST", + data: deployment_config, + }); + + return response.json(); +} + +/** + * @typedef {Object} DeploymentUpdateRequest - Request body for `deployments.update` + * @property {string} version - the 64-character string ID of the model version that you want to deploy + * @property {string} hardware - the SKU for the hardware used to run the model, via `replicate.hardware.list()` + * @property {number} min_instances - the minimum number of instances for scaling + * @property {number} max_instances - the maximum number of instances for scaling + */ + +/** + * Update an existing deployment + * + * @param {string} deployment_owner - Required. The username of the user or organization who owns the deployment + * @param {string} deployment_name - Required. The name of the deployment + * @param {DeploymentUpdateRequest} deployment_config - Required. The deployment changes. + * @returns {Promise} Resolves with the deployment data + */ +async function updateDeployment( + deployment_owner, + deployment_name, + deployment_config +) { + const response = await this.request( + `/deployments/${deployment_owner}/${deployment_name}`, + { + method: "PATCH", + data: deployment_config, + } + ); + + return response.json(); +} + +/** + * List all deployments + * + * @returns {Promise} - Resolves with a page of deployments + */ +async function listDeployments() { + const response = await this.request("/deployments", { + method: "GET", + }); + + return response.json(); +} + module.exports = { predictions: { create: createPrediction, }, get: getDeployment, + create: createDeployment, + update: updateDeployment, + list: listDeployments, }; diff --git a/tsconfig.json b/tsconfig.json index e6b4ed6c..d77efdc5 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -5,9 +5,5 @@ "strict": true, "allowJs": true }, - "exclude": [ - "**/node_modules", - "integration/**" - ] + "exclude": ["**/node_modules", "integration"] } - From 726334afb353df6a08f18a8df74fe6d101c3d644 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Fri, 15 Mar 2024 06:09:58 -0700 Subject: [PATCH 2/3] Align definition for Deployment type to OpenAPI specification --- index.d.ts | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/index.d.ts b/index.d.ts index 0092f786..d1602adf 100644 --- a/index.d.ts +++ b/index.d.ts @@ -33,8 +33,10 @@ declare module "replicate" { created_by: Account; configuration: { hardware: string; - min_instances: number; - max_instances: number; + scaling: { + min_instances: number; + max_instances: number; + }; }; }; } @@ -206,11 +208,16 @@ declare module "replicate" { deployment_owner: string, deployment_name: string, deployment_config: { - version: string; - hardware: string; - min_instances: number; - max_instances: number; - } + version?: string; + hardware?: string; + min_instances?: number; + max_instances?: number; + } & ( + | { version: string } + | { hardware: string } + | { min_instances: number } + | { max_instances: number } + ) ): Promise; list(): Promise>; }; From aabaa0de74a57ea9842585fe7c56489dfa011a02 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Fri, 15 Mar 2024 06:12:03 -0700 Subject: [PATCH 3/3] Add test coverage for deployment endpoints --- index.test.ts | 129 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 129 insertions(+) diff --git a/index.test.ts b/index.test.ts index 7e0ae22e..55adee59 100644 --- a/index.test.ts +++ b/index.test.ts @@ -811,6 +811,135 @@ describe("Replicate client", () => { // Add more tests for error handling, edge cases, etc. }); + describe("deployments.create", () => { + test("Calls the correct API route with the correct payload", async () => { + nock(BASE_URL) + .post("/deployments") + .reply(200, { + owner: "acme", + name: "my-app-image-generator", + current_release: { + number: 1, + model: "stability-ai/sdxl", + version: + "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf", + created_at: "2024-02-15T16:32:57.018467Z", + created_by: { + type: "organization", + username: "acme", + name: "Acme Corp, Inc.", + github_url: "https://github.com/acme", + }, + configuration: { + hardware: "gpu-t4", + scaling: { + min_instances: 1, + max_instances: 5, + }, + }, + }, + }); + + const deployment = await client.deployments.create({ + name: "my-app-image-generator", + model: "stability-ai/sdxl", + version: + "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf", + hardware: "gpu-t4", + min_instances: 1, + max_instances: 5, + }); + + expect(deployment.owner).toBe("acme"); + expect(deployment.name).toBe("my-app-image-generator"); + expect(deployment.current_release.model).toBe("stability-ai/sdxl"); + }); + // Add more tests for error handling, edge cases, etc. + }); + + describe("deployments.update", () => { + test("Calls the correct API route with the correct payload", async () => { + nock(BASE_URL) + .patch("/deployments/acme/my-app-image-generator") + .reply(200, { + owner: "acme", + name: "my-app-image-generator", + current_release: { + number: 2, + model: "stability-ai/sdxl", + version: + "632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532", + created_at: "2024-02-16T08:14:22.345678Z", + created_by: { + type: "organization", + username: "acme", + name: "Acme Corp, Inc.", + github_url: "https://github.com/acme", + }, + configuration: { + hardware: "gpu-a40-large", + scaling: { + min_instances: 3, + max_instances: 10, + }, + }, + }, + }); + + const deployment = await client.deployments.update( + "acme", + "my-app-image-generator", + { + version: + "632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532", + hardware: "gpu-a40-large", + min_instances: 3, + max_instances: 10, + } + ); + + expect(deployment.current_release.number).toBe(2); + expect(deployment.current_release.version).toBe( + "632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532" + ); + expect(deployment.current_release.configuration.hardware).toBe( + "gpu-a40-large" + ); + expect( + deployment.current_release.configuration.scaling?.min_instances + ).toBe(3); + expect( + deployment.current_release.configuration.scaling?.max_instances + ).toBe(10); + }); + // Add more tests for error handling, edge cases, etc. + }); + + describe("deployments.list", () => { + test("Calls the correct API route", async () => { + nock(BASE_URL) + .get("/deployments") + .reply(200, { + next: null, + previous: null, + results: [ + { + owner: "acme", + name: "my-app-image-generator", + current_release: { + // ... + }, + }, + // ... + ], + }); + + const deployments = await client.deployments.list(); + expect(deployments.results.length).toBe(1) + }); + // Add more tests for pagination, error handling, edge cases, etc. + }); + describe("predictions.create with model", () => { test("Calls the correct API route with the correct payload", async () => { nock(BASE_URL)