Skip to content

Commit 0ecfe56

Browse files
committed
Fix HPU tensor parallelism
Signed-off-by: Konrad Zawora <[email protected]>
1 parent 87a0c07 commit 0ecfe56

File tree

3 files changed

+20
-3
lines changed

3 files changed

+20
-3
lines changed

vllm/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1285,7 +1285,7 @@ def __post_init__(self) -> None:
12851285
raise ValueError(f"worker-use-ray can't be used with "
12861286
f"distributed executor backend "
12871287
f"'{self.distributed_executor_backend}'.")
1288-
ray_only_devices = ["tpu", "hpu"]
1288+
ray_only_devices = ["tpu"]
12891289
from vllm.platforms import current_platform
12901290
if (current_platform.device_type in ray_only_devices
12911291
and self.world_size > 1):

vllm/executor/multiproc_worker_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212

1313
import torch
1414

15+
from vllm import envs
1516
from vllm.config import VllmConfig
1617
from vllm.logger import init_logger
18+
from vllm.platforms import current_platform
1719
from vllm.triton_utils.importing import HAS_TRITON
1820
from vllm.utils import _check_multiproc_method, get_mp_context, run_method
1921

@@ -284,6 +286,21 @@ def set_multiprocessing_worker_envs(parallel_config):
284286
process before worker processes are created"""
285287

286288
_check_multiproc_method()
289+
if (current_platform.is_hpu()
290+
and parallel_config.distributed_executor_backend == 'mp'
291+
and envs.VLLM_WORKER_MULTIPROC_METHOD == 'fork'):
292+
if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) is not None:
293+
logger.warning("On HPU, VLLM_WORKER_MULTIPROC_METHOD=fork might "
294+
"cause application hangs on exit. Using "
295+
"VLLM_WORKER_MULTIPROC_METHOD=fork anyway, "
296+
"as it was explicitly requested.")
297+
else:
298+
logger.warning("On HPU, VLLM_WORKER_MULTIPROC_METHOD=fork might "
299+
"cause application hangs on exit. Setting "
300+
"VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. "
301+
"To override that behavior, please set "
302+
"VLLM_WORKER_MULTIPROC_METHOD=fork explicitly.")
303+
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
287304

288305
# Configure thread parallelism if OMP_NUM_THREADS isn't set
289306
#

vllm/worker/hpu_worker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,6 @@ def execute_model(
130130
self,
131131
execute_model_req: Optional[ExecuteModelRequest] = None,
132132
) -> Optional[List[SamplerOutput]]:
133-
assert execute_model_req is not None
134133
# VLLM_HPU_LOG_STEP_GRAPH_COMPILATION - will log graph compilations per engine step, only when there was any - highly recommended to use alongside PT_HPU_METRICS_GC_DETAILS! # noqa:E501
135134
# VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL - will log graph compilations per engine step, always, even if there were none # noqa:E501
136135
# VLLM_HPU_LOG_STEP_CPU_FALLBACKS - will log cpu fallbacks per engine step, only when there was any # noqa:E501
@@ -144,7 +143,8 @@ def execute_model(
144143
'VLLM_HPU_LOG_STEP_CPU_FALLBACKS_ALL', '0') != '0'
145144
log_cpu_fallbacks = os.environ.get('VLLM_HPU_LOG_STEP_CPU_FALLBACKS',
146145
'0') != '0' or log_cpu_fallbacks_all
147-
if log_graph_compilation or log_cpu_fallbacks:
146+
if log_graph_compilation or log_cpu_fallbacks and \
147+
execute_model_req is not None:
148148
from habana_frameworks.torch.hpu.metrics import metric_localcontext
149149
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
150150
is_prompt = any([

0 commit comments

Comments
 (0)