Skip to content

Commit ab2d57c

Browse files
committed
Implement Stream interface with prediction property
1 parent 386a7e6 commit ab2d57c

File tree

9 files changed

+174
-75
lines changed

9 files changed

+174
-75
lines changed

index.d.ts

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,16 @@ declare module 'replicate' {
88
results: T[];
99
}
1010

11+
interface Stream {
12+
response: Response;
13+
reader: ReadableStream;
14+
decoder: TextDecoder;
15+
output: string;
16+
prediction?: Promise<Prediction>;
17+
next(): Promise<{ value: string, done: boolean } | { done: true }>;
18+
[ Symbol.asyncIterator ](): this;
19+
}
20+
1121
export interface Collection {
1222
name: string;
1323
slug: string;
@@ -89,7 +99,7 @@ declare module 'replicate' {
8999
fetch: Function;
90100

91101

92-
run(identifier: RunIdentifier, options: RunOptions & { stream: true }): Promise<AsyncIterable<any>>;
102+
run(identifier: RunIdentifier, options: RunOptions & { stream: true }): Promise<Stream>;
93103
run(identifier: RunIdentifier, options: RunOptions): Promise<object>;
94104

95105
request(route: string, parameters: any): Promise<any>;
@@ -120,7 +130,7 @@ declare module 'replicate' {
120130
};
121131

122132
predictions: {
123-
create(options: CreatePredictionOptions & { stream: true }): Promise<AsyncIterable<any>>;
133+
create(options: CreatePredictionOptions & { stream: true }): Promise<Stream>;
124134
create(options: CreatePredictionOptions): Promise<Prediction>;
125135
get(prediction_id: string): Promise<Prediction>;
126136
cancel(prediction_id: string): Promise<Prediction>;

index.js

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ const models = require('./lib/models');
33
const predictions = require('./lib/predictions');
44
const trainings = require('./lib/trainings');
55
const packageJSON = require('./package.json');
6-
const { handleSSE } = require('./lib/util');
76

87
/**
98
* Replicate API client library
@@ -138,16 +137,27 @@ class Replicate {
138137
* @param {object} [parameters.params] - Query parameters
139138
* @param {object|Headers} [parameters.headers] - HTTP headers
140139
* @param {object} [parameters.data] - Body parameters
141-
* @returns {Promise<object>} - Resolves with the API response data
140+
* @returns {Promise<Response>} - Resolves with the response object
142141
* @throws {Error} If the request failed
143142
*/
144143
async request(route, parameters) {
145144
const { auth, baseUrl, userAgent } = this;
146145

147-
const url = new URL(
148-
route.startsWith('/') ? route.slice(1) : route,
149-
baseUrl.endsWith('/') ? baseUrl : `${baseUrl}/`
150-
);
146+
let url;
147+
if (route instanceof URL) {
148+
if (route.origin !== baseUrl) {
149+
throw new Error(
150+
`Invalid URL: ${route.origin} does not match ${baseUrl}`
151+
);
152+
}
153+
154+
url = route;
155+
} else {
156+
url = new URL(
157+
route.startsWith('/') ? route.slice(1) : route,
158+
baseUrl.endsWith('/') ? baseUrl : `${baseUrl}/`
159+
);
160+
}
151161

152162
const { method = 'GET', params = {}, data } = parameters;
153163

@@ -175,12 +185,7 @@ class Replicate {
175185
throw new Error(`API request failed: ${response.statusText}`);
176186
}
177187

178-
const contentType = response.headers.get('content-type');
179-
if (contentType && contentType.startsWith('text/event-stream')) {
180-
return handleSSE(response);
181-
}
182-
183-
return response.json();
188+
return response;
184189
}
185190

186191
/**
@@ -198,7 +203,7 @@ class Replicate {
198203
const response = await endpoint();
199204
yield response.results;
200205
if (response.next) {
201-
const nextPage = () => this.request(response.next, { method: 'GET' });
206+
const nextPage = () => this.request(response.next, { method: 'GET' }).then((r) => r.json());
202207
yield* this.paginate(nextPage);
203208
}
204209
}

index.test.ts

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -518,7 +518,38 @@ describe('Replicate client', () => {
518518
[
519519
rest.post('https://api.replicate.com/v1/predictions/', (req, res, ctx) => {
520520
expect(req.headers.get('Accept')).toBe('text/event-stream');
521-
return res(ctx.status(201), ctx.text('data: Once upon a time\n\n'));
521+
return res(
522+
ctx.status(201),
523+
ctx.set('Content-Type', 'text/event-stream'),
524+
ctx.set('Link', '<https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq>; rel="self"'),
525+
ctx.body('data: Once upon a time\n\n'),
526+
);
527+
}),
528+
rest.get('https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq', (req, res, ctx) => {
529+
return res(
530+
ctx.status(200),
531+
ctx.json({
532+
id: 'ufawqhfynnddngldkgtslldrkq',
533+
version:
534+
'5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa',
535+
urls: {
536+
get: 'https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq',
537+
cancel:
538+
'https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel',
539+
},
540+
created_at: '2022-04-26T22:13:06.224088Z',
541+
started_at: null,
542+
completed_at: null,
543+
status: 'starting',
544+
input: {
545+
text: 'Alice',
546+
},
547+
output: null,
548+
error: null,
549+
logs: null,
550+
metrics: {},
551+
})
552+
);
522553
}),
523554
],
524555
async () => {
@@ -530,6 +561,10 @@ describe('Replicate client', () => {
530561
}
531562
);
532563

564+
const prediction = await stream.prediction;
565+
expect(prediction.id).toBe('ufawqhfynnddngldkgtslldrkq');
566+
expect(prediction.status).toBe('starting');
567+
533568
for await (const event of await stream) {
534569
expect(event).toBe('Once upon a time');
535570
}

lib/collections.js

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
* @returns {Promise<object>} - Resolves with the collection data
66
*/
77
async function getCollection(collection_slug) {
8-
return this.request(`/collections/${collection_slug}`, {
8+
const response = await this.request(`/collections/${collection_slug}`, {
99
method: 'GET',
1010
});
11+
12+
return response.json();
1113
}
1214

1315
/**
@@ -16,9 +18,11 @@ async function getCollection(collection_slug) {
1618
* @returns {Promise<object>} - Resolves with the collections data
1719
*/
1820
async function listCollections() {
19-
return this.request('/collections', {
21+
const response = await this.request('/collections', {
2022
method: 'GET',
2123
});
24+
25+
return response.json();
2226
}
2327

2428
module.exports = { get: getCollection, list: listCollections };

lib/models.js

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
* @returns {Promise<object>} Resolves with the model data
77
*/
88
async function getModel(model_owner, model_name) {
9-
return this.request(`/models/${model_owner}/${model_name}`, {
9+
const response = await this.request(`/models/${model_owner}/${model_name}`, {
1010
method: 'GET',
1111
});
12+
13+
return response.json();
1214
}
1315

1416
/**
@@ -19,9 +21,11 @@ async function getModel(model_owner, model_name) {
1921
* @returns {Promise<object>} Resolves with the list of model versions
2022
*/
2123
async function listModelVersions(model_owner, model_name) {
22-
return this.request(`/models/${model_owner}/${model_name}/versions`, {
24+
const response = await this.request(`/models/${model_owner}/${model_name}/versions`, {
2325
method: 'GET',
2426
});
27+
28+
return response.json();
2529
}
2630

2731
/**
@@ -33,12 +37,14 @@ async function listModelVersions(model_owner, model_name) {
3337
* @returns {Promise<object>} Resolves with the model version data
3438
*/
3539
async function getModelVersion(model_owner, model_name, version_id) {
36-
return this.request(
40+
const response = await this.request(
3741
`/models/${model_owner}/${model_name}/versions/${version_id}`,
3842
{
3943
method: 'GET',
4044
}
4145
);
46+
47+
return response.json();
4248
}
4349

4450
module.exports = {

lib/predictions.js

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
const { Stream } = require('./stream');
2+
13
/**
24
* Create a new prediction
35
*
4-
* @typedef AsyncIterable
56
* @param {object} options
67
* @param {string} options.version - Required. The model version
78
* @param {object} options.input - Required. An object with the model inputs
@@ -11,7 +12,7 @@
1112
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output
1213
* @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
1314
* @param {boolean} [options.stream] - Whether to stream the prediction output. Defaults to false
14-
* @returns {Promise<object>|AsyncIterable<any>} Resolves with the created prediction, or a stream of prediction output if `options.stream` is `true`
15+
* @returns {Promise<object>|Stream} Resolves with the created prediction, or a stream of prediction output if `options.stream` is `true`
1516
*/
1617
async function createPrediction(options) {
1718
const { wait, stream, ...data } = options;
@@ -34,23 +35,49 @@ async function createPrediction(options) {
3435
headers.append('Accept', 'text/event-stream');
3536
}
3637

37-
const prediction = this.request('/predictions', {
38+
const response = await this.request('/predictions', {
3839
method: 'POST',
3940
data,
4041
headers,
4142
});
4243

44+
if (stream) {
45+
const contentType = response.headers.get('content-type');
46+
if (contentType && contentType.startsWith('text/event-stream')) {
47+
const sse = Stream(response);
48+
49+
// Get the prediction ID from the Link header
50+
const link = response.headers.get('Link');
51+
const match = link.match(/<(.*)>; rel="self"/);
52+
if (match && match[1]) {
53+
const ref = response.headers.get('Link').match(/<(.*)>; rel="self"/)[1];
54+
try {
55+
const url = new URL(ref);
56+
const prediction_id = url.pathname.split('/').pop();
57+
58+
if (prediction_id) {
59+
Object.defineProperty(sse, 'prediction', async () => (this.getPrediction(prediction_id)));
60+
}
61+
} catch (err) {
62+
throw new Error('Unable to get prediction ID from response');
63+
}
64+
}
65+
66+
return sse;
67+
}
68+
}
69+
4370
if (wait) {
4471
const { maxAttempts, interval } = wait;
4572

4673
// eslint-disable-next-line no-promise-executor-return
4774
const sleep = (ms) => new Promise((resolve) => setTimeout(resolve, ms));
4875
await sleep(interval || 250);
4976

50-
return this.wait(await prediction, { maxAttempts, interval });
77+
return this.wait(await response.json(), { maxAttempts, interval });
5178
}
5279

53-
return prediction;
80+
return response.json();
5481
}
5582

5683
/**
@@ -60,9 +87,11 @@ async function createPrediction(options) {
6087
* @returns {Promise<object>} Resolves with the prediction data
6188
*/
6289
async function getPrediction(prediction_id) {
63-
return this.request(`/predictions/${prediction_id}`, {
90+
const response = await this.request(`/predictions/${prediction_id}`, {
6491
method: 'GET',
6592
});
93+
94+
return response.json();
6695
}
6796

6897
/**
@@ -72,9 +101,11 @@ async function getPrediction(prediction_id) {
72101
* @returns {Promise<object>} Resolves with the data for the training
73102
*/
74103
async function cancelPrediction(prediction_id) {
75-
return this.request(`/predictions/${prediction_id}/cancel`, {
104+
const response = await this.request(`/predictions/${prediction_id}/cancel`, {
76105
method: 'POST',
77106
});
107+
108+
return response.json();
78109
}
79110

80111
/**
@@ -83,9 +114,11 @@ async function cancelPrediction(prediction_id) {
83114
* @returns {Promise<object>} - Resolves with a page of predictions
84115
*/
85116
async function listPredictions() {
86-
return this.request('/predictions', {
117+
const response = await this.request('/predictions', {
87118
method: 'GET',
88119
});
120+
121+
return response.json();
89122
}
90123

91124
module.exports = {

lib/stream.js

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/**
2+
* An async iterable that yields Server-Sent Event messages from a response.
3+
*/
4+
class Stream {
5+
/**
6+
* Create a new Server-Sent Event stream.
7+
*
8+
* @param {Response} response - The server response
9+
*/
10+
constructor(response) {
11+
this.response = response;
12+
this.reader = response.body.getReader();
13+
this.decoder = new TextDecoder('utf-8');
14+
this.output = '';
15+
}
16+
17+
async next() {
18+
let eventEndIndex;
19+
while ((eventEndIndex = this.output.indexOf('\n\n')) < 0) { // eslint-disable-line no-cond-assign
20+
const { done, chunk } = await this.reader.read(); // eslint-disable-line no-await-in-loop
21+
if (done) {
22+
return { done: true };
23+
}
24+
this.output += this.decoder.decode(chunk, { stream: true });
25+
}
26+
27+
const event = this.output.slice(0, eventEndIndex);
28+
this.output = this.output.slice(eventEndIndex + 2);
29+
30+
const message = this.parseEvent(event);
31+
return { value: message, done: false };
32+
}
33+
34+
[Symbol.asyncIterator]() {
35+
return this;
36+
}
37+
}
38+
39+
module.exports = {
40+
Stream
41+
};

0 commit comments

Comments
 (0)