Skip to content
Merged
64 changes: 62 additions & 2 deletions packages/tasks/src/local-apps.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,29 @@ export type LocalApp = {
}
);

function isGgufModel(model: ModelData) {
return model.tags.includes("gguf");
function isGgufModel(model: ModelData): boolean {
return model.config?.quantization_config?.quant_method === "gguf";
}

function isAwqModel(model: ModelData): boolean {
return model.config?.quantization_config?.quant_method === "awq";
}

function isGptqModel(model: ModelData): boolean {
return model.config?.quantization_config?.quant_method === "gptq";
}

function isAqlmModel(model: ModelData): boolean {
return model.config?.quantization_config?.quant_method === "aqlm";
}

function isMarlinModel(model: ModelData): boolean {
return model.config?.quantization_config?.quant_method === "marlin";
}

function isFullModel(model: ModelData): boolean {
// Assuming a full model is identified by not having a quant_method
return !model.config?.quantization_config?.quant_method;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isFullModel creates a lot of false positive.

Instead we can maybe check against the supported architectures as suggested by @simon-mo

Something like this

const VLLM_SUPPORTED_ARCHS = [
    "AquilaForCausalLM", "ArcticForCausalLM", "BaiChuanForCausalLM", "BloomForCausalLM", ...
];
model.config?.architectures?.some((arch) => VLLM_SUPPORTED_ARCHS.includes(arch)

cc @julien-c

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can query the vLLM package for this list once you have it installed:

>>> from vllm import ModelRegistry
>>> ModelRegistry.get_supported_archs()
['AquilaModel', 'AquilaForCausalLM', 'BaiChuanForCausalLM', 'BaichuanForCausalLM', 'BloomForCausalLM', 'ChatGLMModel', 'ChatGLMForConditionalGeneration', 'CohereForCausalLM', 'DbrxForCausalLM', 'DeciLMForCausalLM', 'DeepseekForCausalLM', 'FalconForCausalLM', 'GemmaForCausalLM', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTJForCausalLM', 'GPTNeoXForCausalLM', 'InternLMForCausalLM', 'InternLM2ForCausalLM', 'JAISLMHeadModel', 'LlamaForCausalLM', 'LlavaForConditionalGeneration', 'LlavaNextForConditionalGeneration', 'LLaMAForCausalLM', 'MistralForCausalLM', 'MixtralForCausalLM', 'QuantMixtralForCausalLM', 'MptForCausalLM', 'MPTForCausalLM', 'MiniCPMForCausalLM', 'OlmoForCausalLM', 'OPTForCausalLM', 'OrionForCausalLM', 'PhiForCausalLM', 'Phi3ForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'Qwen2MoeForCausalLM', 'RWForCausalLM', 'StableLMEpochForCausalLM', 'StableLmForCausalLM', 'Starcoder2ForCausalLM', 'ArcticForCausalLM', 'XverseForCausalLM', 'Phi3SmallForCausalLM', 'MistralModel']

}

const snippetLlamacpp = (model: ModelData): string[] => {
Expand All @@ -63,6 +84,38 @@ LLAMA_CURL=1 make
];
};

const snippetVllm = (model: ModelData): string[] => {
return [
`
## Deploy with linux and docker (needs Docker installed) a gated model (please, request access in Hugginface's model repo):
docker run --runtime nvidia --gpus all \
--name my_vllm_container \
-v ~/.cache/huggingface:/root/.cache/huggingface \
--env "HUGGING_FACE_HUB_TOKEN=<secret>" \
-p 8000:8000 \
--ipc=host \
vllm/vllm-openai:latest \
--model ${model.id}
`,
`
## Load and run the model
docker exec -it my_vllm_container bash -c "python -m vllm.entrypoints.openai.api_server --model ${model.id} --dtype auto --api-key token-abc123"
`,
`
## Call the server using curl
curl -X POST "http://localhost:8000/v1/chat/completions" \
-H "Content-Type: application/json" \
-H "Authorization: Bearer token-abc123" \
--data '{
"model": "'${model.id}'",
"messages": [
{"role": "user", "content": "Hello!"}
]
}'
`,
];
};

/**
* Add your new local app here.
*
Expand All @@ -82,6 +135,13 @@ export const LOCAL_APPS = {
displayOnModelPage: isGgufModel,
snippet: snippetLlamacpp,
},
vllm: {
prettyLabel: "vLLM",
docsUrl: "https://docs.vllm.ai",
mainTask: "text-generation",
displayOnModelPage: (model: ModelData) => isAwqModel(model) || isGptqModel(model) || isAqlmModel(model) || isMarlinModel(model) || isFullModel(model),
snippet: snippetVllm,
},
lmstudio: {
prettyLabel: "LM Studio",
docsUrl: "https://lmstudio.ai",
Expand Down
4 changes: 4 additions & 0 deletions packages/tasks/src/model-data.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ export interface ModelData {
bits?: number;
load_in_4bit?: boolean;
load_in_8bit?: boolean;
/**
* awq, gptq, aqlm, marlin, … Used by vLLM
*/
quant_method?: string;
};
tokenizer_config?: TokenizerConfig;
adapter_transformers?: {
Expand Down