diff --git a/index.d.ts b/index.d.ts index 924104a..02a2f8d 100644 --- a/index.d.ts +++ b/index.d.ts @@ -130,6 +130,7 @@ declare module 'replicate' { ): Promise; get(training_id: string): Promise; cancel(training_id: string): Promise; + list(): Promise>; }; } } diff --git a/index.js b/index.js index 55e62dd..aae4639 100644 --- a/index.js +++ b/index.js @@ -61,6 +61,7 @@ class Replicate { create: trainings.create.bind(this), get: trainings.get.bind(this), cancel: trainings.cancel.bind(this), + list: trainings.list.bind(this), }; } diff --git a/index.test.ts b/index.test.ts index 7da943d..89bcc8d 100644 --- a/index.test.ts +++ b/index.test.ts @@ -1,4 +1,3 @@ - import { expect, jest, test } from '@jest/globals'; import Replicate, { Prediction } from 'replicate'; import nock from 'nock'; @@ -38,15 +37,13 @@ describe('Replicate client', () => { describe('collections.get', () => { test('Calls the correct API route', async () => { - nock(BASE_URL) - .get('/collections/super-resolution') - .reply(200, { - name: 'Super resolution', - slug: 'super-resolution', - description: - 'Upscaling models that create high-quality images from low-quality images.', - models: [], - }); + nock(BASE_URL).get('/collections/super-resolution').reply(200, { + name: 'Super resolution', + slug: 'super-resolution', + description: + 'Upscaling models that create high-quality images from low-quality images.', + models: [], + }); const collection = await client.collections.get('super-resolution'); expect(collection.name).toBe('Super resolution'); @@ -56,29 +53,26 @@ describe('Replicate client', () => { describe('models.get', () => { test('Calls the correct API route', async () => { - nock(BASE_URL) - .get('/models/replicate/hello-world') - .reply(200, { - url: 'https://replicate.com/replicate/hello-world', - owner: 'replicate', - name: 'hello-world', - description: 'A tiny model that says hello', - visibility: 'public', - github_url: 'https://github.com/replicate/cog-examples', - paper_url: null, - license_url: null, - run_count: 12345, - cover_image_url: '', - default_example: {}, - latest_version: {}, - }); + nock(BASE_URL).get('/models/replicate/hello-world').reply(200, { + url: 'https://replicate.com/replicate/hello-world', + owner: 'replicate', + name: 'hello-world', + description: 'A tiny model that says hello', + visibility: 'public', + github_url: 'https://github.com/replicate/cog-examples', + paper_url: null, + license_url: null, + run_count: 12345, + cover_image_url: '', + default_example: {}, + latest_version: {}, + }); await client.models.get('replicate', 'hello-world'); }); // Add more tests for error handling, edge cases, etc. }); - describe('predictions.create', () => { test('Calls the correct API route with the correct payload', async () => { nock(BASE_URL) @@ -194,7 +188,9 @@ describe('Replicate client', () => { results: [ { id: 'ufawqhfynnddngldkgtslldrkq' } ], next: 'https://api.replicate.com/v1/predictions?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw', }) - .get('/predictions?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw') + .get( + '/predictions?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw' + ) .reply(200, { results: [ { id: 'rrr4z55ocneqzikepnug6xezpe' } ], next: null, @@ -211,130 +207,191 @@ describe('Replicate client', () => { // Add more tests for error handling, edge cases, etc. }); + }); - describe('trainings.create', () => { - test('Calls the correct API route with the correct payload', async () => { - nock(BASE_URL) - .post('/models/owner/model/versions/632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532/trainings') - .reply(200, { - id: 'zz4ibbonubfz7carwiefibzgga', - version: '{version}', - status: 'starting', - input: { - text: '...', - }, - output: null, - error: null, - logs: null, - started_at: null, - created_at: '2023-03-28T21:47:58.566434Z', - completed_at: null, - }); - - - const training = await client.trainings.create( - 'owner', - 'model', - '632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532', - { - destination: 'new_owner/new_model', - input: { - text: '...', - }, - } - ); - expect(training.id).toBe('zz4ibbonubfz7carwiefibzgga'); - }); + describe('trainings.create', () => { + test('Calls the correct API route with the correct payload', async () => { + nock(BASE_URL) + .post( + '/models/owner/model/versions/632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532/trainings' + ) + .reply(200, { + id: 'zz4ibbonubfz7carwiefibzgga', + version: '{version}', + status: 'starting', + input: { + text: '...', + }, + output: null, + error: null, + logs: null, + started_at: null, + created_at: '2023-03-28T21:47:58.566434Z', + completed_at: null, + }); - // Add more tests for error handling, edge cases, etc. + const training = await client.trainings.create( + 'owner', + 'model', + '632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532', + { + destination: 'new_owner/new_model', + input: { + text: '...', + }, + } + ); + expect(training.id).toBe('zz4ibbonubfz7carwiefibzgga'); }); - describe('trainings.get', () => { - test('Calls the correct API route with the correct payload', async () => { - nock(BASE_URL) - .get('/trainings/zz4ibbonubfz7carwiefibzgga') - .reply(200, { - id: 'zz4ibbonubfz7carwiefibzgga', - version: '{version}', - status: 'succeeded', - input: { - data: '...', - param1: '...', - }, - output: { - version: '...', - }, - error: null, - logs: null, - webhook_completed: null, - started_at: null, - created_at: '2023-03-28T21:47:58.566434Z', - completed_at: null, - }); - - const training = await client.trainings.get('zz4ibbonubfz7carwiefibzgga'); - expect(training.status).toBe('succeeded'); - }); + // Add more tests for error handling, edge cases, etc. + }); - // Add more tests for error handling, edge cases, etc. + describe('trainings.get', () => { + test('Calls the correct API route with the correct payload', async () => { + nock(BASE_URL) + .get('/trainings/zz4ibbonubfz7carwiefibzgga') + .reply(200, { + id: 'zz4ibbonubfz7carwiefibzgga', + version: '{version}', + status: 'succeeded', + input: { + data: '...', + param1: '...', + }, + output: { + version: '...', + }, + error: null, + logs: null, + webhook_completed: null, + started_at: null, + created_at: '2023-03-28T21:47:58.566434Z', + completed_at: null, + }); + + const training = await client.trainings.get('zz4ibbonubfz7carwiefibzgga'); + expect(training.status).toBe('succeeded'); }); - describe('trainings.cancel', () => { - test('Calls the correct API route with the correct payload', async () => { - nock(BASE_URL) - .post('/trainings/zz4ibbonubfz7carwiefibzgga/cancel') - .reply(200, { - id: 'zz4ibbonubfz7carwiefibzgga', - version: '{version}', - status: 'canceled', - input: { - data: '...', - param1: '...', - }, - output: { - version: '...', + // Add more tests for error handling, edge cases, etc. + }); + + describe('trainings.cancel', () => { + test('Calls the correct API route with the correct payload', async () => { + nock(BASE_URL) + .post('/trainings/zz4ibbonubfz7carwiefibzgga/cancel') + .reply(200, { + id: 'zz4ibbonubfz7carwiefibzgga', + version: '{version}', + status: 'canceled', + input: { + data: '...', + param1: '...', + }, + output: { + version: '...', + }, + error: null, + logs: null, + webhook_completed: null, + started_at: null, + created_at: '2023-03-28T21:47:58.566434Z', + completed_at: null, + }); + + const training = await client.trainings.cancel( + 'zz4ibbonubfz7carwiefibzgga' + ); + expect(training.status).toBe('canceled'); + }); + + // Add more tests for error handling, edge cases, etc. + }); + + describe('trainings.list', () => { + test('Calls the correct API route with the correct payload', async () => { + nock(BASE_URL) + .get('/trainings') + .reply(200, { + next: 'https://api.replicate.com/v1/trainings?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw', + previous: null, + results: [ + { + id: 'jpzd7hm5gfcapbfyt4mqytarku', + version: + 'b21cbe271e65c1718f2999b038c18b45e21e4fba961181fbfae9342fc53b9e05', + urls: { + get: 'https://api.replicate.com/v1/trainings/jpzd7hm5gfcapbfyt4mqytarku', + cancel: + 'https://api.replicate.com/v1/trainings/jpzd7hm5gfcapbfyt4mqytarku/cancel', + }, + created_at: '2022-04-26T20:00:40.658234Z', + started_at: '2022-04-26T20:00:84.583803Z', + completed_at: '2022-04-26T20:02:27.648305Z', + source: 'web', + status: 'succeeded', }, - error: null, - logs: null, - webhook_completed: null, - started_at: null, - created_at: '2023-03-28T21:47:58.566434Z', - completed_at: null, - }); + ], + }); + const trainings = await client.trainings.list(); + expect(trainings.results.length).toBe(1); + expect(trainings.results[ 0 ].id).toBe('jpzd7hm5gfcapbfyt4mqytarku'); + }); - const training = await client.trainings.cancel('zz4ibbonubfz7carwiefibzgga'); - expect(training.status).toBe('canceled'); - }); + test('Paginates results', async () => { + nock(BASE_URL) + .get('/trainings') + .reply(200, { + results: [ { id: 'ufawqhfynnddngldkgtslldrkq' } ], + next: 'https://api.replicate.com/v1/trainings?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw', + }) + .get( + '/trainings?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw' + ) + .reply(200, { + results: [ { id: 'rrr4z55ocneqzikepnug6xezpe' } ], + next: null, + }); + + const results: Prediction[] = []; + for await (const batch of client.paginate(client.trainings.list)) { + results.push(...batch); + } + expect(results).toEqual([ + { id: 'ufawqhfynnddngldkgtslldrkq' }, + { id: 'rrr4z55ocneqzikepnug6xezpe' }, + ]); // Add more tests for error handling, edge cases, etc. }); + }); - describe('run', () => { - test('Calls the correct API routes', async () => { - nock(BASE_URL) - .post('/predictions') - .reply(200, { - id: 'ufawqhfynnddngldkgtslldrkq', - status: 'processing', - }) - .get('/predictions/ufawqhfynnddngldkgtslldrkq') - .reply(200, { - id: 'ufawqhfynnddngldkgtslldrkq', - status: 'succeeded', - output: 'foobar', - }); - - const output = await client.run( - 'owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa', - { - input: { text: 'Hello, world!' }, - } - ); - expect(output).toBe('foobar'); - }); - }); + describe('run', () => { + test('Calls the correct API routes', async () => { + nock(BASE_URL) + .post('/predictions') + .reply(200, { + id: 'ufawqhfynnddngldkgtslldrkq', + status: 'processing', + }) + .get('/predictions/ufawqhfynnddngldkgtslldrkq') + .reply(200, { + id: 'ufawqhfynnddngldkgtslldrkq', + status: 'succeeded', + output: 'foobar', + }); - // Continue with tests for other methods + const output = await client.run( + 'owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa', + { + input: { text: 'Hello, world!' }, + } + ); + expect(output).toBe('foobar'); + }); }); + + // Continue with tests for other methods }); diff --git a/lib/trainings.js b/lib/trainings.js index c251294..8f8c5d1 100644 --- a/lib/trainings.js +++ b/lib/trainings.js @@ -46,8 +46,20 @@ async function cancelTraining(training_id) { }); } +/** + * List all trainings + * + * @returns {Promise} - Resolves with a page of trainings + */ +async function listTrainings() { + return this.request('/trainings', { + method: 'GET', + }); +} + module.exports = { create: createTraining, get: getTraining, cancel: cancelTraining, + list: listTrainings, };