Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1590,8 +1590,21 @@ class ParallelConfig:
the product of the tensor parallel size and data parallel size."""
data_parallel_rank: int = 0
"""Rank of the data parallel group."""
data_parallel_rank_local: Optional[int] = None
"""Local rank of the data parallel group, defaults to global rank."""
_data_parallel_rank_local: Optional[int] = field(default=None, init=False)
"""Private field to store the local rank of the data parallel group."""

@property
def data_parallel_rank_local(self) -> int:
"""Local rank of the data parallel group, defaults to global rank."""
if self._data_parallel_rank_local is None:
return self.data_parallel_rank
return self._data_parallel_rank_local

@data_parallel_rank_local.setter
def data_parallel_rank_local(self, value: int) -> None:
"""Set the local rank of the data parallel group."""
self._data_parallel_rank_local = value

data_parallel_master_ip: str = "127.0.0.1"
"""IP of the data parallel master."""
data_parallel_master_port: int = 29500
Expand Down
6 changes: 3 additions & 3 deletions vllm/v1/engine/core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,10 +439,10 @@ def _init_core_engines(
) -> None:

# Default case - single core engine.
dp_rank = vllm_config.parallel_config.data_parallel_rank
local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local
core_engine = new_core_engine(
dp_rank, local_dp_rank if local_dp_rank is not None else dp_rank)
vllm_config.parallel_config.data_parallel_rank,
vllm_config.parallel_config.data_parallel_rank_local,
)
core_engines.append(core_engine)
self.core_engine = core_engine

Expand Down