Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
VLLM_DP_SIZE: int = 1
VLLM_DP_MASTER_IP: str = ""
VLLM_DP_MASTER_PORT: int = 0
VLLM_USE_AITER_MOE: bool = False


def get_default_cache_root():
Expand Down Expand Up @@ -630,6 +631,11 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
# Whether to use S3 path for model loading in CI via RunAI Streamer
"VLLM_CI_USE_S3":
lambda: os.environ.get("VLLM_CI_USE_S3", "0") == "1",

# flag to control if vllm should use AITER MoE
"VLLM_USE_AITER_MOE":
lambda: (os.environ.get("VLLM_USE_AITER_MOE", "False").lower() in
("true", "1")),
}

# end-env-vars-definition
Expand Down
47 changes: 46 additions & 1 deletion vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from vllm.utils import direct_register_custom_op

logger = init_logger(__name__)
is_hip = current_platform.is_rocm()


@triton.jit
Expand Down Expand Up @@ -1164,8 +1165,52 @@
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> torch.Tensor:
block_shape: Optional[List[int]] = None,
expert_mask: torch.Tensor = None
) -> torch.Tensor:

if is_hip and envs.VLLM_USE_AITER_MOE:
from aiter.fused_moe_bf16_asm import moe_sorting_ck

local_E = E = w1.shape[0]
if expert_mask is not None:
E = expert_mask.numel()
topk = topk_ids.shape[1]
model_dim = w1.shape[-1]
dtype = hidden_states.dtype

Check failure on line 1180 in vllm/model_executor/layers/fused_moe/fused_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Value of type "Optional[list[int]]" is not indexable [index]

Check failure on line 1180 in vllm/model_executor/layers/fused_moe/fused_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Value of type "Optional[list[int]]" is not indexable [index]
scale_blk_k = block_shape[1]

(
sorted_token_ids,
sorted_weight_buf,
sorted_expert_ids,
num_valid_ids,
out_asm,
) = moe_sorting_ck(
topk_ids, topk_weights, E, model_dim, dtype, expert_mask=expert_mask
)

a1, a1_scale = per_token_group_quant_fp8(hidden_states, scale_blk_k)
aiter.fmoe_fp8_blockscale_g1u1(
out_asm,
a1,

Check failure on line 1196 in vllm/model_executor/layers/fused_moe/fused_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Name "aiter" is not defined [name-defined]

Check failure on line 1196 in vllm/model_executor/layers/fused_moe/fused_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Name "aiter" is not defined [name-defined]

Check failure on line 1196 in vllm/model_executor/layers/fused_moe/fused_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

vllm/model_executor/layers/fused_moe/fused_moe.py:1196:9: F821 Undefined name `aiter`. Consider specifying `requires-python = ">= 3.10"` or `tool.ruff.target-version = "py310"` in your `pyproject.toml` file.
w1,
w2,
sorted_token_ids,
sorted_weight_buf,
sorted_expert_ids,
num_valid_ids,
topk,
w1_scale.view(local_E, -1),
w2_scale.view(local_E, -1),
a1_scale.t().contiguous(),

Check failure on line 1206 in vllm/model_executor/layers/fused_moe/fused_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Item "None" of "Optional[Any]" has no attribute "view" [union-attr]

Check failure on line 1206 in vllm/model_executor/layers/fused_moe/fused_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Item "None" of "Optional[Any]" has no attribute "view" [union-attr]
block_shape[0],

Check failure on line 1207 in vllm/model_executor/layers/fused_moe/fused_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Item "None" of "Optional[Any]" has no attribute "view" [union-attr]
block_shape[1],
None,

Check failure on line 1209 in vllm/model_executor/layers/fused_moe/fused_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Value of type "Optional[list[int]]" is not indexable [index]
)

Check failure on line 1210 in vllm/model_executor/layers/fused_moe/fused_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Value of type "Optional[list[int]]" is not indexable [index]
return out_asm

# Not using AITER MoE
if inplace:
torch.ops.vllm.inplace_fused_experts(
hidden_states, w1, w2, topk_weights, topk_ids, activation,
Expand Down
16 changes: 15 additions & 1 deletion vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
else:
fused_moe_pallas = None # type: ignore
logger = init_logger(__name__)
is_hip = current_platform.is_rocm()


class FusedMoeWeightScaleSupported(Enum):
Expand Down Expand Up @@ -287,6 +288,8 @@ def __init__(
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
num_shared_experts: Optional[int] = 0,
routed_scaling_factor: Optional[float] = 1.0,
):
super().__init__()

Expand Down Expand Up @@ -364,6 +367,7 @@ def __init__(
moe_quant_params["intermediate_size_full"] = intermediate_size

self.quant_method.create_weights(layer=self, **moe_quant_params)
self.aiter_shuffled = False

def _load_per_tensor_weight_scale(self, shard_id: str,
param: torch.nn.Parameter,
Expand Down Expand Up @@ -669,6 +673,14 @@ def forward(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
assert self.quant_method is not None

if is_hip and envs.VLLM_USE_AITER_MOE:
from aiter.ops.shuffle import shuffle_weight

if not self.aiter_shuffled:
self.w13_weight.data = shuffle_weight(self.w13_weight, (16, 16))
self.w2_weight.data = shuffle_weight(self.w2_weight, (16, 16))
self.aiter_shuffled = True

# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
Expand Down Expand Up @@ -698,7 +710,9 @@ def forward(self, hidden_states: torch.Tensor,
def make_expert_params_mapping(
cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str,
ckpt_up_proj_name: str,
num_experts: int) -> List[Tuple[str, str, int, str]]:
num_experts: int,
num_shared_experts: Optional[int] = 0,
) -> List[Tuple[str, str, int, str]]:

return [
# (param_name, weight_name, expert_id, shard_id)
Expand Down
9 changes: 7 additions & 2 deletions vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,10 @@ def __init__(
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func=config.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias)
e_score_correction_bias=self.gate.e_score_correction_bias,
num_shared_experts=config.n_shared_experts,
routed_scaling_factor=self.routed_scaling_factor,
)

if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size *
Expand Down Expand Up @@ -715,7 +718,9 @@ def load_weights(self, weights: Iterable[Tuple[str,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts)
num_experts=self.config.n_routed_experts,
num_shared_experts=self.config.n_shared_experts,
)

params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
Expand Down