Skip to content

Commit 3eac141

Browse files
committed
Enable AITER MoE for ROCm
Signed-off-by: qli88 <[email protected]>
1 parent 67fc426 commit 3eac141

File tree

4 files changed

+74
-4
lines changed

4 files changed

+74
-4
lines changed

vllm/envs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@
9696
VLLM_DP_SIZE: int = 1
9797
VLLM_DP_MASTER_IP: str = ""
9898
VLLM_DP_MASTER_PORT: int = 0
99+
VLLM_USE_AITER_MOE: bool = False
99100

100101

101102
def get_default_cache_root():
@@ -630,6 +631,11 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
630631
# Whether to use S3 path for model loading in CI via RunAI Streamer
631632
"VLLM_CI_USE_S3":
632633
lambda: os.environ.get("VLLM_CI_USE_S3", "0") == "1",
634+
635+
# flag to control if vllm should use AITER MoE
636+
"VLLM_USE_AITER_MOE":
637+
lambda: (os.environ.get("VLLM_USE_AITER_MOE", "False").lower() in
638+
("true", "1")),
633639
}
634640

635641
# end-env-vars-definition

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from vllm.utils import direct_register_custom_op
1919

2020
logger = init_logger(__name__)
21+
is_hip = current_platform.is_rocm()
2122

2223

2324
@triton.jit
@@ -1164,8 +1165,52 @@ def fused_experts(hidden_states: torch.Tensor,
11641165
w2_zp: Optional[torch.Tensor] = None,
11651166
a1_scale: Optional[torch.Tensor] = None,
11661167
a2_scale: Optional[torch.Tensor] = None,
1167-
block_shape: Optional[List[int]] = None) -> torch.Tensor:
1168+
block_shape: Optional[List[int]] = None,
1169+
expert_mask: torch.Tensor = None
1170+
) -> torch.Tensor:
1171+
1172+
if is_hip and envs.VLLM_USE_AITER_MOE:
1173+
from aiter.fused_moe_bf16_asm import moe_sorting_ck
1174+
1175+
local_E = E = w1.shape[0]
1176+
if expert_mask is not None:
1177+
E = expert_mask.numel()
1178+
topk = topk_ids.shape[1]
1179+
model_dim = w1.shape[-1]
1180+
dtype = hidden_states.dtype
1181+
scale_blk_k = block_shape[1]
1182+
1183+
(
1184+
sorted_token_ids,
1185+
sorted_weight_buf,
1186+
sorted_expert_ids,
1187+
num_valid_ids,
1188+
out_asm,
1189+
) = moe_sorting_ck(
1190+
topk_ids, topk_weights, E, model_dim, dtype, expert_mask=expert_mask
1191+
)
1192+
1193+
a1, a1_scale = per_token_group_quant_fp8(hidden_states, scale_blk_k)
1194+
aiter.fmoe_fp8_blockscale_g1u1(
1195+
out_asm,
1196+
a1,
1197+
w1,
1198+
w2,
1199+
sorted_token_ids,
1200+
sorted_weight_buf,
1201+
sorted_expert_ids,
1202+
num_valid_ids,
1203+
topk,
1204+
w1_scale.view(local_E, -1),
1205+
w2_scale.view(local_E, -1),
1206+
a1_scale.t().contiguous(),
1207+
block_shape[0],
1208+
block_shape[1],
1209+
None,
1210+
)
1211+
return out_asm
11681212

1213+
# Not using AITER MoE
11691214
if inplace:
11701215
torch.ops.vllm.inplace_fused_experts(
11711216
hidden_states, w1, w2, topk_weights, topk_ids, activation,

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
else:
3030
fused_moe_pallas = None # type: ignore
3131
logger = init_logger(__name__)
32+
is_hip = current_platform.is_rocm()
3233

3334

3435
class FusedMoeWeightScaleSupported(Enum):
@@ -287,6 +288,8 @@ def __init__(
287288
scoring_func: str = "softmax",
288289
e_score_correction_bias: Optional[torch.Tensor] = None,
289290
activation: str = "silu",
291+
num_shared_experts: Optional[int] = 0,
292+
routed_scaling_factor: Optional[float] = 1.0,
290293
):
291294
super().__init__()
292295

@@ -364,6 +367,7 @@ def __init__(
364367
moe_quant_params["intermediate_size_full"] = intermediate_size
365368

366369
self.quant_method.create_weights(layer=self, **moe_quant_params)
370+
self.aiter_shuffled = False
367371

368372
def _load_per_tensor_weight_scale(self, shard_id: str,
369373
param: torch.nn.Parameter,
@@ -669,6 +673,14 @@ def forward(self, hidden_states: torch.Tensor,
669673
router_logits: torch.Tensor):
670674
assert self.quant_method is not None
671675

676+
if is_hip and envs.VLLM_USE_AITER_MOE:
677+
from aiter.ops.shuffle import shuffle_weight
678+
679+
if not self.aiter_shuffled:
680+
self.w13_weight.data = shuffle_weight(self.w13_weight, (16, 16))
681+
self.w2_weight.data = shuffle_weight(self.w2_weight, (16, 16))
682+
self.aiter_shuffled = True
683+
672684
# Matrix multiply.
673685
final_hidden_states = self.quant_method.apply(
674686
layer=self,
@@ -698,7 +710,9 @@ def forward(self, hidden_states: torch.Tensor,
698710
def make_expert_params_mapping(
699711
cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str,
700712
ckpt_up_proj_name: str,
701-
num_experts: int) -> List[Tuple[str, str, int, str]]:
713+
num_experts: int,
714+
num_shared_experts: Optional[int] = 0,
715+
) -> List[Tuple[str, str, int, str]]:
702716

703717
return [
704718
# (param_name, weight_name, expert_id, shard_id)

vllm/model_executor/models/deepseek_v2.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,10 @@ def __init__(
135135
topk_group=config.topk_group,
136136
prefix=f"{prefix}.experts",
137137
scoring_func=config.scoring_func,
138-
e_score_correction_bias=self.gate.e_score_correction_bias)
138+
e_score_correction_bias=self.gate.e_score_correction_bias,
139+
num_shared_experts=config.n_shared_experts,
140+
routed_scaling_factor=self.routed_scaling_factor,
141+
)
139142

140143
if config.n_shared_experts is not None:
141144
intermediate_size = (config.moe_intermediate_size *
@@ -715,7 +718,9 @@ def load_weights(self, weights: Iterable[Tuple[str,
715718
ckpt_gate_proj_name="gate_proj",
716719
ckpt_down_proj_name="down_proj",
717720
ckpt_up_proj_name="up_proj",
718-
num_experts=self.config.n_routed_experts)
721+
num_experts=self.config.n_routed_experts,
722+
num_shared_experts=self.config.n_shared_experts,
723+
)
719724

720725
params_dict = dict(self.named_parameters())
721726
loaded_params: Set[str] = set()

0 commit comments

Comments
 (0)