Skip to content

Conversation

666even666
Copy link
Contributor

@666even666 666even666 commented Aug 18, 2025

Purpose

This pull request add EPLB support for hunyuan_v1 model, which helps to improve the overall thoughput during LLM Serving.

#20468

Test Plan

import json
import os
import argparse
from vllm import LLM, SamplingParams

prompt = "Where is the capital of Russia? Please provide a brief explanation."

RESULT_FILE = "eplb_test_output.json"

sampling_params = SamplingParams(
    temperature=0.0,
    top_p=1.0,
    top_k=1,
    max_tokens=100
)

def run_inference(enable_eplb: bool, num_redundant_experts: int = 0):
    print(f"Running inference with EPLB={enable_eplb}, redundant experts={num_redundant_experts}")

   llm = LLM(
        model="tencent/Hunyuan-A13B-Instruct-FP8",
        tensor_parallel_size=2,
        enable_expert_parallel=True,
        enable_eplb=enable_eplb,
        num_redundant_experts=num_redundant_experts if enable_eplb else 0,
        eplb_window_size=1000,
        eplb_step_interval=100,
        eplb_log_balancedness=True,
        enforce_eager=True,
        trust_remote_code=True
    )

    result = llm.generate([prompt], sampling_params)
    output_text = result[0].outputs[0].text.strip()

    print("Output:")
    print(output_text)
    print("-" * 50)

    return output_text

def save_result(key: str, value: list):
    if os.path.exists(RESULT_FILE):
        with open(RESULT_FILE, "r") as f:
            results = json.load(f)
    else:
        results = {}

    results[key] = value

    with open(RESULT_FILE, "w") as f:
        json.dump(results, f, indent=2)

    print(f"Output saved to {RESULT_FILE}")

def load_results():
    if os.path.exists(RESULT_FILE):
        with open(RESULT_FILE, "r") as f:
            return json.load(f)
    return {}

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--mode", type=str, choices=["eplb", "normal", "compare"], required=True)
    args = parser.parse_args()

    if args.mode == "eplb":
        outputs = run_inference(enable_eplb=True, num_redundant_experts=32)
        save_result("eplb", outputs)
    elif args.mode == "normal":
        outputs = run_inference(enable_eplb=False)
        save_result("normal", outputs)

execute

python eplb_test.py --mode eplb

Test Result

after rearrange

(Optional) Documentation Update


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds EPLB support for the hunyuan_v1 model, which is a great enhancement for improving throughput in MoE models. The changes primarily involve adapting the model to the MixtureOfExperts interface and handling the complexities of expert weight loading with redundancy. The implementation is mostly solid, but I've identified a couple of issues in the new interface methods that could lead to incorrect behavior, especially in dynamic environments. One is a potential issue with state accumulation in set_eplb_state, and the other is a critical assertion in update_physical_experts_metadata that would prevent dynamic scaling of workers. Please see my detailed comments for suggestions on how to fix these.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This assertion prevents dynamic scaling of the number of expert parallel workers. The number of local physical experts can change during runtime, for example, when scaling the cluster up or down. This assertion will fail in such scenarios, preventing the model from adapting to the new configuration. This should be removed to allow for dynamic changes in the number of local experts.

Comment on lines +958 to +972
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The self.expert_weights list is appended to in this loop without being cleared first. If set_eplb_state is called multiple times (e.g., during re-initialization or complex state updates), this will lead to an accumulation of expert weights. This is likely not the intended behavior and can cause issues during expert rebalancing. You should clear the list at the beginning of this method to ensure it only contains the weights from the current state, similar to how it's done in other MoE model implementations in this repository.

Suggested change
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> None:
for layer_idx, layer in enumerate(self.moe_layers):
self.expert_weights.append(layer.get_expert_weights())
# Register the expert weights.
layer.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> None:
self.expert_weights.clear()
for layer_idx, layer in enumerate(self.moe_layers):
self.expert_weights.append(layer.get_expert_weights())
# Register the expert weights.
layer.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)

