Skip to content

Commit df44df0

Browse files
authored
[Feature] Shared Experts Overlap with FI deepgemm swap kernel, 2.2% throughput improvement and 3.6% TTFT improvement (#28879)
Signed-off-by: yewentao256 <[email protected]>
1 parent 87cbbdf commit df44df0

File tree

4 files changed

+119
-33
lines changed

4 files changed

+119
-33
lines changed

vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def make(
5050
prepare_finalize,
5151
old_quant_method.select_gemm_impl(prepare_finalize, moe_layer),
5252
shared_experts,
53+
getattr(moe_layer, "shared_experts_stream", None),
5354
),
5455
)
5556

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 42 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -850,6 +850,45 @@ def update_expert_map(self):
850850
dp_size=get_dp_group().world_size,
851851
)
852852

853+
def _maybe_setup_shared_experts_stream(
854+
self,
855+
hidden_states: torch.Tensor,
856+
has_separate_shared_experts: bool,
857+
use_chunked_impl: bool,
858+
) -> tuple[bool, torch.Tensor | None]:
859+
use_shared_experts_stream = (
860+
has_separate_shared_experts
861+
and not use_chunked_impl
862+
and self.shared_experts_stream is not None
863+
and (
864+
hidden_states.shape[0]
865+
<= envs.VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD
866+
)
867+
)
868+
869+
hidden_states_clone: torch.Tensor | None = None
870+
if use_shared_experts_stream:
871+
assert self.shared_experts_stream is not None
872+
873+
# Clone BEFORE switching streams to avoid race condition
874+
# where routed_expert kernel may mutate hidden_states.
875+
hidden_states_clone = hidden_states.clone()
876+
877+
# Record that the clone will be used by shared_experts_stream
878+
# to avoid gc issue from deallocation of hidden_states_clone
879+
# For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
880+
# NOTE: We dont need shared_output.record_stream(current_stream())
881+
# because we synch the streams before using shared_output.
882+
hidden_states_clone.record_stream(self.shared_experts_stream)
883+
884+
# Mark sync start point for the separate shared experts
885+
# stream here since we want to run in parallel with the
886+
# router/gate (next op below)
887+
assert self.shared_experts_stream is not None
888+
self.shared_experts_stream.wait_stream(current_stream())
889+
890+
return use_shared_experts_stream, hidden_states_clone
891+
853892
def _load_per_tensor_weight_scale(
854893
self,
855894
shard_id: str,
@@ -1819,36 +1858,12 @@ def forward_impl(
18191858

18201859
use_chunked_impl = self.use_dp_chunking
18211860

1822-
use_shared_experts_stream = (
1823-
has_separate_shared_experts
1824-
and not use_chunked_impl
1825-
and self.shared_experts_stream is not None
1826-
and (
1827-
hidden_states.shape[0]
1828-
<= envs.VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD
1861+
use_shared_experts_stream, hidden_states_clone = (
1862+
self._maybe_setup_shared_experts_stream(
1863+
hidden_states, has_separate_shared_experts, use_chunked_impl
18291864
)
18301865
)
18311866

1832-
if use_shared_experts_stream:
1833-
assert self.shared_experts_stream is not None
1834-
1835-
# Clone BEFORE switching streams to avoid race condition
1836-
# where routed_expert kernel may mutate hidden_states.
1837-
hidden_states_clone = hidden_states.clone()
1838-
1839-
# Record that the clone will be used by shared_experts_stream
1840-
# to avoid gc issue from deallocation of hidden_states_clone
1841-
# For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
1842-
# NOTE: We dont need shared_output.record_stream(current_stream())
1843-
# because we synch the streams before using shared_output.
1844-
hidden_states_clone.record_stream(self.shared_experts_stream)
1845-
1846-
# Mark sync start point for the separate shared experts
1847-
# stream here since we want to run in parallel with the
1848-
# router/gate (next op below)
1849-
assert self.shared_experts_stream is not None
1850-
self.shared_experts_stream.wait_stream(current_stream())
1851-
18521867
# If router/gate provided, then apply it here.
18531868
# (Note: This code runs only when "overlapped mode" is on to allow
18541869
# parallel execution of shared experts with the FusedMoE via

vllm/model_executor/layers/fused_moe/modular_kernel.py

Lines changed: 74 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
count_expert_num_tokens,
1717
disable_inplace,
1818
)
19+
from vllm.platforms import current_platform
1920
from vllm.utils.math_utils import cdiv
2021
from vllm.v1.worker.ubatching import (
2122
dbo_current_ubatch_id,
@@ -709,11 +710,13 @@ def __init__(
709710
prepare_finalize: FusedMoEPrepareAndFinalize,
710711
fused_experts: FusedMoEPermuteExpertsUnpermute,
711712
shared_experts: torch.nn.Module | None = None,
713+
shared_experts_stream: torch.cuda.Stream | None = None,
712714
):
713715
super().__init__()
714716
self.prepare_finalize = prepare_finalize
715717
self.fused_experts = fused_experts
716718
self.shared_experts = shared_experts
719+
self.shared_experts_stream = shared_experts_stream
717720

718721
self._post_init_setup()
719722
assert (
@@ -890,6 +893,34 @@ def _slice_expert_tokens_metadata(
890893
expert_num_tokens_cpu=c_expert_num_tokens_cpu,
891894
)
892895

896+
def _maybe_setup_shared_experts_stream(
897+
self, hidden_states: torch.Tensor
898+
) -> tuple[bool, torch.Tensor | None]:
899+
# decide whether to run shared experts on a separate CUDA stream to
900+
# overlap with the main fused MoE kernel.
901+
use_shared_experts_stream = (
902+
self.shared_experts is not None
903+
and self.shared_experts_stream is not None
904+
and hidden_states.is_cuda
905+
and (
906+
hidden_states.shape[0]
907+
<= envs.VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD
908+
)
909+
)
910+
911+
hidden_states_clone: torch.Tensor | None = None
912+
if use_shared_experts_stream and self.shared_experts_stream is not None:
913+
# TODO: Optimize this (complicated)
914+
# Note: this clone adds overhead but is required
915+
# for correctness with multiple CUDA streams and CUDA graph capture.
916+
hidden_states_clone = hidden_states.clone()
917+
# record that the clone will be used by the separate stream so its
918+
# lifetime is correctly tracked.
919+
hidden_states_clone.record_stream(self.shared_experts_stream)
920+
self.shared_experts_stream.wait_stream(torch.cuda.current_stream())
921+
922+
return use_shared_experts_stream, hidden_states_clone
923+
893924
def _prepare(
894925
self,
895926
hidden_states: torch.Tensor,
@@ -1077,12 +1108,30 @@ def _finalize(
10771108
topk_weights: torch.Tensor,
10781109
topk_ids: torch.Tensor,
10791110
apply_router_weight_on_input: bool,
1111+
hidden_states_clone: torch.Tensor | None = None,
1112+
use_shared_experts_stream: bool = False,
10801113
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
10811114
"""
10821115
The _finalize method is a wrapper around self.prepare_finalize.finalize
10831116
that handles DBO, async and shared expert overlap.
10841117
"""
1085-
shared_output: torch.Tensor | None = None
1118+
1119+
def maybe_run_shared_experts() -> torch.Tensor | None:
1120+
if self.shared_experts is None:
1121+
return None
1122+
1123+
if (
1124+
not use_shared_experts_stream
1125+
or self.shared_experts_stream is not None
1126+
and (not hidden_states.is_cuda or not torch.cuda.is_available())
1127+
):
1128+
# fall back to running on the current stream
1129+
return self.shared_experts(hidden_states)
1130+
1131+
assert hidden_states_clone is not None
1132+
# launch shared experts on the dedicated stream.
1133+
with torch.cuda.stream(self.shared_experts_stream):
1134+
return self.shared_experts(hidden_states_clone)
10861135

10871136
if not self.prepare_finalize.supports_async():
10881137
assert not dbo_enabled()
@@ -1095,8 +1144,7 @@ def _finalize(
10951144
apply_router_weight_on_input,
10961145
self.fused_experts.finalize_weight_and_reduce_impl(),
10971146
)
1098-
if self.shared_experts is not None:
1099-
shared_output = self.shared_experts(hidden_states)
1147+
shared_output = maybe_run_shared_experts()
11001148
else:
11011149
finalize_ret = self.prepare_finalize.finalize_async(
11021150
output,
@@ -1107,8 +1155,7 @@ def _finalize(
11071155
self.fused_experts.finalize_weight_and_reduce_impl(),
11081156
)
11091157

1110-
if self.shared_experts is not None:
1111-
shared_output = self.shared_experts(hidden_states)
1158+
shared_output = maybe_run_shared_experts()
11121159

11131160
# TODO(lucas): refactor this in the alternative schedules followup
11141161
# currently unpack if we have hook + receiver pair or just
@@ -1131,12 +1178,28 @@ def _finalize(
11311178

11321179
receiver()
11331180

1181+
self._wait_for_shared_experts_stream(hidden_states, use_shared_experts_stream)
1182+
11341183
if self.shared_experts is None:
11351184
return output
11361185
else:
11371186
assert shared_output is not None
11381187
return shared_output, output
11391188

1189+
def _wait_for_shared_experts_stream(
1190+
self, hidden_states: torch.Tensor, use_shared_experts_stream: bool
1191+
) -> None:
1192+
# ensure that any work enqueued on the shared_experts_stream is
1193+
# completed before the shared_output tensor is consumed
1194+
if (
1195+
self.shared_experts is not None
1196+
and use_shared_experts_stream
1197+
and self.shared_experts_stream is not None
1198+
and hidden_states.is_cuda
1199+
and current_platform.is_cuda()
1200+
):
1201+
torch.cuda.current_stream().wait_stream(self.shared_experts_stream)
1202+
11401203
def forward(
11411204
self,
11421205
hidden_states: torch.Tensor,
@@ -1183,6 +1246,10 @@ def forward(
11831246
else:
11841247
output = torch.zeros_like(hidden_states)
11851248

1249+
use_shared_experts_stream, hidden_states_clone = (
1250+
self._maybe_setup_shared_experts_stream(hidden_states)
1251+
)
1252+
11861253
local_num_experts = w1.size(0)
11871254
if global_num_experts == -1:
11881255
global_num_experts = local_num_experts
@@ -1219,4 +1286,6 @@ def forward(
12191286
topk_weights,
12201287
topk_ids,
12211288
apply_router_weight_on_input,
1289+
hidden_states_clone=hidden_states_clone,
1290+
use_shared_experts_stream=use_shared_experts_stream,
12221291
)

vllm/model_executor/layers/fused_moe/prepare_finalize.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ def prepare(
4545
assert topk == 1, (
4646
"apply_router_weight_on_input is only implemented for topk=1"
4747
)
48-
a1.mul_(topk_weights.to(a1.dtype))
48+
# Note: do not use inplace for shared experts overlap
49+
a1 = a1 * topk_weights.to(a1.dtype)
4950

5051
a1q, a1q_scale = moe_kernel_quantize_input(
5152
a1,

0 commit comments

Comments
 (0)