Skip to content

Commit 72a5101

Browse files
wenscarltlrmchlsmth
authored andcommitted
Support mnnvl all2allv from Flashinfer (#21003)
Signed-off-by: Shu Wang <[email protected]> Signed-off-by: Shu Wang. <[email protected]> Signed-off-by: Tyler Michael Smith <[email protected]> Signed-off-by: Tyler Michael Smith <[email protected]> Co-authored-by: Tyler Michael Smith <[email protected]> Co-authored-by: Tyler Michael Smith <[email protected]> Signed-off-by: yewentao256 <[email protected]>
1 parent 7d9f44a commit 72a5101

File tree

10 files changed

+410
-40
lines changed

10 files changed

+410
-40
lines changed

tests/kernels/moe/modular_kernel_tools/mk_objects.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,8 @@ def expert_info(kind) -> ExpertInfo:
222222
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
223223
FlashInferExperts)
224224
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
225-
FlashInferCutlassMoEPrepareAndFinalize)
225+
FlashInferCutlassMoEPrepareAndFinalize,
226+
create_flashinfer_prepare_finalize)
226227

227228
register_prepare_and_finalize(
228229
FlashInferCutlassMoEPrepareAndFinalize,
@@ -373,7 +374,7 @@ def make_prepare_finalize(
373374
assert prepare_finalize is not None
374375
return prepare_finalize
375376
elif prepare_finalize_type == FlashInferCutlassMoEPrepareAndFinalize:
376-
return FlashInferCutlassMoEPrepareAndFinalize(
377+
return create_flashinfer_prepare_finalize(
377378
use_dp=moe.moe_parallel_config.dp_size > 1)
378379
else:
379380
return MoEPrepareAndFinalizeNoEP()

vllm/distributed/device_communicators/all2all.py

Lines changed: 110 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,15 @@
1010
from vllm.forward_context import get_forward_context
1111
from vllm.logger import init_logger
1212
from vllm.utils import has_deep_ep, has_pplx
13+
from vllm.utils.flashinfer import has_flashinfer_all2all
1314

1415
from .base_device_communicator import All2AllManagerBase, Cache
1516

17+
if has_flashinfer_all2all():
18+
from flashinfer.comm import Mapping
19+
from flashinfer.comm.mnnvl import MnnvlConfig
20+
from flashinfer.comm.trtllm_alltoall import MnnvlMoe
21+
1622
logger = init_logger(__name__)
1723

1824

@@ -47,24 +53,22 @@ def naive_multicast(self, x: torch.Tensor,
4753

4854
def dispatch(self, hidden_states: torch.Tensor,
4955
router_logits: torch.Tensor):
50-
cu_tokens_across_dp_cpu = get_forward_context(
51-
).dp_metadata.cu_tokens_across_dp_cpu
56+
sizes = get_forward_context(
57+
).dp_metadata.get_chunk_sizes_across_dp_rank()
58+
hidden_states, router_logits = get_dp_group().all_gatherv(
59+
[hidden_states, router_logits],
60+
dim=0,
61+
sizes=sizes,
62+
)
5263

53-
hidden_states = self.naive_multicast(hidden_states,
54-
cu_tokens_across_dp_cpu)
55-
router_logits = self.naive_multicast(router_logits,
56-
cu_tokens_across_dp_cpu)
5764
return hidden_states, router_logits
5865

5966
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
60-
cu_tokens_across_dp_cpu = get_forward_context(
61-
).dp_metadata.cu_tokens_across_dp_cpu
62-
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
63-
self.dp_rank - 1]
64-
end = cu_tokens_across_dp_cpu[self.dp_rank]
65-
66-
all_hidden_states = self.dp_group.all_reduce(hidden_states)
67-
hidden_states = all_hidden_states[start:end, :]
67+
sizes = get_forward_context(
68+
).dp_metadata.get_chunk_sizes_across_dp_rank()
69+
hidden_states = get_dp_group().reduce_scatterv(hidden_states,
70+
dim=0,
71+
sizes=sizes)
6872
return hidden_states
6973

7074
def destroy(self):
@@ -300,4 +304,95 @@ def get_handle(self, kwargs):
300304

301305
# DeepEP LL uses RDMA so no SMs are used for communication
302306
def max_sms_used(self) -> Optional[int]:
303-
return 0
307+
return 0
308+
309+
310+
class FlashInferAllToAllManager(All2AllManagerBase):
311+
"""
312+
All2All communication based on flashinfer kernels.
313+
"""
314+
315+
def __init__(self, cpu_group):
316+
assert has_flashinfer_all2all(
317+
), "flashinfer all2all module not found. Please install/check flashinfer" # noqa
318+
super().__init__(cpu_group)
319+
logger.debug(
320+
"Initialize for flashinfer All2All "
321+
"rank=%d, world size=%d", self.rank, self.world_size)
322+
self.initialized = False
323+
self.alltoall_info = None
324+
325+
def initialize(
326+
self,
327+
world_size: int,
328+
rank: int,
329+
gpus_per_node: int,
330+
):
331+
"""Initialize workspace"""
332+
if self.initialized:
333+
return
334+
335+
self.cleanup()
336+
logger.debug("making map: "
337+
"rank=%d, world size=%d", rank, world_size)
338+
self.mapping = Mapping(
339+
world_size,
340+
rank,
341+
gpus_per_node,
342+
tp_size=world_size,
343+
)
344+
345+
from vllm.distributed.device_communicators.mnnvl_compat import (
346+
CustomCommunicator)
347+
dp_config = MnnvlConfig(
348+
comm_backend=CustomCommunicator(get_dp_group().cpu_group),
349+
fabric_page_size=1 << 29, # 512MB
350+
allocation_granularity=0 # Auto-detect
351+
)
352+
353+
self.workspace_tensor = MnnvlMoe.get_moe_workspaces(
354+
self.mapping, dp_config)
355+
self.prepare_workspace_tensor = MnnvlMoe.get_moe_prepare_workspace(
356+
self.mapping, dp_config)
357+
358+
self.world_size = world_size
359+
self.rank = rank
360+
self.gpus_per_node = gpus_per_node
361+
self.initialized = True
362+
363+
logger.info("FlashInfer All2All initialized for rank %s, size %s",
364+
rank, world_size)
365+
366+
def ensure_alltoall_workspace_initialized(self):
367+
"""Ensure workspace is initialized"""
368+
if not has_flashinfer_all2all():
369+
return False
370+
371+
if self.world_size <= 1:
372+
return False
373+
374+
if not self.initialized:
375+
self.initialize(
376+
world_size=self.world_size,
377+
rank=self.rank,
378+
gpus_per_node=torch.cuda.device_count,
379+
)
380+
return self.initialized
381+
382+
def get_handle(self, kwargs):
383+
return self
384+
385+
def cleanup(self):
386+
"""Clean up workspace"""
387+
if self.initialized and self.workspace_tensor is not None \
388+
and self.prepare_workspace_tensor is not None:
389+
try:
390+
del self.workspace_tensor
391+
del self.prepare_workspace_tensor
392+
except Exception as e:
393+
logger.warning("Failed to cleanup FlashInfer workspace: %s", e)
394+
finally:
395+
self.workspace_tensor = None
396+
self.prepare_workspace_tensor = None
397+
self.mapping = None
398+
self.initialized = False

vllm/distributed/device_communicators/cuda_communicator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,11 @@ def __init__(self,
114114
from .all2all import DeepEPLLAll2AllManager
115115
self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group)
116116
logger.info("Using DeepEP Low-Latency all2all manager.")
117+
elif all2all_backend == "flashinfer_all2allv":
118+
from .all2all import FlashInferAllToAllManager
119+
self.all2all_manager = FlashInferAllToAllManager(
120+
self.cpu_group)
121+
logger.info("Using Flashinfer all2allv manager.")
117122
else:
118123
raise ValueError(f"Unknown all2all backend: {all2all_backend}")
119124

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import torch.distributed as dist
4+
from flashinfer.comm.mnnvl import CommBackend as CommBackend
5+
6+
from vllm.utils.flashinfer import has_flashinfer_all2all
7+
8+
assert has_flashinfer_all2all(), "Flashinfer alltoallv module cannot be found"
9+
10+
11+
class CustomCommunicator(CommBackend):
12+
13+
def __init__(self, group):
14+
self._group = group
15+
16+
def Get_rank(self) -> int:
17+
return self._group.rank()
18+
19+
def Get_size(self) -> int:
20+
return self._group.size()
21+
22+
def allgather(self, data: int):
23+
gathered = [None] * self.Get_size()
24+
dist.all_gather_object(gathered, data, group=self._group)
25+
return gathered
26+
27+
def Split(self, color: int, key: int) -> 'CustomCommunicator':
28+
return self

vllm/envs.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,8 @@
156156
VLLM_ALL2ALL_BACKEND: Literal["naive", "pplx",
157157
"deepep_high_throughput",
158158
"deepep_low_latency",
159-
"allgather_reducescatter"] = \
159+
"allgather_reducescatter",
160+
"flashinfer_all2allv"] = \
160161
"allgather_reducescatter"
161162
VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840
162163
VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1
@@ -1209,12 +1210,14 @@ def get_vllm_port() -> Optional[int]:
12091210
# - "pplx": use pplx kernels
12101211
# - "deepep_high_throughput", use deepep high-throughput kernels
12111212
# - "deepep_low_latency", use deepep low-latency kernels
1213+
# - "flashinfer_all2allv", use flashinfer alltoallv kernels for mnnvl
12121214
"VLLM_ALL2ALL_BACKEND":
12131215
env_with_choices("VLLM_ALL2ALL_BACKEND", "allgather_reducescatter",
12141216
["naive", "pplx",
12151217
"deepep_high_throughput",
12161218
"deepep_low_latency",
1217-
"allgather_reducescatter"]),
1219+
"allgather_reducescatter",
1220+
"flashinfer_all2allv"]),
12181221

