Skip to content

Commit 563f380

Browse files
committed
Add support for training endpoints
1 parent 11e1ffb commit 563f380

File tree

4 files changed

+147
-0
lines changed

4 files changed

+147
-0
lines changed

index.d.ts

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,18 @@ declare module "replicate" {
4343
updated: string;
4444
}
4545

46+
export interface Training {
47+
id: string;
48+
destination: string;
49+
version: string;
50+
input: any;
51+
output: any;
52+
webhook?: string;
53+
webhook_events_filter?: WebhookEventType[];
54+
created: string;
55+
updated: string;
56+
}
57+
4658
export interface CollectionsGetOptions {
4759
collection_slug: string;
4860
}
@@ -74,6 +86,22 @@ declare module "replicate" {
7486
predictionId: string;
7587
}
7688

89+
export interface TrainingsCreateOptions {
90+
destination: string;
91+
version: string;
92+
input: any;
93+
webhook?: string;
94+
webhook_events_filter?: WebhookEventType[];
95+
}
96+
97+
export interface TrainingsGetOptions {
98+
trainingId: string;
99+
}
100+
101+
export interface TrainingsCancelOptions {
102+
trainingId: string;
103+
}
104+
77105
export class Replicate {
78106
constructor(options: ReplicateOptions);
79107

@@ -113,6 +141,12 @@ declare module "replicate" {
113141
get(options: PredictionsGetOptions): Promise<Prediction>;
114142
list(): Promise<Prediction[]>;
115143
};
144+
145+
trainings: {
146+
create(options: TrainingsCreateOptions): Promise<Training>;
147+
get(options: TrainingsGetOptions): Promise<Training>;
148+
cancel(options: TrainingsGetOptions): Promise<Training>;
149+
};
116150
}
117151

118152
export default Replicate;

index.js

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ const axios = require('axios');
33
const collections = require('./lib/collections');
44
const models = require('./lib/models');
55
const predictions = require('./lib/predictions');
6+
const trainings = require('./lib/trainings');
67
const packageJSON = require('./package.json');
78

89
/**
@@ -63,6 +64,12 @@ class Replicate {
6364
get: predictions.get.bind(this),
6465
list: predictions.list.bind(this),
6566
};
67+
68+
this.trainings = {
69+
create: trainings.create.bind(this),
70+
get: trainings.get.bind(this),
71+
cancel: trainings.cancel.bind(this),
72+
};
6673
}
6774

6875
/**

index.test.js

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,59 @@ describe('Replicate client', () => {
131131
// Add more tests for error handling, edge cases, etc.
132132
});
133133

134+
describe('trainings.create', () => {
135+
test('Calls the correct API route with the correct payload', async () => {
136+
client.request = jest.fn();
137+
await client.trainings.create(
138+
'owner',
139+
'model',
140+
'632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532',
141+
{
142+
destination: 'new_owner/new_model',
143+
input: {
144+
text: '...'
145+
}
146+
}
147+
);
148+
149+
expect(client.request).toHaveBeenCalledWith('/models/owner/model/versions/632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532/trainings', {
150+
method: 'POST',
151+
data: {
152+
destination: 'new_owner/new_model',
153+
input: {
154+
text: '...'
155+
},
156+
}
157+
});
158+
});
159+
160+
// Add more tests for error handling, edge cases, etc.
161+
});
162+
163+
describe('trainings.get', () => {
164+
test('Calls the correct API route with the correct payload', async () => {
165+
client.request = jest.fn();
166+
await client.trainings.get(123);
167+
expect(client.request).toHaveBeenCalledWith('/trainings/123', {
168+
method: 'GET',
169+
});
170+
});
171+
172+
// Add more tests for error handling, edge cases, etc.
173+
});
174+
175+
describe('trainings.cancel', () => {
176+
test('Calls the correct API route with the correct payload', async () => {
177+
client.request = jest.fn();
178+
await client.trainings.cancel(123);
179+
expect(client.request).toHaveBeenCalledWith('/trainings/123/cancel', {
180+
method: 'POST',
181+
});
182+
});
183+
184+
// Add more tests for error handling, edge cases, etc.
185+
});
186+
134187
describe('run', () => {
135188
test('Calls the correct API routes', async () => {
136189
client.request = jest.fn();

lib/trainings.js

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/**
2+
* Create a new training
3+
*
4+
* @param {string} model_owner - Required. The username of the user or organization who owns the model
5+
* @param {string} model_name - Required. The name of the model
6+
* @param {string} version_id - Required. The version ID
7+
* @param {object} options
8+
* @param {string} options.destination - Required. The destination for the trained version in the form "{username}/{model_name}"
9+
* @param {object} options.input - Required. An object with the model inputs
10+
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the training updates
11+
* @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
12+
* @returns {Promise<object>} Resolves with the data for the created training
13+
*/
14+
async function createTraining(model_owner, model_name, version_id, options) {
15+
const { ...data } = options;
16+
17+
const training = this.request(`/models/${model_owner}/${model_name}/versions/${version_id}/trainings`, {
18+
method: 'POST',
19+
data,
20+
});
21+
22+
return training;
23+
}
24+
25+
/**
26+
* Fetch a training by ID
27+
*
28+
* @param {string} training_id - Required. The training ID
29+
* @returns {Promise<object>} Resolves with the data for the training
30+
*/
31+
async function getTraining(training_id) {
32+
return this.request(`/trainings/${training_id}`, {
33+
method: 'GET',
34+
});
35+
}
36+
37+
/**
38+
* Cancel a training by ID
39+
*
40+
* @param {string} training_id - Required. The training ID
41+
* @returns {Promise<object>} Resolves with the data for the training
42+
*/
43+
async function cancelTraining(training_id) {
44+
return this.request(`/trainings/${training_id}/cancel`, {
45+
method: 'POST',
46+
});
47+
}
48+
49+
module.exports = {
50+
create: createTraining,
51+
get: getTraining,
52+
cancel: cancelTraining,
53+
};

0 commit comments

Comments
 (0)