-
Notifications
You must be signed in to change notification settings - Fork 386
[inference provider] Add wavespeed.ai as an inference provider #1424
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
arabot777
wants to merge
14
commits into
huggingface:main
Choose a base branch
from
arabot777:feat/wavespeedai
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
a4d8504
add wavespeed.ai as an inference provider
arabot777 686931e
delete debug log
arabot777 0e71b88
Merge branch 'main' into feat/wavespeedai
arabot777 4461225
Merge branch 'main' into feat/wavespeedai
arabot777 e0bf580
Merge branch 'main' into feat/wavespeedai
arabot777 07af35f
Merge branch 'main' into feat/wavespeedai
arabot777 fa3afa4
Merge branch 'main' into feat/wavespeedai
arabot777 214ff99
support lora
arabot777 47c64c6
code review
arabot777 7270c5c
Merge branch 'main' into feat/wavespeedai
arabot777 ba35791
code review
arabot777 ca35eab
Merge branch 'main' into feat/wavespeedai
arabot777 80d4640
delete unused import
arabot777 77be0c6
Merge branch 'main' into feat/wavespeedai
arabot777 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
import { InferenceOutputError } from "../lib/InferenceOutputError"; | ||
import type { ImageToImageArgs } from "../tasks"; | ||
import type { BodyParams, HeaderParams, RequestArgs, UrlParams } from "../types"; | ||
import { delay } from "../utils/delay"; | ||
import { omit } from "../utils/omit"; | ||
import { base64FromBytes } from "../utils/base64FromBytes"; | ||
import { | ||
TaskProviderHelper, | ||
TextToImageTaskHelper, | ||
TextToVideoTaskHelper, | ||
ImageToImageTaskHelper, | ||
} from "./providerHelper"; | ||
|
||
const WAVESPEEDAI_API_BASE_URL = "https://api.wavespeed.ai"; | ||
|
||
/** | ||
* Common response structure for all WaveSpeed AI API responses | ||
*/ | ||
interface WaveSpeedAICommonResponse<T> { | ||
code: number; | ||
message: string; | ||
data: T; | ||
} | ||
|
||
/** | ||
* Response structure for task status and results | ||
*/ | ||
interface WaveSpeedAITaskResponse { | ||
id: string; | ||
model: string; | ||
outputs: string[]; | ||
urls: { | ||
get: string; | ||
}; | ||
has_nsfw_contents: boolean[]; | ||
status: "created" | "processing" | "completed" | "failed"; | ||
created_at: string; | ||
error: string; | ||
executionTime: number; | ||
timings: { | ||
inference: number; | ||
}; | ||
} | ||
|
||
/** | ||
* Response structure for initial task submission | ||
*/ | ||
interface WaveSpeedAISubmitResponse { | ||
id: string; | ||
urls: { | ||
get: string; | ||
}; | ||
} | ||
|
||
type WaveSpeedAIResponse<T = WaveSpeedAITaskResponse> = WaveSpeedAICommonResponse<T>; | ||
|
||
abstract class WavespeedAITask extends TaskProviderHelper { | ||
private accessToken: string | undefined; | ||
|
||
constructor(url?: string) { | ||
super("wavespeed-ai", url || WAVESPEEDAI_API_BASE_URL); | ||
} | ||
|
||
makeRoute(params: UrlParams): string { | ||
return `/api/v2/${params.model}`; | ||
} | ||
preparePayload(params: BodyParams): Record<string, unknown> { | ||
const payload: Record<string, unknown> = { | ||
...omit(params.args, ["inputs", "parameters"]), | ||
...(params.args.parameters as Record<string, unknown>), | ||
prompt: params.args.inputs, | ||
}; | ||
// Add LoRA support if adapter is specified in the mapping | ||
if (params.mapping?.adapter === "lora" && params.mapping.adapterWeightsPath) { | ||
payload.loras = [ | ||
{ | ||
path: params.mapping.adapterWeightsPath, | ||
scale: 1, // Default scale value | ||
}, | ||
]; | ||
} | ||
return payload; | ||
} | ||
|
||
override prepareHeaders(params: HeaderParams, isBinary: boolean): Record<string, string> { | ||
this.accessToken = params.accessToken; | ||
const headers: Record<string, string> = { Authorization: `Bearer ${params.accessToken}` }; | ||
if (!isBinary) { | ||
headers["Content-Type"] = "application/json"; | ||
} | ||
return headers; | ||
} | ||
|
||
override async getResponse( | ||
response: WaveSpeedAIResponse<WaveSpeedAISubmitResponse>, | ||
url?: string, | ||
headers?: Record<string, string> | ||
): Promise<Blob> { | ||
if (!headers && this.accessToken) { | ||
headers = { Authorization: `Bearer ${this.accessToken}` }; | ||
} | ||
if (!headers) { | ||
throw new InferenceOutputError("Headers are required for WaveSpeed AI API calls"); | ||
} | ||
|
||
const resultUrl = response.data.urls.get; | ||
|
||
// Poll for results until completion | ||
while (true) { | ||
const resultResponse = await fetch(resultUrl, { headers }); | ||
|
||
if (!resultResponse.ok) { | ||
throw new InferenceOutputError(`Failed to get result: ${resultResponse.statusText}`); | ||
} | ||
|
||
const result: WaveSpeedAIResponse = await resultResponse.json(); | ||
if (result.code !== 200) { | ||
throw new InferenceOutputError(`API request failed with code ${result.code}: ${result.message}`); | ||
} | ||
|
||
const taskResult = result.data; | ||
|
||
switch (taskResult.status) { | ||
case "completed": { | ||
// Get the media data from the first output URL | ||
if (!taskResult.outputs?.[0]) { | ||
throw new InferenceOutputError("No output URL in completed response"); | ||
} | ||
const mediaResponse = await fetch(taskResult.outputs[0]); | ||
if (!mediaResponse.ok) { | ||
throw new InferenceOutputError("Failed to fetch output data"); | ||
} | ||
return await mediaResponse.blob(); | ||
} | ||
case "failed": { | ||
throw new InferenceOutputError(taskResult.error || "Task failed"); | ||
} | ||
case "processing": | ||
case "created": | ||
// Wait before polling again | ||
await delay(500); | ||
continue; | ||
|
||
default: { | ||
throw new InferenceOutputError(`Unknown status: ${taskResult.status}`); | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
export class WavespeedAITextToImageTask extends WavespeedAITask implements TextToImageTaskHelper { | ||
constructor() { | ||
super(WAVESPEEDAI_API_BASE_URL); | ||
} | ||
} | ||
|
||
export class WavespeedAITextToVideoTask extends WavespeedAITask implements TextToVideoTaskHelper { | ||
constructor() { | ||
super(WAVESPEEDAI_API_BASE_URL); | ||
} | ||
} | ||
|
||
export class WavespeedAIImageToImageTask extends WavespeedAITask implements ImageToImageTaskHelper { | ||
constructor() { | ||
super(WAVESPEEDAI_API_BASE_URL); | ||
} | ||
|
||
async preparePayloadAsync(args: ImageToImageArgs): Promise<RequestArgs> { | ||
return { | ||
...args, | ||
inputs: args.parameters?.prompt, | ||
image: base64FromBytes( | ||
new Uint8Array(args.inputs instanceof ArrayBuffer ? args.inputs : await (args.inputs as Blob).arrayBuffer()) | ||
), | ||
}; | ||
} | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure this type alias is needed, can we remove it?
WaveSpeedAICommonResponse
can be renamed toWaveSpeedAIResponse
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This type is needed and will be used in two places. It's uncertain whether it will be used again in the future.
It follows the DRY (Don't Repeat Yourself) principle
It provides better type safety (through default generic parameters)
It makes the code more readable and maintainable