@@ -44,11 +44,11 @@ const snippetImportInferenceClient = (accessToken: string, provider: SnippetInfe
4444from huggingface_hub import InferenceClient
4545
4646client = InferenceClient(
47- provider="${ provider } ",
48- api_key="${ accessToken || "{API_TOKEN}" } "
47+ provider="${ provider } ",
48+ api_key="${ accessToken || "{API_TOKEN}" } ",
4949)` ;
5050
51- export const snippetConversational = (
51+ const snippetConversational = (
5252 model : ModelDataMinimal ,
5353 accessToken : string ,
5454 provider : SnippetInferenceProvider ,
@@ -89,7 +89,7 @@ stream = client.chat.completions.create(
8989 model="${ model . id } ",
9090 messages=messages,
9191 ${ configStr }
92- stream=True
92+ stream=True,
9393)
9494
9595for chunk in stream:
@@ -159,7 +159,7 @@ print(completion.choices[0].message)`,
159159 }
160160} ;
161161
162- export const snippetZeroShotClassification = ( model : ModelDataMinimal ) : InferenceSnippet [ ] => {
162+ const snippetZeroShotClassification = ( model : ModelDataMinimal ) : InferenceSnippet [ ] => {
163163 return [
164164 {
165165 client : "requests" ,
@@ -176,12 +176,11 @@ output = query({
176176 ] ;
177177} ;
178178
179- export const snippetZeroShotImageClassification = ( model : ModelDataMinimal ) : InferenceSnippet [ ] => {
179+ const snippetZeroShotImageClassification = ( model : ModelDataMinimal ) : InferenceSnippet [ ] => {
180180 return [
181181 {
182182 client : "requests" ,
183- content : `\
184- def query(data):
183+ content : `def query(data):
185184 with open(data["image_path"], "rb") as f:
186185 img = f.read()
187186 payload={
@@ -199,7 +198,7 @@ output = query({
199198 ] ;
200199} ;
201200
202- export const snippetBasic = (
201+ const snippetBasic = (
203202 model : ModelDataMinimal ,
204203 accessToken : string ,
205204 provider : SnippetInferenceProvider
@@ -213,9 +212,8 @@ export const snippetBasic = (
213212${ snippetImportInferenceClient ( accessToken , provider ) }
214213
215214result = client.${ HFH_INFERENCE_CLIENT_METHODS [ model . pipeline_tag ] } (
216- model="${ model . id } ",
217215 inputs=${ getModelInputSnippet ( model ) } ,
218- provider ="${ provider } ",
216+ model ="${ model . id } ",
219217)
220218
221219print(result)
@@ -237,7 +235,7 @@ output = query({
237235 ] ;
238236} ;
239237
240- export const snippetFile = ( model : ModelDataMinimal ) : InferenceSnippet [ ] => {
238+ const snippetFile = ( model : ModelDataMinimal ) : InferenceSnippet [ ] => {
241239 return [
242240 {
243241 client : "requests" ,
@@ -253,7 +251,7 @@ output = query(${getModelInputSnippet(model)})`,
253251 ] ;
254252} ;
255253
256- export const snippetTextToImage = (
254+ const snippetTextToImage = (
257255 model : ModelDataMinimal ,
258256 accessToken : string ,
259257 provider : SnippetInferenceProvider ,
@@ -268,7 +266,7 @@ ${snippetImportInferenceClient(accessToken, provider)}
268266# output is a PIL.Image object
269267image = client.text_to_image(
270268 ${ getModelInputSnippet ( model ) } ,
271- model="${ model . id } "
269+ model="${ model . id } ",
272270)` ,
273271 } ,
274272 ...( provider === "fal-ai"
@@ -312,7 +310,7 @@ image = Image.open(io.BytesIO(image_bytes))`,
312310 ] ;
313311} ;
314312
315- export const snippetTextToVideo = (
313+ const snippetTextToVideo = (
316314 model : ModelDataMinimal ,
317315 accessToken : string ,
318316 provider : SnippetInferenceProvider
@@ -326,14 +324,14 @@ ${snippetImportInferenceClient(accessToken, provider)}
326324
327325video = client.text_to_video(
328326 ${ getModelInputSnippet ( model ) } ,
329- model="${ model . id } "
327+ model="${ model . id } ",
330328)` ,
331329 } ,
332330 ]
333331 : [ ] ;
334332} ;
335333
336- export const snippetTabular = ( model : ModelDataMinimal ) : InferenceSnippet [ ] => {
334+ const snippetTabular = ( model : ModelDataMinimal ) : InferenceSnippet [ ] => {
337335 return [
338336 {
339337 client : "requests" ,
@@ -349,7 +347,7 @@ response = query({
349347 ] ;
350348} ;
351349
352- export const snippetTextToAudio = ( model : ModelDataMinimal ) : InferenceSnippet [ ] => {
350+ const snippetTextToAudio = ( model : ModelDataMinimal ) : InferenceSnippet [ ] => {
353351 // Transformers TTS pipeline and api-inference-community (AIC) pipeline outputs are diverged
354352 // with the latest update to inference-api (IA).
355353 // Transformers IA returns a byte object (wav file), whereas AIC returns wav and sampling_rate.
@@ -374,8 +372,7 @@ Audio(audio_bytes)`,
374372 return [
375373 {
376374 client : "requests" ,
377- content : `\
378- def query(payload):
375+ content : `def query(payload):
379376 response = requests.post(API_URL, headers=headers, json=payload)
380377 return response.json()
381378
@@ -390,26 +387,97 @@ Audio(audio, rate=sampling_rate)`,
390387 }
391388} ;
392389
393- export const snippetDocumentQuestionAnswering = ( model : ModelDataMinimal ) : InferenceSnippet [ ] => {
390+ const snippetAutomaticSpeechRecognition = (
391+ model : ModelDataMinimal ,
392+ accessToken : string ,
393+ provider : SnippetInferenceProvider
394+ ) : InferenceSnippet [ ] => {
395+ return [
396+ {
397+ client : "huggingface_hub" ,
398+ content : `${ snippetImportInferenceClient ( accessToken , provider ) }
399+ output = client.automatic_speech_recognition(${ getModelInputSnippet ( model ) } , model="${ model . id } ")` ,
400+ } ,
401+ snippetFile ( model ) [ 0 ] ,
402+ ] ;
403+ } ;
404+
405+ const snippetDocumentQuestionAnswering = (
406+ model : ModelDataMinimal ,
407+ accessToken : string ,
408+ provider : SnippetInferenceProvider
409+ ) : InferenceSnippet [ ] => {
410+ const inputsAsStr = getModelInputSnippet ( model ) as string ;
411+ const inputsAsObj = JSON . parse ( inputsAsStr ) ;
412+
394413 return [
414+ {
415+ client : "huggingface_hub" ,
416+ content : `${ snippetImportInferenceClient ( accessToken , provider ) }
417+ output = client.document_question_answering(
418+ "${ inputsAsObj . image } ",
419+ question="${ inputsAsObj . question } ",
420+ model="${ model . id } ",
421+ )` ,
422+ } ,
395423 {
396424 client : "requests" ,
397- content : `\
398- def query(payload):
425+ content : `def query(payload):
399426 with open(payload["image"], "rb") as f:
400427 img = f.read()
401- payload["image"] = base64.b64encode(img).decode("utf-8")
428+ payload["image"] = base64.b64encode(img).decode("utf-8")
402429 response = requests.post(API_URL, headers=headers, json=payload)
403430 return response.json()
404431
405432output = query({
406- "inputs": ${ getModelInputSnippet ( model ) } ,
433+ "inputs": ${ inputsAsStr } ,
407434})` ,
408435 } ,
409436 ] ;
410437} ;
411438
412- export const pythonSnippets : Partial <
439+ const snippetImageToImage = (
440+ model : ModelDataMinimal ,
441+ accessToken : string ,
442+ provider : SnippetInferenceProvider
443+ ) : InferenceSnippet [ ] => {
444+ const inputsAsStr = getModelInputSnippet ( model ) as string ;
445+ const inputsAsObj = JSON . parse ( inputsAsStr ) ;
446+
447+ return [
448+ {
449+ client : "huggingface_hub" ,
450+ content : `${ snippetImportInferenceClient ( accessToken , provider ) }
451+ # output is a PIL.Image object
452+ image = client.image_to_image(
453+ "${ inputsAsObj . image } ",
454+ prompt="${ inputsAsObj . prompt } ",
455+ model="${ model . id } ",
456+ )` ,
457+ } ,
458+ {
459+ client : "requests" ,
460+ content : `def query(payload):
461+ with open(payload["inputs"], "rb") as f:
462+ img = f.read()
463+ payload["inputs"] = base64.b64encode(img).decode("utf-8")
464+ response = requests.post(API_URL, headers=headers, json=payload)
465+ return response.content
466+
467+ image_bytes = query({
468+ "inputs": "${ inputsAsObj . image } ",
469+ "parameters": {"prompt": "${ inputsAsObj . prompt } "},
470+ })
471+
472+ # You can access the image with PIL.Image for example
473+ import io
474+ from PIL import Image
475+ image = Image.open(io.BytesIO(image_bytes))` ,
476+ } ,
477+ ] ;
478+ } ;
479+
480+ const pythonSnippets : Partial <
413481 Record <
414482 PipelineType ,
415483 (
@@ -435,7 +503,7 @@ export const pythonSnippets: Partial<
435503 "image-text-to-text" : snippetConversational ,
436504 "fill-mask" : snippetBasic ,
437505 "sentence-similarity" : snippetBasic ,
438- "automatic-speech-recognition" : snippetFile ,
506+ "automatic-speech-recognition" : snippetAutomaticSpeechRecognition ,
439507 "text-to-image" : snippetTextToImage ,
440508 "text-to-video" : snippetTextToVideo ,
441509 "text-to-speech" : snippetTextToAudio ,
@@ -449,6 +517,7 @@ export const pythonSnippets: Partial<
449517 "image-segmentation" : snippetFile ,
450518 "document-question-answering" : snippetDocumentQuestionAnswering ,
451519 "image-to-text" : snippetFile ,
520+ "image-to-image" : snippetImageToImage ,
452521 "zero-shot-image-classification" : snippetZeroShotImageClassification ,
453522} ;
454523
@@ -471,17 +540,24 @@ export function getPythonInferenceSnippet(
471540 return snippets . map ( ( snippet ) => {
472541 return {
473542 ...snippet ,
474- content :
475- snippet . client === "requests"
476- ? `\
477- import requests
478-
479- API_URL = "${ openAIbaseUrl ( provider ) } "
480- headers = {"Authorization": ${ accessToken ? `"Bearer ${ accessToken } "` : `f"Bearer {API_TOKEN}"` } }
481-
482- ${ snippet . content } `
483- : snippet . content ,
543+ content : addImportsToSnippet ( snippet . content , model , accessToken ) ,
484544 } ;
485545 } ) ;
486546 }
487547}
548+
549+ const addImportsToSnippet = ( snippet : string , model : ModelDataMinimal , accessToken : string ) : string => {
550+ if ( snippet . includes ( "requests" ) ) {
551+ snippet = `import requests
552+
553+ API_URL = "https://router.huggingface.co/hf-inference/models/${ model . id } "
554+ headers = {"Authorization": ${ accessToken ? `"Bearer ${ accessToken } "` : `f"Bearer {API_TOKEN}"` } }
555+
556+ ${ snippet } `;
557+ }
558+ if ( snippet . includes ( "base64" ) ) {
559+ snippet = `import base64
560+ ${ snippet } `;
561+ }
562+ return snippet ;
563+ } ;
0 commit comments