-
-
Notifications
You must be signed in to change notification settings - Fork 11.8k
Add Arcee (AFM) model support to vLLM #21264
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: alyosha-swamy <[email protected]>
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for the Arcee (AFM) model to vLLM. The implementation correctly identifies the key architectural difference—the ReLU² activation in the MLP—and reuses existing components like LlamaAttention effectively. The changes are well-structured and include necessary updates to documentation and model registration.
I've identified one area for improvement in the ArceeModel.load_weights method concerning inefficient imports and incomplete support for quantization scale loading. Addressing this will improve model loading performance and ensure correctness for features like AWQ quantization. Overall, this is a solid contribution.
| def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> set[str]: | ||
| """Load pre-trained weights from HuggingFace format into the model.""" | ||
| # Mapping for merging or renaming weight parameters from HF into our model | ||
| stacked_params_mapping = [ | ||
| # Each tuple: (combined_param_name, hf_subparam_name, index_or_key) | ||
| (".qkv_proj", ".q_proj", "q"), | ||
| (".qkv_proj", ".k_proj", "k"), | ||
| (".qkv_proj", ".v_proj", "v"), | ||
| # Note: No gate_proj since AFM has no gated MLP | ||
| ] | ||
| params_dict = dict(self.named_parameters()) | ||
| loaded_params: set[str] = set() | ||
| for name, loaded_weight in weights: | ||
| # Skip rotary cache parameters if present (not actual model weights) | ||
| if "rotary_emb.inv_freq" in name or "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: | ||
| continue | ||
| # Handle quantization KV cache scales if present | ||
| if hasattr(self, "quant_config") and self.quant_config is not None: | ||
| # If name corresponds to a quantization scale parameter, remap and load it | ||
| from vllm.model_executor.model_loader.weight_utils import default_weight_loader, maybe_remap_kv_scale_name | ||
| if "scale" in name: | ||
| maybe_name = maybe_remap_kv_scale_name(name, params_dict) | ||
| if maybe_name is None: | ||
| continue | ||
| name = maybe_name | ||
| # Pipeline parallel: skip parameters not on this rank | ||
| from vllm.model_executor.models.utils import is_pp_missing_parameter | ||
| from vllm.model_executor.model_loader.weight_utils import default_weight_loader | ||
| if is_pp_missing_parameter(name, self): | ||
| continue | ||
|
|
||
| # Attempt to map and load merged parameters | ||
| for param_name, weight_name, shard_id in stacked_params_mapping: | ||
| if weight_name not in name: | ||
| continue | ||
| mapped_name = name.replace(weight_name, param_name) | ||
| if mapped_name.endswith(".bias") and mapped_name not in params_dict: | ||
| # Skip any unexpected biases (e.g., from certain quantization or GPTQ checkpoints) | ||
| break | ||
| if mapped_name in params_dict: | ||
| param = params_dict[mapped_name] | ||
| weight_loader = getattr(param, "weight_loader", default_weight_loader) | ||
| weight_loader(param, loaded_weight, shard_id) # load the shard into the combined param | ||
| loaded_params.add(mapped_name) | ||
| else: | ||
| logging.warning(f"Unexpected parameter in checkpoint: {name}") | ||
| break | ||
| else: | ||
| # No special mapping, try direct load | ||
| if name in params_dict: | ||
| # For tied embeddings, skip loading lm_head if it will be tied | ||
| if name.startswith("lm_head.") and getattr(self.config, "tie_word_embeddings", False): | ||
| continue | ||
| param = params_dict[name] | ||
| weight_loader = getattr(param, "weight_loader", default_weight_loader) | ||
| weight_loader(param, loaded_weight) | ||
| loaded_params.add(name) | ||
| else: | ||
| # Silently skip any unmatched parameters (e.g., vision tower weights in multimodal models) | ||
| logging.debug(f"Ignoring unmatched checkpoint parameter: {name}") | ||
| return loaded_params |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The load_weights method in ArceeModel has a couple of issues:
- Inefficient Imports: Imports are performed inside the main loop, which is inefficient as they will be executed for every weight parameter. These should be moved to the top of the method.
- Incomplete Quantization Support: The logic for handling quantization scales is incomplete. It's missing the handling for AWQ KV cache scales, which is present in other models like Llama. This will cause issues when using AWQ-quantized versions of this model.
I've provided a refactored version of the method that addresses these points by moving imports out of the loop and adding the correct logic for handling quantization scales, aligning it with the implementation in LlamaModel.
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> set[str]:
"""Load pre-trained weights from HuggingFace format into the model."""
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.utils import is_pp_missing_parameter
# Mapping for merging or renaming weight parameters from HF into our model
stacked_params_mapping = [
# Each tuple: (combined_param_name, hf_subparam_name, index_or_key)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
# Note: No gate_proj since AFM has no gated MLP
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
# Skip rotary cache parameters if present (not actual model weights)
if "rotary_emb.inv_freq" in name or "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
continue
# Handle quantization KV cache scales
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)):
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
if "scale" in name:
# Remapping the name of FP8 kv-scale.
maybe_name = maybe_remap_kv_scale_name(name, params_dict)
if maybe_name is None:
continue
name = maybe_name
# Pipeline parallel: skip parameters not on this rank
if is_pp_missing_parameter(name, self):
continue
# Attempt to map and load merged parameters
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
mapped_name = name.replace(weight_name, param_name)
if mapped_name.endswith(".bias") and mapped_name not in params_dict:
# Skip any unexpected biases (e.g., from certain quantization or GPTQ checkpoints)
break
if mapped_name in params_dict:
param = params_dict[mapped_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight, shard_id) # load the shard into the combined param
loaded_params.add(mapped_name)
else:
logging.warning(f"Unexpected parameter in checkpoint: {name}")
break
else:
# No special mapping, try direct load
if name in params_dict:
# For tied embeddings, skip loading lm_head if it will be tied
if name.startswith("lm_head.") and getattr(self.config, "tie_word_embeddings", False):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
else:
# Silently skip any unmatched parameters (e.g., vision tower weights in multimodal models)
logging.debug(f"Ignoring unmatched checkpoint parameter: {name}")
return loaded_params
[New Model] Support Arcee (Arcee Foundational Models)
1. Purpose (Why this PR?)
Add inference support for Arcee Foundational Model (AFM) so that users can serve it with vLLM in both Python and API-server workflows. AFM uses a unique ReLU² activation in its MLP layers, differentiating it from standard Llama-based models.
2. Model details
3. Implementation overview
ArceeForCausalLMclass invllm/model_executor/models/arcee.pywith customArceeMLPusing ReLU² activation_TEXT_GENERATION_MODELSinvllm/model_executor/models/registry.pydocs/models/supported_models.mdwith Arcee entry in text generation tableLlamaAttentionfrom existing Llama implementation for attention layers4. Performance / sanity check
Expected: Coherent completion about life's meaning
Observed: " a question that has been asked throughout the history of mankind. The search for an answer to this question has inspired countless works of art, literature, and philosophy. Whether we consider the existentialist ideas of Albert Camus or the religious perspectives of spiritual leaders"
5. Test plan ✔️
pytest tests/models/test_arcee.pypython -c "from vllm import LLM; llm = LLM('arcee-ai/AFM-4.5B-Base')"vllm serve arcee-ai/AFM-4.5B-Base --trust-remote-codecurl localhost:8000/v1/completions6. Documentation
docs/models/supported_models.mdunder Text Generation modelsArceeForCausalLMwith example modelarcee-ai/AFM-4.5B-BaseChecklist
pre-commit run --all-files(ruff formatting)pytest -q)Notes for reviewers
The key architectural difference from standard Llama models is the MLP activation function. Arcee uses ReLU² (squared ReLU) instead of SiLU:
ArceeMLPimplements:x = torch.pow(torch.relu(x), 2)gate_proj), onlyup_projanddown_projThe model has been tested with an internal HF repo during development, but the official model is
arcee-ai/AFM-4.5B-Base.Test result
All outputs are coherent and contextually appropriate.