Skip to content

Commit b87c21f

Browse files
authored
[Misc][Platform] Move use allgather to platform (#14010)
Signed-off-by: Mengqing Cao <[email protected]>
1 parent e584b85 commit b87c21f

File tree

4 files changed

+24
-7
lines changed

4 files changed

+24
-7
lines changed

vllm/model_executor/layers/logits_processor.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import torch.nn as nn
99

1010
import vllm.envs as envs
11-
from vllm.config import get_current_vllm_config
1211
from vllm.distributed import (tensor_model_parallel_all_gather,
1312
tensor_model_parallel_gather)
1413
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -51,11 +50,7 @@ def __init__(self,
5150
# Soft cap the logits. Used in Gemma 2.
5251
self.soft_cap = soft_cap
5352
# Whether to use gather or all-gather to gather the logits.
54-
parallel_config = get_current_vllm_config().parallel_config
55-
self.use_all_gather = current_platform.is_tpu() \
56-
or current_platform.is_neuron() \
57-
or envs.VLLM_USE_V1 \
58-
or parallel_config.distributed_executor_backend == "external_launcher" # noqa
53+
self.use_all_gather = current_platform.use_all_gather()
5954

6055
def forward(
6156
self,
@@ -83,7 +78,8 @@ def forward(
8378
logits *= self.scale
8479

8580
# Apply logits processors (if any).
86-
if sampling_metadata is not None:
81+
if sampling_metadata is not None and \
82+
sampling_metadata.seq_groups is not None:
8783
logits = _apply_logits_processors(logits, sampling_metadata)
8884

8985
return logits

vllm/platforms/interface.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,19 @@ def get_device_communicator_cls(cls) -> str:
330330
"""
331331
return "vllm.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase" # noqa
332332

333+
@classmethod
334+
def use_all_gather(cls) -> bool:
335+
"""
336+
Whether to use allgather in LogitsProcessor to gather the logits.
337+
"""
338+
import vllm.envs as envs
339+
from vllm.config import get_current_vllm_config
340+
341+
parallel_config = get_current_vllm_config().parallel_config
342+
return (envs.VLLM_USE_V1
343+
or parallel_config.distributed_executor_backend
344+
== "external_launcher")
345+
333346

334347
class UnspecifiedPlatform(Platform):
335348
_enum = PlatformEnum.UNSPECIFIED

vllm/platforms/neuron.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
5555
def is_pin_memory_available(cls) -> bool:
5656
logger.warning("Pin memory is not supported on Neuron.")
5757
return False
58+
59+
@classmethod
60+
def use_all_gather(cls) -> bool:
61+
return True

vllm/platforms/tpu.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,3 +119,7 @@ def is_pin_memory_available(cls):
119119
@classmethod
120120
def get_device_communicator_cls(cls) -> str:
121121
return "vllm.distributed.device_communicators.tpu_communicator.TpuCommunicator" # noqa
122+
123+
@classmethod
124+
def use_all_gather(cls) -> bool:
125+
return True

0 commit comments

Comments
 (0)