@666even666
Copy link
Contributor Author

@abmfy This PR adds EPLB support to hunyuan. I used tencent/Hunyuan-A13B-Instruct-FP8 for testing and had to skip w13_input_scale and w2_input_scale in get_expert_weights as they are scalars and cannot be reshaped to (self.local_num_experts, -1). Let me know if you have other suggestions :)

@abmfy
Copy link
Member

abmfy commented Aug 20, 2025

Thanks for the contribution!

I haven’t had time to run an accuracy test yet, but you shouldn’t skip the scale tensors -- otherwise, the weights and scales of the physical experts will become misaligned after a rearrangement.

If I recall correctly, the scale tensor looks like this:

w13_input_scale = torch.nn.Parameter(
    torch.ones(num_experts, dtype=torch.float32),
    requires_grad=False,
)

So it’s still a tensor with a dimension of num_local_experts.

Could you attach an accuracy test on GSM8K after fixing this issue?

@666even666
Copy link
Contributor Author

666even666 commented Aug 21, 2025

Thanks for the review @abmfy! I do see that w13_input_scale is set to torch.ones(num_experts, dtype=torch.float32) in fp8.py. However when I print out tensor shapes in layer.py it shows that w13_input_scale is a scalar
scalar and if I don't skip w13_input_scale, I always end up with an error complaining that scalar cannot be reshaped to (self.local_num_experts, -1).

I will try to debug this issue. Let me know if you have any insights in the meantime :)

Comment on lines 1347 to 1349
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @666even666, it looks like in fp8.py#L797-L800 the input_scales are converted into a single-element tensor when the activation scheme is static, which is the case for Hunyuan-FP8 models.

In this case, these scales don’t need to be transferred since they’re shared across the whole layer. However, we shouldn’t skip them unconditionally, as that might break models using the dynamic activation scheme. I think we should instead add a check to skip only one-element tensors (with shape []). What do you think?

Copy link
Contributor Author

@666even666 666even666 Aug 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for spotting it @abmfy ! I have added the scalar check as suggested. BTW is there a recommended way to run accuracy test on GSM8K (e.g. existing test script/process) ?

Copy link
Member

@abmfy abmfy Aug 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've been using this one and compare the numbers before/after enabling EPLB:

PORT=${PORT:-8080}
MODEL="MODEL_NAME_HERE"

OPENAI_API_KEY=EMPTY lm_eval \
    --model local-completions \
    --model_args "base_url=http://localhost:$PORT/v1/completions,model=$MODEL,num_concurrent=4096,trust_remote_code=True" \
    --tasks gsm8k \
    --output_path results/gsm8k.json

You'll need to uv pip install lm_eval[api] first

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much for the suggestion! I have attached test results below :)

@666even666
Copy link
Contributor Author

Hi @abmfy

Here is the accuracy test result on GSM8K

Without EPLB
hunyuaneplb

With EPLB
hunyuan

Copy link
Member

@abmfy abmfy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for the contribution 🎉

One last thing: have you confirmed that the balancedness score increases with log_balancedness: true set in EplbConfig? If so, I think we're good to go and can merge once CI passes.

cc @simon-mo Could you please approve to trigger CI? Thank you!

Copy link

mergify bot commented Sep 4, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @666even666.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 4, 2025
@mergify mergify bot removed the needs-rebase label Sep 4, 2025
@666even666
Copy link
Contributor Author

666even666 commented Sep 4, 2025

@abmfy Thanks for the review! From the screenshot I took, I do see the balancedness score increased but not in a monotonic way. Do you think this is expected? (I am using ep size of 2)
rearrange
test

@abmfy
Copy link
Member

abmfy commented Sep 5, 2025

@abmfy Thanks for the review! From the screenshot I took, I do see the balancedness score increased but not in a monotonic way. Do you think this is expected? (I am using ep size of 2) rearrange test

