diff --git a/packages/inference/README.md b/packages/inference/README.md index ad4fcb879..a525a3ec2 100644 --- a/packages/inference/README.md +++ b/packages/inference/README.md @@ -63,6 +63,7 @@ Currently, we support the following providers: - [Cohere](https://cohere.com) - [Cerebras](https://cerebras.ai/) - [Groq](https://groq.com) +- [Wavespeed.ai](https://wavespeed.ai/) To send requests to a third-party provider, you have to pass the `provider` parameter to the inference function. The default value of the `provider` parameter is "auto", which will select the first of the providers available for the model, sorted by your preferred order in https://hf.co/settings/inference-providers. @@ -96,6 +97,7 @@ Only a subset of models are supported when requesting third-party providers. You - [Cohere supported models](https://huggingface.co/api/partners/cohere/models) - [Cerebras supported models](https://huggingface.co/api/partners/cerebras/models) - [Groq supported models](https://console.groq.com/docs/models) +- [Wavespeed.ai supported models](https://huggingface.co/api/partners/wavespeed-ai/models) ❗**Important note:** To be compatible, the third-party API must adhere to the "standard" shape API we expect on HF model pages for each pipeline task type. This is not an issue for LLMs as everyone converged on the OpenAI API anyways, but can be more tricky for other tasks like "text-to-image" or "automatic-speech-recognition" where there exists no standard API. Let us know if any help is needed or if we can make things easier for you! diff --git a/packages/inference/src/lib/getProviderHelper.ts b/packages/inference/src/lib/getProviderHelper.ts index a96595b72..c458609e0 100644 --- a/packages/inference/src/lib/getProviderHelper.ts +++ b/packages/inference/src/lib/getProviderHelper.ts @@ -47,6 +47,7 @@ import type { import * as Replicate from "../providers/replicate"; import * as Sambanova from "../providers/sambanova"; import * as Together from "../providers/together"; +import * as WavesppedAI from "../providers/wavespeed-ai"; import type { InferenceProvider, InferenceProviderOrPolicy, InferenceTask } from "../types"; export const PROVIDERS: Record>> = { @@ -146,6 +147,11 @@ export const PROVIDERS: Record { + code: number; + message: string; + data: T; +} + +/** + * Response structure for task status and results + */ +interface WaveSpeedAITaskResponse { + id: string; + model: string; + outputs: string[]; + urls: { + get: string; + }; + has_nsfw_contents: boolean[]; + status: "created" | "processing" | "completed" | "failed"; + created_at: string; + error: string; + executionTime: number; + timings: { + inference: number; + }; +} + +/** + * Response structure for initial task submission + */ +interface WaveSpeedAISubmitResponse { + id: string; + urls: { + get: string; + }; +} + +type WaveSpeedAIResponse = WaveSpeedAICommonResponse; + +abstract class WavespeedAITask extends TaskProviderHelper { + private accessToken: string | undefined; + + constructor(url?: string) { + super("wavespeed-ai", url || WAVESPEEDAI_API_BASE_URL); + } + + makeRoute(params: UrlParams): string { + return `/api/v2/${params.model}`; + } + preparePayload(params: BodyParams): Record { + const payload: Record = { + ...omit(params.args, ["inputs", "parameters"]), + ...(params.args.parameters as Record), + prompt: params.args.inputs, + }; + // Add LoRA support if adapter is specified in the mapping + if (params.mapping?.adapter === "lora" && params.mapping.adapterWeightsPath) { + payload.loras = [ + { + path: params.mapping.adapterWeightsPath, + scale: 1, // Default scale value + }, + ]; + } + return payload; + } + + override prepareHeaders(params: HeaderParams, isBinary: boolean): Record { + this.accessToken = params.accessToken; + const headers: Record = { Authorization: `Bearer ${params.accessToken}` }; + if (!isBinary) { + headers["Content-Type"] = "application/json"; + } + return headers; + } + + override async getResponse( + response: WaveSpeedAIResponse, + url?: string, + headers?: Record + ): Promise { + if (!headers && this.accessToken) { + headers = { Authorization: `Bearer ${this.accessToken}` }; + } + if (!headers) { + throw new InferenceOutputError("Headers are required for WaveSpeed AI API calls"); + } + + const resultUrl = response.data.urls.get; + + // Poll for results until completion + while (true) { + const resultResponse = await fetch(resultUrl, { headers }); + + if (!resultResponse.ok) { + throw new InferenceOutputError(`Failed to get result: ${resultResponse.statusText}`); + } + + const result: WaveSpeedAIResponse = await resultResponse.json(); + if (result.code !== 200) { + throw new InferenceOutputError(`API request failed with code ${result.code}: ${result.message}`); + } + + const taskResult = result.data; + + switch (taskResult.status) { + case "completed": { + // Get the media data from the first output URL + if (!taskResult.outputs?.[0]) { + throw new InferenceOutputError("No output URL in completed response"); + } + const mediaResponse = await fetch(taskResult.outputs[0]); + if (!mediaResponse.ok) { + throw new InferenceOutputError("Failed to fetch output data"); + } + return await mediaResponse.blob(); + } + case "failed": { + throw new InferenceOutputError(taskResult.error || "Task failed"); + } + case "processing": + case "created": + // Wait before polling again + await delay(500); + continue; + + default: { + throw new InferenceOutputError(`Unknown status: ${taskResult.status}`); + } + } + } + } +} + +export class WavespeedAITextToImageTask extends WavespeedAITask implements TextToImageTaskHelper { + constructor() { + super(WAVESPEEDAI_API_BASE_URL); + } +} + +export class WavespeedAITextToVideoTask extends WavespeedAITask implements TextToVideoTaskHelper { + constructor() { + super(WAVESPEEDAI_API_BASE_URL); + } +} + +export class WavespeedAIImageToImageTask extends WavespeedAITask implements ImageToImageTaskHelper { + constructor() { + super(WAVESPEEDAI_API_BASE_URL); + } + + async preparePayloadAsync(args: ImageToImageArgs): Promise { + return { + ...args, + inputs: args.parameters?.prompt, + image: base64FromBytes( + new Uint8Array(args.inputs instanceof ArrayBuffer ? args.inputs : await (args.inputs as Blob).arrayBuffer()) + ), + }; + } +} diff --git a/packages/inference/src/types.ts b/packages/inference/src/types.ts index a6df1ba3e..00292854e 100644 --- a/packages/inference/src/types.ts +++ b/packages/inference/src/types.ts @@ -55,6 +55,7 @@ export const INFERENCE_PROVIDERS = [ "replicate", "sambanova", "together", + "wavespeed-ai", ] as const; export const PROVIDERS_OR_POLICIES = [...INFERENCE_PROVIDERS, "auto"] as const; diff --git a/packages/inference/test/InferenceClient.spec.ts b/packages/inference/test/InferenceClient.spec.ts index 3365f38cf..244de983d 100644 --- a/packages/inference/test/InferenceClient.spec.ts +++ b/packages/inference/test/InferenceClient.spec.ts @@ -2023,4 +2023,113 @@ describe.skip("InferenceClient", () => { }, TIMEOUT ); + describe.concurrent( + "Wavespeed AI", + () => { + const client = new InferenceClient(env.HF_WAVESPEED_KEY ?? "dummy"); + + HARDCODED_MODEL_INFERENCE_MAPPING["wavespeed-ai"] = { + "wavespeed-ai/flux-schnell": { + hfModelId: "wavespeed-ai/flux-schnell", + providerId: "wavespeed-ai/flux-schnell", + status: "live", + task: "text-to-image", + }, + "wavespeed-ai/wan-2.1/t2v-480p": { + hfModelId: "wavespeed-ai/wan-2.1/t2v-480p", + providerId: "wavespeed-ai/wan-2.1/t2v-480p", + status: "live", + task: "text-to-video", + }, + "wavespeed-ai/hidream-e1-full": { + hfModelId: "wavespeed-ai/hidream-e1-full", + providerId: "wavespeed-ai/hidream-e1-full", + status: "live", + task: "image-to-image", + }, + "wavespeed-ai/flux-dev-lora": { + hfModelId: "wavespeed-ai/flux-dev-lora", + providerId: "wavespeed-ai/flux-dev-lora", + status: "live", + task: "text-to-image", + adapter: "lora", + adapterWeightsPath: + "https://d32s1zkpjdc4b1.cloudfront.net/predictions/599f3739f5354afc8a76a12042736bfd/1.safetensors", + }, + "wavespeed-ai/flux-dev-lora-ultra-fast": { + hfModelId: "wavespeed-ai/flux-dev-lora-ultra-fast", + providerId: "wavespeed-ai/flux-dev-lora-ultra-fast", + status: "live", + task: "text-to-image", + adapter: "lora", + adapterWeightsPath: "linoyts/yarn_art_Flux_LoRA", + }, + }; + + it(`textToImage - wavespeed-ai/flux-schnell`, async () => { + const res = await client.textToImage({ + model: "wavespeed-ai/flux-schnell", + provider: "wavespeed-ai", + inputs: + "Cute boy with a hat, exploring nature, holding a telescope, backpack, surrounded by flowers, cartoon style, vibrant colors.", + }); + expect(res).toBeInstanceOf(Blob); + }); + + it(`textToImage - wavespeed-ai/flux-dev-lora`, async () => { + const res = await client.textToImage({ + model: "wavespeed-ai/flux-dev-lora", + provider: "wavespeed-ai", + inputs: + "Cute boy with a hat, exploring nature, holding a telescope, backpack, surrounded by flowers, cartoon style, vibrant colors.", + }); + expect(res).toBeInstanceOf(Blob); + }); + + it(`textToImage - wavespeed-ai/flux-dev-lora-ultra-fast`, async () => { + const res = await client.textToImage({ + model: "wavespeed-ai/flux-dev-lora-ultra-fast", + provider: "wavespeed-ai", + inputs: + "Cute boy with a hat, exploring nature, holding a telescope, backpack, surrounded by flowers, cartoon style, vibrant colors.", + }); + expect(res).toBeInstanceOf(Blob); + }); + + it(`textToVideo - wavespeed-ai/wan-2.1/t2v-480p`, async () => { + const res = await client.textToVideo({ + model: "wavespeed-ai/wan-2.1/t2v-480p", + provider: "wavespeed-ai", + inputs: + "A cool street dancer, wearing a baggy hoodie and hip-hop pants, dancing in front of a graffiti wall, night neon background, quick camera cuts, urban trends.", + parameters: { + guidance_scale: 5, + num_inference_steps: 30, + seed: -1, + }, + duration: 5, + enable_safety_checker: true, + flow_shift: 2.9, + size: "480*832", + }); + expect(res).toBeInstanceOf(Blob); + }); + + it(`imageToImage - wavespeed-ai/hidream-e1-full`, async () => { + const res = await client.imageToImage({ + model: "wavespeed-ai/hidream-e1-full", + provider: "wavespeed-ai", + inputs: new Blob([readTestFile("cheetah.png")], { type: "image / png" }), + parameters: { + prompt: "The leopard chases its prey", + guidance_scale: 5, + num_inference_steps: 30, + seed: -1, + }, + }); + expect(res).toBeInstanceOf(Blob); + }); + }, + 60000 * 5 + ); });