Skip to content

Conversation

@alyosha-swamy
Copy link
Contributor

@alyosha-swamy alyosha-swamy commented Jul 20, 2025

[New Model] Support Arcee (Arcee Foundational Models)

1. Purpose (Why this PR?)

Add inference support for Arcee Foundational Model (AFM) so that users can serve it with vLLM in both Python and API-server workflows. AFM uses a unique ReLU² activation in its MLP layers, differentiating it from standard Llama-based models.

2. Model details

Field Value / Reference
Source repo / HF id huggingface.co/arcee-ai/AFM-4.5B-Base
Architecture Llama-style decoder-only transformer with ReLU² MLP activation
Context length 64k tokens
Hidden size / #layers 4096 / 32
License CC BY-NC 4.0
Special quirks Uses ReLU² (squared ReLU) activation instead of SiLU in MLP layers

3. Implementation overview

  • Added ArceeForCausalLM class in vllm/model_executor/models/arcee.py with custom ArceeMLP using ReLU² activation
  • Registered model in _TEXT_GENERATION_MODELS in vllm/model_executor/models/registry.py
  • Updated docs/models/supported_models.md with Arcee entry in text generation table
  • Reused LlamaAttention from existing Llama implementation for attention layers
  • Implemented proper LoRA and Pipeline Parallelism support

4. Performance / sanity check

$ python -m vllm.entrypoints.openai.api_server --model arcee-ai/AFM-4.5B-Base --trust-remote-code
$ curl http://localhost:8000/v1/completions -H "Content-Type: application/json" -d '{
    "model": "arcee-ai/AFM-4.5B-Base",
    "prompt": "The future of artificial intelligence is",
    "max_tokens": 50
}'

Expected: Coherent completion about AI

Observed: "The future of artificial intelligence is bright and full of possibilities. As AI continues to evolve, we can expect to see significant advancements in areas such as natural language processing, computer vision, and machine learning..."

5. Test plan ✔️

Test Command Expected
Unit pytest tests/models/test_arcee.py All tests pass
Model Loading python -c "from vllm import LLM; llm = LLM('arcee-ai/AFM-4.5B-Base')" Model loads without errors
Integration vllm serve arcee-ai/AFM-4.5B-Base --trust-remote-code Server starts, responds to requests
Generation curl localhost:8000/v1/completions 200 OK + valid completions

6. Documentation

  • Added row to docs/models/supported_models.md under Text Generation models
  • Model listed as ArceeForCausalLM with example model arcee-ai/AFM-4.5B-Base
  • Marked as supporting LoRA (✅), Pipeline Parallel (✅), and V1 (✅)

Checklist

  • I ran pre-commit run --all-files (ruff formatting)
  • All CI tests pass locally (pytest -q)
  • The PR description follows vLLM's "Essential Elements" template
  • No breaking changes for existing model classes

Notes for reviewers

The key architectural difference from standard Llama models is the MLP activation function. Arcee uses ReLU² (squared ReLU) instead of SiLU:

  • ArceeMLP implements: x = torch.pow(torch.relu(x), 2)
  • No gating mechanism (no gate_proj), only up_proj and down_proj
  • All other components (attention, layer norm, etc.) reuse existing Llama implementations

The model has been tested with an internal HF repo during development, but the official model is arcee-ai/AFM-4.5B-Base.

Test result

seq Prompt vLLM Output
0 "Hello, world!" "Hello, world! Welcome to the exciting realm of programming..."
1 "The meaning of life is" "The meaning of life is a profound question that has puzzled philosophers..."
2 "In 2025, technology will" "In 2025, technology will continue to reshape our daily lives with advances in AI..."

All outputs are coherent and contextually appropriate.

@alyosha-swamy alyosha-swamy requested a review from hmellor as a code owner July 20, 2025 20:47
@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added documentation Improvements or additions to documentation new-model Requests to new models labels Jul 20, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The code introduces a new model, ArceeForCausalLM, with a unique ReLU² activation. The implementation is well-structured, but some import statements are misplaced and a minor performance improvement can be made.

if hidden_act != "relu2":
raise ValueError(f"Unsupported activation: {hidden_act}. Only 'relu2' is supported for AFM.")
# Define ReLU^2 activation: (ReLU(x))^2 elementwise
self.act_fn = lambda x: torch.pow(torch.relu(x), 2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Consider using torch.relu(x).square() instead of torch.pow(torch.relu(x), 2) for potentially better performance and readability. torch.relu(x).square() is an alias for x**2 and can be more optimized.

self.act_fn = lambda x: torch.relu(x).square()

attention_bias = config.qkv_bias

# Self-Attention (using LLaMA’s attention structure)
from vllm.model_executor.models.llama import LlamaAttention # import here to avoid circular import
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Move the import statement from vllm.model_executor.models.llama import LlamaAttention to the top of the file to adhere to PEP 8 guidelines. This improves code readability and makes dependencies explicit at the beginning of the file.

self.aux_hidden_state_layers: Tuple[int, ...] = tuple()

# Prepare factory for empty intermediate tensors (for pipeline scheduling)
from vllm.model_executor.models.utils import make_empty_intermediate_tensors_factory
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Move the import statement from vllm.model_executor.models.utils import make_empty_intermediate_tensors_factory to the top of the file to adhere to PEP 8 guidelines. This improves code readability and makes dependencies explicit at the beginning of the file.

# Handle quantization KV cache scales if present
if hasattr(self, "quant_config") and self.quant_config is not None:
# If name corresponds to a quantization scale parameter, remap and load it
from vllm.model_executor.model_loader.weight_utils import default_weight_loader, maybe_remap_kv_scale_name
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Move the import statements from vllm.model_executor.model_loader.weight_utils import default_weight_loader, maybe_remap_kv_scale_name and from vllm.model_executor.models.utils import is_pp_missing_parameter to the top of the file to adhere to PEP 8 guidelines. This improves code readability and makes dependencies explicit at the beginning of the file.

self.unpadded_vocab_size += lora_config.lora_extra_vocab_size

# Import DEFAULT_VOCAB_PADDING_SIZE
from vllm.model_executor.layers.vocab_parallel_embedding import DEFAULT_VOCAB_PADDING_SIZE
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Move the import statement from vllm.model_executor.layers.vocab_parallel_embedding import DEFAULT_VOCAB_PADDING_SIZE to the top of the file to adhere to PEP 8 guidelines. This improves code readability and makes dependencies explicit at the beginning of the file.

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> set[str]:
"""Load weights into the model (delegates to inner model and handles tied embeddings)."""
# Use AutoWeightsLoader for consistency with vLLM's loading mechanism
from vllm.model_executor.models.utils import AutoWeightsLoader
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Move the import statement from vllm.model_executor.models.utils import AutoWeightsLoader to the top of the file to adhere to PEP 8 guidelines. This improves code readability and makes dependencies explicit at the beginning of the file.

@hmellor hmellor mentioned this pull request Jul 21, 2025
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation new-model Requests to new models

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant