Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 32 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ Install with `npm install replicate`

Set your API token as an environment variable called `REPLICATE_API_TOKEN`.

### Making preedictions

To run a prediction and return its output:

```js
Expand All @@ -24,11 +26,11 @@ const prediction = await replicate
"stability-ai/stable-diffusion:db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf"
)
.predict({
prompt: "painting of a cat by andy warhol",
prompt: "an astronaut riding on a horse",
});

console.log(prediction.output);
// "https://replicate.delivery/pbxt/oeJLu7D1Y7UWESpzerfINqgwZgONSCubSjSw0msf8i4AP2BCB/out-0.png"
// "https://replicate.delivery/pbxt/nSREat5H54rxGJo1kk2xLLG2fpr0NBE0HBD5L0jszLoy8oSIA/out-0.png"
```

If you want to do something like updating progress while the prediction is
Expand All @@ -43,7 +45,7 @@ await replicate
)
.predict(
{
prompt: "painting of a cat by andy warhol",
prompt: "an astronaut riding on a horse",
},
{
onUpdate: (prediction) => {
Expand All @@ -64,7 +66,7 @@ const prediction = await replicate
"stability-ai/stable-diffusion:db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf"
)
.createPrediction({
prompt: "painting of a cat by andy warhol",
prompt: "an astronaut riding on a horse",
});

console.log(prediction.status); // "starting"
Expand All @@ -73,6 +75,32 @@ console.log(prediction.status); // "starting"
From there, you can fetch the current status of the prediction using
`await prediction.load()` or `await replicate.prediction(prediction.id).load()`.

#### Webhooks

You can also provide webhook configuration to have Replicate send POST requests
to your service when certain events occur:

```js
import replicate from "replicate";

await replicate
.model(
"stability-ai/stable-diffusion:db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf"
)
.createPrediction(
{
prompt: "an astronaut riding on a horse",
},
{
// See https://replicate.com/docs/reference/http#create-prediction--webhook
webhook: "https://your.host/webhook",

// See https://replicate.com/docs/reference/http#create-prediction--webhook_events_filter
webhookEventsFilter: ["output", "completed"],
}
);
```

## Contributing

While we'd love to accept contributions to this library, please open an issue
Expand Down
4 changes: 3 additions & 1 deletion lib/Model.js
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ export default class Model extends ReplicateObject {
return prediction;
}

async createPrediction(input) {
async createPrediction(input, { webhook, webhookEventsFilter } = {}) {
// This is here and not on `Prediction` because conceptually, a prediction
// from a model "belongs" to the model. It's an odd feature of the API that
// the prediction creation isn't an action on the model (or that it doesn't
Expand All @@ -131,6 +131,8 @@ export default class Model extends ReplicateObject {
const predictionData = await this.client.request("POST /v1/predictions", {
version: this.version,
input,
webhook,
webhook_events_filter: webhookEventsFilter,
});

return new Prediction(predictionData, this);
Expand Down
40 changes: 40 additions & 0 deletions lib/Model.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -276,4 +276,44 @@ describe("createPrediction()", () => {
input: { text: "test text" },
});
});

it("supports webhook URL", async () => {
jest.spyOn(client, "request").mockResolvedValue({
id: "testprediction",
status: PredictionStatus.SUCCEEDED,
});

await model.createPrediction(
{ text: "test text" },
{ webhook: "http://test.host/webhook" }
);

expect(client.request).toHaveBeenCalledWith("POST /v1/predictions", {
version: "testversion",
input: { text: "test text" },
webhook: "http://test.host/webhook",
});
});

it("supports webhook events filter", async () => {
jest.spyOn(client, "request").mockResolvedValue({
id: "testprediction",
status: PredictionStatus.SUCCEEDED,
});

await model.createPrediction(
{ text: "test text" },
{
webhook: "http://test.host/webhook",
webhookEventsFilter: ["output", "completed"],
}
);

expect(client.request).toHaveBeenCalledWith("POST /v1/predictions", {
version: "testversion",
input: { text: "test text" },
webhook: "http://test.host/webhook",
webhook_events_filter: ["output", "completed"],
});
});
});