12191222
# Flashinfer MoE backend for vLLM's fused Mixture-of-Experts support.
12201223
# Both require compute capability 10.0 or above.

vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from vllm.logger import init_logger
99
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
1010
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
11-
FlashInferCutlassMoEPrepareAndFinalize)
11+
create_flashinfer_prepare_finalize)
1212
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
1313
TopKWeightAndReduceNoOP)
1414
from vllm.utils.flashinfer import (flashinfer_cutlass_fused_moe,
@@ -108,7 +108,7 @@ def workspace_shapes(
108108
of each tuple must be the number of tokens.
109109
"""
110110
aq_m, aq_n = aq.shape
111-
workspace2 = ()
111+
workspace2 = (0, )
112112
output_shape = (aq_m, aq_n * 2) if self.quant_dtype != \
113113
torch.float8_e4m3fn else (aq_m, aq_n)
114114
workspace_dtype = a.dtype
@@ -192,9 +192,8 @@ def flashinfer_cutlass_moe_fp4(
192192
expert_map: Optional[torch.Tensor] = None,
193193
apply_router_weight_on_input: bool = False,
194194
) -> torch.Tensor:
195-
196195
fused_experts = mk.FusedMoEModularKernel(
197-
FlashInferCutlassMoEPrepareAndFinalize(use_dp=False),
196+
create_flashinfer_prepare_finalize(use_dp=False),
198197
FlashInferExperts(
199198
out_dtype=hidden_states.dtype,
200199
quant_config=quant_config,

0 commit comments

Comments
 (0)