Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 48 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -251,17 +251,14 @@ const response = await replicate.predictions.create(options);
| ------------------------------- | -------- | -------------------------------------------------------------------------------------------------------------------------------- |
| `options.version` | string | **Required**. The model version |
| `options.input` | object | **Required**. An object with the model's inputs |
| `options.stream` | boolean | Requests a URL for streaming output output |
| `options.webhook` | string | An HTTPS URL for receiving a webhook when the prediction has new output |
| `options.webhook_events_filter` | string[] | You can change which events trigger webhook requests by specifying webhook events (`start` \| `output` \| `logs` \| `completed`) |

```jsonc
{
"id": "ufawqhfynnddngldkgtslldrkq",
"version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
"urls": {
"get": "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq",
"cancel": "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel"
},
"status": "succeeded",
"input": {
"text": "Alice"
Expand All @@ -272,10 +269,56 @@ const response = await replicate.predictions.create(options);
"metrics": {},
"created_at": "2022-04-26T22:13:06.224088Z",
"started_at": null,
"completed_at": null
"completed_at": null,
"urls": {
"get": "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq",
"cancel": "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel",
"stream": "https://streaming.api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq" // Present only if `options.stream` is `true`
}
}
```

#### Streaming

Specify the `stream` option when creating a prediction
to request a URL to receive streaming output using
[server-sent events (SSE)](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events).

