Skip to content

Conversation

2015aroras
Copy link
Contributor

@2015aroras 2015aroras commented Sep 15, 2025

Purpose

This PR adds the implementation for the FlexOlmo models. The HF implementation is being added concurrently, so the PR includes the config too.

Test Plan

The test plan is to see that basic generation (via examples/offline_inference/basic/generate.py) produces sensible output. I cannot run HF vs vLLM (in a shareable manner) because the HF implementation is being added concurrently. Nevertheless, I used a custom script to do HF vs vLLM and saw only minor errors (that would eventually propagate to be larger) with identical output.

Test Result

Result of running examples/offline_inference/basic/generate.py:

--------------------------------------------------
Prompt: 'Hello, my name is'
Generated text: ' Helen and your friend, I am Leonard, I have been interested in dancing throughout'
--------------------------------------------------
Prompt: 'The president of the United States is'
Generated text: ' commander in chief of the Army and Navy of the United States."\n\nIt is easy'
--------------------------------------------------
Prompt: 'The capital of France is'
Generated text: ' Paris, France**\n\nComplex equivalence relations give rise to new concepts. We have seen'
--------------------------------------------------
Prompt: 'The future of AI is'
Generated text: ' a long exploration through a chipped rabbit hole under the wing of a cocker'

Excerpt of diff between HF and vLLM activations (using a custom script).

Input: San Francisco is 1 of 5
vLLM output:  cities in the United States that have a higher percentage of people with college degrees than
HF output:  cities in the United States that have a higher percentage of people with college degrees than
vLLM and HF output are the same!
Coord diff abs mean for key model.embed_tokens|input (HF/vllm 0/0) 0.0
Coord diff abs mean for key model.embed_tokens|output (HF/vllm 1/1) 0.0
No vllm state for model.rotary_emb|input (HF 2), continuing
No vllm state for model.rotary_emb|output (HF 3), continuing
Coord diff abs mean for key model.layers.0.self_attn.q_proj|input (HF/vllm 4/2) 0.0
Coord diff abs mean for key model.layers.0.self_attn.q_proj|output (HF/vllm 5/3) 0.0
Coord diff abs mean for key model.layers.0.self_attn.q_norm|input (HF/vllm 6/8) 0.0
Coord diff abs mean for key model.layers.0.self_attn.q_norm|output (HF/vllm 7/9) 0.0001232391077792272
Coord diff abs mean for key model.layers.0.self_attn.k_proj|input (HF/vllm 8/4) 0.0
Coord diff abs mean for key model.layers.0.self_attn.k_proj|output (HF/vllm 9/5) 0.0
Coord diff abs mean for key model.layers.0.self_attn.k_norm|input (HF/vllm 10/10) 0.0
Coord diff abs mean for key model.layers.0.self_attn.k_norm|output (HF/vllm 11/11) 0.000123702033306472
Coord diff abs mean for key model.layers.0.self_attn.v_proj|input (HF/vllm 12/6) 0.0
Coord diff abs mean for key model.layers.0.self_attn.v_proj|output (HF/vllm 13/7) 0.0
Coord diff abs mean for key model.layers.0.self_attn.o_proj|input (HF/vllm 14/18) 0.0007298562559299171
Coord diff abs mean for key model.layers.0.self_attn.o_proj|output (HF/vllm 15/19) 0.0036155611742287874
Coord diff abs mean for key model.layers.0.self_attn|input (HF/vllm 16/20) 0.0
Coord diff abs mean for key model.layers.0.self_attn|output (HF/vllm 17/21) 0.0036155611742287874
Coord diff abs mean for key model.layers.0.post_attention_layernorm|input (HF/vllm 18/22) 0.0036155611742287874
Coord diff abs mean for key model.layers.0.post_attention_layernorm|output (HF/vllm 19/23) 1.977770807570778e-05
Coord diff abs mean for key model.layers.0.mlp.gate|input (HF/vllm 20/24) 2.1024250600021333e-05
Coord diff abs mean for key model.layers.0.mlp.gate|output (HF/vllm 21/25) 0.00052642822265625
No vllm state for model.layers.0.mlp.experts.0.gate_proj|input (HF 22), continuing
No vllm state for model.layers.0.mlp.experts.0.gate_proj|output (HF 23), continuing
No vllm state for model.layers.0.mlp.experts.0.act_fn|input (HF 24), continuing
No vllm state for model.layers.0.mlp.experts.0.act_fn|output (HF 25), continuing
No vllm state for model.layers.0.mlp.experts.0.up_proj|input (HF 26), continuing
No vllm state for model.layers.0.mlp.experts.0.up_proj|output (HF 27), continuing
No vllm state for model.layers.0.mlp.experts.0.down_proj|input (HF 28), continuing
No vllm state for model.layers.0.mlp.experts.0.down_proj|output (HF 29), continuing
No vllm state for model.layers.0.mlp.experts.0|input (HF 30), continuing
No vllm state for model.layers.0.mlp.experts.0|output (HF 31), continuing
No vllm state for model.layers.0.mlp.experts.1.gate_proj|input (HF 32), continuing
No vllm state for model.layers.0.mlp.experts.1.gate_proj|output (HF 33), continuing
No vllm state for model.layers.0.mlp.experts.1.act_fn|input (HF 34), continuing
No vllm state for model.layers.0.mlp.experts.1.act_fn|output (HF 35), continuing
No vllm state for model.layers.0.mlp.experts.1.up_proj|input (HF 36), continuing
No vllm state for model.layers.0.mlp.experts.1.up_proj|output (HF 37), continuing
No vllm state for model.layers.0.mlp.experts.1.down_proj|input (HF 38), continuing
No vllm state for model.layers.0.mlp.experts.1.down_proj|output (HF 39), continuing
No vllm state for model.layers.0.mlp.experts.1|input (HF 40), continuing
No vllm state for model.layers.0.mlp.experts.1|output (HF 41), continuing
Coord diff abs mean for key model.layers.0.mlp|input (HF/vllm 42/27) 2.1024250600021333e-05
Coord diff abs mean for key model.layers.0.mlp|output (HF/vllm 43/28) 0.06523162871599197
Coord diff abs mean for key model.layers.0.post_feedforward_layernorm|input (HF/vllm 44/29) 0.06523162871599197
Coord diff abs mean for key model.layers.0.post_feedforward_layernorm|output (HF/vllm 45/30) 0.0002040487597696483
Coord diff abs mean for key model.layers.0|input (HF/vllm 46/31) 0.0
Coord diff abs mean for key model.layers.0|output (HF/vllm 47/32) 0.0002227535587735474
Coord diff abs mean for key model.layers.1.self_attn.q_proj|input (HF/vllm 48/33) 0.0002227535587735474
Coord diff abs mean for key model.layers.1.self_attn.q_proj|output (HF/vllm 49/34) 0.0005754886660724878
...

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify mergify bot added documentation Improvements or additions to documentation new-model Requests to new models labels Sep 15, 2025
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