The figures here aren’t very informative since we only used a single prompt with EP=2. Ideally, we should run benchmarks with larger batches and a larger-scale EP setup; otherwise, the improvements brought by EPLB may not be very apparent.

That said, I believe the model adaptation is working well, so we’re good to proceed.

Thanks again for the contribution! 🎉

Signed-off-by: Yiwen Chen <[email protected]>
Copy link

mergify bot commented Sep 5, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @666even666.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 5, 2025
@666even666
Copy link
Contributor Author

Ahh I see. That makes a lot of sense. Thanks a lot @abmfy for your help and support!!

@mergify mergify bot removed the needs-rebase label Sep 5, 2025
@666even666
Copy link
Contributor Author

Looks like we still need another review to merge the pr. @abmfy Do you know someone who can help review the change :)

@abmfy
Copy link
Member

abmfy commented Sep 15, 2025

@simon-mo Could you please take a look and help merge this if it doesn’t affect other parts of the codebase? Thanks a lot!

@simon-mo simon-mo enabled auto-merge (squash) September 16, 2025 00:27
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 16, 2025
@666even666 666even666 requested a review from mgoin as a code owner September 16, 2025 02:16
@robertgshaw2-redhat robertgshaw2-redhat changed the title [Feature][EPLB] Add EPLB support for hunyuan_v1 [EPLB] Add EPLB support for hunyuan_v1 Sep 16, 2025
@simon-mo simon-mo merged commit 9d8a2d8 into vllm-project:main Sep 18, 2025
54 checks passed
845473182 pushed a commit to dsxsteven/vllm_splitPR that referenced this pull request Sep 18, 2025
…litPR into model_register

* 'model_register' of https://github.com/dsxsteven/vllm_splitPR: (138 commits)
  Retrieve `sliding_window` from text config in Gemma3 MM (vllm-project#25085)
  [Docs] Fix API Reference (vllm-project#25140)
  [Kernel] Better inf handling for grouped topk cu (vllm-project#24886)
  [CLI] Use streaming in CLI chat and completion commands (vllm-project#23769)
  [benchmark] add peak throughput metrics and plot (vllm-project#23867)
  [Spec Decode] Efficient padded speculation (vllm-project#24539)
  [V0 Deprecation] Remove more V0 tests (vllm-project#25117)
  [EPLB] Add EPLB support for hunyuan_v1 (vllm-project#23078)
  [XPU] Whisper model support on XPU Platform (vllm-project#25123)
  Mark prompt logprobs as incompatible with prompt embeds at API level (vllm-project#25077)
  [Model] enable data parallel for InternVL vision encoder (vllm-project#23909)
  [Kernels] Overlap shared experts with combine instead of dispatch (vllm-project#24254)
  [Bugfix][Qwen3-Next] add prefixes to shared_expert in qwen3-next and mlp in qwen2moe to successfully load ignored params in quantized models (vllm-project#24960)
  [Core][MM] Cleanup `MultiModalCache` (vllm-project#25006)
  [Docs] Clean up the contributing README (vllm-project#25099)
  [MM Encoder] Apply DP ViT for Qwen3-VL model series (vllm-project#24955)
  [Kernels] Enable DeepGEMM by default (vllm-project#24462)
  [V0 Deprecation] Skip PP test (vllm-project#25128)
  [V0 Deprecation] Remove misc V0 tests (vllm-project#25118)
  [V0 Deprecation] Remove V0 Tracing & Metrics tests (vllm-project#25115)
  ...
debroy-rh pushed a commit to debroy-rh/vllm that referenced this pull request Sep 19, 2025
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
charlifu pushed a commit to ROCm/vllm that referenced this pull request Sep 25, 2025
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
choprahetarth pushed a commit to Tandemn-Labs/vllm that referenced this pull request Oct 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

eplb ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants