Skip to content

Commit 9ccdf48

Browse files
committed
Update request method to return response instead of JSON
1 parent 398cca5 commit 9ccdf48

File tree

6 files changed

+77
-38
lines changed

6 files changed

+77
-38
lines changed

index.d.ts

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,14 @@ declare module 'replicate' {
8585
webhook_events_filter?: WebhookEventType[];
8686
}
8787
): Promise<object>;
88-
request(route: string, parameters: any): Promise<any>;
88+
89+
request(route: string | URL, options: {
90+
method?: string;
91+
headers?: object | Headers;
92+
params?: object;
93+
data?: object;
94+
}): Promise<Response>;
95+
8996
paginate<T>(endpoint: () => Promise<Page<T>>): AsyncGenerator<[ T ]>;
9097
wait(
9198
prediction: Prediction,

index.js

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -125,43 +125,53 @@ class Replicate {
125125
* Make a request to the Replicate API.
126126
*
127127
* @param {string} route - REST API endpoint path
128-
* @param {object} parameters - Request parameters
129-
* @param {string} [parameters.method] - HTTP method. Defaults to GET
130-
* @param {object} [parameters.params] - Query parameters
131-
* @param {object} [parameters.data] - Body parameters
132-
* @returns {Promise<object>} - Resolves with the API response data
128+
* @param {object} options - Request parameters
129+
* @param {string} [options.method] - HTTP method. Defaults to GET
130+
* @param {object} [options.params] - Query parameters
131+
* @param {object|Headers} [options.headers] - HTTP headers
132+
* @param {object} [options.data] - Body parameters
133+
* @returns {Promise<Response>} - Resolves with the response object
133134
* @throws {ApiError} If the request failed
134135
*/
135-
async request(route, parameters) {
136+
async request(route, options) {
136137
const { auth, baseUrl, userAgent } = this;
137138

138-
const url = new URL(
139-
route.startsWith('/') ? route.slice(1) : route,
140-
baseUrl.endsWith('/') ? baseUrl : `${baseUrl}/`
141-
);
139+
let url;
140+
if (route instanceof URL) {
141+
url = route;
142+
} else {
143+
url = new URL(
144+
route.startsWith('/') ? route.slice(1) : route,
145+
baseUrl.endsWith('/') ? baseUrl : `${baseUrl}/`
146+
);
147+
}
142148

143-
const { method = 'GET', params = {}, data } = parameters;
149+
const { method = 'GET', params = {}, data } = options;
144150

145151
Object.entries(params).forEach(([key, value]) => {
146152
url.searchParams.append(key, value);
147153
});
148154

149-
const headers = {
150-
Authorization: `Token ${auth}`,
151-
'Content-Type': 'application/json',
152-
'User-Agent': userAgent,
153-
};
155+
const headers = new Headers();
156+
headers.append('Authorization', `Token ${auth}`);
157+
headers.append('Content-Type', 'application/json');
158+
headers.append('User-Agent', userAgent);
159+
if (options.headers) {
160+
options.headers.forEach((value, key) => {
161+
headers.append(key, value);
162+
});
163+
}
154164

155-
const options = {
165+
const init = {
156166
method,
157167
headers,
158168
body: data ? JSON.stringify(data) : undefined,
159169
};
160170

161-
const response = await this.fetch(url, options);
171+
const response = await this.fetch(url, init);
162172

163173
if (!response.ok) {
164-
const request = new Request(url, options);
174+
const request = new Request(url, init);
165175
const responseText = await response.text();
166176
throw new ApiError(
167177
`Request to ${url} failed with status ${response.status} ${response.statusText}: ${responseText}.`,
@@ -170,7 +180,7 @@ class Replicate {
170180
);
171181
}
172182

173-
return response.json();
183+
return response;
174184
}
175185

176186
/**
@@ -188,7 +198,7 @@ class Replicate {
188198
const response = await endpoint();
189199
yield response.results;
190200
if (response.next) {
191-
const nextPage = () => this.request(response.next, { method: 'GET' });
201+
const nextPage = () => this.request(response.next, { method: 'GET' }).then((r) => r.json());
192202
yield* this.paginate(nextPage);
193203
}
194204
}

lib/collections.js

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
* @returns {Promise<object>} - Resolves with the collection data
66
*/
77
async function getCollection(collection_slug) {
8-
return this.request(`/collections/${collection_slug}`, {
8+
const response = await this.request(`/collections/${collection_slug}`, {
99
method: 'GET',
1010
});
11+
12+
return response.json();
1113
}
1214

1315
/**
@@ -16,9 +18,11 @@ async function getCollection(collection_slug) {
1618
* @returns {Promise<object>} - Resolves with the collections data
1719
*/
1820
async function listCollections() {
19-
return this.request('/collections', {
21+
const response = await this.request('/collections', {
2022
method: 'GET',
2123
});
24+
25+
return response.json();
2226
}
2327

2428
module.exports = { get: getCollection, list: listCollections };

lib/models.js

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
* @returns {Promise<object>} Resolves with the model data
77
*/
88
async function getModel(model_owner, model_name) {
9-
return this.request(`/models/${model_owner}/${model_name}`, {
9+
const response = await this.request(`/models/${model_owner}/${model_name}`, {
1010
method: 'GET',
1111
});
12+
13+
return response.json();
1214
}
1315

1416
/**
@@ -19,9 +21,11 @@ async function getModel(model_owner, model_name) {
1921
* @returns {Promise<object>} Resolves with the list of model versions
2022
*/
2123
async function listModelVersions(model_owner, model_name) {
22-
return this.request(`/models/${model_owner}/${model_name}/versions`, {
24+
const response = await this.request(`/models/${model_owner}/${model_name}/versions`, {
2325
method: 'GET',
2426
});
27+
28+
return response.json();
2529
}
2630

2731
/**
@@ -33,12 +37,14 @@ async function listModelVersions(model_owner, model_name) {
3337
* @returns {Promise<object>} Resolves with the model version data
3438
*/
3539
async function getModelVersion(model_owner, model_name, version_id) {
36-
return this.request(
40+
const response = await this.request(
3741
`/models/${model_owner}/${model_name}/versions/${version_id}`,
3842
{
3943
method: 'GET',
4044
}
4145
);
46+
47+
return response.json();
4248
}
4349

4450
module.exports = {

lib/predictions.js

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ async function createPrediction(options) {
2323
}
2424
}
2525

26-
const prediction = this.request('/predictions', {
26+
const response = await this.request('/predictions', {
2727
method: 'POST',
2828
data,
2929
});
@@ -35,10 +35,10 @@ async function createPrediction(options) {
3535
const sleep = (ms) => new Promise((resolve) => setTimeout(resolve, ms));
3636
await sleep(interval || 250);
3737

38-
return this.wait(await prediction, { maxAttempts, interval });
38+
return this.wait(await response, { maxAttempts, interval });
3939
}
4040

41-
return prediction;
41+
return response.json();
4242
}
4343

4444
/**
@@ -48,9 +48,11 @@ async function createPrediction(options) {
4848
* @returns {Promise<object>} Resolves with the prediction data
4949
*/
5050
async function getPrediction(prediction_id) {
51-
return this.request(`/predictions/${prediction_id}`, {
51+
const response = await this.request(`/predictions/${prediction_id}`, {
5252
method: 'GET',
5353
});
54+
55+
return response.json();
5456
}
5557

5658
/**
@@ -60,9 +62,11 @@ async function getPrediction(prediction_id) {
6062
* @returns {Promise<object>} Resolves with the data for the training
6163
*/
6264
async function cancelPrediction(prediction_id) {
63-
return this.request(`/predictions/${prediction_id}/cancel`, {
65+
const response = await this.request(`/predictions/${prediction_id}/cancel`, {
6466
method: 'POST',
6567
});
68+
69+
return response.json();
6670
}
6771

6872
/**
@@ -71,9 +75,11 @@ async function cancelPrediction(prediction_id) {
7175
* @returns {Promise<object>} - Resolves with a page of predictions
7276
*/
7377
async function listPredictions() {
74-
return this.request('/predictions', {
78+
const response = await this.request('/predictions', {
7579
method: 'GET',
7680
});
81+
82+
return response.json();
7783
}
7884

7985
module.exports = {

lib/trainings.js

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@ async function createTraining(model_owner, model_name, version_id, options) {
2323
}
2424
}
2525

26-
const training = this.request(`/models/${model_owner}/${model_name}/versions/${version_id}/trainings`, {
26+
const response = await this.request(`/models/${model_owner}/${model_name}/versions/${version_id}/trainings`, {
2727
method: 'POST',
2828
data,
2929
});
3030

31-
return training;
31+
return response.json();
3232
}
3333

3434
/**
@@ -38,9 +38,11 @@ async function createTraining(model_owner, model_name, version_id, options) {
3838
* @returns {Promise<object>} Resolves with the data for the training
3939
*/
4040
async function getTraining(training_id) {
41-
return this.request(`/trainings/${training_id}`, {
41+
const response = await this.request(`/trainings/${training_id}`, {
4242
method: 'GET',
4343
});
44+
45+
return response.json();
4446
}
4547

4648
/**
@@ -50,9 +52,11 @@ async function getTraining(training_id) {
5052
* @returns {Promise<object>} Resolves with the data for the training
5153
*/
5254
async function cancelTraining(training_id) {
53-
return this.request(`/trainings/${training_id}/cancel`, {
55+
const response = await this.request(`/trainings/${training_id}/cancel`, {
5456
method: 'POST',
5557
});
58+
59+
return response.json();
5660
}
5761

5862
/**
@@ -61,9 +65,11 @@ async function cancelTraining(training_id) {
6165
* @returns {Promise<object>} - Resolves with a page of trainings
6266
*/
6367
async function listTrainings() {
64-
return this.request('/trainings', {
68+
const response = await this.request('/trainings', {
6569
method: 'GET',
6670
});
71+
72+
return response.json();
6773
}
6874

6975
module.exports = {

0 commit comments

Comments
 (0)