Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,43 @@ const response = await replicate.models.list();
}
```

### `replicate.models.create`

Create a new public or private model.

```js
const response = await replicate.models.create(model_owner, model_name, options);
```

| name | type | description |
| ------------------------- | ------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `model_owner` | string | **Required**. The name of the user or organization that will own the model. This must be the same as the user or organization that is making the API request. In other words, the API token used in the request must belong to this user or organization. |
| `model_name` | string | **Required**. The name of the model. This must be unique among all models owned by the user or organization. |
| `options.visibility` | string | **Required**. Whether the model should be public or private. A public model can be viewed and run by anyone, whereas a private model can be viewed and run only by the user or organization members that own the model. |
| `options.hardware` | string | **Required**. The SKU for the hardware used to run the model. Possible values can be found by calling [`replicate.hardware.list()](#replicatehardwarelist)`. |
| `options.description` | string | A description of the model. |
| `options.github_url` | string | A URL for the model's source code on GitHub. |
| `options.paper_url` | string | A URL for the model's paper. |
| `options.license_url` | string | A URL for the model's license. |
| `options.cover_image_url` | string | A URL for the model's cover image. This should be an image file. |

### `replicate.hardware.list`

List available hardware for running models on Replicate.

```js
const response = await replicate.hardware.list()
```

```jsonc
[
{"name": "CPU", "sku": "cpu" },
{"name": "Nvidia T4 GPU", "sku": "gpu-t4" },
{"name": "Nvidia A40 GPU", "sku": "gpu-a40-small" },
{"name": "Nvidia A40 (Large) GPU", "sku": "gpu-a40-large" },
]
```

### `replicate.models.versions.list`

Get a list of all published versions of a model, including input and output schemas for each version.
Expand Down
52 changes: 37 additions & 15 deletions index.d.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
declare module 'replicate' {
type Status = 'starting' | 'processing' | 'succeeded' | 'failed' | 'canceled';
type Visibility = 'public' | 'private';
type WebhookEventType = 'start' | 'output' | 'logs' | 'completed';

export interface ApiError extends Error {
Expand All @@ -14,6 +15,11 @@ declare module 'replicate' {
models?: Model[];
}

export interface Hardware {
sku: string;
name: string
}

export interface Model {
url: string;
owner: string;
Expand Down Expand Up @@ -115,9 +121,40 @@ declare module 'replicate' {
get(collection_slug: string): Promise<Collection>;
};

deployments: {
predictions: {
create(
deployment_owner: string,
deployment_name: string,
options: {
input: object;
stream?: boolean;
webhook?: string;
webhook_events_filter?: WebhookEventType[];
}
): Promise<Prediction>;
};
};

hardware: {
list(): Promise<Hardware[]>
}

models: {
get(model_owner: string, model_name: string): Promise<Model>;
list(): Promise<Page<Model>>;
create(
model_owner: string,
model_name: string,
options: {
visibility: Visibility;
hardware: string;
description?: string;
github_url?: string;
paper_url?: string;
license_url?: string;
cover_image_url?: string;
}): Promise<Model>;
versions: {
list(model_owner: string, model_name: string): Promise<ModelVersion[]>;
get(
Expand Down Expand Up @@ -157,20 +194,5 @@ declare module 'replicate' {
cancel(training_id: string): Promise<Training>;
list(): Promise<Page<Training>>;
};

deployments: {
predictions: {
create(
deployment_owner: string,
deployment_name: string,
options: {
input: object;
stream?: boolean;
webhook?: string;
webhook_events_filter?: WebhookEventType[];
}
): Promise<Prediction>;
};
};
}
}
18 changes: 12 additions & 6 deletions index.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ const { withAutomaticRetries } = require('./lib/util');

const collections = require('./lib/collections');
const deployments = require('./lib/deployments');
const hardware = require('./lib/hardware');
const models = require('./lib/models');
const predictions = require('./lib/predictions');
const trainings = require('./lib/trainings');
Expand Down Expand Up @@ -49,9 +50,20 @@ class Replicate {
get: collections.get.bind(this),
};

this.deployments = {
predictions: {
create: deployments.predictions.create.bind(this),
}
};

this.hardware = {
list: hardware.list.bind(this),
};

this.models = {
get: models.get.bind(this),
list: models.list.bind(this),
create: models.create.bind(this),
versions: {
list: models.versions.list.bind(this),
get: models.versions.get.bind(this),
Expand All @@ -71,12 +83,6 @@ class Replicate {
cancel: trainings.cancel.bind(this),
list: trainings.list.bind(this),
};

this.deployments = {
predictions: {
create: deployments.predictions.create.bind(this),
}
};
}

/**
Expand Down
54 changes: 51 additions & 3 deletions index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -136,20 +136,20 @@ describe('Replicate client', () => {
nock(BASE_URL)
.get('/models')
.reply(200, {
results: [{ url: 'https://replicate.com/some-user/model-1' }],
results: [ { url: 'https://replicate.com/some-user/model-1' } ],
next: 'https://api.replicate.com/v1/models?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw',
})
.get('/models?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw')
.reply(200, {
results: [{ url: 'https://replicate.com/some-user/model-2' }],
results: [ { url: 'https://replicate.com/some-user/model-2' } ],
next: null,
});

const results: Model[] = [];
for await (const batch of client.paginate(client.models.list)) {
results.push(...batch);
}
expect(results).toEqual([{ url: 'https://replicate.com/some-user/model-1' }, { url: 'https://replicate.com/some-user/model-2' }]);
expect(results).toEqual([ { url: 'https://replicate.com/some-user/model-1' }, { url: 'https://replicate.com/some-user/model-2' } ]);

// Add more tests for error handling, edge cases, etc.
});
Expand Down Expand Up @@ -662,6 +662,54 @@ describe('Replicate client', () => {
// Add more tests for error handling, edge cases, etc.
});

describe('hardware.list', () => {
test('Calls the correct API route', async () => {
nock(BASE_URL)
.get('/hardware')
.reply(200, [
{ name: "CPU", sku: "cpu" },
{ name: "Nvidia T4 GPU", sku: "gpu-t4" },
{ name: "Nvidia A40 GPU", sku: "gpu-a40-small" },
{ name: "Nvidia A40 (Large) GPU", sku: "gpu-a40-large" },
]);

const hardware = await client.hardware.list();
expect(hardware.length).toBe(4);
expect(hardware[ 0 ].name).toBe('CPU');
expect(hardware[ 0 ].sku).toBe('cpu');
});
// Add more tests for error handling, edge cases, etc.
});

describe('models.create', () => {
test('Calls the correct API route with the correct payload', async () => {
nock(BASE_URL)
.post('/models')
.reply(200, {
owner: 'test-owner',
name: 'test-model',
visibility: 'public',
hardware: 'cpu',
description: 'A test model',
});

const model = await client.models.create(
'test-owner',
'test-model',
{
visibility: 'public',
hardware: 'cpu',
description: 'A test model',
});

expect(model.owner).toBe('test-owner');
expect(model.name).toBe('test-model');
expect(model.visibility).toBe('public');
// expect(model.hardware).toBe('cpu');
expect(model.description).toBe('A test model');
});
});

describe('run', () => {
test('Calls the correct API routes', async () => {
let firstPollingRequest = true;
Expand Down
16 changes: 16 additions & 0 deletions lib/hardware.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
/**
* List hardware
*
* @returns {Promise<object[]>} Resolves with the array of hardware
*/
async function listHardware() {
const response = await this.request('/hardware', {
method: 'GET',
});

return response.json();
}

module.exports = {
list: listHardware,
};
27 changes: 27 additions & 0 deletions lib/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,35 @@ async function listModels() {
return response.json();
}

/**
* Create a new model
*
* @param {string} model_owner - Required. The name of the user or organization that will own the model. This must be the same as the user or organization that is making the API request. In other words, the API token used in the request must belong to this user or organization.
* @param {string} model_name - Required. The name of the model. This must be unique among all models owned by the user or organization.
* @param {object} options
* @param {("public"|"private")} options.visibility - Required. Whether the model should be public or private. A public model can be viewed and run by anyone, whereas a private model can be viewed and run only by the user or organization members that own the model.
* @param {string} options.hardware - Required. The SKU for the hardware used to run the model. Possible values can be found by calling `Replicate.hardware.list()`.
* @param {string} options.description - A description of the model.
* @param {string} options.github_url - A URL for the model's source code on GitHub.
* @param {string} options.paper_url - A URL for the model's paper.
* @param {string} options.license_url - A URL for the model's license.
* @param {string} options.cover_image_url - A URL for the model's cover image. This should be an image file.
* @returns {Promise<object>} Resolves with the model version data
*/
async function createModel(model_owner, model_name, options) {
const data = { owner: model_owner, name: model_name, ...options };

const response = await this.request('/models', {
method: 'POST',
data,
});

return response.json();
}

module.exports = {
get: getModel,
list: listModels,
create: createModel,
versions: { list: listModelVersions, get: getModelVersion },
};