diff --git a/packages/inference/README.md b/packages/inference/README.md index ad4fcb879..21e46625b 100644 --- a/packages/inference/README.md +++ b/packages/inference/README.md @@ -53,7 +53,7 @@ Currently, we support the following providers: - [HF Inference](https://huggingface.co/docs/inference-providers/providers/hf-inference) - [Hyperbolic](https://hyperbolic.xyz) - [Nebius](https://studio.nebius.ai) -- [Novita](https://novita.ai/?utm_source=github_huggingface&utm_medium=github_readme&utm_campaign=link) +- [Novita](https://novita.ai) - [Nscale](https://nscale.com) - [OVHcloud](https://endpoints.ai.cloud.ovh.net/) - [Replicate](https://replicate.com) @@ -96,6 +96,7 @@ Only a subset of models are supported when requesting third-party providers. You - [Cohere supported models](https://huggingface.co/api/partners/cohere/models) - [Cerebras supported models](https://huggingface.co/api/partners/cerebras/models) - [Groq supported models](https://console.groq.com/docs/models) +- [Novita AI supported models](https://huggingface.co/api/partners/novita/models) ❗**Important note:** To be compatible, the third-party API must adhere to the "standard" shape API we expect on HF model pages for each pipeline task type. This is not an issue for LLMs as everyone converged on the OpenAI API anyways, but can be more tricky for other tasks like "text-to-image" or "automatic-speech-recognition" where there exists no standard API. Let us know if any help is needed or if we can make things easier for you! diff --git a/packages/inference/src/lib/getProviderHelper.ts b/packages/inference/src/lib/getProviderHelper.ts index 98f1d075c..9afcb8980 100644 --- a/packages/inference/src/lib/getProviderHelper.ts +++ b/packages/inference/src/lib/getProviderHelper.ts @@ -120,6 +120,7 @@ export const PROVIDERS: Record { + override preparePayload(params: BodyParams): Record { + const { num_inference_steps, ...restParameters } = params.args.parameters ?? {}; return { ...omit(params.args, ["inputs", "parameters"]), - ...(params.args.parameters as Record), + ...restParameters, + steps: num_inference_steps, prompt: params.args.inputs, }; } - override async getResponse(response: NovitaOutput): Promise { - const isValidOutput = - typeof response === "object" && - !!response && - "video" in response && - typeof response.video === "object" && - !!response.video && - "video_url" in response.video && - typeof response.video.video_url === "string" && - isUrl(response.video.video_url); - if (!isValidOutput) { - throw new InferenceOutputError("Expected { video: { video_url: string } }"); + override async getResponse( + response: NovitaAsyncAPIOutput, + url?: string, + headers?: Record + ): Promise { + if (!url || !headers) { + throw new InferenceOutputError("URL and headers are required for text-to-video task"); } + const taskId = response.task_id; + if (!taskId) { + throw new InferenceOutputError("No task ID found in the response"); + } + + const parsedUrl = new URL(url); + const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${ + parsedUrl.host === "router.huggingface.co" ? "/novita" : "" + }`; + const resultUrl = `${baseUrl}/v3/async/task-result?task_id=${taskId}`; + + let status = ""; + let taskResult: unknown; - const urlResponse = await fetch(response.video.video_url); - return await urlResponse.blob(); + while (status !== "TASK_STATUS_SUCCEED" && status !== "TASK_STATUS_FAILED") { + await delay(500); + const resultResponse = await fetch(resultUrl, { headers }); + if (!resultResponse.ok) { + throw new InferenceOutputError("Failed to fetch task result"); + } + try { + taskResult = await resultResponse.json(); + if ( + taskResult && + typeof taskResult === "object" && + "task" in taskResult && + taskResult.task && + typeof taskResult.task === "object" && + "status" in taskResult.task && + typeof taskResult.task.status === "string" + ) { + status = taskResult.task.status; + } else { + throw new InferenceOutputError("Failed to get task status"); + } + } catch (error) { + throw new InferenceOutputError("Failed to parse task result"); + } + } + + if (status === "TASK_STATUS_FAILED") { + throw new InferenceOutputError("Task failed"); + } + + if ( + typeof taskResult === "object" && + !!taskResult && + "videos" in taskResult && + typeof taskResult.videos === "object" && + !!taskResult.videos && + Array.isArray(taskResult.videos) && + taskResult.videos.length > 0 && + "video_url" in taskResult.videos[0] && + typeof taskResult.videos[0].video_url === "string" && + isUrl(taskResult.videos[0].video_url) + ) { + const urlResponse = await fetch(taskResult.videos[0].video_url); + return await urlResponse.blob(); + } else { + throw new InferenceOutputError("Expected { videos: [{ video_url: string }] }"); + } } } diff --git a/packages/inference/src/tasks/cv/textToVideo.ts b/packages/inference/src/tasks/cv/textToVideo.ts index 29c61b172..507eda811 100644 --- a/packages/inference/src/tasks/cv/textToVideo.ts +++ b/packages/inference/src/tasks/cv/textToVideo.ts @@ -3,7 +3,7 @@ import { resolveProvider } from "../../lib/getInferenceProviderMapping.js"; import { getProviderHelper } from "../../lib/getProviderHelper.js"; import { makeRequestOptions } from "../../lib/makeRequestOptions.js"; import type { FalAiQueueOutput } from "../../providers/fal-ai.js"; -import type { NovitaOutput } from "../../providers/novita.js"; +import type { NovitaAsyncAPIOutput } from "../../providers/novita.js"; import type { ReplicateOutput } from "../../providers/replicate.js"; import type { BaseArgs, Options } from "../../types.js"; import { innerRequest } from "../../utils/request.js"; @@ -15,7 +15,7 @@ export type TextToVideoOutput = Blob; export async function textToVideo(args: TextToVideoArgs, options?: Options): Promise { const provider = await resolveProvider(args.provider, args.model, args.endpointUrl); const providerHelper = getProviderHelper(provider, "text-to-video"); - const { data: response } = await innerRequest( + const { data: response } = await innerRequest( args, providerHelper, {