-
-
Notifications
You must be signed in to change notification settings - Fork 11k
[EPLB] Add EPLB support for hunyuan_v1 #23078
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
95d19de
9a1e959
0106052
5890066
666c0b3
ee058b6
b64b2c3
251c269
9c7dbac
97daa62
44a4d79
b69f3eb
ba251a0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,7 +23,8 @@ | |
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| """Inference-only HunYuan model compatible with HuggingFace weights.""" | ||
| from collections.abc import Iterable | ||
| import typing | ||
| from collections.abc import Callable, Iterable | ||
| from typing import Any, Optional, Union | ||
|
|
||
| import regex as re | ||
|
|
@@ -33,8 +34,8 @@ | |
|
|
||
| from vllm.attention import Attention, AttentionType | ||
| from vllm.compilation.decorators import support_torch_compile | ||
| from vllm.config import CacheConfig, VllmConfig | ||
| from vllm.distributed import (get_pp_group, | ||
| from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config | ||
| from vllm.distributed import (get_ep_group, get_pp_group, | ||
| get_tensor_model_parallel_world_size, | ||
| tensor_model_parallel_all_reduce) | ||
| from vllm.model_executor.layers.activation import SiluAndMul | ||
|
|
@@ -56,7 +57,7 @@ | |
| from vllm.model_executor.sampling_metadata import SamplingMetadata | ||
| from vllm.sequence import IntermediateTensors | ||
|
|
||
| from .interfaces import SupportsLoRA, SupportsPP | ||
| from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP | ||
| from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, | ||
| make_layers) | ||
|
|
||
|
|
@@ -355,10 +356,16 @@ def __init__( | |
| quant_config: Optional[QuantizationConfig] = None, | ||
| layer_id: int = -1, | ||
| prefix: str = "", | ||
| enable_eplb: bool = False, | ||
| ): | ||
| super().__init__() | ||
| self.tp_size = get_tensor_model_parallel_world_size() | ||
|
|
||
| self.ep_group = get_ep_group().device_group | ||
| self.ep_rank = self.ep_group.rank() | ||
| self.ep_size = self.ep_group.size() | ||
| self.n_routed_experts = config.num_experts | ||
|
|
||
| if self.tp_size > config.num_experts: | ||
| raise ValueError( | ||
| f"Tensor parallel size {self.tp_size} is greater than " | ||
|
|
@@ -379,15 +386,32 @@ def __init__( | |
| config.moe_intermediate_size, int) else | ||
| config.moe_intermediate_size[layer_id]) | ||
|
|
||
| # Load balancing settings. | ||
| vllm_config = get_current_vllm_config() | ||
| eplb_config = vllm_config.parallel_config.eplb_config | ||
| self.enable_eplb = enable_eplb | ||
|
|
||
| self.n_logical_experts = self.n_routed_experts | ||
| self.n_redundant_experts = eplb_config.num_redundant_experts | ||
| self.n_physical_experts = (self.n_logical_experts + | ||
| self.n_redundant_experts) | ||
| self.n_local_physical_experts = self.n_physical_experts // self.ep_size | ||
| self.physical_expert_start = (self.ep_rank * | ||
| self.n_local_physical_experts) | ||
| self.physical_expert_end = (self.physical_expert_start + | ||
| self.n_local_physical_experts) | ||
|
|
||
| self.experts = FusedMoE( | ||
| num_experts=config.num_experts, | ||
| num_experts=self.n_routed_experts, | ||
| top_k=top_k, | ||
| hidden_size=config.hidden_size, | ||
| intermediate_size=intermediate_size, | ||
| reduce_results=False, | ||
| renormalize=top_k > 1, | ||
| quant_config=quant_config, | ||
| prefix=f"{prefix}.experts", | ||
| enable_eplb=self.enable_eplb, | ||
| num_redundant_experts=self.n_redundant_experts, | ||
| ) | ||
|
|
||
| self.gate = ReplicatedLinear(config.hidden_size, | ||
|
|
@@ -446,6 +470,7 @@ def __init__( | |
| quant_config: Optional[QuantizationConfig] = None, | ||
| prefix: str = "", | ||
| layer_id: int = -1, | ||
| enable_eplb: bool = False, | ||
| ) -> None: | ||
| super().__init__() | ||
| assert layer_id >= 0 | ||
|
|
@@ -509,6 +534,7 @@ def __init__( | |
| quant_config=quant_config, | ||
| layer_id=layer_id, | ||
| prefix=f"{prefix}.mlp", | ||
| enable_eplb=enable_eplb, | ||
| ) | ||
| else: | ||
| self.mlp = HunYuanMLP( | ||
|
|
@@ -562,6 +588,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): | |
| cache_config = vllm_config.cache_config | ||
| quant_config = vllm_config.quant_config | ||
| lora_config = vllm_config.lora_config | ||
| eplb_config = vllm_config.parallel_config.eplb_config | ||
| enable_eplb = vllm_config.parallel_config.enable_eplb | ||
| self.num_redundant_experts = eplb_config.num_redundant_experts | ||
|
|
||
| self.config = config | ||
| self.quant_config = quant_config | ||
|
|
@@ -588,6 +617,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): | |
| cache_config=cache_config, | ||
| quant_config=quant_config, | ||
| prefix=prefix, | ||
| enable_eplb=enable_eplb, | ||
| ), | ||
| prefix=f"{prefix}.layers", | ||
| ) | ||
|
|
@@ -674,6 +704,7 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: | |
| ckpt_down_proj_name="down_proj", | ||
| ckpt_up_proj_name="up_proj", | ||
| num_experts=self.config.num_experts, | ||
| num_redundant_experts=self.num_redundant_experts, | ||
| ) | ||
| else: | ||
| return [] | ||
|
|
@@ -803,25 +834,43 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): | |
| # Skip loading extra bias for GPTQ models. | ||
| if name.endswith(".bias") and name not in params_dict: | ||
| continue | ||
| is_expert_weight = False | ||
| for mapping in expert_params_mapping: | ||
| param_name, weight_name, expert_id, shard_id = mapping | ||
| if weight_name not in name: | ||
| continue | ||
| name = name.replace(weight_name, param_name) | ||
| # Skip layers on other devices. | ||
| if is_pp_missing_parameter(name, self): | ||
| # this is an expert weight and should not be | ||
| # attempted to load as other weights later | ||
| is_expert_weight = True | ||
|
|
||
| # Do not modify `name` since the loop may continue here | ||
| # Instead, create a new variable | ||
| name_mapped = name.replace(weight_name, param_name) | ||
| if is_pp_missing_parameter(name_mapped, self): | ||
| continue | ||
| param = params_dict[name] | ||
| weight_loader = param.weight_loader | ||
| weight_loader( | ||
| param = params_dict[name_mapped] | ||
| # We should ask the weight loader to return success or not | ||
| # here since otherwise we may skip experts with other | ||
| # available replicas. | ||
| weight_loader = typing.cast(Callable[..., bool], | ||
| param.weight_loader) | ||
| success = weight_loader( | ||
| param, | ||
| loaded_weight, | ||
| name, | ||
| name_mapped, | ||
| shard_id=shard_id, | ||
| expert_id=expert_id, | ||
| return_success=True, | ||
| ) | ||
| break | ||
| if success: | ||
| name = name_mapped | ||
| break | ||
| else: | ||
| if is_expert_weight: | ||
| # We've checked that this is an expert weight | ||
| # However it's not mapped locally to this rank | ||
| # So we simply skip it | ||
| continue | ||
| # Remapping the name of FP8 kv-scale. | ||
| name = maybe_remap_kv_scale_name(name, params_dict) | ||
| if name is None: | ||
|
|
@@ -841,7 +890,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): | |
| return loaded_params | ||
|
|
||
|
|
||
| class HunYuanV1Base(nn.Module, SupportsLoRA, SupportsPP): | ||
| class HunYuanV1Base(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts): | ||
| packed_modules_mapping = { | ||
| "qkv_proj": [ | ||
| "q_proj", | ||
|
|
@@ -882,6 +931,64 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): | |
| else: | ||
| self.lm_head = PPMissingLayer() | ||
|
|
||
| # Set MoE hyperparameters | ||
| self.expert_weights = [] | ||
| self.num_expert_groups = 1 | ||
| self.moe_layers: list[FusedMoE] = [] | ||
| example_layer = None | ||
| for layer in self.model.layers: | ||
| if isinstance(layer, PPMissingLayer): | ||
| continue | ||
|
|
||
| assert isinstance(layer, HunYuanDecoderLayer) | ||
| if isinstance(layer.mlp, HunYuanSparseMoeBlock): | ||
| example_layer = layer.mlp | ||
| self.moe_layers.append(layer.mlp.experts) | ||
|
|
||
| if example_layer is None: | ||
| raise RuntimeError("No HunYuanMoE layer found in model.layers.") | ||
|
|
||
| self.num_moe_layers = len(self.moe_layers) | ||
| self.num_logical_experts = example_layer.n_logical_experts | ||
| self.num_physical_experts = example_layer.n_physical_experts | ||
| self.num_local_physical_experts = example_layer.n_local_physical_experts | ||
| self.num_routed_experts = example_layer.n_routed_experts | ||
| self.num_redundant_experts = example_layer.n_redundant_experts | ||
|
|
||
| def set_eplb_state( | ||
| self, | ||
| expert_load_view: torch.Tensor, | ||
| logical_to_physical_map: torch.Tensor, | ||
| logical_replica_count: torch.Tensor, | ||
| ) -> None: | ||
| for layer_idx, layer in enumerate(self.moe_layers): | ||
| self.expert_weights.append(layer.get_expert_weights()) | ||
| # Register the expert weights. | ||
| layer.set_eplb_state( | ||
| moe_layer_idx=layer_idx, | ||
| expert_load_view=expert_load_view, | ||
| logical_to_physical_map=logical_to_physical_map, | ||
| logical_replica_count=logical_replica_count, | ||
| ) | ||
|
|
||
| def update_physical_experts_metadata( | ||
| self, | ||
| num_physical_experts: int, | ||
| num_local_physical_experts: int, | ||
| ) -> None: | ||
| assert self.num_local_physical_experts == num_local_physical_experts | ||
|
||
| self.num_physical_experts = num_physical_experts | ||
| self.num_local_physical_experts = num_local_physical_experts | ||
| self.num_redundant_experts = (num_physical_experts - | ||
| self.num_logical_experts) | ||
| for layer in self.model.layers: | ||
| if isinstance(layer.mlp, HunYuanSparseMoeBlock): | ||
| moe = layer.mlp | ||
| moe.n_local_physical_experts = num_local_physical_experts | ||
| moe.n_physical_experts = num_physical_experts | ||
| moe.n_redundant_experts = self.num_redundant_experts | ||
| moe.experts.update_expert_map() | ||
|
|
||
| def forward( | ||
| self, | ||
| input_ids: torch.Tensor, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
self.expert_weightslist is appended to in this loop without being cleared first. Ifset_eplb_stateis called multiple times (e.g., during re-initialization or complex state updates), this will lead to an accumulation of expert weights. This is likely not the intended behavior and can cause issues during expert rebalancing. You should clear the list at the beginning of this method to ensure it only contains the weights from the current state, similar to how it's done in other MoE model implementations in this repository.