Skip to content

[inference provider] Add wavespeed.ai as an inference provider #1424

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions packages/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ Currently, we support the following providers:
- [Cohere](https://cohere.com)
- [Cerebras](https://cerebras.ai/)
- [Groq](https://groq.com)
- [Wavespeed.ai](https://wavespeed.ai/)

To send requests to a third-party provider, you have to pass the `provider` parameter to the inference function. The default value of the `provider` parameter is "auto", which will select the first of the providers available for the model, sorted by your preferred order in https://hf.co/settings/inference-providers.

Expand Down Expand Up @@ -96,6 +97,7 @@ Only a subset of models are supported when requesting third-party providers. You
- [Cohere supported models](https://huggingface.co/api/partners/cohere/models)
- [Cerebras supported models](https://huggingface.co/api/partners/cerebras/models)
- [Groq supported models](https://console.groq.com/docs/models)
- [Wavespeed.ai supported models](https://huggingface.co/api/partners/wavespeed-ai/models)

❗**Important note:** To be compatible, the third-party API must adhere to the "standard" shape API we expect on HF model pages for each pipeline task type.
This is not an issue for LLMs as everyone converged on the OpenAI API anyways, but can be more tricky for other tasks like "text-to-image" or "automatic-speech-recognition" where there exists no standard API. Let us know if any help is needed or if we can make things easier for you!
Expand Down
6 changes: 6 additions & 0 deletions packages/inference/src/lib/getProviderHelper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ import type {
import * as Replicate from "../providers/replicate";
import * as Sambanova from "../providers/sambanova";
import * as Together from "../providers/together";
import * as WavesppedAI from "../providers/wavespeed-ai";
import type { InferenceProvider, InferenceProviderOrPolicy, InferenceTask } from "../types";

export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask, TaskProviderHelper>>> = {
Expand Down Expand Up @@ -146,6 +147,11 @@ export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask,
conversational: new Together.TogetherConversationalTask(),
"text-generation": new Together.TogetherTextGenerationTask(),
},
"wavespeed-ai": {
"text-to-image": new WavesppedAI.WavespeedAITextToImageTask(),
"text-to-video": new WavesppedAI.WavespeedAITextToVideoTask(),
"image-to-image": new WavesppedAI.WavespeedAIImageToImageTask(),
},
};

/**
Expand Down
1 change: 1 addition & 0 deletions packages/inference/src/providers/consts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,5 @@ export const HARDCODED_MODEL_INFERENCE_MAPPING: Record<
replicate: {},
sambanova: {},
together: {},
"wavespeed-ai": {},
};
178 changes: 178 additions & 0 deletions packages/inference/src/providers/wavespeed-ai.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import { InferenceOutputError } from "../lib/InferenceOutputError";
import type { ImageToImageArgs } from "../tasks";
import type { BodyParams, HeaderParams, RequestArgs, UrlParams } from "../types";
import { delay } from "../utils/delay";
import { omit } from "../utils/omit";
import { base64FromBytes } from "../utils/base64FromBytes";
import {
TaskProviderHelper,
TextToImageTaskHelper,
TextToVideoTaskHelper,
ImageToImageTaskHelper,
} from "./providerHelper";

const WAVESPEEDAI_API_BASE_URL = "https://api.wavespeed.ai";

/**
* Common response structure for all WaveSpeed AI API responses
*/
interface WaveSpeedAICommonResponse<T> {
code: number;
message: string;
data: T;
}

/**
* Response structure for task status and results
*/
interface WaveSpeedAITaskResponse {
id: string;
model: string;
outputs: string[];
urls: {
get: string;
};
has_nsfw_contents: boolean[];
status: "created" | "processing" | "completed" | "failed";
created_at: string;
error: string;
executionTime: number;
timings: {
inference: number;
};
}

/**
* Response structure for initial task submission
*/
interface WaveSpeedAISubmitResponse {
id: string;
urls: {
get: string;
};
}

type WaveSpeedAIResponse<T = WaveSpeedAITaskResponse> = WaveSpeedAICommonResponse<T>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this type alias is needed, can we remove it?

Suggested change
type WaveSpeedAIResponse<T = WaveSpeedAITaskResponse> = WaveSpeedAICommonResponse<T>;

WaveSpeedAICommonResponse can be renamed to WaveSpeedAIResponse

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This type is needed and will be used in two places. It's uncertain whether it will be used again in the future.
It follows the DRY (Don't Repeat Yourself) principle
It provides better type safety (through default generic parameters)
It makes the code more readable and maintainable


abstract class WavespeedAITask extends TaskProviderHelper {
private accessToken: string | undefined;

constructor(url?: string) {
super("wavespeed-ai", url || WAVESPEEDAI_API_BASE_URL);
}

makeRoute(params: UrlParams): string {
return `/api/v2/${params.model}`;
}
preparePayload(params: BodyParams): Record<string, unknown> {
const payload: Record<string, unknown> = {
...omit(params.args, ["inputs", "parameters"]),
...(params.args.parameters as Record<string, unknown>),
prompt: params.args.inputs,
};
// Add LoRA support if adapter is specified in the mapping
if (params.mapping?.adapter === "lora" && params.mapping.adapterWeightsPath) {
payload.loras = [
{
path: params.mapping.adapterWeightsPath,
scale: 1, // Default scale value
},
];
}
return payload;
}

override prepareHeaders(params: HeaderParams, isBinary: boolean): Record<string, string> {
this.accessToken = params.accessToken;
const headers: Record<string, string> = { Authorization: `Bearer ${params.accessToken}` };
if (!isBinary) {
headers["Content-Type"] = "application/json";
}
return headers;
}

override async getResponse(
response: WaveSpeedAIResponse<WaveSpeedAISubmitResponse>,
url?: string,
headers?: Record<string, string>
): Promise<Blob> {
if (!headers && this.accessToken) {
headers = { Authorization: `Bearer ${this.accessToken}` };
}
if (!headers) {
throw new InferenceOutputError("Headers are required for WaveSpeed AI API calls");
}

const resultUrl = response.data.urls.get;

// Poll for results until completion
while (true) {
const resultResponse = await fetch(resultUrl, { headers });

if (!resultResponse.ok) {
throw new InferenceOutputError(`Failed to get result: ${resultResponse.statusText}`);
}

const result: WaveSpeedAIResponse = await resultResponse.json();
if (result.code !== 200) {
throw new InferenceOutputError(`API request failed with code ${result.code}: ${result.message}`);
}

const taskResult = result.data;

switch (taskResult.status) {
case "completed": {
// Get the media data from the first output URL
if (!taskResult.outputs?.[0]) {
throw new InferenceOutputError("No output URL in completed response");
}
const mediaResponse = await fetch(taskResult.outputs[0]);
if (!mediaResponse.ok) {
throw new InferenceOutputError("Failed to fetch output data");
}
return await mediaResponse.blob();
}
case "failed": {
throw new InferenceOutputError(taskResult.error || "Task failed");
}
case "processing":
case "created":
// Wait before polling again
await delay(500);
continue;

default: {
throw new InferenceOutputError(`Unknown status: ${taskResult.status}`);
}
}
}
}
}

export class WavespeedAITextToImageTask extends WavespeedAITask implements TextToImageTaskHelper {
constructor() {
super(WAVESPEEDAI_API_BASE_URL);
}
}

export class WavespeedAITextToVideoTask extends WavespeedAITask implements TextToVideoTaskHelper {
constructor() {
super(WAVESPEEDAI_API_BASE_URL);
}
}

export class WavespeedAIImageToImageTask extends WavespeedAITask implements ImageToImageTaskHelper {
constructor() {
super(WAVESPEEDAI_API_BASE_URL);
}

async preparePayloadAsync(args: ImageToImageArgs): Promise<RequestArgs> {
return {
...args,
inputs: args.parameters?.prompt,
image: base64FromBytes(
new Uint8Array(args.inputs instanceof ArrayBuffer ? args.inputs : await (args.inputs as Blob).arrayBuffer())
),
};
}
}
1 change: 1 addition & 0 deletions packages/inference/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ export const INFERENCE_PROVIDERS = [
"replicate",
"sambanova",
"together",
"wavespeed-ai",
] as const;

export const PROVIDERS_OR_POLICIES = [...INFERENCE_PROVIDERS, "auto"] as const;
Expand Down
109 changes: 109 additions & 0 deletions packages/inference/test/InferenceClient.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2023,4 +2023,113 @@ describe.skip("InferenceClient", () => {
},
TIMEOUT
);
describe.concurrent(
"Wavespeed AI",
() => {
const client = new InferenceClient(env.HF_WAVESPEED_KEY ?? "dummy");

HARDCODED_MODEL_INFERENCE_MAPPING["wavespeed-ai"] = {
"wavespeed-ai/flux-schnell": {
hfModelId: "wavespeed-ai/flux-schnell",
providerId: "wavespeed-ai/flux-schnell",
status: "live",
task: "text-to-image",
},
"wavespeed-ai/wan-2.1/t2v-480p": {
hfModelId: "wavespeed-ai/wan-2.1/t2v-480p",
providerId: "wavespeed-ai/wan-2.1/t2v-480p",
status: "live",
task: "text-to-video",
},
"wavespeed-ai/hidream-e1-full": {
hfModelId: "wavespeed-ai/hidream-e1-full",
providerId: "wavespeed-ai/hidream-e1-full",
status: "live",
task: "image-to-image",
},
"wavespeed-ai/flux-dev-lora": {
hfModelId: "wavespeed-ai/flux-dev-lora",
providerId: "wavespeed-ai/flux-dev-lora",
status: "live",
task: "text-to-image",
adapter: "lora",
adapterWeightsPath:
"https://d32s1zkpjdc4b1.cloudfront.net/predictions/599f3739f5354afc8a76a12042736bfd/1.safetensors",
},
"wavespeed-ai/flux-dev-lora-ultra-fast": {
hfModelId: "wavespeed-ai/flux-dev-lora-ultra-fast",
providerId: "wavespeed-ai/flux-dev-lora-ultra-fast",
status: "live",
task: "text-to-image",
adapter: "lora",
adapterWeightsPath: "linoyts/yarn_art_Flux_LoRA",
},
};

it(`textToImage - wavespeed-ai/flux-schnell`, async () => {
const res = await client.textToImage({
model: "wavespeed-ai/flux-schnell",
provider: "wavespeed-ai",
inputs:
"Cute boy with a hat, exploring nature, holding a telescope, backpack, surrounded by flowers, cartoon style, vibrant colors.",
});
expect(res).toBeInstanceOf(Blob);
});

it(`textToImage - wavespeed-ai/flux-dev-lora`, async () => {
const res = await client.textToImage({
model: "wavespeed-ai/flux-dev-lora",
provider: "wavespeed-ai",
inputs:
"Cute boy with a hat, exploring nature, holding a telescope, backpack, surrounded by flowers, cartoon style, vibrant colors.",
});
expect(res).toBeInstanceOf(Blob);
});

it(`textToImage - wavespeed-ai/flux-dev-lora-ultra-fast`, async () => {
const res = await client.textToImage({
model: "wavespeed-ai/flux-dev-lora-ultra-fast",
provider: "wavespeed-ai",
inputs:
"Cute boy with a hat, exploring nature, holding a telescope, backpack, surrounded by flowers, cartoon style, vibrant colors.",
});
expect(res).toBeInstanceOf(Blob);
});

it(`textToVideo - wavespeed-ai/wan-2.1/t2v-480p`, async () => {
const res = await client.textToVideo({
model: "wavespeed-ai/wan-2.1/t2v-480p",
provider: "wavespeed-ai",
inputs:
"A cool street dancer, wearing a baggy hoodie and hip-hop pants, dancing in front of a graffiti wall, night neon background, quick camera cuts, urban trends.",
parameters: {
guidance_scale: 5,
num_inference_steps: 30,
seed: -1,
},
duration: 5,
enable_safety_checker: true,
flow_shift: 2.9,
size: "480*832",
});
expect(res).toBeInstanceOf(Blob);
});

it(`imageToImage - wavespeed-ai/hidream-e1-full`, async () => {
const res = await client.imageToImage({
model: "wavespeed-ai/hidream-e1-full",
provider: "wavespeed-ai",
inputs: new Blob([readTestFile("cheetah.png")], { type: "image / png" }),
parameters: {
prompt: "The leopard chases its prey",
guidance_scale: 5,
num_inference_steps: 30,
seed: -1,
},
});
expect(res).toBeInstanceOf(Blob);
});
},
60000 * 5
);
});