Skip to content

Commit 83f7973

Browse files
committed
Add support for models.create endpoint
1 parent 14d13d9 commit 83f7973

File tree

4 files changed

+68
-0
lines changed

4 files changed

+68
-0
lines changed

index.d.ts

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
declare module 'replicate' {
22
type Status = 'starting' | 'processing' | 'succeeded' | 'failed' | 'canceled';
3+
type Visibility = 'public' | 'private';
34
type WebhookEventType = 'start' | 'output' | 'logs' | 'completed';
45

56
export interface ApiError extends Error {
@@ -142,6 +143,17 @@ declare module 'replicate' {
142143
models: {
143144
get(model_owner: string, model_name: string): Promise<Model>;
144145
list(): Promise<Page<Model>>;
146+
create(options: {
147+
owner: string;
148+
name: string;
149+
visibility: Visibility;
150+
hardware: string;
151+
description?: string;
152+
github_url?: string;
153+
paper_url?: string;
154+
license_url?: string;
155+
cover_image_url?: string;
156+
}): Promise<Model>;
145157
versions: {
146158
list(model_owner: string, model_name: string): Promise<ModelVersion[]>;
147159
get(

index.js

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ class Replicate {
6363
this.models = {
6464
get: models.get.bind(this),
6565
list: models.list.bind(this),
66+
create: models.create.bind(this),
6667
versions: {
6768
list: models.versions.list.bind(this),
6869
get: models.versions.get.bind(this),

index.test.ts

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -684,6 +684,34 @@ describe('Replicate client', () => {
684684
// Add more tests for error handling, edge cases, etc.
685685
});
686686

687+
describe('models.create', () => {
688+
test('Calls the correct API route with the correct payload', async () => {
689+
nock(BASE_URL)
690+
.post('/models')
691+
.reply(200, {
692+
owner: 'test-owner',
693+
name: 'test-model',
694+
visibility: 'public',
695+
hardware: 'cpu',
696+
description: 'A test model',
697+
});
698+
699+
const model = await client.models.create({
700+
owner: 'test-owner',
701+
name: 'test-model',
702+
visibility: 'public',
703+
hardware: 'cpu',
704+
description: 'A test model',
705+
});
706+
707+
expect(model.owner).toBe('test-owner');
708+
expect(model.name).toBe('test-model');
709+
expect(model.visibility).toBe('public');
710+
// expect(model.hardware).toBe('test-hardware');
711+
expect(model.description).toBe('A test model');
712+
});
713+
});
714+
687715
describe('run', () => {
688716
test('Calls the correct API routes', async () => {
689717
let firstPollingRequest = true;

lib/models.js

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,35 @@ async function listModels() {
5757
return response.json();
5858
}
5959

60+
/**
61+
* Create a new model
62+
*
63+
* @param {object} options
64+
* @param {string} options.owner - Required. The name of the user or organization that will own the model. This must be the same as the user or organization that is making the API request. In other words, the API token used in the request must belong to this user or organization.
65+
* @param {string} options.name - Required. The name of the model. This must be unique among all models owned by the user or organization.
66+
* @param {("public"|"private")} options.visibility - Required. Whether the model should be public or private. A public model can be viewed and run by anyone, whereas a private model can be viewed and run only by the user or organization members that own the model.
67+
* @param {string} options.hardware - Required. The SKU for the hardware used to run the model. Possible values can be found by calling `Replicate.hardware.list()`.
68+
* @param {string} options.description - A description of the model.
69+
* @param {string} options.github_url - A URL for the model's source code on GitHub.
70+
* @param {string} options.paper_url - A URL for the model's paper.
71+
* @param {string} options.license_url - A URL for the model's license.
72+
* @param {string} options.cover_image_url - A URL for the model's cover image. This should be an image file.
73+
* @returns {Promise<object>} Resolves with the model version data
74+
*/
75+
async function createModel(options) {
76+
const data = { ...options };
77+
78+
const response = await this.request('/models', {
79+
method: 'POST',
80+
data,
81+
});
82+
83+
return response.json();
84+
}
85+
6086
module.exports = {
6187
get: getModel,
6288
list: listModels,
89+
create: createModel,
6390
versions: { list: listModelVersions, get: getModelVersion },
6491
};

0 commit comments

Comments
 (0)