Skip to content

Commit c18c81e

Browse files
committed
Allow replicate.predictions.create to accept version or model options
1 parent 91c03f5 commit c18c81e

File tree

2 files changed

+43
-13
lines changed

2 files changed

+43
-13
lines changed

index.d.ts

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,13 @@ declare module "replicate" {
7575
results: T[];
7676
}
7777

78+
export interface ServerSentEvent {
79+
event: string;
80+
data: string;
81+
id?: string;
82+
retry?: number;
83+
}
84+
7885
export default class Replicate {
7986
constructor(options?: {
8087
auth?: string;
@@ -103,6 +110,16 @@ declare module "replicate" {
103110
progress?: (prediction: Prediction) => void
104111
): Promise<object>;
105112

113+
stream(
114+
identifier: `${string}/${string}` | `${string}/${string}:${string}`,
115+
options: {
116+
input: object;
117+
webhook?: string;
118+
webhook_events_filter?: WebhookEventType[];
119+
signal?: AbortSignal;
120+
}
121+
): AsyncGenerator<ServerSentEvent>;
122+
106123
request(
107124
route: string | URL,
108125
options: {
@@ -186,13 +203,15 @@ declare module "replicate" {
186203
};
187204

188205
predictions: {
189-
create(options: {
190-
version: string;
191-
input: object;
192-
stream?: boolean;
193-
webhook?: string;
194-
webhook_events_filter?: WebhookEventType[];
195-
}): Promise<Prediction>;
206+
create(
207+
options: {
208+
version: string;
209+
input: object;
210+
stream?: boolean;
211+
webhook?: string;
212+
webhook_events_filter?: WebhookEventType[];
213+
} & ({ version: string } | { model: string })
214+
): Promise<Prediction>;
196215
get(prediction_id: string): Promise<Prediction>;
197216
cancel(prediction_id: string): Promise<Prediction>;
198217
list(): Promise<Page<Prediction>>;

lib/predictions.js

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,16 @@
22
* Create a new prediction
33
*
44
* @param {object} options
5-
* @param {string} options.version - Required. The model version
5+
* @param {string} options.model - The model.
6+
* @param {string} options.version - The model version.
67
* @param {object} options.input - Required. An object with the model inputs
78
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output
89
* @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
910
* @param {boolean} [options.stream] - Whether to stream the prediction output. Defaults to false
1011
* @returns {Promise<object>} Resolves with the created prediction
1112
*/
1213
async function createPrediction(options) {
13-
const { stream, ...data } = options;
14+
const { model, version, stream, ...data } = options;
1415

1516
if (data.webhook) {
1617
try {
@@ -21,10 +22,20 @@ async function createPrediction(options) {
2122
}
2223
}
2324

24-
const response = await this.request("/predictions", {
25-
method: "POST",
26-
data: { ...data, stream },
27-
});
25+
let response;
26+
if (version) {
27+
response = await this.request("/predictions", {
28+
method: "POST",
29+
data: { ...data, stream, version },
30+
});
31+
} else if (model) {
32+
response = await this.request(`/models/${model}/predictions`, {
33+
method: "POST",
34+
data: { ...data, stream },
35+
});
36+
} else {
37+
throw new Error("Either model or version must be specified");
38+
}
2839

2940
return response.json();
3041
}

0 commit comments

Comments
 (0)