@@ -4,6 +4,11 @@ import { stringifyGenerationConfig, stringifyMessages } from "./common.js";
44import { getModelInputSnippet } from "./inputs.js" ;
55import type { InferenceSnippet , ModelDataMinimal } from "./types.js" ;
66
7+ const snippetImportInferenceClient = ( model : ModelDataMinimal , accessToken : string ) : string =>
8+ `from huggingface_hub import InferenceClient
9+ client = InferenceClient("${ model . id } ", token="${ accessToken || "{API_TOKEN}" } ")
10+ ` ;
11+
712export const snippetConversational = (
813 model : ModelDataMinimal ,
914 accessToken : string ,
@@ -161,18 +166,28 @@ export const snippetFile = (model: ModelDataMinimal): InferenceSnippet => ({
161166output = query(${ getModelInputSnippet ( model ) } )` ,
162167} ) ;
163168
164- export const snippetTextToImage = ( model : ModelDataMinimal ) : InferenceSnippet => ( {
165- content : `def query(payload):
169+ export const snippetTextToImage = ( model : ModelDataMinimal , accessToken : string ) : InferenceSnippet [ ] => [
170+ {
171+ client : "huggingface_hub" ,
172+ content : `${ snippetImportInferenceClient ( model , accessToken ) }
173+ # output is a PIL.Image object
174+ image = client.text_to_image(${ getModelInputSnippet ( model ) } )` ,
175+ } ,
176+ {
177+ client : "requests" ,
178+ content : `def query(payload):
166179 response = requests.post(API_URL, headers=headers, json=payload)
167180 return response.content
168181image_bytes = query({
169182 "inputs": ${ getModelInputSnippet ( model ) } ,
170183})
184+
171185# You can access the image with PIL.Image for example
172186import io
173187from PIL import Image
174188image = Image.open(io.BytesIO(image_bytes))` ,
175- } ) ;
189+ } ,
190+ ] ;
176191
177192export const snippetTabular = ( model : ModelDataMinimal ) : InferenceSnippet => ( {
178193 content : `def query(payload):
@@ -288,12 +303,14 @@ export function getPythonInferenceSnippet(
288303 return snippets . map ( ( snippet ) => {
289304 return {
290305 ...snippet ,
291- content : `import requests
306+ content : snippet . content . includes ( "requests" )
307+ ? `import requests
292308
293309API_URL = "https://api-inference.huggingface.co/models/${ model . id } "
294310headers = {"Authorization": ${ accessToken ? `"Bearer ${ accessToken } "` : `f"Bearer {API_TOKEN}"` } }
295311
296- ${ snippet . content } `,
312+ ${ snippet . content } `
313+ : snippet . content ,
297314 } ;
298315 } ) ;
299316 }
0 commit comments