Skip to content

Commit 8d6e815

Browse files
authored
[Inference] support all tasks for the auto policy (#1458)
For the `auto` policy, `hf-inference` doesn't support `text-to-video` so we should bypass the mapping see the missing snippet in https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B?language=python&inference_api=true&inference_provider=auto
1 parent b89cc94 commit 8d6e815

File tree

2 files changed

+41
-38
lines changed

2 files changed

+41
-38
lines changed

packages/inference/src/lib/getProviderHelper.ts

Lines changed: 40 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ import type {
4747
import * as Replicate from "../providers/replicate";
4848
import * as Sambanova from "../providers/sambanova";
4949
import * as Together from "../providers/together";
50-
import type { InferenceProvider, InferenceTask } from "../types";
50+
import type { InferenceProvider, InferenceProviderOrPolicy, InferenceTask } from "../types";
5151

5252
export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask, TaskProviderHelper>>> = {
5353
"black-forest-labs": {
@@ -152,128 +152,132 @@ export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask,
152152
* Get provider helper instance by name and task
153153
*/
154154
export function getProviderHelper(
155-
provider: InferenceProvider,
155+
provider: InferenceProviderOrPolicy,
156156
task: "text-to-image"
157157
): TextToImageTaskHelper & TaskProviderHelper;
158158
export function getProviderHelper(
159-
provider: InferenceProvider,
159+
provider: InferenceProviderOrPolicy,
160160
task: "conversational"
161161
): ConversationalTaskHelper & TaskProviderHelper;
162162
export function getProviderHelper(
163-
provider: InferenceProvider,
163+
provider: InferenceProviderOrPolicy,
164164
task: "text-generation"
165165
): TextGenerationTaskHelper & TaskProviderHelper;
166166
export function getProviderHelper(
167-
provider: InferenceProvider,
167+
provider: InferenceProviderOrPolicy,
168168
task: "text-to-speech"
169169
): TextToSpeechTaskHelper & TaskProviderHelper;
170170
export function getProviderHelper(
171-
provider: InferenceProvider,
171+
provider: InferenceProviderOrPolicy,
172172
task: "text-to-audio"
173173
): TextToAudioTaskHelper & TaskProviderHelper;
174174
export function getProviderHelper(
175-
provider: InferenceProvider,
175+
provider: InferenceProviderOrPolicy,
176176
task: "automatic-speech-recognition"
177177
): AutomaticSpeechRecognitionTaskHelper & TaskProviderHelper;
178178
export function getProviderHelper(
179-
provider: InferenceProvider,
179+
provider: InferenceProviderOrPolicy,
180180
task: "text-to-video"
181181
): TextToVideoTaskHelper & TaskProviderHelper;
182182
export function getProviderHelper(
183-
provider: InferenceProvider,
183+
provider: InferenceProviderOrPolicy,
184184
task: "text-classification"
185185
): TextClassificationTaskHelper & TaskProviderHelper;
186186
export function getProviderHelper(
187-
provider: InferenceProvider,
187+
provider: InferenceProviderOrPolicy,
188188
task: "question-answering"
189189
): QuestionAnsweringTaskHelper & TaskProviderHelper;
190190
export function getProviderHelper(
191-
provider: InferenceProvider,
191+
provider: InferenceProviderOrPolicy,
192192
task: "audio-classification"
193193
): AudioClassificationTaskHelper & TaskProviderHelper;
194194
export function getProviderHelper(
195-
provider: InferenceProvider,
195+
provider: InferenceProviderOrPolicy,
196196
task: "audio-to-audio"
197197
): AudioToAudioTaskHelper & TaskProviderHelper;
198198
export function getProviderHelper(
199-
provider: InferenceProvider,
199+
provider: InferenceProviderOrPolicy,
200200
task: "fill-mask"
201201
): FillMaskTaskHelper & TaskProviderHelper;
202202
export function getProviderHelper(
203-
provider: InferenceProvider,
203+
provider: InferenceProviderOrPolicy,
204204
task: "feature-extraction"
205205
): FeatureExtractionTaskHelper & TaskProviderHelper;
206206
export function getProviderHelper(
207-
provider: InferenceProvider,
207+
provider: InferenceProviderOrPolicy,
208208
task: "image-classification"
209209
): ImageClassificationTaskHelper & TaskProviderHelper;
210210
export function getProviderHelper(
211-
provider: InferenceProvider,
211+
provider: InferenceProviderOrPolicy,
212212
task: "image-segmentation"
213213
): ImageSegmentationTaskHelper & TaskProviderHelper;
214214
export function getProviderHelper(
215-
provider: InferenceProvider,
215+
provider: InferenceProviderOrPolicy,
216216
task: "document-question-answering"
217217
): DocumentQuestionAnsweringTaskHelper & TaskProviderHelper;
218218
export function getProviderHelper(
219-
provider: InferenceProvider,
219+
provider: InferenceProviderOrPolicy,
220220
task: "image-to-text"
221221
): ImageToTextTaskHelper & TaskProviderHelper;
222222
export function getProviderHelper(
223-
provider: InferenceProvider,
223+
provider: InferenceProviderOrPolicy,
224224
task: "object-detection"
225225
): ObjectDetectionTaskHelper & TaskProviderHelper;
226226
export function getProviderHelper(
227-
provider: InferenceProvider,
227+
provider: InferenceProviderOrPolicy,
228228
task: "zero-shot-image-classification"
229229
): ZeroShotImageClassificationTaskHelper & TaskProviderHelper;
230230
export function getProviderHelper(
231-
provider: InferenceProvider,
231+
provider: InferenceProviderOrPolicy,
232232
task: "zero-shot-classification"
233233
): ZeroShotClassificationTaskHelper & TaskProviderHelper;
234234
export function getProviderHelper(
235-
provider: InferenceProvider,
235+
provider: InferenceProviderOrPolicy,
236236
task: "image-to-image"
237237
): ImageToImageTaskHelper & TaskProviderHelper;
238238
export function getProviderHelper(
239-
provider: InferenceProvider,
239+
provider: InferenceProviderOrPolicy,
240240
task: "sentence-similarity"
241241
): SentenceSimilarityTaskHelper & TaskProviderHelper;
242242
export function getProviderHelper(
243-
provider: InferenceProvider,
243+
provider: InferenceProviderOrPolicy,
244244
task: "table-question-answering"
245245
): TableQuestionAnsweringTaskHelper & TaskProviderHelper;
246246
export function getProviderHelper(
247-
provider: InferenceProvider,
247+
provider: InferenceProviderOrPolicy,
248248
task: "tabular-classification"
249249
): TabularClassificationTaskHelper & TaskProviderHelper;
250250
export function getProviderHelper(
251-
provider: InferenceProvider,
251+
provider: InferenceProviderOrPolicy,
252252
task: "tabular-regression"
253253
): TabularRegressionTaskHelper & TaskProviderHelper;
254254
export function getProviderHelper(
255-
provider: InferenceProvider,
255+
provider: InferenceProviderOrPolicy,
256256
task: "token-classification"
257257
): TokenClassificationTaskHelper & TaskProviderHelper;
258258
export function getProviderHelper(
259-
provider: InferenceProvider,
259+
provider: InferenceProviderOrPolicy,
260260
task: "translation"
261261
): TranslationTaskHelper & TaskProviderHelper;
262262
export function getProviderHelper(
263-
provider: InferenceProvider,
263+
provider: InferenceProviderOrPolicy,
264264
task: "summarization"
265265
): SummarizationTaskHelper & TaskProviderHelper;
266266
export function getProviderHelper(
267-
provider: InferenceProvider,
267+
provider: InferenceProviderOrPolicy,
268268
task: "visual-question-answering"
269269
): VisualQuestionAnsweringTaskHelper & TaskProviderHelper;
270-
export function getProviderHelper(provider: InferenceProvider, task: InferenceTask | undefined): TaskProviderHelper;
270+
export function getProviderHelper(
271+
provider: InferenceProviderOrPolicy,
272+
task: InferenceTask | undefined
273+
): TaskProviderHelper;
271274

272-
export function getProviderHelper(provider: InferenceProvider, task: InferenceTask | undefined): TaskProviderHelper {
273-
if (provider === "hf-inference") {
274-
if (!task) {
275-
return new HFInference.HFInferenceTask();
276-
}
275+
export function getProviderHelper(
276+
provider: InferenceProviderOrPolicy,
277+
task: InferenceTask | undefined
278+
): TaskProviderHelper {
279+
if ((provider === "hf-inference" && !task) || provider === "auto") {
280+
return new HFInference.HFInferenceTask();
277281
}
278282
if (!task) {
279283
throw new Error("you need to provide a task name when using an external provider, e.g. 'text-to-image'");

packages/inference/src/snippets/getInferenceSnippets.ts

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,7 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
144144
}
145145
let providerHelper: ReturnType<typeof getProviderHelper>;
146146
try {
147-
/// For the "auto" provider policy we use hf-inference snippets
148-
providerHelper = getProviderHelper(provider === "auto" ? "hf-inference" : provider, task);
147+
providerHelper = getProviderHelper(provider, task);
149148
} catch (e) {
150149
console.error(`Failed to get provider helper for ${provider} (${task})`, e);
151150
return [];

0 commit comments

Comments
 (0)