Skip to content

Commit 95cda80

Browse files
committed
Allow replicate.predictions.create to accept version or model options
1 parent cbde2c1 commit 95cda80

File tree

2 files changed

+26
-13
lines changed

2 files changed

+26
-13
lines changed

index.d.ts

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -203,13 +203,15 @@ declare module "replicate" {
203203
};
204204

205205
predictions: {
206-
create(options: {
207-
version: string;
208-
input: object;
209-
stream?: boolean;
210-
webhook?: string;
211-
webhook_events_filter?: WebhookEventType[];
212-
}): Promise<Prediction>;
206+
create(
207+
options: {
208+
version: string;
209+
input: object;
210+
stream?: boolean;
211+
webhook?: string;
212+
webhook_events_filter?: WebhookEventType[];
213+
} & ({ version: string } | { model: string })
214+
): Promise<Prediction>;
213215
get(prediction_id: string): Promise<Prediction>;
214216
cancel(prediction_id: string): Promise<Prediction>;
215217
list(): Promise<Page<Prediction>>;

lib/predictions.js

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,16 @@
22
* Create a new prediction
33
*
44
* @param {object} options
5-
* @param {string} options.version - Required. The model version
5+
* @param {string} options.model - The model.
6+
* @param {string} options.version - The model version.
67
* @param {object} options.input - Required. An object with the model inputs
78
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output
89
* @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
910
* @param {boolean} [options.stream] - Whether to stream the prediction output. Defaults to false
1011
* @returns {Promise<object>} Resolves with the created prediction
1112
*/
1213
async function createPrediction(options) {
13-
const { stream, ...data } = options;
14+
const { model, version, stream, ...data } = options;
1415

1516
if (data.webhook) {
1617
try {
@@ -21,10 +22,20 @@ async function createPrediction(options) {
2122
}
2223
}
2324

24-
const response = await this.request("/predictions", {
25-
method: "POST",
26-
data: { ...data, stream },
27-
});
25+
let response;
26+
if (version) {
27+
response = await this.request("/predictions", {
28+
method: "POST",
29+
data: { ...data, stream, version },
30+
});
31+
} else if (model) {
32+
response = await this.request(`/models/${model}/predictions`, {
33+
method: "POST",
34+
data: { ...data, stream },
35+
});
36+
} else {
37+
throw new Error("Either model or version must be specified");
38+
}
2839

2940
return response.json();
3041
}

0 commit comments

Comments
 (0)