Skip to content

Commit a626ee5

Browse files
Wauplinmishig25
andauthored
Document python text to image snippets (#1016)
Supersedes #994. This PR adds an `huggingface_hub` snippet for `text-to-image` inference in Python. I added a test as done in #1003. Once this one is approved and merged, I'll move on with all other tasks that the `InferenceClient` supports. --------- Co-authored-by: Mishig <[email protected]>
1 parent f5de41b commit a626ee5

File tree

2 files changed

+59
-5
lines changed

2 files changed

+59
-5
lines changed

packages/tasks/src/snippets/python.spec.ts

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,4 +104,41 @@ stream = client.chat.completions.create(
104104
for chunk in stream:
105105
print(chunk.choices[0].delta.content, end="")`);
106106
});
107+
108+
it("text-to-image", async () => {
109+
const model: ModelDataMinimal = {
110+
id: "black-forest-labs/FLUX.1-schnell",
111+
pipeline_tag: "text-to-image",
112+
tags: [],
113+
inference: "",
114+
};
115+
const snippets = getPythonInferenceSnippet(model, "api_token") as InferenceSnippet[];
116+
117+
expect(snippets.length).toEqual(2);
118+
119+
expect(snippets[0].client).toEqual("huggingface_hub");
120+
expect(snippets[0].content).toEqual(`from huggingface_hub import InferenceClient
121+
client = InferenceClient("black-forest-labs/FLUX.1-schnell", token="api_token")
122+
123+
# output is a PIL.Image object
124+
image = client.text_to_image("Astronaut riding a horse")`);
125+
126+
expect(snippets[1].client).toEqual("requests");
127+
expect(snippets[1].content).toEqual(`import requests
128+
129+
API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell"
130+
headers = {"Authorization": "Bearer api_token"}
131+
132+
def query(payload):
133+
response = requests.post(API_URL, headers=headers, json=payload)
134+
return response.content
135+
image_bytes = query({
136+
"inputs": "Astronaut riding a horse",
137+
})
138+
139+
# You can access the image with PIL.Image for example
140+
import io
141+
from PIL import Image
142+
image = Image.open(io.BytesIO(image_bytes))`);
143+
});
107144
});

packages/tasks/src/snippets/python.ts

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@ import { stringifyGenerationConfig, stringifyMessages } from "./common.js";
44
import { getModelInputSnippet } from "./inputs.js";
55
import type { InferenceSnippet, ModelDataMinimal } from "./types.js";
66

7+
const snippetImportInferenceClient = (model: ModelDataMinimal, accessToken: string): string =>
8+
`from huggingface_hub import InferenceClient
9+
client = InferenceClient("${model.id}", token="${accessToken || "{API_TOKEN}"}")
10+
`;
11+
712
export const snippetConversational = (
813
model: ModelDataMinimal,
914
accessToken: string,
@@ -161,18 +166,28 @@ export const snippetFile = (model: ModelDataMinimal): InferenceSnippet => ({
161166
output = query(${getModelInputSnippet(model)})`,
162167
});
163168

164-
export const snippetTextToImage = (model: ModelDataMinimal): InferenceSnippet => ({
165-
content: `def query(payload):
169+
export const snippetTextToImage = (model: ModelDataMinimal, accessToken: string): InferenceSnippet[] => [
170+
{
171+
client: "huggingface_hub",
172+
content: `${snippetImportInferenceClient(model, accessToken)}
173+
# output is a PIL.Image object
174+
image = client.text_to_image(${getModelInputSnippet(model)})`,
175+
},
176+
{
177+
client: "requests",
178+
content: `def query(payload):
166179
response = requests.post(API_URL, headers=headers, json=payload)
167180
return response.content
168181
image_bytes = query({
169182
"inputs": ${getModelInputSnippet(model)},
170183
})
184+
171185
# You can access the image with PIL.Image for example
172186
import io
173187
from PIL import Image
174188
image = Image.open(io.BytesIO(image_bytes))`,
175-
});
189+
},
190+
];
176191

177192
export const snippetTabular = (model: ModelDataMinimal): InferenceSnippet => ({
178193
content: `def query(payload):
@@ -288,12 +303,14 @@ export function getPythonInferenceSnippet(
288303
return snippets.map((snippet) => {
289304
return {
290305
...snippet,
291-
content: `import requests
306+
content: snippet.content.includes("requests")
307+
? `import requests
292308
293309
API_URL = "https://api-inference.huggingface.co/models/${model.id}"
294310
headers = {"Authorization": ${accessToken ? `"Bearer ${accessToken}"` : `f"Bearer {API_TOKEN}"`}}
295311
296-
${snippet.content}`,
312+
${snippet.content}`
313+
: snippet.content,
297314
};
298315
});
299316
}

0 commit comments

Comments
 (0)