diff --git a/README.md b/README.md index 55e443c..36fbcb0 100644 --- a/README.md +++ b/README.md @@ -201,6 +201,41 @@ const response = await replicate.models.get(model_owner, model_name); } ``` +### `replicate.models.list` + +Get a paginated list of all public models. + +```js +const response = await replicate.models.list(); +``` + +```jsonc +{ + "next": null, + "previous": null, + "results": [ + { + "url": "https://replicate.com/replicate/hello-world", + "owner": "replicate", + "name": "hello-world", + "description": "A tiny model that says hello", + "visibility": "public", + "github_url": "https://github.com/replicate/cog-examples", + "paper_url": null, + "license_url": null, + "run_count": 5681081, + "cover_image_url": "...", + "default_example": { + /* ... */ + }, + "latest_version": { + /* ... */ + } + } + ] +} +``` + ### `replicate.models.versions.list` Get a list of all published versions of a model, including input and output schemas for each version. diff --git a/index.d.ts b/index.d.ts index 32f279a..601e15b 100644 --- a/index.d.ts +++ b/index.d.ts @@ -117,6 +117,7 @@ declare module 'replicate' { models: { get(model_owner: string, model_name: string): Promise; + list(): Promise>; versions: { list(model_owner: string, model_name: string): Promise; get( diff --git a/index.js b/index.js index c6a2cc2..4f74985 100644 --- a/index.js +++ b/index.js @@ -51,6 +51,7 @@ class Replicate { this.models = { get: models.get.bind(this), + list: models.list.bind(this), versions: { list: models.versions.list.bind(this), get: models.versions.get.bind(this), diff --git a/index.test.ts b/index.test.ts index c35b41f..ab4e9d6 100644 --- a/index.test.ts +++ b/index.test.ts @@ -1,5 +1,5 @@ import { expect, jest, test } from '@jest/globals'; -import Replicate, { ApiError, Prediction } from 'replicate'; +import Replicate, { ApiError, Model, Prediction } from 'replicate'; import nock from 'nock'; import fetch from 'cross-fetch'; @@ -131,6 +131,30 @@ describe('Replicate client', () => { // Add more tests for error handling, edge cases, etc. }); + describe('models.list', () => { + test('Paginates results', async () => { + nock(BASE_URL) + .get('/models') + .reply(200, { + results: [{ url: 'https://replicate.com/some-user/model-1' }], + next: 'https://api.replicate.com/v1/models?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw', + }) + .get('/models?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw') + .reply(200, { + results: [{ url: 'https://replicate.com/some-user/model-2' }], + next: null, + }); + + const results: Model[] = []; + for await (const batch of client.paginate(client.models.list)) { + results.push(...batch); + } + expect(results).toEqual([{ url: 'https://replicate.com/some-user/model-1' }, { url: 'https://replicate.com/some-user/model-2' }]); + + // Add more tests for error handling, edge cases, etc. + }); + }); + describe('predictions.create', () => { test('Calls the correct API route with the correct payload', async () => { nock(BASE_URL) diff --git a/lib/models.js b/lib/models.js index 373ed23..be05750 100644 --- a/lib/models.js +++ b/lib/models.js @@ -37,17 +37,28 @@ async function listModelVersions(model_owner, model_name) { * @returns {Promise} Resolves with the model version data */ async function getModelVersion(model_owner, model_name, version_id) { - const response = await this.request( - `/models/${model_owner}/${model_name}/versions/${version_id}`, - { - method: 'GET', - } - ); + const response = await this.request(`/models/${model_owner}/${model_name}/versions/${version_id}`, { + method: 'GET', + }); + + return response.json(); +} + +/** + * List all public models + * + * @returns {Promise} Resolves with the model version data + */ +async function listModels() { + const response = await this.request('/models', { + method: 'GET', + }); return response.json(); } module.exports = { get: getModel, + list: listModels, versions: { list: listModelVersions, get: getModelVersion }, };