diff --git a/biome.json b/biome.json index ecb665f..094cf0e 100644 --- a/biome.json +++ b/biome.json @@ -1,7 +1,11 @@ { "$schema": "https://biomejs.dev/schemas/1.0.0/schema.json", "files": { - "ignore": [".wrangler", "vendor/*"] + "ignore": [ + ".wrangler", + "node_modules", + "vendor/*" + ] }, "formatter": { "indentStyle": "space", diff --git a/index.d.ts b/index.d.ts index 31a2325..1ef9e89 100644 --- a/index.d.ts +++ b/index.d.ts @@ -39,6 +39,21 @@ declare module "replicate" { }; } + export interface FileObject { + id: string; + name: string; + content_type: string; + size: number; + etag: string; + checksum: string; + metadata: Record; + created_at: string; + expires_at: string | null; + urls: { + get: string; + }; + } + export interface Hardware { sku: string; name: string; @@ -93,6 +108,8 @@ declare module "replicate" { export type Training = Prediction; + export type FileEncodingStrategy = "default" | "upload" | "data-uri"; + export interface Page { previous?: string; next?: string; @@ -119,12 +136,14 @@ declare module "replicate" { input: Request | string, init?: RequestInit ) => Promise; + fileEncodingStrategy?: FileEncodingStrategy; }); auth: string; userAgent?: string; baseUrl?: string; fetch: (input: Request | string, init?: RequestInit) => Promise; + fileEncodingStrategy: FileEncodingStrategy; run( identifier: `${string}/${string}` | `${string}/${string}:${string}`, diff --git a/index.js b/index.js index f4c0e2c..3ef0d3d 100644 --- a/index.js +++ b/index.js @@ -46,6 +46,7 @@ class Replicate { * @param {string} options.userAgent - Identifier of your app * @param {string} [options.baseUrl] - Defaults to https://api.replicate.com/v1 * @param {Function} [options.fetch] - Fetch function to use. Defaults to `globalThis.fetch` + * @param {"default" | "upload" | "data-uri"} [options.fileEncodingStrategy] - Determines the file encoding strategy to use */ constructor(options = {}) { this.auth = @@ -55,6 +56,7 @@ class Replicate { options.userAgent || `replicate-javascript/${packageJSON.version}`; this.baseUrl = options.baseUrl || "https://api.replicate.com/v1"; this.fetch = options.fetch || globalThis.fetch; + this.fileEncodingStrategy = options.fileEncodingStrategy ?? "default"; this.accounts = { current: accounts.current.bind(this), @@ -230,10 +232,17 @@ class Replicate { } } + let body = undefined; + if (data instanceof FormData) { + body = data; + } else if (data) { + body = JSON.stringify(data); + } + const init = { method, headers, - body: data ? JSON.stringify(data) : undefined, + body, }; const shouldRetry = diff --git a/index.test.ts b/index.test.ts index 53737e0..7502969 100644 --- a/index.test.ts +++ b/index.test.ts @@ -222,13 +222,13 @@ describe("Replicate client", () => { expect(prediction.id).toBe("ufawqhfynnddngldkgtslldrkq"); }); - test.each([ + const fileTestCases = [ // Skip test case if File type is not available ...(typeof File !== "undefined" ? [ { type: "file", - value: new File(["hello world"], "hello.txt", { + value: new File(["hello world"], "file_hello.txt", { type: "text/plain", }), expected: "data:text/plain;base64,aGVsbG8gd29ybGQ=", @@ -245,11 +245,67 @@ describe("Replicate client", () => { value: Buffer.from("hello world"), expected: "data:application/octet-stream;base64,aGVsbG8gd29ybGQ=", }, - ])( + ]; + + test.each(fileTestCases)( + "converts a $type input into a Replicate file URL", + async ({ value: data, type }) => { + const mockedFetch = jest.spyOn(client, "fetch"); + + nock(BASE_URL) + .post("/files") + .matchHeader("Content-Type", "multipart/form-data") + .reply(201, { + urls: { + get: "https://replicate.com/api/files/123", + }, + }) + .post( + "/predictions", + (body) => body.input.data === "https://replicate.com/api/files/123" + ) + .reply(201, (_uri: string, body: Record) => { + return body; + }); + + const prediction = await client.predictions.create({ + version: + "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + input: { + prompt: "Tell me a story", + data, + }, + stream: true, + }); + + expect(client.fetch).toHaveBeenCalledWith( + new URL("https://api.replicate.com/v1/files"), + { + method: "POST", + body: expect.any(FormData), + headers: expect.objectContaining({ + "Content-Type": "multipart/form-data", + }), + } + ); + const form = mockedFetch.mock.calls[0][1]?.body as FormData; + // @ts-ignore + expect(form?.get("content")?.name).toMatch(new RegExp(`^${type}_`)); + + expect(prediction.input).toEqual({ + prompt: "Tell me a story", + data: "https://replicate.com/api/files/123", + }); + } + ); + + test.each(fileTestCases)( "converts a $type input into a base64 encoded string", async ({ value: data, expected }) => { let actual: Record | undefined; nock(BASE_URL) + .post("/files") + .reply(503, "Service Unavailable") .post("/predictions") .reply(201, (_uri: string, body: Record) => { actual = body; diff --git a/lib/deployments.js b/lib/deployments.js index 4f6f3c6..27a2f6a 100644 --- a/lib/deployments.js +++ b/lib/deployments.js @@ -30,7 +30,11 @@ async function createPrediction(deployment_owner, deployment_name, options) { method: "POST", data: { ...data, - input: await transformFileInputs(input), + input: await transformFileInputs( + this, + input, + this.fileEncodingStrategy + ), stream, }, } diff --git a/lib/files.js b/lib/files.js new file mode 100644 index 0000000..f6620e9 --- /dev/null +++ b/lib/files.js @@ -0,0 +1,86 @@ +/** + * Create a file + * + * @param {object} file - Required. The file object. + * @param {object} metadata - Optional. User-provided metadata associated with the file. + * @returns {Promise} - Resolves with the file data + */ +async function createFile(file, metadata = {}) { + const form = new FormData(); + + let filename; + let blob; + if (file instanceof Blob) { + filename = file.name || `blob_${Date.now()}`; + blob = file; + } else if (Buffer.isBuffer(file)) { + filename = `buffer_${Date.now()}`; + blob = new Blob(file, { type: "application/octet-stream" }); + } else { + throw new Error("Invalid file argument, must be a Blob, File or Buffer"); + } + + form.append("content", blob, filename); + form.append( + "metadata", + new Blob([JSON.stringify(metadata)], { type: "application/json" }) + ); + + const response = await this.request("/files", { + method: "POST", + data: form, + headers: { + "Content-Type": "multipart/form-data", + }, + }); + + return response.json(); +} + +/** + * List all files + * + * @returns {Promise} - Resolves with the files data + */ +async function listFiles() { + const response = await this.request("/files", { + method: "GET", + }); + + return response.json(); +} + +/** + * Get a file + * + * @param {string} file_id - Required. The ID of the file. + * @returns {Promise} - Resolves with the file data + */ +async function getFile(file_id) { + const response = await this.request(`/files/${file_id}`, { + method: "GET", + }); + + return response.json(); +} + +/** + * Delete a file + * + * @param {string} file_id - Required. The ID of the file. + * @returns {Promise} - Resolves with the deletion confirmation + */ +async function deleteFile(file_id) { + const response = await this.request(`/files/${file_id}`, { + method: "DELETE", + }); + + return response.json(); +} + +module.exports = { + create: createFile, + list: listFiles, + get: getFile, + delete: deleteFile, +}; diff --git a/lib/predictions.js b/lib/predictions.js index 5b0370e..c290d40 100644 --- a/lib/predictions.js +++ b/lib/predictions.js @@ -30,7 +30,11 @@ async function createPrediction(options) { method: "POST", data: { ...data, - input: await transformFileInputs(input), + input: await transformFileInputs( + this, + input, + this.fileEncodingStrategy + ), version, stream, }, @@ -40,7 +44,11 @@ async function createPrediction(options) { method: "POST", data: { ...data, - input: await transformFileInputs(input), + input: await transformFileInputs( + this, + input, + this.fileEncodingStrategy + ), stream, }, }); diff --git a/lib/util.js b/lib/util.js index 68b1d9d..e164899 100644 --- a/lib/util.js +++ b/lib/util.js @@ -1,4 +1,5 @@ const ApiError = require("./error"); +const { create: createFile } = require("./files"); /** * @see {@link validateWebhook} @@ -209,12 +210,58 @@ async function withAutomaticRetries(request, options = {}) { } attempts += 1; } - /* eslint-enable no-await-in-loop */ } while (attempts < maxRetries); return request(); } +/** + * Walks the inputs and, for any File or Blob, tries to upload it to Replicate + * and replaces the input with the URL of the uploaded file. + * + * @param {Replicate} client - The client used to upload the file + * @param {object} inputs - The inputs to transform + * @param {"default" | "upload" | "data-uri"} strategy - Whether to upload files to Replicate, encode as dataURIs or try both. + * @returns {object} - The transformed inputs + * @throws {ApiError} If the request to upload the file fails + */ +async function transformFileInputs(client, inputs, strategy) { + switch (strategy) { + case "data-uri": + return await transformFileInputsToBase64EncodedDataURIs(client, inputs); + case "upload": + return await transformFileInputsToReplicateFileURLs(client, inputs); + case "default": + try { + return await transformFileInputsToReplicateFileURLs(client, inputs); + } catch (error) { + return await transformFileInputsToBase64EncodedDataURIs(inputs); + } + default: + throw new Error(`Unexpected file upload strategy: ${strategy}`); + } +} + +/** + * Walks the inputs and, for any File or Blob, tries to upload it to Replicate + * and replaces the input with the URL of the uploaded file. + * + * @param {Replicate} client - The client used to upload the file + * @param {object} inputs - The inputs to transform + * @returns {object} - The transformed inputs + * @throws {ApiError} If the request to upload the file fails + */ +async function transformFileInputsToReplicateFileURLs(client, inputs) { + return await transform(inputs, async (value) => { + if (value instanceof Blob || value instanceof Buffer) { + const file = await createFile.call(client, value); + return file.urls.get; + } + + return value; + }); +} + const MAX_DATA_URI_SIZE = 10_000_000; /** @@ -225,9 +272,9 @@ const MAX_DATA_URI_SIZE = 10_000_000; * @returns {object} - The transformed inputs * @throws {Error} If the size of inputs exceeds a given threshould set by MAX_DATA_URI_SIZE */ -async function transformFileInputs(inputs) { +async function transformFileInputsToBase64EncodedDataURIs(inputs) { let totalBytes = 0; - const result = await transform(inputs, async (value) => { + return await transform(inputs, async (value) => { let buffer; let mime; @@ -258,8 +305,6 @@ async function transformFileInputs(inputs) { return `data:${mime};base64,${data}`; }); - - return result; } // Walk a JavaScript object and transform the leaf values.