@@ -47,7 +47,7 @@ import type {
47
47
import * as Replicate from "../providers/replicate" ;
48
48
import * as Sambanova from "../providers/sambanova" ;
49
49
import * as Together from "../providers/together" ;
50
- import type { InferenceProvider , InferenceTask } from "../types" ;
50
+ import type { InferenceProvider , InferenceProviderOrPolicy , InferenceTask } from "../types" ;
51
51
52
52
export const PROVIDERS : Record < InferenceProvider , Partial < Record < InferenceTask , TaskProviderHelper > > > = {
53
53
"black-forest-labs" : {
@@ -152,128 +152,132 @@ export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask,
152
152
* Get provider helper instance by name and task
153
153
*/
154
154
export function getProviderHelper (
155
- provider : InferenceProvider ,
155
+ provider : InferenceProviderOrPolicy ,
156
156
task : "text-to-image"
157
157
) : TextToImageTaskHelper & TaskProviderHelper ;
158
158
export function getProviderHelper (
159
- provider : InferenceProvider ,
159
+ provider : InferenceProviderOrPolicy ,
160
160
task : "conversational"
161
161
) : ConversationalTaskHelper & TaskProviderHelper ;
162
162
export function getProviderHelper (
163
- provider : InferenceProvider ,
163
+ provider : InferenceProviderOrPolicy ,
164
164
task : "text-generation"
165
165
) : TextGenerationTaskHelper & TaskProviderHelper ;
166
166
export function getProviderHelper (
167
- provider : InferenceProvider ,
167
+ provider : InferenceProviderOrPolicy ,
168
168
task : "text-to-speech"
169
169
) : TextToSpeechTaskHelper & TaskProviderHelper ;
170
170
export function getProviderHelper (
171
- provider : InferenceProvider ,
171
+ provider : InferenceProviderOrPolicy ,
172
172
task : "text-to-audio"
173
173
) : TextToAudioTaskHelper & TaskProviderHelper ;
174
174
export function getProviderHelper (
175
- provider : InferenceProvider ,
175
+ provider : InferenceProviderOrPolicy ,
176
176
task : "automatic-speech-recognition"
177
177
) : AutomaticSpeechRecognitionTaskHelper & TaskProviderHelper ;
178
178
export function getProviderHelper (
179
- provider : InferenceProvider ,
179
+ provider : InferenceProviderOrPolicy ,
180
180
task : "text-to-video"
181
181
) : TextToVideoTaskHelper & TaskProviderHelper ;
182
182
export function getProviderHelper (
183
- provider : InferenceProvider ,
183
+ provider : InferenceProviderOrPolicy ,
184
184
task : "text-classification"
185
185
) : TextClassificationTaskHelper & TaskProviderHelper ;
186
186
export function getProviderHelper (
187
- provider : InferenceProvider ,
187
+ provider : InferenceProviderOrPolicy ,
188
188
task : "question-answering"
189
189
) : QuestionAnsweringTaskHelper & TaskProviderHelper ;
190
190
export function getProviderHelper (
191
- provider : InferenceProvider ,
191
+ provider : InferenceProviderOrPolicy ,
192
192
task : "audio-classification"
193
193
) : AudioClassificationTaskHelper & TaskProviderHelper ;
194
194
export function getProviderHelper (
195
- provider : InferenceProvider ,
195
+ provider : InferenceProviderOrPolicy ,
196
196
task : "audio-to-audio"
197
197
) : AudioToAudioTaskHelper & TaskProviderHelper ;
198
198
export function getProviderHelper (
199
- provider : InferenceProvider ,
199
+ provider : InferenceProviderOrPolicy ,
200
200
task : "fill-mask"
201
201
) : FillMaskTaskHelper & TaskProviderHelper ;
202
202
export function getProviderHelper (
203
- provider : InferenceProvider ,
203
+ provider : InferenceProviderOrPolicy ,
204
204
task : "feature-extraction"
205
205
) : FeatureExtractionTaskHelper & TaskProviderHelper ;
206
206
export function getProviderHelper (
207
- provider : InferenceProvider ,
207
+ provider : InferenceProviderOrPolicy ,
208
208
task : "image-classification"
209
209
) : ImageClassificationTaskHelper & TaskProviderHelper ;
210
210
export function getProviderHelper (
211
- provider : InferenceProvider ,
211
+ provider : InferenceProviderOrPolicy ,
212
212
task : "image-segmentation"
213
213
) : ImageSegmentationTaskHelper & TaskProviderHelper ;
214
214
export function getProviderHelper (
215
- provider : InferenceProvider ,
215
+ provider : InferenceProviderOrPolicy ,
216
216
task : "document-question-answering"
217
217
) : DocumentQuestionAnsweringTaskHelper & TaskProviderHelper ;
218
218
export function getProviderHelper (
219
- provider : InferenceProvider ,
219
+ provider : InferenceProviderOrPolicy ,
220
220
task : "image-to-text"
221
221
) : ImageToTextTaskHelper & TaskProviderHelper ;
222
222
export function getProviderHelper (
223
- provider : InferenceProvider ,
223
+ provider : InferenceProviderOrPolicy ,
224
224
task : "object-detection"
225
225
) : ObjectDetectionTaskHelper & TaskProviderHelper ;
226
226
export function getProviderHelper (
227
- provider : InferenceProvider ,
227
+ provider : InferenceProviderOrPolicy ,
228
228
task : "zero-shot-image-classification"
229
229
) : ZeroShotImageClassificationTaskHelper & TaskProviderHelper ;
230
230
export function getProviderHelper (
231
- provider : InferenceProvider ,
231
+ provider : InferenceProviderOrPolicy ,
232
232
task : "zero-shot-classification"
233
233
) : ZeroShotClassificationTaskHelper & TaskProviderHelper ;
234
234
export function getProviderHelper (
235
- provider : InferenceProvider ,
235
+ provider : InferenceProviderOrPolicy ,
236
236
task : "image-to-image"
237
237
) : ImageToImageTaskHelper & TaskProviderHelper ;
238
238
export function getProviderHelper (
239
- provider : InferenceProvider ,
239
+ provider : InferenceProviderOrPolicy ,
240
240
task : "sentence-similarity"
241
241
) : SentenceSimilarityTaskHelper & TaskProviderHelper ;
242
242
export function getProviderHelper (
243
- provider : InferenceProvider ,
243
+ provider : InferenceProviderOrPolicy ,
244
244
task : "table-question-answering"
245
245
) : TableQuestionAnsweringTaskHelper & TaskProviderHelper ;
246
246
export function getProviderHelper (
247
- provider : InferenceProvider ,
247
+ provider : InferenceProviderOrPolicy ,
248
248
task : "tabular-classification"
249
249
) : TabularClassificationTaskHelper & TaskProviderHelper ;
250
250
export function getProviderHelper (
251
- provider : InferenceProvider ,
251
+ provider : InferenceProviderOrPolicy ,
252
252
task : "tabular-regression"
253
253
) : TabularRegressionTaskHelper & TaskProviderHelper ;
254
254
export function getProviderHelper (
255
- provider : InferenceProvider ,
255
+ provider : InferenceProviderOrPolicy ,
256
256
task : "token-classification"
257
257
) : TokenClassificationTaskHelper & TaskProviderHelper ;
258
258
export function getProviderHelper (
259
- provider : InferenceProvider ,
259
+ provider : InferenceProviderOrPolicy ,
260
260
task : "translation"
261
261
) : TranslationTaskHelper & TaskProviderHelper ;
262
262
export function getProviderHelper (
263
- provider : InferenceProvider ,
263
+ provider : InferenceProviderOrPolicy ,
264
264
task : "summarization"
265
265
) : SummarizationTaskHelper & TaskProviderHelper ;
266
266
export function getProviderHelper (
267
- provider : InferenceProvider ,
267
+ provider : InferenceProviderOrPolicy ,
268
268
task : "visual-question-answering"
269
269
) : VisualQuestionAnsweringTaskHelper & TaskProviderHelper ;
270
- export function getProviderHelper ( provider : InferenceProvider , task : InferenceTask | undefined ) : TaskProviderHelper ;
270
+ export function getProviderHelper (
271
+ provider : InferenceProviderOrPolicy ,
272
+ task : InferenceTask | undefined
273
+ ) : TaskProviderHelper ;
271
274
272
- export function getProviderHelper ( provider : InferenceProvider , task : InferenceTask | undefined ) : TaskProviderHelper {
273
- if ( provider === "hf-inference" ) {
274
- if ( ! task ) {
275
- return new HFInference . HFInferenceTask ( ) ;
276
- }
275
+ export function getProviderHelper (
276
+ provider : InferenceProviderOrPolicy ,
277
+ task : InferenceTask | undefined
278
+ ) : TaskProviderHelper {
279
+ if ( ( provider === "hf-inference" && ! task ) || provider === "auto" ) {
280
+ return new HFInference . HFInferenceTask ( ) ;
277
281
}
278
282
if ( ! task ) {
279
283
throw new Error ( "you need to provide a task name when using an external provider, e.g. 'text-to-image'" ) ;
0 commit comments