Skip to content

Commit f5ddd9f

Browse files
committed
Implement streaming responses for run and predictions.create endpoints
1 parent 3cea681 commit f5ddd9f

File tree

4 files changed

+103
-23
lines changed

4 files changed

+103
-23
lines changed

index.d.ts

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,21 @@ declare module 'replicate' {
5858

5959
export type Training = Prediction;
6060

61+
interface RunOptions {
62+
input: any;
63+
wait?: {
64+
interval?: number;
65+
maxAttempts?: number;
66+
};
67+
webhook?: string;
68+
webhook_events_filter?: string[];
69+
stream?: boolean;
70+
}
71+
72+
interface CreatePredictionOptions extends RunOptions {
73+
version: string;
74+
}
75+
6176
export default class Replicate {
6277
constructor(options: {
6378
auth: string;
@@ -71,15 +86,10 @@ declare module 'replicate' {
7186
baseUrl?: string;
7287
fetch: Function;
7388

74-
run(
75-
identifier: `${string}/${string}:${string}`,
76-
options: {
77-
input: object;
78-
wait?: boolean | { interval?: number; maxAttempts?: number };
79-
webhook?: string;
80-
webhook_events_filter?: WebhookEventType[];
81-
}
82-
): Promise<object>;
89+
90+
run(identifier: string, options: RunOptions & { stream: true }): Promise<AsyncIterable<any>>;
91+
run(identifier: string, options: RunOptions): Promise<object>;
92+
8393
request(route: string, parameters: any): Promise<any>;
8494
paginate<T>(endpoint: () => Promise<Page<T>>): AsyncGenerator<[ T ]>;
8595
wait(
@@ -108,12 +118,8 @@ declare module 'replicate' {
108118
};
109119

110120
predictions: {
111-
create(options: {
112-
version: string;
113-
input: object;
114-
webhook?: string;
115-
webhook_events_filter?: WebhookEventType[];
116-
}): Promise<Prediction>;
121+
create(options: CreatePredictionOptions & { stream: true }): Promise<AsyncIterable<any>>;
122+
create(options: CreatePredictionOptions): Promise<Prediction>;
117123
get(prediction_id: string): Promise<Prediction>;
118124
cancel(prediction_id: string): Promise<Prediction>;
119125
list(): Promise<Page<Prediction>>;

index.js

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ const models = require('./lib/models');
33
const predictions = require('./lib/predictions');
44
const trainings = require('./lib/trainings');
55
const packageJSON = require('./package.json');
6+
const { handleSSE } = require('./lib/util');
67

78
/**
89
* Replicate API client library
@@ -74,6 +75,7 @@ class Replicate {
7475
/**
7576
* Run a model and wait for its output.
7677
*
78+
* @typedef AsyncIterable
7779
* @param {string} identifier - Required. The model version identifier in the format "{owner}/{name}:{version}"
7880
* @param {object} options
7981
* @param {object} options.input - Required. An object with the model inputs
@@ -82,8 +84,9 @@ class Replicate {
8284
* @param {number} [options.wait.maxAttempts] - Maximum number of polling attempts. Defaults to no limit
8385
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output
8486
* @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
87+
* @param {boolean} [options.stream] - Whether to stream the prediction output. Defaults to false
8588
* @throws {Error} If the prediction failed
86-
* @returns {Promise<object>} - Resolves with the output of running the model
89+
* @returns {Promise<object>|Promise<AsyncIterable<any>>} - Resolves with the output of running the model, or a stream of output objects if `options.stream` is `true`
8790
*/
8891
async run(identifier, options) {
8992
// Define a pattern for owner and model names that allows
@@ -105,6 +108,14 @@ class Replicate {
105108
}
106109

107110
const { version } = match.groups;
111+
112+
if (options.stream) {
113+
return this.predictions.create({
114+
...options,
115+
version,
116+
});
117+
}
118+
108119
const prediction = await this.predictions.create({
109120
wait: true,
110121
...options,
@@ -125,6 +136,7 @@ class Replicate {
125136
* @param {object} parameters - Request parameters
126137
* @param {string} [parameters.method] - HTTP method. Defaults to GET
127138
* @param {object} [parameters.params] - Query parameters
139+
* @param {object|Headers} [parameters.headers] - HTTP headers
128140
* @param {object} [parameters.data] - Body parameters
129141
* @returns {Promise<object>} - Resolves with the API response data
130142
* @throws {Error} If the request failed
@@ -143,11 +155,15 @@ class Replicate {
143155
url.searchParams.append(key, value);
144156
});
145157

146-
const headers = {
147-
Authorization: `Token ${auth}`,
148-
'Content-Type': 'application/json',
149-
'User-Agent': userAgent,
150-
};
158+
const headers = new Headers();
159+
headers.append('Authorization', `Token ${auth}`);
160+
headers.append('Content-Type', 'application/json');
161+
headers.append('User-Agent', userAgent);
162+
if (parameters.headers) {
163+
parameters.headers.forEach((value, key) => {
164+
headers.append(key, value);
165+
});
166+
}
151167

152168
const response = await this.fetch(url, {
153169
method,
@@ -159,6 +175,11 @@ class Replicate {
159175
throw new Error(`API request failed: ${response.statusText}`);
160176
}
161177

178+
const contentType = response.headers.get('content-type');
179+
if (contentType && contentType.startsWith('text/event-stream')) {
180+
return handleSSE(response);
181+
}
182+
162183
return response.json();
163184
}
164185

lib/predictions.js

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
/**
22
* Create a new prediction
33
*
4+
* @typedef AsyncIterable
45
* @param {object} options
56
* @param {string} options.version - Required. The model version
67
* @param {object} options.input - Required. An object with the model inputs
@@ -9,10 +10,15 @@
910
* @param {number} [options.wait.maxAttempts] - Maximum number of polling attempts. Defaults to no limit
1011
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output
1112
* @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
12-
* @returns {Promise<object>} Resolves with the created prediction data
13+
* @param {boolean} [options.stream] - Whether to stream the prediction output. Defaults to false
14+
* @returns {Promise<object>|AsyncIterable<any>} Resolves with the created prediction, or a stream of prediction output if `options.stream` is `true`
1315
*/
1416
async function createPrediction(options) {
15-
const { wait, ...data } = options;
17+
const { wait, stream, ...data } = options;
18+
19+
if (stream && wait) {
20+
throw new Error('Incompatible options: stream and wait');
21+
}
1622

1723
if (data.webhook) {
1824
try {
@@ -23,9 +29,15 @@ async function createPrediction(options) {
2329
}
2430
}
2531

32+
const headers = new Headers();
33+
if (stream) {
34+
headers.append('Accept', 'text/event-stream');
35+
}
36+
2637
const prediction = this.request('/predictions', {
2738
method: 'POST',
2839
data,
40+
headers,
2941
});
3042

3143
if (wait) {

lib/util.js

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/**
2+
* An async iterable that yields Server-Sent Event messages from a response.
3+
* @param {Response} response - SSE response
4+
* @yields {object} - SSE message
5+
*/
6+
async function* handleSSE(response) {
7+
const reader = response.body.getReader();
8+
const decoder = new TextDecoder('utf-8');
9+
let data = '';
10+
11+
for (; ;) {
12+
const { done, chunk } = await reader.read(); // eslint-disable-line no-await-in-loop
13+
if (done) {
14+
break;
15+
}
16+
17+
data += decoder.decode(chunk, { stream: true });
18+
19+
let eventEndIndex;
20+
// eslint-disable-next-line no-cond-assign
21+
while ((eventEndIndex = data.indexOf('\n\n')) >= 0) {
22+
const event = data.slice(0, eventEndIndex);
23+
data = data.slice(eventEndIndex + 2);
24+
25+
const message = {};
26+
event.split('\n').forEach((line) => {
27+
const [field, ...rest] = line.split(':');
28+
const value = rest.join(':').trim();
29+
if (field) {
30+
message[field] = value;
31+
}
32+
});
33+
34+
yield message;
35+
}
36+
}
37+
}
38+
39+
module.exports = {
40+
handleSSE
41+
};

0 commit comments

Comments
 (0)