Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions vllm/model_executor/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import torch.nn as nn

import vllm.envs as envs
from vllm.config import get_current_vllm_config
from vllm.distributed import (tensor_model_parallel_all_gather,
tensor_model_parallel_gather)
from vllm.model_executor.layers.vocab_parallel_embedding import (
Expand Down Expand Up @@ -51,11 +50,7 @@ def __init__(self,
# Soft cap the logits. Used in Gemma 2.
self.soft_cap = soft_cap
# Whether to use gather or all-gather to gather the logits.
parallel_config = get_current_vllm_config().parallel_config
self.use_all_gather = current_platform.is_tpu() \
or current_platform.is_neuron() \
or envs.VLLM_USE_V1 \
or parallel_config.distributed_executor_backend == "external_launcher" # noqa
self.use_all_gather = current_platform.use_all_gather()

def forward(
self,
Expand Down Expand Up @@ -83,7 +78,8 @@ def forward(
logits *= self.scale

# Apply logits processors (if any).
if sampling_metadata is not None:
if sampling_metadata is not None and \
sampling_metadata.seq_groups is not None:
logits = _apply_logits_processors(logits, sampling_metadata)

return logits
Expand Down
13 changes: 13 additions & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,19 @@ def get_device_communicator_cls(cls) -> str:
"""
return "vllm.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase" # noqa

@classmethod
def use_all_gather(cls) -> bool:
"""
Whether to use allgather in LogitsProcessor to gather the logits.
"""
import vllm.envs as envs
from vllm.config import get_current_vllm_config

parallel_config = get_current_vllm_config().parallel_config
return (envs.VLLM_USE_V1
or parallel_config.distributed_executor_backend
== "external_launcher")


class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED
Expand Down
4 changes: 4 additions & 0 deletions vllm/platforms/neuron.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
def is_pin_memory_available(cls) -> bool:
logger.warning("Pin memory is not supported on Neuron.")
return False

@classmethod
def use_all_gather(cls) -> bool:
return True
4 changes: 4 additions & 0 deletions vllm/platforms/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,7 @@ def is_pin_memory_available(cls):
@classmethod
def get_device_communicator_cls(cls) -> str:
return "vllm.distributed.device_communicators.tpu_communicator.TpuCommunicator" # noqa

@classmethod
def use_all_gather(cls) -> bool:
return True