Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -113,5 +113,21 @@ declare module 'replicate' {
get(prediction_id: string): Promise<Prediction>;
list(): Promise<Page<Prediction>>;
};

trainings: {
create(
model_owner: string,
model_name: string,
version_id: string,
options: {
destination: `${string}/${string}`;
input: object;
webhook?: string;
webhook_events_filter?: WebhookEventType[];
}
): Promise<Training>;
get(options: TrainingsGetOptions): Promise<Training>;
cancel(options: TrainingsGetOptions): Promise<Training>;
};
}
}
7 changes: 7 additions & 0 deletions index.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ const axios = require('axios');
const collections = require('./lib/collections');
const models = require('./lib/models');
const predictions = require('./lib/predictions');
const trainings = require('./lib/trainings');
const packageJSON = require('./package.json');

/**
Expand Down Expand Up @@ -63,6 +64,12 @@ class Replicate {
get: predictions.get.bind(this),
list: predictions.list.bind(this),
};

this.trainings = {
create: trainings.create.bind(this),
get: trainings.get.bind(this),
cancel: trainings.cancel.bind(this),
};
}

/**
Expand Down
162 changes: 138 additions & 24 deletions index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ describe('Replicate client', () => {

beforeEach(() => {
client = new Replicate({ auth: 'test-token' });
client['instance'] = jest.fn<typeof axios>();
client[ 'instance' ] = jest.fn<typeof axios>();
});

describe('constructor', () => {
Expand All @@ -36,7 +36,7 @@ describe('Replicate client', () => {

describe('collections.get', () => {
test('Calls the correct API route', async () => {
client['instance'].mockResolvedValueOnce({
client[ 'instance' ].mockResolvedValueOnce({
data: {
name: 'Super resolution',
slug: 'super-resolution',
Expand All @@ -46,7 +46,7 @@ describe('Replicate client', () => {
},
});
const collection = await client.collections.get('super-resolution');
expect(client['instance']).toHaveBeenCalledWith(
expect(client[ 'instance' ]).toHaveBeenCalledWith(
'/collections/super-resolution',
{
method: 'GET',
Expand All @@ -60,7 +60,7 @@ describe('Replicate client', () => {

describe('models.get', () => {
test('Calls the correct API route', async () => {
client['instance'].mockResolvedValueOnce({
client[ 'instance' ].mockResolvedValueOnce({
data: {
url: 'https://replicate.com/replicate/hello-world',
owner: 'replicate',
Expand All @@ -77,7 +77,7 @@ describe('Replicate client', () => {
},
});
await client.models.get('replicate', 'hello-world');
expect(client['instance']).toHaveBeenCalledWith(
expect(client[ 'instance' ]).toHaveBeenCalledWith(
'/models/replicate/hello-world',
{
method: 'GET',
Expand All @@ -90,7 +90,7 @@ describe('Replicate client', () => {

describe('predictions.create', () => {
test('Calls the correct API route with the correct payload', async () => {
client['instance'].mockResolvedValueOnce({
client[ 'instance' ].mockResolvedValueOnce({
data: {
id: 'ufawqhfynnddngldkgtslldrkq',
version:
Expand Down Expand Up @@ -121,11 +121,11 @@ describe('Replicate client', () => {
text: 'Alice',
},
webhook: 'http://test.host/webhook',
webhook_events_filter: ['output', 'completed'],
webhook_events_filter: [ 'output', 'completed' ],
});
expect(prediction.id).toBe('ufawqhfynnddngldkgtslldrkq');

expect(client['instance']).toHaveBeenCalledWith('/predictions', {
expect(client[ 'instance' ]).toHaveBeenCalledWith('/predictions', {
method: 'POST',
data: {
version:
Expand All @@ -134,7 +134,7 @@ describe('Replicate client', () => {
text: 'Alice',
},
webhook: 'http://test.host/webhook',
webhook_events_filter: ['output', 'completed'],
webhook_events_filter: [ 'output', 'completed' ],
},
});
});
Expand All @@ -144,7 +144,7 @@ describe('Replicate client', () => {

describe('predictions.get', () => {
test('Calls the correct API route with the correct payload', async () => {
client['instance'].mockResolvedValueOnce({
client[ 'instance' ].mockResolvedValueOnce({
data: {
id: 'rrr4z55ocneqzikepnug6xezpe',
version:
Expand Down Expand Up @@ -178,7 +178,7 @@ describe('Replicate client', () => {
);
expect(prediction.id).toBe('rrr4z55ocneqzikepnug6xezpe');

expect(client['instance']).toHaveBeenCalledWith(
expect(client[ 'instance' ]).toHaveBeenCalledWith(
'/predictions/rrr4z55ocneqzikepnug6xezpe',
{
method: 'GET',
Expand All @@ -191,7 +191,7 @@ describe('Replicate client', () => {

describe('predictions.list', () => {
test('Calls the correct API route with the correct payload', async () => {
client['instance'].mockResolvedValueOnce({
client[ 'instance' ].mockResolvedValueOnce({
data: {
next: 'https://api.replicate.com/v1/predictions?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw',
previous: null,
Expand All @@ -217,23 +217,23 @@ describe('Replicate client', () => {

const predictions = await client.predictions.list();
expect(predictions.results.length).toBe(1);
expect(predictions.results[0].id).toBe('jpzd7hm5gfcapbfyt4mqytarku');
expect(predictions.results[ 0 ].id).toBe('jpzd7hm5gfcapbfyt4mqytarku');

expect(client['instance']).toHaveBeenCalledWith('/predictions', {
expect(client[ 'instance' ]).toHaveBeenCalledWith('/predictions', {
method: 'GET',
});
});

test('Paginates results', async () => {
client['instance'].mockResolvedValueOnce({
client[ 'instance' ].mockResolvedValueOnce({
data: {
results: [{ id: 'ufawqhfynnddngldkgtslldrkq' }],
results: [ { id: 'ufawqhfynnddngldkgtslldrkq' } ],
next: 'https://api.replicate.com/v1/predictions?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw',
},
});
client['instance'].mockResolvedValueOnce({
client[ 'instance' ].mockResolvedValueOnce({
data: {
results: [{ id: 'rrr4z55ocneqzikepnug6xezpe' }],
results: [ { id: 'rrr4z55ocneqzikepnug6xezpe' } ],
next: null,
},
});
Expand All @@ -248,10 +248,10 @@ describe('Replicate client', () => {
{ id: 'rrr4z55ocneqzikepnug6xezpe' },
]);

expect(client['instance']).toHaveBeenCalledWith('/predictions', {
expect(client[ 'instance' ]).toHaveBeenCalledWith('/predictions', {
method: 'GET',
});
expect(client['instance']).toHaveBeenCalledWith(
expect(client[ 'instance' ]).toHaveBeenCalledWith(
'https://api.replicate.com/v1/predictions?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw',
{
method: 'GET',
Expand All @@ -262,15 +262,129 @@ describe('Replicate client', () => {
// Add more tests for error handling, edge cases, etc.
});

describe('trainings.create', () => {
test('Calls the correct API route with the correct payload', async () => {
client[ 'instance' ].mockResolvedValueOnce({
data: {
"id": "zz4ibbonubfz7carwiefibzgga",
"version": "{version}",
"status": "starting",
"input": {
"text": "..."
},
"output": null,
"error": null,
"logs": null,
"started_at": null,
"created_at": "2023-03-28T21:47:58.566434Z",
"completed_at": null
}
});

const training = await client.trainings.create(
'owner',
'model',
'632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532',
{
destination: 'new_owner/new_model',
input: {
text: '...'
}
}
);
expect(training.id).toBe('zz4ibbonubfz7carwiefibzgga');

expect(client[ 'instance' ]).toHaveBeenCalledWith('/models/owner/model/versions/632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532/trainings', {
method: 'POST',
data: {
destination: 'new_owner/new_model',
input: {
text: '...'
},
}
});
});

// Add more tests for error handling, edge cases, etc.
});

describe('trainings.get', () => {
test('Calls the correct API route with the correct payload', async () => {
client[ 'instance' ].mockResolvedValueOnce({
data: {
"id": "zz4ibbonubfz7carwiefibzgga",
"version": "{version}",
"status": "succeeded",
"input": {
"data": "...",
"param1": "..."
},
"output": {
"version": "..."
},
"error": null,
"logs": null,
"webhook_completed": null,
"started_at": null,
"created_at": "2023-03-28T21:47:58.566434Z",
"completed_at": null
}
});

const training = await client.trainings.get('zz4ibbonubfz7carwiefibzgga');
expect(training.status).toBe('succeeded');

expect(client[ 'instance' ]).toHaveBeenCalledWith('/trainings/zz4ibbonubfz7carwiefibzgga', {
method: 'GET',
});
});

// Add more tests for error handling, edge cases, etc.
});

describe('trainings.cancel', () => {
test('Calls the correct API route with the correct payload', async () => {
client[ 'instance' ].mockResolvedValueOnce({
data: {
"id": "zz4ibbonubfz7carwiefibzgga",
"version": "{version}",
"status": "canceled",
"input": {
"data": "...",
"param1": "..."
},
"output": {
"version": "..."
},
"error": null,
"logs": null,
"webhook_completed": null,
"started_at": null,
"created_at": "2023-03-28T21:47:58.566434Z",
"completed_at": null
}
});

const training = await client.trainings.cancel("zz4ibbonubfz7carwiefibzgga");
expect(training.status).toBe('canceled');

expect(client[ 'instance' ]).toHaveBeenCalledWith('/trainings/zz4ibbonubfz7carwiefibzgga/cancel', {
method: 'POST',
});
});

// Add more tests for error handling, edge cases, etc.
});

describe('run', () => {
test('Calls the correct API routes', async () => {
client['instance'].mockResolvedValueOnce({
client[ 'instance' ].mockResolvedValueOnce({
data: {
id: 'ufawqhfynnddngldkgtslldrkq',
status: 'processing',
},
});
client['instance'].mockResolvedValueOnce({
client[ 'instance' ].mockResolvedValueOnce({
data: {
id: 'ufawqhfynnddngldkgtslldrkq',
status: 'succeeded',
Expand All @@ -283,7 +397,7 @@ describe('Replicate client', () => {
input: { text: 'Hello, world!' },
}
);
expect(client['instance']).toHaveBeenCalledWith('/predictions', {
expect(client[ 'instance' ]).toHaveBeenCalledWith('/predictions', {
method: 'POST',
data: {
version:
Expand All @@ -293,7 +407,7 @@ describe('Replicate client', () => {
},
},
});
expect(client['instance']).toHaveBeenCalledWith(
expect(client[ 'instance' ]).toHaveBeenCalledWith(
'/predictions/ufawqhfynnddngldkgtslldrkq',
{
method: 'GET',
Expand Down
53 changes: 53 additions & 0 deletions lib/trainings.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/**
* Create a new training
*
* @param {string} model_owner - Required. The username of the user or organization who owns the model
* @param {string} model_name - Required. The name of the model
* @param {string} version_id - Required. The version ID
* @param {object} options
* @param {string} options.destination - Required. The destination for the trained version in the form "{username}/{model_name}"
* @param {object} options.input - Required. An object with the model inputs
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the training updates
* @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
* @returns {Promise<object>} Resolves with the data for the created training
*/
async function createTraining(model_owner, model_name, version_id, options) {
const { ...data } = options;

const training = this.request(`/models/${model_owner}/${model_name}/versions/${version_id}/trainings`, {
method: 'POST',
data,
});

return training;
}

/**
* Fetch a training by ID
*
* @param {string} training_id - Required. The training ID
* @returns {Promise<object>} Resolves with the data for the training
*/
async function getTraining(training_id) {
return this.request(`/trainings/${training_id}`, {
method: 'GET',
});
}

/**
* Cancel a training by ID
*
* @param {string} training_id - Required. The training ID
* @returns {Promise<object>} Resolves with the data for the training
*/
async function cancelTraining(training_id) {
return this.request(`/trainings/${training_id}/cancel`, {
method: 'POST',
});
}

module.exports = {
create: createTraining,
get: getTraining,
cancel: cancelTraining,
};