From 328951380453a14fc3574c9cea5e3c71ca548ec6 Mon Sep 17 00:00:00 2001 From: F Date: Tue, 7 Mar 2023 17:02:14 +0000 Subject: [PATCH] Wrap required arguments up together For now we only have `input`, but this allows us to change the API in future and matches the API more directly. --- README.md | 16 ++++++++++++---- lib/Model.js | 8 ++++++-- lib/Model.test.js | 14 +++++++------- 3 files changed, 25 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 5fe325f..156ca0d 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,9 @@ const prediction = await replicate "stability-ai/stable-diffusion:db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf" ) .predict({ - prompt: "an astronaut riding on a horse", + input: { + prompt: "an astronaut riding on a horse", + }, }); console.log(prediction.output); @@ -45,7 +47,9 @@ await replicate ) .predict( { - prompt: "an astronaut riding on a horse", + input: { + prompt: "an astronaut riding on a horse", + }, }, { onUpdate: (prediction) => { @@ -66,7 +70,9 @@ const prediction = await replicate "stability-ai/stable-diffusion:db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf" ) .createPrediction({ - prompt: "an astronaut riding on a horse", + input: { + prompt: "an astronaut riding on a horse", + }, }); console.log(prediction.status); // "starting" @@ -89,7 +95,9 @@ await replicate ) .createPrediction( { - prompt: "an astronaut riding on a horse", + input: { + prompt: "an astronaut riding on a horse", + }, }, { // See https://replicate.com/docs/reference/http#create-prediction--webhook diff --git a/lib/Model.js b/lib/Model.js index ed27db3..1c0bcec 100644 --- a/lib/Model.js +++ b/lib/Model.js @@ -47,7 +47,7 @@ export default class Model extends ReplicateObject { } async predict( - input, + { input }, { onUpdate = noop, onTemporaryError = noop, @@ -122,12 +122,16 @@ export default class Model extends ReplicateObject { return prediction; } - async createPrediction(input, { webhook, webhookEventsFilter } = {}) { + async createPrediction({ input }, { webhook, webhookEventsFilter } = {}) { // This is here and not on `Prediction` because conceptually, a prediction // from a model "belongs" to the model. It's an odd feature of the API that // the prediction creation isn't an action on the model (or that it doesn't // actually use the model information, only the version), but we don't need // to expose that to users of this library. + if (!input) { + throw new ReplicateError("input is required"); + } + const predictionData = await this.client.request("POST /v1/predictions", { version: this.version, input, diff --git a/lib/Model.test.js b/lib/Model.test.js index a4a7486..c416b48 100644 --- a/lib/Model.test.js +++ b/lib/Model.test.js @@ -94,7 +94,7 @@ describe("predict()", () => { ); await model.predict( - { text: "test text" }, + { input: { text: "test text" } }, {}, { defaultPollingInterval: 0 } ); @@ -136,7 +136,7 @@ describe("predict()", () => { .mockImplementation((action) => requestMockReturnValues[action]); await model.predict( - { text: "test text" }, + { input: { text: "test text" } }, {}, { defaultPollingInterval: 0 } ); @@ -190,7 +190,7 @@ describe("predict()", () => { }); const prediction = await model.predict( - { text: "test text" }, + { input: { text: "test text" } }, {}, { defaultPollingInterval: 0 } ); @@ -251,7 +251,7 @@ describe("predict()", () => { const backoffFn = jest.fn(() => 0); const prediction = await model.predict( - { text: "test text" }, + { input: { text: "test text" } }, {}, { defaultPollingInterval: 0, backoffFn } ); @@ -269,7 +269,7 @@ describe("createPrediction()", () => { status: PredictionStatus.SUCCEEDED, }); - await model.createPrediction({ text: "test text" }); + await model.createPrediction({ input: { text: "test text" } }); expect(client.request).toHaveBeenCalledWith("POST /v1/predictions", { version: "testversion", @@ -284,7 +284,7 @@ describe("createPrediction()", () => { }); await model.createPrediction( - { text: "test text" }, + { input: { text: "test text" } }, { webhook: "http://test.host/webhook" } ); @@ -302,7 +302,7 @@ describe("createPrediction()", () => { }); await model.createPrediction( - { text: "test text" }, + { input: { text: "test text" } }, { webhook: "http://test.host/webhook", webhookEventsFilter: ["output", "completed"],