Skip to content

Commit a5374ad

Browse files
committed
Merge branch 'main' into stream
2 parents 6359a74 + 6343c52 commit a5374ad

File tree

5 files changed

+81
-9
lines changed

5 files changed

+81
-9
lines changed

README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,14 @@ const response = await replicate.predictions.list();
403403

404404
### `replicate.trainings.create`
405405

406+
Use the training API to fine-tune language models
407+
to make them better at a particular task.
408+
To see what **language models** currently support fine-tuning,
409+
check out Replicate's [collection of trainable language models](https://replicate.com/collections/trainable-language-models).
410+
411+
If you're looking to fine-tune **image models**,
412+
check out Replicate's [guide to fine-tuning image models](https://replicate.com/docs/guides/fine-tune-an-image-model).
413+
406414
```js
407415
const response = await replicate.trainings.create(model_owner, model_name, version_id, options);
408416
```
@@ -434,6 +442,10 @@ const response = await replicate.trainings.create(model_owner, model_name, versi
434442
}
435443
```
436444

445+
> **Warning**
446+
> If you try to fine-tune a model that doesn't support training,
447+
> you'll get a `400 Bad Request` response from the server.
448+
437449
### `replicate.trainings.get`
438450

439451
```js

index.d.ts

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@ declare module 'replicate' {
22
type Status = 'starting' | 'processing' | 'succeeded' | 'failed' | 'canceled';
33
type WebhookEventType = 'start' | 'output' | 'logs' | 'completed';
44

5-
interface Page<T> {
6-
previous?: string;
7-
next?: string;
8-
results: T[];
5+
export interface ApiError extends Error {
6+
request: Request;
7+
response: Response;
98
}
109

1110
export interface Collection {
@@ -63,6 +62,12 @@ declare module 'replicate' {
6362

6463
export type Training = Prediction;
6564

65+
interface Page<T> {
66+
previous?: string;
67+
next?: string;
68+
results: T[];
69+
}
70+
6671
export default class Replicate {
6772
constructor(options: {
6873
auth: string;

index.js

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
const ApiError = require('./lib/error');
2+
13
const collections = require('./lib/collections');
24
const models = require('./lib/models');
35
const predictions = require('./lib/predictions');
46
const trainings = require('./lib/trainings');
7+
58
const packageJSON = require('./package.json');
69

710
/**
@@ -132,7 +135,7 @@ class Replicate {
132135
* @param {object|Headers} [options.headers] - HTTP headers
133136
* @param {object} [options.data] - Body parameters
134137
* @returns {Promise<Response>} - Resolves with the response object
135-
* @throws {Error} If the request failed
138+
* @throws {ApiError} If the request failed
136139
*/
137140
async request(route, options) {
138141
const { auth, baseUrl, userAgent } = this;
@@ -163,14 +166,22 @@ class Replicate {
163166
});
164167
}
165168

166-
const response = await this.fetch(url, {
169+
const init = {
167170
method,
168171
headers,
169172
body: data ? JSON.stringify(data) : undefined,
170-
});
173+
};
174+
175+
const response = await this.fetch(url, init);
171176

172177
if (!response.ok) {
173-
throw new Error(`API request failed: ${response.statusText}`);
178+
const request = new Request(url, init);
179+
const responseText = await response.text();
180+
throw new ApiError(
181+
`Request to ${url} failed with status ${response.status} ${response.statusText}: ${responseText}.`,
182+
request,
183+
response,
184+
);
174185
}
175186

176187
return response;
@@ -187,7 +198,7 @@ class Replicate {
187198
* @param {Function} endpoint - Function that returns a promise for the next page of results
188199
* @yields {object[]} Each page of results
189200
*/
190-
async *paginate(endpoint) {
201+
async * paginate(endpoint) {
191202
const response = await endpoint();
192203
yield response.results;
193204
if (response.next) {

index.test.ts

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,29 @@ describe('Replicate client', () => {
181181
});
182182
}).rejects.toThrow('Invalid webhook URL');
183183
});
184+
185+
test('Throws an error with details failing response is JSON', async () => {
186+
nock(BASE_URL)
187+
.post('/predictions')
188+
.reply(400, {
189+
status: 400,
190+
detail: "Invalid input",
191+
}, { "Content-Type": "application/json" })
192+
193+
try {
194+
expect.assertions(2);
195+
196+
await client.predictions.create({
197+
version: '5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa',
198+
input: {
199+
text: null,
200+
},
201+
});
202+
} catch (error) {
203+
expect(error.response.status).toBe(400);
204+
expect(error.message).toContain("Invalid input")
205+
}
206+
})
184207
// Add more tests for error handling, edge cases, etc.
185208
});
186209

lib/error.js

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
/**
2+
* A representation of an API error.
3+
*/
4+
class ApiError extends Error {
5+
/**
6+
* Creates a representation of an API error.
7+
*
8+
* @param {string} message - Error message
9+
* @param {Request} request - HTTP request
10+
* @param {Response} response - HTTP response
11+
* @returns {ApiError} - An instance of ApiError
12+
*/
13+
constructor(message, request, response) {
14+
super(message);
15+
this.name = 'ApiError';
16+
this.request = request;
17+
this.response = response;
18+
}
19+
}
20+
21+
module.exports = ApiError;

0 commit comments

Comments
 (0)