diff --git a/index.d.ts b/index.d.ts index 69e651b..8dc998a 100644 --- a/index.d.ts +++ b/index.d.ts @@ -280,4 +280,10 @@ declare module "replicate" { }, secret: string ): boolean; + + export function parseProgressFromLogs(logs: Prediction | string): { + percentage: number; + current: number; + total: number; + } | null; } diff --git a/index.js b/index.js index 83b9888..24376fe 100644 --- a/index.js +++ b/index.js @@ -1,7 +1,11 @@ const ApiError = require("./lib/error"); const ModelVersionIdentifier = require("./lib/identifier"); const { Stream } = require("./lib/stream"); -const { withAutomaticRetries, validateWebhook } = require("./lib/util"); +const { + withAutomaticRetries, + validateWebhook, + parseProgressFromLogs, +} = require("./lib/util"); const accounts = require("./lib/accounts"); const collections = require("./lib/collections"); @@ -375,3 +379,4 @@ class Replicate { module.exports = Replicate; module.exports.validateWebhook = validateWebhook; +module.exports.parseProgressFromLogs = parseProgressFromLogs; diff --git a/index.test.ts b/index.test.ts index f00a7e6..97abc6f 100644 --- a/index.test.ts +++ b/index.test.ts @@ -4,6 +4,7 @@ import Replicate, { Model, Prediction, validateWebhook, + parseProgressFromLogs, } from "replicate"; import nock from "nock"; import fetch from "cross-fetch"; @@ -888,29 +889,55 @@ describe("Replicate client", () => { }); describe("run", () => { - test("Calls the correct API routes for a version", async () => { - const firstPollingRequest = true; - + test("Calls the correct API routes", async () => { nock(BASE_URL) .post("/predictions") .reply(201, { id: "ufawqhfynnddngldkgtslldrkq", status: "starting", + logs: null, }) .get("/predictions/ufawqhfynnddngldkgtslldrkq") - .twice() .reply(200, { id: "ufawqhfynnddngldkgtslldrkq", status: "processing", + logs: [ + "Using seed: 12345", + "0%| | 0/5 [00:00 { input: { text: "Hello, world!" }, wait: { interval: 1 }, }, - progress + (prediction) => { + const progress = parseProgressFromLogs(prediction); + callback(prediction, progress); + } ); expect(output).toBe("Goodbye!"); - expect(progress).toHaveBeenNthCalledWith(1, { - id: "ufawqhfynnddngldkgtslldrkq", - status: "starting", - }); + expect(callback).toHaveBeenNthCalledWith( + 1, + { + id: "ufawqhfynnddngldkgtslldrkq", + status: "starting", + logs: null, + }, + null + ); - expect(progress).toHaveBeenNthCalledWith(2, { - id: "ufawqhfynnddngldkgtslldrkq", - status: "processing", - }); + expect(callback).toHaveBeenNthCalledWith( + 2, + { + id: "ufawqhfynnddngldkgtslldrkq", + status: "processing", + logs: expect.any(String), + }, + { + percentage: 0.4, + current: 2, + total: 5, + } + ); - expect(progress).toHaveBeenNthCalledWith(3, { - id: "ufawqhfynnddngldkgtslldrkq", - status: "processing", - }); + expect(callback).toHaveBeenNthCalledWith( + 3, + { + id: "ufawqhfynnddngldkgtslldrkq", + status: "processing", + logs: expect.any(String), + }, + { + percentage: 0.8, + current: 4, + total: 5, + } + ); - expect(progress).toHaveBeenNthCalledWith(4, { - id: "ufawqhfynnddngldkgtslldrkq", - status: "succeeded", - output: "Goodbye!", - }); + expect(callback).toHaveBeenNthCalledWith( + 4, + { + id: "ufawqhfynnddngldkgtslldrkq", + status: "succeeded", + logs: expect.any(String), + output: "Goodbye!", + }, + { + percentage: 1.0, + current: 5, + total: 5, + } + ); - expect(progress).toHaveBeenCalledTimes(4); + expect(callback).toHaveBeenCalledTimes(4); }); test("Calls the correct API routes for a model", async () => { diff --git a/lib/util.js b/lib/util.js index 48d7563..949bafa 100644 --- a/lib/util.js +++ b/lib/util.js @@ -246,4 +246,55 @@ function isPlainObject(value) { ); } -module.exports = { transformFileInputs, validateWebhook, withAutomaticRetries }; +/** + * Parse progress from prediction logs. + * + * This function supports log statements in the following format, + * which are generated by https://github.com/tqdm/tqdm and similar libraries: + * + * ``` + * 76%|████████████████████████████ | 7568/10000 [00:33<00:10, 229.00it/s] + * ``` + * + * @example + * const progress = parseProgressFromLogs("76%|████████████████████████████ | 7568/10000 [00:33<00:10, 229.00it/s]"); + * console.log(progress); + * // { + * // percentage: 0.76, + * // current: 7568, + * // total: 10000, + * // } + * + * @param {object|string} input - A prediction object or string. + * @returns {(object|null)} - An object with the percentage, current, and total, or null if no progress can be parsed. + */ +function parseProgressFromLogs(input) { + const logs = typeof input === "object" && input.logs ? input.logs : input; + if (!logs || typeof logs !== "string") { + return null; + } + + const pattern = /^\s*(\d+)%\s*\|.+?\|\s*(\d+)\/(\d+)/; + const lines = logs.split("\n").reverse(); + + for (const line of lines) { + const matches = line.match(pattern); + + if (matches && matches.length === 4) { + return { + percentage: parseInt(matches[1], 10) / 100, + current: parseInt(matches[2], 10), + total: parseInt(matches[3], 10), + }; + } + } + + return null; +} + +module.exports = { + transformFileInputs, + validateWebhook, + withAutomaticRetries, + parseProgressFromLogs, +};