1616 count_expert_num_tokens ,
1717 disable_inplace ,
1818)
19+ from vllm .platforms import current_platform
1920from vllm .utils .math_utils import cdiv
2021from vllm .v1 .worker .ubatching import (
2122 dbo_current_ubatch_id ,
@@ -709,11 +710,13 @@ def __init__(
709710 prepare_finalize : FusedMoEPrepareAndFinalize ,
710711 fused_experts : FusedMoEPermuteExpertsUnpermute ,
711712 shared_experts : torch .nn .Module | None = None ,
713+ shared_experts_stream : torch .cuda .Stream | None = None ,
712714 ):
713715 super ().__init__ ()
714716 self .prepare_finalize = prepare_finalize
715717 self .fused_experts = fused_experts
716718 self .shared_experts = shared_experts
719+ self .shared_experts_stream = shared_experts_stream
717720
718721 self ._post_init_setup ()
719722 assert (
@@ -890,6 +893,34 @@ def _slice_expert_tokens_metadata(
890893 expert_num_tokens_cpu = c_expert_num_tokens_cpu ,
891894 )
892895
896+ def _maybe_setup_shared_experts_stream (
897+ self , hidden_states : torch .Tensor
898+ ) -> tuple [bool , torch .Tensor | None ]:
899+ # decide whether to run shared experts on a separate CUDA stream to
900+ # overlap with the main fused MoE kernel.
901+ use_shared_experts_stream = (
902+ self .shared_experts is not None
903+ and self .shared_experts_stream is not None
904+ and hidden_states .is_cuda
905+ and (
906+ hidden_states .shape [0 ]
907+ <= envs .VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD
908+ )
909+ )
910+
911+ hidden_states_clone : torch .Tensor | None = None
912+ if use_shared_experts_stream and self .shared_experts_stream is not None :
913+ # TODO: Optimize this (complicated)
914+ # Note: this clone adds overhead but is required
915+ # for correctness with multiple CUDA streams and CUDA graph capture.
916+ hidden_states_clone = hidden_states .clone ()
917+ # record that the clone will be used by the separate stream so its
918+ # lifetime is correctly tracked.
919+ hidden_states_clone .record_stream (self .shared_experts_stream )
920+ self .shared_experts_stream .wait_stream (torch .cuda .current_stream ())
921+
922+ return use_shared_experts_stream , hidden_states_clone
923+
893924 def _prepare (
894925 self ,
895926 hidden_states : torch .Tensor ,
@@ -1077,12 +1108,30 @@ def _finalize(
10771108 topk_weights : torch .Tensor ,
10781109 topk_ids : torch .Tensor ,
10791110 apply_router_weight_on_input : bool ,
1111+ hidden_states_clone : torch .Tensor | None = None ,
1112+ use_shared_experts_stream : bool = False ,
10801113 ) -> torch .Tensor | tuple [torch .Tensor , torch .Tensor ]:
10811114 """
10821115 The _finalize method is a wrapper around self.prepare_finalize.finalize
10831116 that handles DBO, async and shared expert overlap.
10841117 """
1085- shared_output : torch .Tensor | None = None
1118+
1119+ def maybe_run_shared_experts () -> torch .Tensor | None :
1120+ if self .shared_experts is None :
1121+ return None
1122+
1123+ if (
1124+ not use_shared_experts_stream
1125+ or self .shared_experts_stream is not None
1126+ and (not hidden_states .is_cuda or not torch .cuda .is_available ())
1127+ ):
1128+ # fall back to running on the current stream
1129+ return self .shared_experts (hidden_states )
1130+
1131+ assert hidden_states_clone is not None
1132+ # launch shared experts on the dedicated stream.
1133+ with torch .cuda .stream (self .shared_experts_stream ):
1134+ return self .shared_experts (hidden_states_clone )
10861135
10871136 if not self .prepare_finalize .supports_async ():
10881137 assert not dbo_enabled ()
@@ -1095,8 +1144,7 @@ def _finalize(
10951144 apply_router_weight_on_input ,
10961145 self .fused_experts .finalize_weight_and_reduce_impl (),
10971146 )
1098- if self .shared_experts is not None :
1099- shared_output = self .shared_experts (hidden_states )
1147+ shared_output = maybe_run_shared_experts ()
11001148 else :
11011149 finalize_ret = self .prepare_finalize .finalize_async (
11021150 output ,
@@ -1107,8 +1155,7 @@ def _finalize(
11071155 self .fused_experts .finalize_weight_and_reduce_impl (),
11081156 )
11091157
1110- if self .shared_experts is not None :
1111- shared_output = self .shared_experts (hidden_states )
1158+ shared_output = maybe_run_shared_experts ()
11121159
11131160 # TODO(lucas): refactor this in the alternative schedules followup
11141161 # currently unpack if we have hook + receiver pair or just
@@ -1131,12 +1178,28 @@ def _finalize(
11311178
11321179 receiver ()
11331180
1181+ self ._wait_for_shared_experts_stream (hidden_states , use_shared_experts_stream )
1182+
11341183 if self .shared_experts is None :
11351184 return output
11361185 else :
11371186 assert shared_output is not None
11381187 return shared_output , output
11391188
1189+ def _wait_for_shared_experts_stream (
1190+ self , hidden_states : torch .Tensor , use_shared_experts_stream : bool
1191+ ) -> None :
1192+ # ensure that any work enqueued on the shared_experts_stream is
1193+ # completed before the shared_output tensor is consumed
1194+ if (
1195+ self .shared_experts is not None
1196+ and use_shared_experts_stream
1197+ and self .shared_experts_stream is not None
1198+ and hidden_states .is_cuda
1199+ and current_platform .is_cuda ()
1200+ ):
1201+ torch .cuda .current_stream ().wait_stream (self .shared_experts_stream )
1202+
11401203 def forward (
11411204 self ,
11421205 hidden_states : torch .Tensor ,
@@ -1183,6 +1246,10 @@ def forward(
11831246 else :
11841247 output = torch .zeros_like (hidden_states )
11851248
1249+ use_shared_experts_stream , hidden_states_clone = (
1250+ self ._maybe_setup_shared_experts_stream (hidden_states )
1251+ )
1252+
11861253 local_num_experts = w1 .size (0 )
11871254 if global_num_experts == - 1 :
11881255 global_num_experts = local_num_experts
@@ -1219,4 +1286,6 @@ def forward(
12191286 topk_weights ,
12201287 topk_ids ,
12211288 apply_router_weight_on_input ,
1289+ hidden_states_clone = hidden_states_clone ,
1290+ use_shared_experts_stream = use_shared_experts_stream ,
12221291 )
0 commit comments