Skip to content

Commit 3289513

Browse files
committed
Wrap required arguments up together
For now we only have `input`, but this allows us to change the API in future and matches the API more directly.
1 parent 5eccd13 commit 3289513

File tree

3 files changed

+25
-13
lines changed

3 files changed

+25
-13
lines changed

README.md

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ const prediction = await replicate
2626
"stability-ai/stable-diffusion:db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf"
2727
)
2828
.predict({
29-
prompt: "an astronaut riding on a horse",
29+
input: {
30+
prompt: "an astronaut riding on a horse",
31+
},
3032
});
3133

3234
console.log(prediction.output);
@@ -45,7 +47,9 @@ await replicate
4547
)
4648
.predict(
4749
{
48-
prompt: "an astronaut riding on a horse",
50+
input: {
51+
prompt: "an astronaut riding on a horse",
52+
},
4953
},
5054
{
5155
onUpdate: (prediction) => {
@@ -66,7 +70,9 @@ const prediction = await replicate
6670
"stability-ai/stable-diffusion:db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf"
6771
)
6872
.createPrediction({
69-
prompt: "an astronaut riding on a horse",
73+
input: {
74+
prompt: "an astronaut riding on a horse",
75+
},
7076
});
7177

7278
console.log(prediction.status); // "starting"
@@ -89,7 +95,9 @@ await replicate
8995
)
9096
.createPrediction(
9197
{
92-
prompt: "an astronaut riding on a horse",
98+
input: {
99+
prompt: "an astronaut riding on a horse",
100+
},
93101
},
94102
{
95103
// See https://replicate.com/docs/reference/http#create-prediction--webhook

lib/Model.js

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ export default class Model extends ReplicateObject {
4747
}
4848

4949
async predict(
50-
input,
50+
{ input },
5151
{
5252
onUpdate = noop,
5353
onTemporaryError = noop,
@@ -122,12 +122,16 @@ export default class Model extends ReplicateObject {
122122
return prediction;
123123
}
124124

125-
async createPrediction(input, { webhook, webhookEventsFilter } = {}) {
125+
async createPrediction({ input }, { webhook, webhookEventsFilter } = {}) {
126126
// This is here and not on `Prediction` because conceptually, a prediction
127127
// from a model "belongs" to the model. It's an odd feature of the API that
128128
// the prediction creation isn't an action on the model (or that it doesn't
129129
// actually use the model information, only the version), but we don't need
130130
// to expose that to users of this library.
131+
if (!input) {
132+
throw new ReplicateError("input is required");
133+
}
134+
131135
const predictionData = await this.client.request("POST /v1/predictions", {
132136
version: this.version,
133137
input,

lib/Model.test.js

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ describe("predict()", () => {
9494
);
9595

9696
await model.predict(
97-
{ text: "test text" },
97+
{ input: { text: "test text" } },
9898
{},
9999
{ defaultPollingInterval: 0 }
100100
);
@@ -136,7 +136,7 @@ describe("predict()", () => {
136136
.mockImplementation((action) => requestMockReturnValues[action]);
137137

138138
await model.predict(
139-
{ text: "test text" },
139+
{ input: { text: "test text" } },
140140
{},
141141
{ defaultPollingInterval: 0 }
142142
);
@@ -190,7 +190,7 @@ describe("predict()", () => {
190190
});
191191

192192
const prediction = await model.predict(
193-
{ text: "test text" },
193+
{ input: { text: "test text" } },
194194
{},
195195
{ defaultPollingInterval: 0 }
196196
);
@@ -251,7 +251,7 @@ describe("predict()", () => {
251251
const backoffFn = jest.fn(() => 0);
252252

253253
const prediction = await model.predict(
254-
{ text: "test text" },
254+
{ input: { text: "test text" } },
255255
{},
256256
{ defaultPollingInterval: 0, backoffFn }
257257
);
@@ -269,7 +269,7 @@ describe("createPrediction()", () => {
269269
status: PredictionStatus.SUCCEEDED,
270270
});
271271

272-
await model.createPrediction({ text: "test text" });
272+
await model.createPrediction({ input: { text: "test text" } });
273273

274274
expect(client.request).toHaveBeenCalledWith("POST /v1/predictions", {
275275
version: "testversion",
@@ -284,7 +284,7 @@ describe("createPrediction()", () => {
284284
});
285285

286286
await model.createPrediction(
287-
{ text: "test text" },
287+
{ input: { text: "test text" } },
288288
{ webhook: "http://test.host/webhook" }
289289
);
290290

@@ -302,7 +302,7 @@ describe("createPrediction()", () => {
302302
});
303303

304304
await model.createPrediction(
305-
{ text: "test text" },
305+
{ input: { text: "test text" } },
306306
{
307307
webhook: "http://test.host/webhook",
308308
webhookEventsFilter: ["output", "completed"],

0 commit comments

Comments
 (0)