Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 113 additions & 37 deletions packages/inference/src/snippets/python.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ const snippetImportInferenceClient = (accessToken: string, provider: SnippetInfe
from huggingface_hub import InferenceClient

client = InferenceClient(
provider="${provider}",
api_key="${accessToken || "{API_TOKEN}"}"
provider="${provider}",
api_key="${accessToken || "{API_TOKEN}"}",
)`;

export const snippetConversational = (
const snippetConversational = (
model: ModelDataMinimal,
accessToken: string,
provider: SnippetInferenceProvider,
Expand Down Expand Up @@ -89,7 +89,7 @@ stream = client.chat.completions.create(
model="${model.id}",
messages=messages,
${configStr}
stream=True
stream=True,
)

for chunk in stream:
Expand Down Expand Up @@ -159,7 +159,7 @@ print(completion.choices[0].message)`,
}
};

export const snippetZeroShotClassification = (model: ModelDataMinimal): InferenceSnippet[] => {
const snippetZeroShotClassification = (model: ModelDataMinimal): InferenceSnippet[] => {
return [
{
client: "requests",
Expand All @@ -176,12 +176,11 @@ output = query({
];
};

export const snippetZeroShotImageClassification = (model: ModelDataMinimal): InferenceSnippet[] => {
const snippetZeroShotImageClassification = (model: ModelDataMinimal): InferenceSnippet[] => {
return [
{
client: "requests",
content: `\
def query(data):
content: `def query(data):
with open(data["image_path"], "rb") as f:
img = f.read()
payload={
Expand All @@ -199,7 +198,7 @@ output = query({
];
};

export const snippetBasic = (
const snippetBasic = (
model: ModelDataMinimal,
accessToken: string,
provider: SnippetInferenceProvider
Expand All @@ -213,9 +212,8 @@ export const snippetBasic = (
${snippetImportInferenceClient(accessToken, provider)}

result = client.${HFH_INFERENCE_CLIENT_METHODS[model.pipeline_tag]}(
model="${model.id}",
inputs=${getModelInputSnippet(model)},
provider="${provider}",
model="${model.id}",
)

print(result)
Expand All @@ -237,7 +235,7 @@ output = query({
];
};

export const snippetFile = (model: ModelDataMinimal): InferenceSnippet[] => {
const snippetFile = (model: ModelDataMinimal): InferenceSnippet[] => {
return [
{
client: "requests",
Expand All @@ -253,7 +251,7 @@ output = query(${getModelInputSnippet(model)})`,
];
};

export const snippetTextToImage = (
const snippetTextToImage = (
model: ModelDataMinimal,
accessToken: string,
provider: SnippetInferenceProvider,
Expand All @@ -268,7 +266,7 @@ ${snippetImportInferenceClient(accessToken, provider)}
# output is a PIL.Image object
image = client.text_to_image(
${getModelInputSnippet(model)},
model="${model.id}"
model="${model.id}",
)`,
},
...(provider === "fal-ai"
Expand Down Expand Up @@ -312,7 +310,7 @@ image = Image.open(io.BytesIO(image_bytes))`,
];
};

export const snippetTextToVideo = (
const snippetTextToVideo = (
model: ModelDataMinimal,
accessToken: string,
provider: SnippetInferenceProvider
Expand All @@ -326,14 +324,14 @@ ${snippetImportInferenceClient(accessToken, provider)}

video = client.text_to_video(
${getModelInputSnippet(model)},
model="${model.id}"
model="${model.id}",
)`,
},
]
: [];
};

export const snippetTabular = (model: ModelDataMinimal): InferenceSnippet[] => {
const snippetTabular = (model: ModelDataMinimal): InferenceSnippet[] => {
return [
{
client: "requests",
Expand All @@ -349,7 +347,7 @@ response = query({
];
};

export const snippetTextToAudio = (model: ModelDataMinimal): InferenceSnippet[] => {
const snippetTextToAudio = (model: ModelDataMinimal): InferenceSnippet[] => {
// Transformers TTS pipeline and api-inference-community (AIC) pipeline outputs are diverged
// with the latest update to inference-api (IA).
// Transformers IA returns a byte object (wav file), whereas AIC returns wav and sampling_rate.
Expand All @@ -374,8 +372,7 @@ Audio(audio_bytes)`,
return [
{
client: "requests",
content: `\
def query(payload):
content: `def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
return response.json()

Expand All @@ -390,26 +387,97 @@ Audio(audio, rate=sampling_rate)`,
}
};

export const snippetDocumentQuestionAnswering = (model: ModelDataMinimal): InferenceSnippet[] => {
const snippetAutomaticSpeechRecognition = (
model: ModelDataMinimal,
accessToken: string,
provider: SnippetInferenceProvider
): InferenceSnippet[] => {
return [
{
client: "huggingface_hub",
content: `${snippetImportInferenceClient(accessToken, provider)}
output = client.automatic_speech_recognition(${getModelInputSnippet(model)}, model="${model.id}")`,
},
snippetFile(model)[0],
];
};

const snippetDocumentQuestionAnswering = (
model: ModelDataMinimal,
accessToken: string,
provider: SnippetInferenceProvider
): InferenceSnippet[] => {
const inputsAsStr = getModelInputSnippet(model) as string;
const inputsAsObj = JSON.parse(inputsAsStr);

return [
{
client: "huggingface_hub",
content: `${snippetImportInferenceClient(accessToken, provider)}
output = client.document_question_answering(
"${inputsAsObj.image}",
question="${inputsAsObj.question}",
model="${model.id}",
)`,
},
{
client: "requests",
content: `\
def query(payload):
content: `def query(payload):
with open(payload["image"], "rb") as f:
img = f.read()
payload["image"] = base64.b64encode(img).decode("utf-8")
payload["image"] = base64.b64encode(img).decode("utf-8")
response = requests.post(API_URL, headers=headers, json=payload)
return response.json()

output = query({
"inputs": ${getModelInputSnippet(model)},
"inputs": ${inputsAsStr},
})`,
},
];
};

export const pythonSnippets: Partial<
const snippetImageToImage = (
model: ModelDataMinimal,
accessToken: string,
provider: SnippetInferenceProvider
): InferenceSnippet[] => {
const inputsAsStr = getModelInputSnippet(model) as string;
const inputsAsObj = JSON.parse(inputsAsStr);

return [
{
client: "huggingface_hub",
content: `${snippetImportInferenceClient(accessToken, provider)}
# output is a PIL.Image object
image = client.image_to_image(
"${inputsAsObj.image}",
prompt="${inputsAsObj.prompt}",
model="${model.id}",
)`,
},
{
client: "requests",
content: `def query(payload):
with open(payload["inputs"], "rb") as f:
img = f.read()
payload["inputs"] = base64.b64encode(img).decode("utf-8")
response = requests.post(API_URL, headers=headers, json=payload)
return response.content

image_bytes = query({
"inputs": "${inputsAsObj.image}",
"parameters": {"prompt": "${inputsAsObj.prompt}"},
})

# You can access the image with PIL.Image for example
import io
from PIL import Image
image = Image.open(io.BytesIO(image_bytes))`,
},
];
};

const pythonSnippets: Partial<
Record<
PipelineType,
(
Expand All @@ -435,7 +503,7 @@ export const pythonSnippets: Partial<
"image-text-to-text": snippetConversational,
"fill-mask": snippetBasic,
"sentence-similarity": snippetBasic,
"automatic-speech-recognition": snippetFile,
"automatic-speech-recognition": snippetAutomaticSpeechRecognition,
"text-to-image": snippetTextToImage,
"text-to-video": snippetTextToVideo,
"text-to-speech": snippetTextToAudio,
Expand All @@ -449,6 +517,7 @@ export const pythonSnippets: Partial<
"image-segmentation": snippetFile,
"document-question-answering": snippetDocumentQuestionAnswering,
"image-to-text": snippetFile,
"image-to-image": snippetImageToImage,
"zero-shot-image-classification": snippetZeroShotImageClassification,
};

Expand All @@ -471,17 +540,24 @@ export function getPythonInferenceSnippet(
return snippets.map((snippet) => {
return {
...snippet,
content:
snippet.client === "requests"
? `\
import requests

API_URL = "${openAIbaseUrl(provider)}"
headers = {"Authorization": ${accessToken ? `"Bearer ${accessToken}"` : `f"Bearer {API_TOKEN}"`}}

${snippet.content}`
: snippet.content,
content: addImportsToSnippet(snippet.content, model, accessToken),
};
});
}
}

const addImportsToSnippet = (snippet: string, model: ModelDataMinimal, accessToken: string): string => {
if (snippet.includes("requests")) {
snippet = `import requests

API_URL = "https://router.huggingface.co/hf-inference/models/${model.id}"
headers = {"Authorization": ${accessToken ? `"Bearer ${accessToken}"` : `f"Bearer {API_TOKEN}"`}}

${snippet}`;
}
if (snippet.includes("base64")) {
snippet = `import base64
${snippet}`;
}
return snippet;
};
33 changes: 33 additions & 0 deletions packages/tasks-gen/scripts/generate-snippets-fixtures.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,17 @@ const TEST_CASES: {
providers: SnippetInferenceProvider[];
opts?: Record<string, unknown>;
}[] = [
{
testName: "automatic-speech-recognition",
model: {
id: "openai/whisper-large-v3-turbo",
pipeline_tag: "automatic-speech-recognition",
tags: [],
inference: "",
},
languages: ["py"],
providers: ["hf-inference"],
},
{
testName: "conversational-llm-non-stream",
model: {
Expand Down Expand Up @@ -79,6 +90,28 @@ const TEST_CASES: {
providers: ["hf-inference", "fireworks-ai"],
opts: { streaming: true },
},
{
testName: "document-question-answering",
model: {
id: "impira/layoutlm-invoices",
pipeline_tag: "document-question-answering",
tags: [],
inference: "",
},
languages: ["py"],
providers: ["hf-inference"],
},
{
testName: "image-to-image",
model: {
id: "stabilityai/stable-diffusion-xl-refiner-1.0",
pipeline_tag: "image-to-image",
tags: [],
inference: "",
},
languages: ["py"],
providers: ["hf-inference"],
},
{
testName: "text-to-image",
model: {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from huggingface_hub import InferenceClient

client = InferenceClient(
provider="hf-inference",
api_key="api_token",
)
output = client.automatic_speech_recognition("sample1.flac", model="openai/whisper-large-v3-turbo")
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import requests

API_URL = "https://router.huggingface.co/hf-inference/models/openai/whisper-large-v3-turbo"
headers = {"Authorization": "Bearer api_token"}

def query(filename):
with open(filename, "rb") as f:
data = f.read()
response = requests.post(API_URL, headers=headers, data=data)
return response.json()

output = query("sample1.flac")
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from huggingface_hub import InferenceClient

client = InferenceClient(
provider="hf-inference",
api_key="api_token"
provider="hf-inference",
api_key="api_token",
)

messages = [
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from huggingface_hub import InferenceClient

client = InferenceClient(
provider="together",
api_key="api_token"
provider="together",
api_key="api_token",
)

messages = [
Expand Down
Loading