diff --git a/index.d.ts b/index.d.ts index eabcc9b..45a2430 100644 --- a/index.d.ts +++ b/index.d.ts @@ -93,9 +93,9 @@ declare module "replicate" { model: string; version: string; input: object; - output?: any; + output?: any; // TODO: this should be `unknown` source: "api" | "web"; - error?: any; + error?: unknown; logs?: string; metrics?: { predict_time?: number; @@ -156,7 +156,7 @@ declare module "replicate" { identifier: `${string}/${string}` | `${string}/${string}:${string}`, options: { input: object; - wait?: boolean | number | { mode?: "poll"; interval?: number }; + wait?: boolean | number | { interval?: number }; webhook?: string; webhook_events_filter?: WebhookEventType[]; signal?: AbortSignal; @@ -215,7 +215,7 @@ declare module "replicate" { stream?: boolean; webhook?: string; webhook_events_filter?: WebhookEventType[]; - block?: boolean; + wait?: boolean | number | { mode?: "poll"; interval?: number }; } ): Promise; }; @@ -304,7 +304,7 @@ declare module "replicate" { stream?: boolean; webhook?: string; webhook_events_filter?: WebhookEventType[]; - block?: boolean; + wait?: boolean | number | { mode?: "poll"; interval?: number }; } & ({ version: string } | { model: string }) ): Promise; get(prediction_id: string): Promise; diff --git a/index.js b/index.js index 712bc59..ac4b815 100644 --- a/index.js +++ b/index.js @@ -147,19 +147,20 @@ class Replicate { const { wait, signal, ...data } = options; const identifier = ModelVersionIdentifier.parse(ref); + const isBlocking = typeof wait === "boolean" || typeof wait === "number"; let prediction; if (identifier.version) { prediction = await this.predictions.create({ ...data, version: identifier.version, - wait: wait, + wait: isBlocking ? wait : false, }); } else if (identifier.owner && identifier.name) { prediction = await this.predictions.create({ ...data, model: `${identifier.owner}/${identifier.name}`, - wait: wait, + wait: isBlocking ? wait : false, }); } else { throw new Error("Invalid model version identifier"); @@ -170,23 +171,26 @@ class Replicate { progress(prediction); } - prediction = await this.wait( - prediction, - wait || {}, - async (updatedPrediction) => { - // Call progress callback with the updated prediction object - if (progress) { - progress(updatedPrediction); + const isDone = isBlocking && prediction.status !== "starting"; + if (!isDone) { + prediction = await this.wait( + prediction, + isBlocking ? {} : wait, + async (updatedPrediction) => { + // Call progress callback with the updated prediction object + if (progress) { + progress(updatedPrediction); + } + + // We handle the cancel later in the function. + if (signal && signal.aborted) { + return true; // stop polling + } + + return false; // continue polling } - - // We handle the cancel later in the function. - if (signal && signal.aborted) { - return true; // stop polling - } - - return false; // continue polling - } - ); + ); + } if (signal && signal.aborted) { prediction = await this.predictions.cancel(prediction.id);