@@ -4,11 +4,35 @@ import { stringifyGenerationConfig, stringifyMessages } from "./common.js";
44import { getModelInputSnippet } from "./inputs.js" ;
55import type { InferenceSnippet , ModelDataMinimal } from "./types.js" ;
66
7+ // Snippets shared between tasks
8+
79const snippetImportInferenceClient = ( model : ModelDataMinimal , accessToken : string ) : string =>
810 `from huggingface_hub import InferenceClient
911client = InferenceClient("${ model . id } ", token="${ accessToken || "{API_TOKEN}" } ")
1012` ;
1113
14+ const snippetBasic = ( model : ModelDataMinimal ) : InferenceSnippet => ( {
15+ content : `def query(payload):
16+ response = requests.post(API_URL, headers=headers, json=payload)
17+ return response.json()
18+
19+ output = query({
20+ "inputs": ${ getModelInputSnippet ( model ) } ,
21+ })` ,
22+ } ) ;
23+
24+ const snippetFile = ( model : ModelDataMinimal ) : InferenceSnippet => ( {
25+ content : `def query(filename):
26+ with open(filename, "rb") as f:
27+ data = f.read()
28+ response = requests.post(API_URL, headers=headers, data=data)
29+ return response.json()
30+
31+ output = query(${ getModelInputSnippet ( model ) } )` ,
32+ } ) ;
33+
34+ // Specific snippets
35+
1236const snippetConversational = (
1337 model : ModelDataMinimal ,
1438 accessToken : string ,
@@ -118,76 +142,31 @@ print(completion.choices[0].message)`,
118142 }
119143} ;
120144
121- const snippetZeroShotClassification = ( model : ModelDataMinimal ) : InferenceSnippet => ( {
122- content : `def query(payload):
123- response = requests.post(API_URL, headers=headers, json=payload)
124- return response.json()
125-
126- output = query({
127- "inputs": ${ getModelInputSnippet ( model ) } ,
128- "parameters": {"candidate_labels": ["refund", "legal", "faq"]},
129- })` ,
130- } ) ;
145+ const snippetDocumentQuestionAnswering = ( model : ModelDataMinimal , accessToken : string ) : InferenceSnippet [ ] => {
146+ const inputsAsStr = getModelInputSnippet ( model ) as string ;
147+ const inputsAsObj = JSON . parse ( inputsAsStr ) ;
131148
132- const snippetZeroShotImageClassification = ( model : ModelDataMinimal ) : InferenceSnippet => ( {
133- content : `def query(data):
134- with open(data["image_path"], "rb") as f:
149+ return [
150+ {
151+ client : "huggingface_hub" ,
152+ content : `${ snippetImportInferenceClient ( model , accessToken ) }
153+ output = client.document_question_answering(${ inputsAsObj . image } , question=${ inputsAsObj . question } )` ,
154+ } ,
155+ {
156+ client : "requests" ,
157+ content : `def query(payload):
158+ with open(payload["image"], "rb") as f:
135159 img = f.read()
136- payload={
137- "parameters": data["parameters"],
138- "inputs": base64.b64encode(img).decode("utf-8")
139- }
160+ payload["image"] = base64.b64encode(img).decode("utf-8")
140161 response = requests.post(API_URL, headers=headers, json=payload)
141162 return response.json()
142163
143164output = query({
144- "image_path": ${ getModelInputSnippet ( model ) } ,
145- "parameters": {"candidate_labels": ["cat", "dog", "llama"]},
146- })` ,
147- } ) ;
148-
149- const snippetBasic = ( model : ModelDataMinimal ) : InferenceSnippet => ( {
150- content : `def query(payload):
151- response = requests.post(API_URL, headers=headers, json=payload)
152- return response.json()
153-
154- output = query({
155- "inputs": ${ getModelInputSnippet ( model ) } ,
165+ "inputs": ${ inputsAsStr } ,
156166})` ,
157- } ) ;
158-
159- const snippetFile = ( model : ModelDataMinimal ) : InferenceSnippet => ( {
160- content : `def query(filename):
161- with open(filename, "rb") as f:
162- data = f.read()
163- response = requests.post(API_URL, headers=headers, data=data)
164- return response.json()
165-
166- output = query(${ getModelInputSnippet ( model ) } )` ,
167- } ) ;
168-
169- const snippetTextToImage = ( model : ModelDataMinimal , accessToken : string ) : InferenceSnippet [ ] => [
170- {
171- client : "huggingface_hub" ,
172- content : `${ snippetImportInferenceClient ( model , accessToken ) }
173- # output is a PIL.Image object
174- image = client.text_to_image(${ getModelInputSnippet ( model ) } )` ,
175- } ,
176- {
177- client : "requests" ,
178- content : `def query(payload):
179- response = requests.post(API_URL, headers=headers, json=payload)
180- return response.content
181- image_bytes = query({
182- "inputs": ${ getModelInputSnippet ( model ) } ,
183- })
184-
185- # You can access the image with PIL.Image for example
186- import io
187- from PIL import Image
188- image = Image.open(io.BytesIO(image_bytes))` ,
189- } ,
190- ] ;
167+ } ,
168+ ] ;
169+ } ;
191170
192171const snippetTabular = ( model : ModelDataMinimal ) : InferenceSnippet => ( {
193172 content : `def query(payload):
@@ -231,31 +210,56 @@ Audio(audio, rate=sampling_rate)`,
231210 }
232211} ;
233212
234- const snippetDocumentQuestionAnswering = ( model : ModelDataMinimal , accessToken : string ) : InferenceSnippet [ ] => {
235- const inputsAsStr = getModelInputSnippet ( model ) as string ;
236- const inputsAsObj = JSON . parse ( inputsAsStr ) ;
213+ const snippetTextToImage = ( model : ModelDataMinimal , accessToken : string ) : InferenceSnippet [ ] => [
214+ {
215+ client : "huggingface_hub" ,
216+ content : `${ snippetImportInferenceClient ( model , accessToken ) }
217+ # output is a PIL.Image object
218+ image = client.text_to_image(${ getModelInputSnippet ( model ) } )` ,
219+ } ,
220+ {
221+ client : "requests" ,
222+ content : `def query(payload):
223+ response = requests.post(API_URL, headers=headers, json=payload)
224+ return response.content
225+ image_bytes = query({
226+ "inputs": ${ getModelInputSnippet ( model ) } ,
227+ })
237228
238- return [
239- {
240- client : "huggingface_hub" ,
241- content : `${ snippetImportInferenceClient ( model , accessToken ) }
242- output = client.document_question_answering(${ inputsAsObj . image } , question=${ inputsAsObj . question } )` ,
243- } ,
244- {
245- client : "requests" ,
246- content : `def query(payload):
247- with open(payload["image"], "rb") as f:
229+ # You can access the image with PIL.Image for example
230+ import io
231+ from PIL import Image
232+ image = Image.open(io.BytesIO(image_bytes))` ,
233+ } ,
234+ ] ;
235+
236+ const snippetZeroShotClassification = ( model : ModelDataMinimal ) : InferenceSnippet => ( {
237+ content : `def query(payload):
238+ response = requests.post(API_URL, headers=headers, json=payload)
239+ return response.json()
240+
241+ output = query({
242+ "inputs": ${ getModelInputSnippet ( model ) } ,
243+ "parameters": {"candidate_labels": ["refund", "legal", "faq"]},
244+ })` ,
245+ } ) ;
246+
247+ const snippetZeroShotImageClassification = ( model : ModelDataMinimal ) : InferenceSnippet => ( {
248+ content : `def query(data):
249+ with open(data["image_path"], "rb") as f:
248250 img = f.read()
249- payload["image"] = base64.b64encode(img).decode("utf-8")
251+ payload={
252+ "parameters": data["parameters"],
253+ "inputs": base64.b64encode(img).decode("utf-8")
254+ }
250255 response = requests.post(API_URL, headers=headers, json=payload)
251256 return response.json()
252257
253258output = query({
254- "inputs": ${ inputsAsStr } ,
259+ "image_path": ${ getModelInputSnippet ( model ) } ,
260+ "parameters": {"candidate_labels": ["cat", "dog", "llama"]},
255261})` ,
256- } ,
257- ] ;
258- } ;
262+ } ) ;
259263
260264const pythonSnippets : Partial <
261265 Record <
0 commit comments