Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ const prediction = await replicate
"stability-ai/stable-diffusion:db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf"
)
.predict({
prompt: "an astronaut riding on a horse",
input: {
prompt: "an astronaut riding on a horse",
},
});

console.log(prediction.output);
Expand All @@ -45,7 +47,9 @@ await replicate
)
.predict(
{
prompt: "an astronaut riding on a horse",
input: {
prompt: "an astronaut riding on a horse",
},
},
{
onUpdate: (prediction) => {
Expand All @@ -66,7 +70,9 @@ const prediction = await replicate
"stability-ai/stable-diffusion:db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf"
)
.createPrediction({
prompt: "an astronaut riding on a horse",
input: {
prompt: "an astronaut riding on a horse",
},
});

console.log(prediction.status); // "starting"
Expand All @@ -89,7 +95,9 @@ await replicate
)
.createPrediction(
{
prompt: "an astronaut riding on a horse",
input: {
prompt: "an astronaut riding on a horse",
},
},
{
// See https://replicate.com/docs/reference/http#create-prediction--webhook
Expand Down
8 changes: 6 additions & 2 deletions lib/Model.js
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ export default class Model extends ReplicateObject {
}

async predict(
input,
{ input },
{
onUpdate = noop,
onTemporaryError = noop,
Expand Down Expand Up @@ -122,12 +122,16 @@ export default class Model extends ReplicateObject {
return prediction;
}

async createPrediction(input, { webhook, webhookEventsFilter } = {}) {
async createPrediction({ input }, { webhook, webhookEventsFilter } = {}) {
// This is here and not on `Prediction` because conceptually, a prediction
// from a model "belongs" to the model. It's an odd feature of the API that
// the prediction creation isn't an action on the model (or that it doesn't
// actually use the model information, only the version), but we don't need
// to expose that to users of this library.
if (!input) {
throw new ReplicateError("input is required");
}

const predictionData = await this.client.request("POST /v1/predictions", {
version: this.version,
input,
Expand Down
14 changes: 7 additions & 7 deletions lib/Model.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ describe("predict()", () => {
);

await model.predict(
{ text: "test text" },
{ input: { text: "test text" } },
{},
{ defaultPollingInterval: 0 }
);
Expand Down Expand Up @@ -128,7 +128,7 @@ describe("predict()", () => {
.mockImplementation((action) => requestMockReturnValues[action]);

await model.predict(
{ text: "test text" },
{ input: { text: "test text" } },
{},
{ defaultPollingInterval: 0 }
);
Expand Down Expand Up @@ -182,7 +182,7 @@ describe("predict()", () => {
});

const prediction = await model.predict(
{ text: "test text" },
{ input: { text: "test text" } },
{},
{ defaultPollingInterval: 0 }
);
Expand Down Expand Up @@ -237,7 +237,7 @@ describe("predict()", () => {
const backoffFn = jest.fn(() => 0);

const prediction = await model.predict(
{ text: "test text" },
{ input: { text: "test text" } },
{},
{ defaultPollingInterval: 0, backoffFn }
);
Expand All @@ -255,7 +255,7 @@ describe("createPrediction()", () => {
status: PredictionStatus.SUCCEEDED,
});

await model.createPrediction({ text: "test text" });
await model.createPrediction({ input: { text: "test text" } });

expect(client.request).toHaveBeenCalledWith("POST /v1/predictions", {
version: "testversion",
Expand All @@ -270,7 +270,7 @@ describe("createPrediction()", () => {
});

await model.createPrediction(
{ text: "test text" },
{ input: { text: "test text" } },
{ webhook: "http://test.host/webhook" }
);

Expand All @@ -288,7 +288,7 @@ describe("createPrediction()", () => {
});

await model.createPrediction(
{ text: "test text" },
{ input: { text: "test text" } },
{
webhook: "http://test.host/webhook",
webhookEventsFilter: ["output", "completed"],
Expand Down