diff --git a/index.d.ts b/index.d.ts index bf9cd5d..c415bf0 100644 --- a/index.d.ts +++ b/index.d.ts @@ -164,6 +164,18 @@ declare module 'replicate' { version_id: string ): Promise; }; + predictions: { + create( + model_owner: string, + model_name: string, + options: { + input: object; + stream?: boolean; + webhook?: string; + webhook_events_filter?: WebhookEventType[]; + } + ): Promise; + }; }; predictions: { diff --git a/index.js b/index.js index acb07eb..0552ea5 100644 --- a/index.js +++ b/index.js @@ -68,6 +68,9 @@ class Replicate { list: models.versions.list.bind(this), get: models.versions.get.bind(this), }, + predictions: { + create: models.predictions.create.bind(this), + }, }; this.predictions = { diff --git a/index.test.ts b/index.test.ts index d033902..ec9e523 100644 --- a/index.test.ts +++ b/index.test.ts @@ -668,6 +668,38 @@ describe('Replicate client', () => { // Add more tests for error handling, edge cases, etc. }); + describe('models.predictions.create', () => { + test('Calls the correct API route with the correct payload', async () => { + nock(BASE_URL) + .post('/models/meta/llama-2-70b-chat/predictions') + .reply(200, { + id: "heat2o3bzn3ahtr6bjfftvbaci", + model: "replicate/lifeboat-70b", + version: "d-c6559c5791b50af57b69f4a73f8e021c", + input: { + prompt: "Please write a haiku about llamas." + }, + logs: "", + error: null, + status: "starting", + created_at: "2023-11-27T13:35:45.99397566Z", + urls: { + cancel: "https://api.replicate.com/v1/predictions/heat2o3bzn3ahtr6bjfftvbaci/cancel", + get: "https://api.replicate.com/v1/predictions/heat2o3bzn3ahtr6bjfftvbaci" + } + }); + const prediction = await client.models.predictions.create("meta", "llama-2-70b-chat", { + input: { + prompt: "Please write a haiku about llamas." + }, + webhook: 'http://test.host/webhook', + webhook_events_filter: [ 'output', 'completed' ], + }); + expect(prediction.id).toBe('heat2o3bzn3ahtr6bjfftvbaci'); + }); + // Add more tests for error handling, edge cases, etc. + }); + describe('hardware.list', () => { test('Calls the correct API route', async () => { nock(BASE_URL) diff --git a/lib/models.js b/lib/models.js index 3c4e5b1..12419fc 100644 --- a/lib/models.js +++ b/lib/models.js @@ -83,9 +83,33 @@ async function createModel(model_owner, model_name, options) { return response.json(); } +/** + * Create a new prediction + * + * @param {string} model_owner - Required. The name of the user or organization that owns the model + * @param {string} model_name - Required. The name of the model + * @param {object} options + * @param {object} options.input - Required. An object with the model inputs + * @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output + * @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`) + * @param {boolean} [options.stream] - Whether to stream the prediction output. Defaults to false + * @returns {Promise} Resolves with the created prediction + */ +async function createPrediction(model_owner, model_name, options) { + const { stream, ...data } = options; + + const response = await this.request(`/models/${model_owner}/${model_name}/predictions`, { + method: 'POST', + data: { ...data, stream }, + }); + + return response.json(); +} + module.exports = { get: getModel, list: listModels, create: createModel, versions: { list: listModelVersions, get: getModelVersion }, + predictions: { create: createPrediction }, };