diff --git a/packages/tasks/src/model-libraries-snippets.ts b/packages/tasks/src/model-libraries-snippets.ts index 9782c3bfea..39e973b511 100644 --- a/packages/tasks/src/model-libraries-snippets.ts +++ b/packages/tasks/src/model-libraries-snippets.ts @@ -1373,27 +1373,58 @@ export const transformers = (model: ModelData): string[] => { } const remote_code_snippet = model.tags.includes(TAG_CUSTOM_CODE) ? ", trust_remote_code=True" : ""; - let autoSnippet: string; + const autoSnippet = []; if (info.processor) { - const varName = + const processorVarName = info.processor === "AutoTokenizer" ? "tokenizer" : info.processor === "AutoFeatureExtractor" ? "extractor" : "processor"; - autoSnippet = [ + autoSnippet.push( "# Load model directly", `from transformers import ${info.processor}, ${info.auto_model}`, "", - `${varName} = ${info.processor}.from_pretrained("${model.id}"` + remote_code_snippet + ")", - `model = ${info.auto_model}.from_pretrained("${model.id}"` + remote_code_snippet + ")", - ].join("\n"); + `${processorVarName} = ${info.processor}.from_pretrained("${model.id}"` + remote_code_snippet + ")", + `model = ${info.auto_model}.from_pretrained("${model.id}"` + remote_code_snippet + ")" + ); + if (model.tags.includes("conversational")) { + if (model.tags.includes("image-text-to-text")) { + autoSnippet.push( + "messages = [", + [ + " {", + ' "role": "user",', + ' "content": [', + ' {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/p-blog/candy.JPG"},', + ' {"type": "text", "text": "What animal is on the candy?"}', + " ]", + " },", + ].join("\n"), + "]" + ); + } else { + autoSnippet.push("messages = [", ' {"role": "user", "content": "Who are you?"},', "]"); + } + autoSnippet.push( + "inputs = ${processorVarName}.apply_chat_template(", + " messages,", + " add_generation_prompt=True,", + " tokenize=True,", + " return_dict=True,", + ' return_tensors="pt",', + ").to(model.device)", + "", + "outputs = model.generate(**inputs, max_new_tokens=40)", + 'print(${processorVarName}.decode(outputs[0][inputs["input_ids"].shape[-1]:]))' + ); + } } else { - autoSnippet = [ + autoSnippet.push( "# Load model directly", `from transformers import ${info.auto_model}`, - `model = ${info.auto_model}.from_pretrained("${model.id}"` + remote_code_snippet + ")", - ].join("\n"); + `model = ${info.auto_model}.from_pretrained("${model.id}"` + remote_code_snippet + ', torch_dtype="auto"),' + ); } if (model.pipeline_tag && LIBRARY_TASK_MAPPING.transformers?.includes(model.pipeline_tag)) { @@ -1437,9 +1468,9 @@ export const transformers = (model: ModelData): string[] => { ); } - return [pipelineSnippet.join("\n"), autoSnippet]; + return [pipelineSnippet.join("\n"), autoSnippet.join("\n")]; } - return [autoSnippet]; + return [autoSnippet.join("\n")]; }; export const transformersJS = (model: ModelData): string[] => {