7
7
import torch
8
8
from torch .nn .parameter import UninitializedParameter
9
9
10
- import vllm .envs as envs
11
10
from vllm .config import get_current_vllm_config
12
11
from vllm .distributed import (get_dp_group , get_tensor_model_parallel_rank ,
13
12
get_tensor_model_parallel_world_size ,
@@ -342,14 +341,6 @@ def __init__(
342
341
if params_dtype is None :
343
342
params_dtype = torch .get_default_dtype ()
344
343
345
- # For smuggling this layer into the fused moe custom op
346
- compilation_config = get_current_vllm_config ().compilation_config
347
- if prefix in compilation_config .static_forward_context :
348
- raise ValueError ("Duplicate layer name: {}" .format (prefix ))
349
- compilation_config .static_forward_context [prefix ] = self
350
- self .layer_name = prefix
351
- self .use_direct_call = not envs .VLLM_TEST_ENABLE_EP
352
-
353
344
# Note: here we guard against accessing the TP and DP groups when
354
345
# uninitialized (this happens when testing)
355
346
self .tp_size = (tp_size if tp_size is not None else
@@ -361,7 +352,21 @@ def __init__(
361
352
if self .dp_size == 1 else get_dp_group ().rank_in_group )
362
353
self .global_num_experts = num_experts
363
354
364
- if envs .VLLM_TEST_ENABLE_EP :
355
+ # Use expert parallelism instead of tensor parallelism?
356
+ vllm_config = get_current_vllm_config ()
357
+ use_ep = (vllm_config .parallel_config .enable_expert_parallel
358
+ and self .tp_size > 1 )
359
+
360
+ # For smuggling this layer into the fused moe custom op
361
+ self .use_direct_call = self .dp_size == 1
362
+ if self .use_direct_call :
363
+ compilation_config = vllm_config .compilation_config
364
+ if prefix in compilation_config .static_forward_context :
365
+ raise ValueError ("Duplicate layer name: {}" .format (prefix ))
366
+ compilation_config .static_forward_context [prefix ] = self
367
+ self .layer_name = prefix
368
+
369
+ if use_ep :
365
370
# Set TP size to 1 to adjust for EP and adjust EP size and rank
366
371
# for DP attention.
367
372
self .ep_rank = tp_rank + self .tp_size * self .dp_rank
0 commit comments