def load_weights(self, weights: Iterable[tuple[str,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can use AutoWeightsLoader, and move this function into FlexOlmoModel

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

d1adc72 Done. I had written the FlexOlmo implementation a while ago so I missed improvements like this one.


# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can impl expert_params_mapping like https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen2_moe.py#L557, thus this model can support BNB quantization directly

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

56b3cdf Done

@2015aroras 2015aroras marked this pull request as ready for review September 16, 2025 03:49
"Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"), # noqa: E501
"FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"),
"FalconH1ForCausalLM":_HfExamplesInfo("tiiuae/Falcon-H1-0.5B-Base"),
"FlexOlmoForCausalLM": _HfExamplesInfo("shanearora/Flex-reddit-2x7B-1T"),
Copy link
Collaborator

@jeejeelee jeejeelee Sep 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"FlexOlmoForCausalLM": _HfExamplesInfo("shanearora/Flex-reddit-2x7B-1T"),
"FlexOlmoForCausalLM": _HfExamplesInfo(allenai/Flex-reddit-2x7B-1T"),

Ignore this, I misunderstood

@hmellor
Copy link
Member

hmellor commented Sep 16, 2025

Is the Transformers impl compatible with the Transformers backend? If yes then the vLLM PR can be reduced to 3 lines, similar to #22665

@2015aroras
Copy link
Contributor Author

2015aroras commented Sep 16, 2025

Is the Transformers impl compatible with the Transformers backend? If yes then the vLLM PR can be reduced to 3 lines, similar to #22665

It probably is compatible (just put out the transformers PR), but the perf is quite poor compared to a native implementation like this. I see that there's ongoing work to address moe perf, so if you'd prefer the 3-line approach then I can try do that instead.

@hmellor
Copy link
Member

hmellor commented Sep 17, 2025

Ah I didn't see this was an MoE model. You're right that the performance of MoE's on the Transformers backend is not good today.

In my PoC PR #22650 I managed to get performance to be within 1% of the native vLLM implementation 🚀

@jeejeelee
Copy link
Collaborator

So let's move forward

@jeejeelee jeejeelee requested a review from Isotr0py September 18, 2025 10:57
Comment on lines 463 to 464
class FlexOlmoForCausalLM(nn.Module, SupportsPP):

Copy link
Member

@Isotr0py Isotr0py Sep 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't had bandwidth to review this PR in detail yet, will take a review ASAP (about tomorrow)

But with a glance, seems FlexOlmo's implementation is quite similar to exiting OlMoE?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main difference is that RMS norm is applied after attention/feedforward rather than before.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If that's the case, we can integrate OlMoE to avoid redundant code, and we can refer to motif's implementation approach.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not seeing how motif's implementation relates to our situation, but it sounds like you want FlexOlmo's implementation to be done within olmoe.py (which sounds doable enough).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

motif inherits from llama and only implements their decoder layer. Your model integrates olmoe, implements your decoder layer, and modifies some of olmoe's code.

Copy link
Contributor Author

@2015aroras 2015aroras Sep 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5b3c09e 2ba0a2d
Modified both olmoe and FlexOlmo to leverage inheritance. Both models appear to work fine. Not sure if this is what you had in mind.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's it! You got it right

@2015aroras
Copy link
Contributor Author

Gentle ping

| `FalconForCausalLM` | Falcon | `tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc. | | ✅︎ | ✅︎ |
| `FalconMambaForCausalLM` | FalconMamba | `tiiuae/falcon-mamba-7b`, `tiiuae/falcon-mamba-7b-instruct`, etc. | | ✅︎ | ✅︎ |
| `FalconH1ForCausalLM` | Falcon-H1 | `tiiuae/Falcon-H1-34B-Base`, `tiiuae/Falcon-H1-34B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `FlexOlmoForCausalLM` | FlexOlmo | TBA | | ✅︎ | ✅︎ |
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: Is this model still not announced?

Copy link
Member

@Isotr0py Isotr0py left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay. LGTM!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation new-model Requests to new models
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants