diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index a90a71159f72..079950d5c9b8 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1416,8 +1416,8 @@ def get_expert_weights(self) -> Iterable[torch.Tensor]: return [ weight.view(self.local_num_experts, -1) for name, weight in weights - if name not in NON_EXPERT_WEIGHTS - and not name.startswith("_shared_experts.") + if name not in NON_EXPERT_WEIGHTS and weight.shape != torch.Size( + []) and not name.startswith("_shared_experts.") ] def set_eplb_state( diff --git a/vllm/model_executor/models/hunyuan_v1.py b/vllm/model_executor/models/hunyuan_v1.py index a74a44bc2b51..af99d1213077 100644 --- a/vllm/model_executor/models/hunyuan_v1.py +++ b/vllm/model_executor/models/hunyuan_v1.py @@ -23,7 +23,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only HunYuan model compatible with HuggingFace weights.""" -from collections.abc import Iterable +import typing +from collections.abc import Callable, Iterable from typing import Any, Optional, Union import regex as re @@ -33,8 +34,8 @@ from vllm.attention import Attention, AttentionType from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, +from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config +from vllm.distributed import (get_ep_group, get_pp_group, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.model_executor.layers.activation import SiluAndMul @@ -56,7 +57,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA, SupportsPP +from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, make_layers) @@ -355,10 +356,16 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, layer_id: int = -1, prefix: str = "", + enable_eplb: bool = False, ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() + self.ep_group = get_ep_group().device_group + self.ep_rank = self.ep_group.rank() + self.ep_size = self.ep_group.size() + self.n_routed_experts = config.num_experts + if self.tp_size > config.num_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " @@ -379,8 +386,23 @@ def __init__( config.moe_intermediate_size, int) else config.moe_intermediate_size[layer_id]) + # Load balancing settings. + vllm_config = get_current_vllm_config() + eplb_config = vllm_config.parallel_config.eplb_config + self.enable_eplb = enable_eplb + + self.n_logical_experts = self.n_routed_experts + self.n_redundant_experts = eplb_config.num_redundant_experts + self.n_physical_experts = (self.n_logical_experts + + self.n_redundant_experts) + self.n_local_physical_experts = self.n_physical_experts // self.ep_size + self.physical_expert_start = (self.ep_rank * + self.n_local_physical_experts) + self.physical_expert_end = (self.physical_expert_start + + self.n_local_physical_experts) + self.experts = FusedMoE( - num_experts=config.num_experts, + num_experts=self.n_routed_experts, top_k=top_k, hidden_size=config.hidden_size, intermediate_size=intermediate_size, @@ -388,6 +410,8 @@ def __init__( renormalize=top_k > 1, quant_config=quant_config, prefix=f"{prefix}.experts", + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, ) self.gate = ReplicatedLinear(config.hidden_size, @@ -446,6 +470,7 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, prefix: str = "", layer_id: int = -1, + enable_eplb: bool = False, ) -> None: super().__init__() assert layer_id >= 0 @@ -509,6 +534,7 @@ def __init__( quant_config=quant_config, layer_id=layer_id, prefix=f"{prefix}.mlp", + enable_eplb=enable_eplb, ) else: self.mlp = HunYuanMLP( @@ -562,6 +588,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config + eplb_config = vllm_config.parallel_config.eplb_config + enable_eplb = vllm_config.parallel_config.enable_eplb + self.num_redundant_experts = eplb_config.num_redundant_experts self.config = config self.quant_config = quant_config @@ -588,6 +617,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): cache_config=cache_config, quant_config=quant_config, prefix=prefix, + enable_eplb=enable_eplb, ), prefix=f"{prefix}.layers", ) @@ -674,6 +704,7 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", num_experts=self.config.num_experts, + num_redundant_experts=self.num_redundant_experts, ) else: return [] @@ -803,25 +834,43 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + is_expert_weight = False for mapping in expert_params_mapping: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue - name = name.replace(weight_name, param_name) - # Skip layers on other devices. - if is_pp_missing_parameter(name, self): + # this is an expert weight and should not be + # attempted to load as other weights later + is_expert_weight = True + + # Do not modify `name` since the loop may continue here + # Instead, create a new variable + name_mapped = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name_mapped, self): continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader( + param = params_dict[name_mapped] + # We should ask the weight loader to return success or not + # here since otherwise we may skip experts with other + # available replicas. + weight_loader = typing.cast(Callable[..., bool], + param.weight_loader) + success = weight_loader( param, loaded_weight, - name, + name_mapped, shard_id=shard_id, expert_id=expert_id, + return_success=True, ) - break + if success: + name = name_mapped + break else: + if is_expert_weight: + # We've checked that this is an expert weight + # However it's not mapped locally to this rank + # So we simply skip it + continue # Remapping the name of FP8 kv-scale. name = maybe_remap_kv_scale_name(name, params_dict) if name is None: @@ -841,7 +890,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): return loaded_params -class HunYuanV1Base(nn.Module, SupportsLoRA, SupportsPP): +class HunYuanV1Base(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -882,6 +931,64 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): else: self.lm_head = PPMissingLayer() + # Set MoE hyperparameters + self.expert_weights = [] + self.num_expert_groups = 1 + self.moe_layers: list[FusedMoE] = [] + example_layer = None + for layer in self.model.layers: + if isinstance(layer, PPMissingLayer): + continue + + assert isinstance(layer, HunYuanDecoderLayer) + if isinstance(layer.mlp, HunYuanSparseMoeBlock): + example_layer = layer.mlp + self.moe_layers.append(layer.mlp.experts) + + if example_layer is None: + raise RuntimeError("No HunYuanMoE layer found in model.layers.") + + self.num_moe_layers = len(self.moe_layers) + self.num_logical_experts = example_layer.n_logical_experts + self.num_physical_experts = example_layer.n_physical_experts + self.num_local_physical_experts = example_layer.n_local_physical_experts + self.num_routed_experts = example_layer.n_routed_experts + self.num_redundant_experts = example_layer.n_redundant_experts + + def set_eplb_state( + self, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + ) -> None: + for layer_idx, layer in enumerate(self.moe_layers): + self.expert_weights.append(layer.get_expert_weights()) + # Register the expert weights. + layer.set_eplb_state( + moe_layer_idx=layer_idx, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) + + def update_physical_experts_metadata( + self, + num_physical_experts: int, + num_local_physical_experts: int, + ) -> None: + assert self.num_local_physical_experts == num_local_physical_experts + self.num_physical_experts = num_physical_experts + self.num_local_physical_experts = num_local_physical_experts + self.num_redundant_experts = (num_physical_experts - + self.num_logical_experts) + for layer in self.model.layers: + if isinstance(layer.mlp, HunYuanSparseMoeBlock): + moe = layer.mlp + moe.n_local_physical_experts = num_local_physical_experts + moe.n_physical_experts = num_physical_experts + moe.n_redundant_experts = self.num_redundant_experts + moe.experts.update_expert_map() + def forward( self, input_ids: torch.Tensor,