diff --git a/index.d.ts b/index.d.ts index 5671f62..cb3b26e 100644 --- a/index.d.ts +++ b/index.d.ts @@ -12,7 +12,7 @@ declare module 'replicate' { name: string; slug: string; description: string; - models: Model[]; + models?: Model[]; } export interface Model { @@ -90,6 +90,7 @@ declare module 'replicate' { ): Promise; collections: { + list(): Promise>; get(collection_slug: string): Promise; }; diff --git a/index.js b/index.js index 5f033a2..390038d 100644 --- a/index.js +++ b/index.js @@ -40,6 +40,7 @@ class Replicate { this.fetch = options.fetch || globalThis.fetch; this.collections = { + list: collections.list.bind(this), get: collections.get.bind(this), }; diff --git a/index.test.ts b/index.test.ts index 39633dd..eab4ffe 100644 --- a/index.test.ts +++ b/index.test.ts @@ -35,13 +35,39 @@ describe('Replicate client', () => { }); }); + describe('collections.list', () => { + test('Calls the correct API route', async () => { + nock(BASE_URL) + .get('/collections') + .reply(200, { + results: [ + { + name: 'Super resolution', + slug: 'super-resolution', + description: 'Upscaling models that create high-quality images from low-quality images.', + }, + { + name: 'Image classification', + slug: 'image-classification', + description: 'Models that classify images.', + }, + ], + next: null, + previous: null, + }); + + const collections = await client.collections.list(); + expect(collections.results.length).toBe(2); + }); + // Add more tests for error handling, edge cases, etc. + }); + describe('collections.get', () => { test('Calls the correct API route', async () => { nock(BASE_URL).get('/collections/super-resolution').reply(200, { name: 'Super resolution', slug: 'super-resolution', - description: - 'Upscaling models that create high-quality images from low-quality images.', + description: 'Upscaling models that create high-quality images from low-quality images.', models: [], }); diff --git a/lib/collections.js b/lib/collections.js index 668262e..195bf8e 100644 --- a/lib/collections.js +++ b/lib/collections.js @@ -10,4 +10,14 @@ async function getCollection(collection_slug) { }); } -module.exports = { get: getCollection }; +/** + * Fetch a list of model collections + * @returns {Promise} - Resolves with the collections data + */ +async function listCollections() { + return this.request('/collections', { + method: 'GET', + }); +} + +module.exports = { get: getCollection, list: listCollections };