If the requested model version supports streaming,
then the returned prediction will have a `stream` entry in its `urls` property
with a URL that you can use to construct an
[`EventSource`](https://developer.mozilla.org/en-US/docs/Web/API/EventSource).

```js
if (prediction && prediction.urls && prediction.urls.stream) {
const source = new EventSource(prediction.urls.stream, { withCredentials: true });

source.addEventListener("output", (e) => {
console.log("output", e.data);
});

source.addEventListener("error"), (e) => {
console.error("error", JSON.parse(e.data));
});

source.addEventListener("done"), (e) => {
source.close();
console.log("done", JSON.parse(e.data));
});
}
```

A prediction's event stream consists of the following event types:

| event | format | description |
| -------- | ---------- | ---------------------------------------------- |
| `output` | plain text | Emitted when the prediction returns new output |
| `error` | JSON | Emitted when the prediction returns an error |
| `done` | JSON | Emitted when the prediction finishes |

A `done` event is emitted when a prediction finishes successfully,
is cancelled, or produces an error.

### `replicate.predictions.get`

```js
Expand Down
20 changes: 17 additions & 3 deletions index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ declare module 'replicate' {
created_at: string;
updated_at: string;
completed_at?: string;
urls: {
get: string;
cancel: string;
stream?: string;
};
}

export type Training = Prediction;
Expand Down Expand Up @@ -80,18 +85,26 @@ declare module 'replicate' {
identifier: `${string}/${string}:${string}`,
options: {
input: object;
wait?: boolean | { interval?: number; maxAttempts?: number };
wait?: { interval?: number; max_attempts?: number };
webhook?: string;
webhook_events_filter?: WebhookEventType[];
}
): Promise<object>;
request(route: string, parameters: any): Promise<any>;

request(route: string | URL, options: {
method?: string;
headers?: object | Headers;
params?: object;
data?: object;
}): Promise<Response>;

paginate<T>(endpoint: () => Promise<Page<T>>): AsyncGenerator<[ T ]>;

wait(
prediction: Prediction,
options: {
interval?: number;
maxAttempts?: number;
max_attempts?: number;
}
): Promise<Prediction>;

Expand All @@ -116,6 +129,7 @@ declare module 'replicate' {
create(options: {
version: string;
input: object;
stream?: boolean;
webhook?: string;
webhook_events_filter?: WebhookEventType[];
}): Promise<Prediction>;
Expand Down
74 changes: 44 additions & 30 deletions index.js
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,17 @@ class Replicate {
* @param {string} identifier - Required. The model version identifier in the format "{owner}/{name}:{version}"
* @param {object} options
* @param {object} options.input - Required. An object with the model inputs
* @param {object} [options.wait] - Whether to wait for the prediction to finish. Defaults to false
* @param {object} [options.wait] - Options for waiting for the prediction to finish
* @param {number} [options.wait.interval] - Polling interval in milliseconds. Defaults to 250
* @param {number} [options.wait.maxAttempts] - Maximum number of polling attempts. Defaults to no limit
* @param {number} [options.wait.max_attempts] - Maximum number of polling attempts. Defaults to no limit
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output
* @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
* @throws {Error} If the prediction failed
* @returns {Promise<object>} - Resolves with the output of running the model
*/
async run(identifier, options) {
const { wait, ...data } = options;

// Define a pattern for owner and model names that allows
// letters, digits, and certain special characters.
// Example: "user123", "abc__123", "user.name"
Expand All @@ -108,12 +110,14 @@ class Replicate {
}

const { version } = match.groups;
const prediction = await this.predictions.create({
wait: true,
...options,

let prediction = await this.predictions.create({
...data,
version,
});

prediction = await this.wait(prediction, wait || {});

if (prediction.status === 'failed') {
throw new Error(`Prediction failed: ${prediction.error}`);
}
Expand All @@ -125,43 +129,53 @@ class Replicate {
* Make a request to the Replicate API.
*
* @param {string} route - REST API endpoint path
* @param {object} parameters - Request parameters
* @param {string} [parameters.method] - HTTP method. Defaults to GET
* @param {object} [parameters.params] - Query parameters
* @param {object} [parameters.data] - Body parameters
* @returns {Promise<object>} - Resolves with the API response data
* @param {object} options - Request parameters
* @param {string} [options.method] - HTTP method. Defaults to GET
* @param {object} [options.params] - Query parameters
* @param {object|Headers} [options.headers] - HTTP headers
* @param {object} [options.data] - Body parameters
* @returns {Promise<Response>} - Resolves with the response object
* @throws {ApiError} If the request failed
*/
async request(route, parameters) {
async request(route, options) {
const { auth, baseUrl, userAgent } = this;

const url = new URL(
route.startsWith('/') ? route.slice(1) : route,
baseUrl.endsWith('/') ? baseUrl : `${baseUrl}/`
);
let url;
if (route instanceof URL) {
url = route;
} else {
url = new URL(
route.startsWith('/') ? route.slice(1) : route,
baseUrl.endsWith('/') ? baseUrl : `${baseUrl}/`
);
}

const { method = 'GET', params = {}, data } = parameters;
const { method = 'GET', params = {}, data } = options;

Object.entries(params).forEach(([key, value]) => {
url.searchParams.append(key, value);
});

const headers = {
Authorization: `Token ${auth}`,
'Content-Type': 'application/json',
'User-Agent': userAgent,
};
const headers = new Headers();
headers.append('Authorization', `Token ${auth}`);
headers.append('Content-Type', 'application/json');
headers.append('User-Agent', userAgent);
if (options.headers) {
options.headers.forEach((value, key) => {
headers.append(key, value);
});
}

const options = {
const init = {
method,
headers,
body: data ? JSON.stringify(data) : undefined,
};

const response = await this.fetch(url, options);
const response = await this.fetch(url, init);

if (!response.ok) {
const request = new Request(url, options);
const request = new Request(url, init);
const responseText = await response.text();
throw new ApiError(
`Request to ${url} failed with status ${response.status} ${response.statusText}: ${responseText}.`,
Expand All @@ -170,7 +184,7 @@ class Replicate {
);
}

return response.json();
return response;
}

/**
Expand All @@ -188,7 +202,7 @@ class Replicate {
const response = await endpoint();
yield response.results;
if (response.next) {
const nextPage = () => this.request(response.next, { method: 'GET' });
const nextPage = () => this.request(response.next, { method: 'GET' }).then((r) => r.json());
yield* this.paginate(nextPage);
}
}
Expand All @@ -204,7 +218,7 @@ class Replicate {
* @param {object} prediction - Prediction object
* @param {object} options - Options
* @param {number} [options.interval] - Polling interval in milliseconds. Defaults to 250
* @param {number} [options.maxAttempts] - Maximum number of polling attempts. Defaults to no limit
* @param {number} [options.max_attempts] - Maximum number of polling attempts. Defaults to no limit
* @throws {Error} If the prediction doesn't complete within the maximum number of attempts
* @throws {Error} If the prediction failed
* @returns {Promise<object>} Resolves with the completed prediction object
Expand All @@ -230,17 +244,17 @@ class Replicate {

let attempts = 0;
const interval = options.interval || 250;
const maxAttempts = options.maxAttempts || null;
const max_attempts = options.max_attempts || null;

while (
updatedPrediction.status !== 'succeeded' &&
updatedPrediction.status !== 'failed' &&
updatedPrediction.status !== 'canceled'
) {
attempts += 1;
if (maxAttempts && attempts > maxAttempts) {
if (max_attempts && attempts > max_attempts) {
throw new Error(
`Prediction ${id} did not finish after ${maxAttempts} attempts`
`Prediction ${id} did not finish after ${max_attempts} attempts`
);
}

Expand Down
20 changes: 19 additions & 1 deletion index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,24 @@ describe('Replicate client', () => {
expect(prediction.id).toBe('ufawqhfynnddngldkgtslldrkq');
});

test('Passes stream parameter to API endpoint', async () => {
nock(BASE_URL)
.post('/predictions')
.reply(201, (_uri, body) => {
expect(body[ 'stream' ]).toBe(true);
return body
})

await client.predictions.create({
version:
'5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa',
input: {
prompt: 'Tell me a story',
},
stream: true
});
});

test('Throws an error if webhook URL is invalid', async () => {
await expect(async () => {
await client.predictions.create({
Expand Down Expand Up @@ -506,7 +524,7 @@ describe('Replicate client', () => {
status: 'processing',
})
.get('/predictions/ufawqhfynnddngldkgtslldrkq')
.reply(200, {
.reply(201, {
id: 'ufawqhfynnddngldkgtslldrkq',
status: 'succeeded',
output: 'foobar',
Expand Down
8 changes: 6 additions & 2 deletions lib/collections.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
* @returns {Promise<object>} - Resolves with the collection data
*/
async function getCollection(collection_slug) {
return this.request(`/collections/${collection_slug}`, {
const response = await this.request(`/collections/${collection_slug}`, {
method: 'GET',
});

return response.json();
}

/**
Expand All @@ -16,9 +18,11 @@ async function getCollection(collection_slug) {
* @returns {Promise<object>} - Resolves with the collections data
*/
async function listCollections() {
return this.request('/collections', {
const response = await this.request('/collections', {
method: 'GET',
});

return response.json();
}

module.exports = { get: getCollection, list: listCollections };
3 changes: 3 additions & 0 deletions lib/error.js
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
/**
* A representation of an API error.
*/
class ApiError extends Error {
/**
* Creates a representation of an API error.
Expand Down
12 changes: 9 additions & 3 deletions lib/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
* @returns {Promise<object>} Resolves with the model data
*/
async function getModel(model_owner, model_name) {
return this.request(`/models/${model_owner}/${model_name}`, {
const response = await this.request(`/models/${model_owner}/${model_name}`, {
method: 'GET',
});

return response.json();
}

/**
Expand All @@ -19,9 +21,11 @@ async function getModel(model_owner, model_name) {
* @returns {Promise<object>} Resolves with the list of model versions
*/
async function listModelVersions(model_owner, model_name) {
return this.request(`/models/${model_owner}/${model_name}/versions`, {
const response = await this.request(`/models/${model_owner}/${model_name}/versions`, {
method: 'GET',
});

return response.json();
}

/**
Expand All @@ -33,12 +37,14 @@ async function listModelVersions(model_owner, model_name) {
* @returns {Promise<object>} Resolves with the model version data
*/
async function getModelVersion(model_owner, model_name, version_id) {
return this.request(
const response = await this.request(
`/models/${model_owner}/${model_name}/versions/${version_id}`,
{
method: 'GET',
}
);

return response.json();
}

module.exports = {
Expand Down
Loading