diff --git a/README.md b/README.md index 36fbcb0..1eb7a2b 100644 --- a/README.md +++ b/README.md @@ -236,6 +236,43 @@ const response = await replicate.models.list(); } ``` +### `replicate.models.create` + +Create a new public or private model. + +```js +const response = await replicate.models.create(model_owner, model_name, options); +``` + +| name | type | description | +| ------------------------- | ------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `model_owner` | string | **Required**. The name of the user or organization that will own the model. This must be the same as the user or organization that is making the API request. In other words, the API token used in the request must belong to this user or organization. | +| `model_name` | string | **Required**. The name of the model. This must be unique among all models owned by the user or organization. | +| `options.visibility` | string | **Required**. Whether the model should be public or private. A public model can be viewed and run by anyone, whereas a private model can be viewed and run only by the user or organization members that own the model. | +| `options.hardware` | string | **Required**. The SKU for the hardware used to run the model. Possible values can be found by calling [`replicate.hardware.list()](#replicatehardwarelist)`. | +| `options.description` | string | A description of the model. | +| `options.github_url` | string | A URL for the model's source code on GitHub. | +| `options.paper_url` | string | A URL for the model's paper. | +| `options.license_url` | string | A URL for the model's license. | +| `options.cover_image_url` | string | A URL for the model's cover image. This should be an image file. | + +### `replicate.hardware.list` + +List available hardware for running models on Replicate. + +```js +const response = await replicate.hardware.list() +``` + +```jsonc +[ + {"name": "CPU", "sku": "cpu" }, + {"name": "Nvidia T4 GPU", "sku": "gpu-t4" }, + {"name": "Nvidia A40 GPU", "sku": "gpu-a40-small" }, + {"name": "Nvidia A40 (Large) GPU", "sku": "gpu-a40-large" }, +] +``` + ### `replicate.models.versions.list` Get a list of all published versions of a model, including input and output schemas for each version. diff --git a/index.d.ts b/index.d.ts index 601e15b..a3e2ee0 100644 --- a/index.d.ts +++ b/index.d.ts @@ -1,5 +1,6 @@ declare module 'replicate' { type Status = 'starting' | 'processing' | 'succeeded' | 'failed' | 'canceled'; + type Visibility = 'public' | 'private'; type WebhookEventType = 'start' | 'output' | 'logs' | 'completed'; export interface ApiError extends Error { @@ -14,6 +15,11 @@ declare module 'replicate' { models?: Model[]; } + export interface Hardware { + sku: string; + name: string + } + export interface Model { url: string; owner: string; @@ -115,9 +121,40 @@ declare module 'replicate' { get(collection_slug: string): Promise; }; + deployments: { + predictions: { + create( + deployment_owner: string, + deployment_name: string, + options: { + input: object; + stream?: boolean; + webhook?: string; + webhook_events_filter?: WebhookEventType[]; + } + ): Promise; + }; + }; + + hardware: { + list(): Promise + } + models: { get(model_owner: string, model_name: string): Promise; list(): Promise>; + create( + model_owner: string, + model_name: string, + options: { + visibility: Visibility; + hardware: string; + description?: string; + github_url?: string; + paper_url?: string; + license_url?: string; + cover_image_url?: string; + }): Promise; versions: { list(model_owner: string, model_name: string): Promise; get( @@ -157,20 +194,5 @@ declare module 'replicate' { cancel(training_id: string): Promise; list(): Promise>; }; - - deployments: { - predictions: { - create( - deployment_owner: string, - deployment_name: string, - options: { - input: object; - stream?: boolean; - webhook?: string; - webhook_events_filter?: WebhookEventType[]; - } - ): Promise; - }; - }; } } diff --git a/index.js b/index.js index 902ba57..acb07eb 100644 --- a/index.js +++ b/index.js @@ -3,6 +3,7 @@ const { withAutomaticRetries } = require('./lib/util'); const collections = require('./lib/collections'); const deployments = require('./lib/deployments'); +const hardware = require('./lib/hardware'); const models = require('./lib/models'); const predictions = require('./lib/predictions'); const trainings = require('./lib/trainings'); @@ -49,9 +50,20 @@ class Replicate { get: collections.get.bind(this), }; + this.deployments = { + predictions: { + create: deployments.predictions.create.bind(this), + } + }; + + this.hardware = { + list: hardware.list.bind(this), + }; + this.models = { get: models.get.bind(this), list: models.list.bind(this), + create: models.create.bind(this), versions: { list: models.versions.list.bind(this), get: models.versions.get.bind(this), @@ -71,12 +83,6 @@ class Replicate { cancel: trainings.cancel.bind(this), list: trainings.list.bind(this), }; - - this.deployments = { - predictions: { - create: deployments.predictions.create.bind(this), - } - }; } /** diff --git a/index.test.ts b/index.test.ts index ab4e9d6..377357b 100644 --- a/index.test.ts +++ b/index.test.ts @@ -136,12 +136,12 @@ describe('Replicate client', () => { nock(BASE_URL) .get('/models') .reply(200, { - results: [{ url: 'https://replicate.com/some-user/model-1' }], + results: [ { url: 'https://replicate.com/some-user/model-1' } ], next: 'https://api.replicate.com/v1/models?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw', }) .get('/models?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw') .reply(200, { - results: [{ url: 'https://replicate.com/some-user/model-2' }], + results: [ { url: 'https://replicate.com/some-user/model-2' } ], next: null, }); @@ -149,7 +149,7 @@ describe('Replicate client', () => { for await (const batch of client.paginate(client.models.list)) { results.push(...batch); } - expect(results).toEqual([{ url: 'https://replicate.com/some-user/model-1' }, { url: 'https://replicate.com/some-user/model-2' }]); + expect(results).toEqual([ { url: 'https://replicate.com/some-user/model-1' }, { url: 'https://replicate.com/some-user/model-2' } ]); // Add more tests for error handling, edge cases, etc. }); @@ -662,6 +662,54 @@ describe('Replicate client', () => { // Add more tests for error handling, edge cases, etc. }); + describe('hardware.list', () => { + test('Calls the correct API route', async () => { + nock(BASE_URL) + .get('/hardware') + .reply(200, [ + { name: "CPU", sku: "cpu" }, + { name: "Nvidia T4 GPU", sku: "gpu-t4" }, + { name: "Nvidia A40 GPU", sku: "gpu-a40-small" }, + { name: "Nvidia A40 (Large) GPU", sku: "gpu-a40-large" }, + ]); + + const hardware = await client.hardware.list(); + expect(hardware.length).toBe(4); + expect(hardware[ 0 ].name).toBe('CPU'); + expect(hardware[ 0 ].sku).toBe('cpu'); + }); + // Add more tests for error handling, edge cases, etc. + }); + + describe('models.create', () => { + test('Calls the correct API route with the correct payload', async () => { + nock(BASE_URL) + .post('/models') + .reply(200, { + owner: 'test-owner', + name: 'test-model', + visibility: 'public', + hardware: 'cpu', + description: 'A test model', + }); + + const model = await client.models.create( + 'test-owner', + 'test-model', + { + visibility: 'public', + hardware: 'cpu', + description: 'A test model', + }); + + expect(model.owner).toBe('test-owner'); + expect(model.name).toBe('test-model'); + expect(model.visibility).toBe('public'); + // expect(model.hardware).toBe('cpu'); + expect(model.description).toBe('A test model'); + }); + }); + describe('run', () => { test('Calls the correct API routes', async () => { let firstPollingRequest = true; diff --git a/lib/hardware.js b/lib/hardware.js new file mode 100644 index 0000000..487f3b8 --- /dev/null +++ b/lib/hardware.js @@ -0,0 +1,16 @@ +/** + * List hardware + * + * @returns {Promise} Resolves with the array of hardware + */ +async function listHardware() { + const response = await this.request('/hardware', { + method: 'GET', + }); + + return response.json(); +} + +module.exports = { + list: listHardware, +}; diff --git a/lib/models.js b/lib/models.js index be05750..3c4e5b1 100644 --- a/lib/models.js +++ b/lib/models.js @@ -57,8 +57,35 @@ async function listModels() { return response.json(); } +/** + * Create a new model + * + * @param {string} model_owner - Required. The name of the user or organization that will own the model. This must be the same as the user or organization that is making the API request. In other words, the API token used in the request must belong to this user or organization. + * @param {string} model_name - Required. The name of the model. This must be unique among all models owned by the user or organization. + * @param {object} options + * @param {("public"|"private")} options.visibility - Required. Whether the model should be public or private. A public model can be viewed and run by anyone, whereas a private model can be viewed and run only by the user or organization members that own the model. + * @param {string} options.hardware - Required. The SKU for the hardware used to run the model. Possible values can be found by calling `Replicate.hardware.list()`. + * @param {string} options.description - A description of the model. + * @param {string} options.github_url - A URL for the model's source code on GitHub. + * @param {string} options.paper_url - A URL for the model's paper. + * @param {string} options.license_url - A URL for the model's license. + * @param {string} options.cover_image_url - A URL for the model's cover image. This should be an image file. + * @returns {Promise} Resolves with the model version data + */ +async function createModel(model_owner, model_name, options) { + const data = { owner: model_owner, name: model_name, ...options }; + + const response = await this.request('/models', { + method: 'POST', + data, + }); + + return response.json(); +} + module.exports = { get: getModel, list: listModels, + create: createModel, versions: { list: listModelVersions, get: getModelVersion }, };