diff --git a/index.d.ts b/index.d.ts index 05edbe6..5620f3b 100644 --- a/index.d.ts +++ b/index.d.ts @@ -188,28 +188,19 @@ declare module "replicate" { version_id: string ): Promise; }; - predictions: { - create( - model_owner: string, - model_name: string, - options: { - input: object; - stream?: boolean; - webhook?: string; - webhook_events_filter?: WebhookEventType[]; - } - ): Promise; - }; }; predictions: { - create(options: { - version: string; - input: object; - stream?: boolean; - webhook?: string; - webhook_events_filter?: WebhookEventType[]; - }): Promise; + create( + options: { + model?: string; + version?: string; + input: object; + stream?: boolean; + webhook?: string; + webhook_events_filter?: WebhookEventType[]; + } & ({ version: string } | { model: string }) + ): Promise; get(prediction_id: string): Promise; cancel(prediction_id: string): Promise; list(): Promise>; diff --git a/index.js b/index.js index 5a6c532..a3ed7f8 100644 --- a/index.js +++ b/index.js @@ -70,9 +70,6 @@ class Replicate { list: models.versions.list.bind(this), get: models.versions.get.bind(this), }, - predictions: { - create: models.predictions.create.bind(this), - }, }; this.predictions = { @@ -117,12 +114,13 @@ class Replicate { ...data, version: identifier.version, }); + } else if (identifier.owner && identifier.name) { + prediction = await this.predictions.create({ + ...data, + model: `${identifier.owner}/${identifier.name}`, + }); } else { - prediction = await this.models.predictions.create( - identifier.owner, - identifier.name, - data - ); + throw new Error("Invalid model version identifier"); } // Call progress callback with the initial prediction object @@ -260,12 +258,14 @@ class Replicate { version: identifier.version, stream: true, }); + } else if (identifier.owner && identifier.name) { + prediction = await this.predictions.create({ + ...data, + model: `${identifier.owner}/${identifier.name}`, + stream: true, + }); } else { - prediction = await this.models.predictions.create( - identifier.owner, - identifier.name, - { ...data, stream: true } - ); + throw new Error("Invalid model version identifier"); } if (prediction.urls && prediction.urls.stream) { diff --git a/index.test.ts b/index.test.ts index 98bff4a..5b5a1dd 100644 --- a/index.test.ts +++ b/index.test.ts @@ -700,7 +700,7 @@ describe("Replicate client", () => { // Add more tests for error handling, edge cases, etc. }); - describe("models.predictions.create", () => { + describe("predictions.create with model", () => { test("Calls the correct API route with the correct payload", async () => { nock(BASE_URL) .post("/models/meta/llama-2-70b-chat/predictions") @@ -721,17 +721,14 @@ describe("Replicate client", () => { get: "https://api.replicate.com/v1/predictions/heat2o3bzn3ahtr6bjfftvbaci", }, }); - const prediction = await client.models.predictions.create( - "meta", - "llama-2-70b-chat", - { - input: { - prompt: "Please write a haiku about llamas.", - }, - webhook: "http://test.host/webhook", - webhook_events_filter: ["output", "completed"], - } - ); + const prediction = await client.predictions.create({ + model: "meta/llama-2-70b-chat", + input: { + prompt: "Please write a haiku about llamas.", + }, + webhook: "http://test.host/webhook", + webhook_events_filter: ["output", "completed"], + }); expect(prediction.id).toBe("heat2o3bzn3ahtr6bjfftvbaci"); }); // Add more tests for error handling, edge cases, etc. diff --git a/lib/models.js b/lib/models.js index cb32e9e..c6a02fc 100644 --- a/lib/models.js +++ b/lib/models.js @@ -89,36 +89,9 @@ async function createModel(model_owner, model_name, options) { return response.json(); } -/** - * Create a new prediction - * - * @param {string} model_owner - Required. The name of the user or organization that owns the model - * @param {string} model_name - Required. The name of the model - * @param {object} options - * @param {object} options.input - Required. An object with the model inputs - * @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output - * @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`) - * @param {boolean} [options.stream] - Whether to stream the prediction output. Defaults to false - * @returns {Promise} Resolves with the created prediction - */ -async function createPrediction(model_owner, model_name, options) { - const { stream, ...data } = options; - - const response = await this.request( - `/models/${model_owner}/${model_name}/predictions`, - { - method: "POST", - data: { ...data, stream }, - } - ); - - return response.json(); -} - module.exports = { get: getModel, list: listModels, create: createModel, versions: { list: listModelVersions, get: getModelVersion }, - predictions: { create: createPrediction }, }; diff --git a/lib/predictions.js b/lib/predictions.js index b1f7fe5..294e8d9 100644 --- a/lib/predictions.js +++ b/lib/predictions.js @@ -2,7 +2,8 @@ * Create a new prediction * * @param {object} options - * @param {string} options.version - Required. The model version + * @param {string} options.model - The model. + * @param {string} options.version - The model version. * @param {object} options.input - Required. An object with the model inputs * @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output * @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`) @@ -10,7 +11,7 @@ * @returns {Promise} Resolves with the created prediction */ async function createPrediction(options) { - const { stream, ...data } = options; + const { model, version, stream, ...data } = options; if (data.webhook) { try { @@ -21,10 +22,20 @@ async function createPrediction(options) { } } - const response = await this.request("/predictions", { - method: "POST", - data: { ...data, stream }, - }); + let response; + if (version) { + response = await this.request("/predictions", { + method: "POST", + data: { ...data, stream, version }, + }); + } else if (model) { + response = await this.request(`/models/${model}/predictions`, { + method: "POST", + data: { ...data, stream }, + }); + } else { + throw new Error("Either model or version must be specified"); + } return response.json(); }