Skip to content

Commit f75c8ec

Browse files
pdesupinskimarkc-614
authored andcommitted
Support NUMA Binding for Callable Entrypoints (pytorch#160163)
# Context This is an extension of pytorch#149334. # This PR Add support for NUMA bindings with Callable entrypoints, such as `do_train` instead of `/usr/local/bin/python`. Most notably, we utilize a hack in order to force `Process.start()` to use custom NUMA bindings for each subprocess. Please search for `HACK:` in the code to see a description of the implementation we chose, and pytorch#160006 for discussion of alternatives and why this is necessary. Other changes: * Remove unnecessary `--preferred` option from all binding strategies. By default, Linux already allocates memory to the NUMA node local to the CPU which triggered the allocation. (See [MPOL_LOCAL](https://man7.org/linux/man-pages/man2/set_mempolicy.2.html).) * Refactor so that the main API is `maybe_wrap_command_with_numa_bindings`, which computes bindings for a single rank at a time, rather than `maybe_wrap_with_numa_bindings` which computed bindings for all ranks at once. This allowed for more code sharing between `Callable` and `str` entrypoints. # Test Plan ## Automated `$ pytest test/test_numa_binding.py` ## Manual Using [this benchmark,](https://gist.github.com/pdesupinski/bbe01ade455d86e989794f2c612e2d91), ran ``` $ PYTHONUNBUFFERED=1 LOGLEVEL=INFO perf stat -e ls_dmnd_fills_from_sys.dram_io_far,ls_dmnd_fills_from_sys.dram_io_near -- python -m torch.distributed.run --standalone --nproc-per-node=8 --numa-binding=node --run-path mlp_train.py 2>&1 | tee node_callable.txt && PYTHONUNBUFFERED=1 LOGLEVEL=INFO perf stat -e ls_dmnd_fills_from_sys.dram_io_far,ls_dmnd_fills_from_sys.dram_io_near -- python -u -m torch.distributed.run --standalone --nproc-per-node=8 --run-path mlp_train.py 2>&1 | tee none_callable.txt ``` and observed * 6.6% remote memory accesses with 'node' bindings * 11.6% remote without bindings I also ran similar with `str` entrypoints as before just to be sure it's still working. NOTE: [--run-path triggers the code to be run inside a `Callable`.](https://github.com/pytorch/pytorch/blob/017259f9c65b6fad55fb9597d7077e2543eaae46/torch/distributed/run.py#L870) Pull Request resolved: pytorch#160163 Approved by: https://github.com/d4l3k
1 parent cc9f8d6 commit f75c8ec

File tree

12 files changed

+424
-215
lines changed

12 files changed

+424
-215
lines changed

docs/source/elastic/numa.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
NUMA Binding Utilities
44
======================
55

6-
.. automodule:: torch.distributed.numa
6+
.. automodule:: torch.numa
77
:members:
88

9-
.. automodule:: torch.distributed.numa.binding
9+
.. automodule:: torch.numa.binding
1010
:members:

test/test_numa_binding.py

Lines changed: 169 additions & 81 deletions
Large diffs are not rendered by default.

torch/distributed/elastic/agent/server/api.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from torch.distributed.elastic.multiprocessing import ProcessFailure, SignalException
2828
from torch.distributed.elastic.rendezvous import RendezvousGracefulExitError
2929
from torch.distributed.elastic.utils.logging import get_logger
30-
from torch.distributed.numa.binding import NumaOptions
30+
from torch.numa.binding import NumaOptions
3131

3232

3333
__all__ = [
@@ -104,13 +104,6 @@ def __post_init__(self):
104104
self.entrypoint = self.fn
105105
assert self.entrypoint
106106

107-
if (
108-
self.numa_options is not None
109-
and not self.numa_options.should_fall_back_if_binding_fails
110-
and not isinstance(self.entrypoint, str)
111-
):
112-
raise ValueError("numa_options is only supported for str entrypoints.")
113-
114107
def get_entrypoint_name(self):
115108
"""Get the entry point name.
116109

torch/distributed/elastic/multiprocessing/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def trainer(a, b, c):
8080
to_map,
8181
)
8282
from torch.distributed.elastic.utils.logging import get_logger
83-
from torch.distributed.numa.binding import NumaOptions
83+
from torch.numa.binding import NumaOptions
8484

8585

8686
__all__ = [
@@ -227,6 +227,7 @@ def start_processes(
227227
log_line_prefixes=log_line_prefixes,
228228
start_method=start_method,
229229
logs_specs=logs_specs,
230+
numa_options=numa_options,
230231
)
231232

232233
try:

torch/distributed/elastic/multiprocessing/api.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
SubprocessHandler,
3838
)
3939
from torch.distributed.elastic.multiprocessing.tail_log import TailLog
40-
from torch.distributed.numa.binding import maybe_wrap_with_numa_bindings, NumaOptions
40+
from torch.numa.binding import NumaOptions
4141

4242

4343
IS_WINDOWS = sys.platform == "win32"
@@ -631,6 +631,7 @@ def __init__(
631631
start_method: str,
632632
logs_specs: LogsSpecs,
633633
log_line_prefixes: Optional[dict[int, str]] = None,
634+
numa_options: Optional[NumaOptions] = None,
634635
):
635636
super().__init__(
636637
name,
@@ -655,6 +656,8 @@ def __init__(
655656
# successfully. If any process died on event.wait() calling set() method will deadlock.
656657
self._worker_finished_event = mp.get_context(self.start_method).Event()
657658

659+
self._numa_options: Optional[NumaOptions] = numa_options
660+
658661
def _start(self):
659662
if self._pc:
660663
raise ValueError(
@@ -676,6 +679,7 @@ def _start(self):
676679
join=False,
677680
daemon=False,
678681
start_method=self.start_method,
682+
numa_options=self._numa_options,
679683
)
680684

681685
def _is_done(self) -> bool:
@@ -814,10 +818,6 @@ def __init__(
814818
log_line_prefixes: Optional[dict[int, str]] = None,
815819
numa_options: Optional[NumaOptions] = None,
816820
):
817-
entrypoint, args = maybe_wrap_with_numa_bindings(
818-
entrypoint=entrypoint, local_rank_to_args=args, numa_options=numa_options
819-
)
820-
821821
super().__init__(
822822
name,
823823
entrypoint,
@@ -831,6 +831,7 @@ def __init__(
831831
self._running_local_ranks: set[int] = set(range(self.nprocs))
832832
self._failures: dict[int, ProcessFailure] = {}
833833
self.subprocess_handlers: dict[int, SubprocessHandler] = {}
834+
self._numa_options: Optional[NumaOptions] = numa_options
834835

835836
def _start(self):
836837
if self.subprocess_handlers:
@@ -845,6 +846,7 @@ def _start(self):
845846
stdout=self.stdouts[local_rank],
846847
stderr=self.stderrs[local_rank],
847848
local_rank_id=local_rank,
849+
numa_options=self._numa_options,
848850
)
849851
for local_rank in range(self.nprocs)
850852
}

torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6+
from typing import Optional
67

78
from torch.distributed.elastic.multiprocessing.subprocess_handler.subprocess_handler import (
89
SubprocessHandler,
910
)
11+
from torch.numa.binding import NumaOptions
1012

1113

1214
__all__ = ["get_subprocess_handler"]
@@ -19,6 +21,7 @@ def get_subprocess_handler(
1921
stdout: str,
2022
stderr: str,
2123
local_rank_id: int,
24+
numa_options: Optional[NumaOptions] = None,
2225
) -> SubprocessHandler:
2326
return SubprocessHandler(
2427
entrypoint=entrypoint,
@@ -27,4 +30,5 @@ def get_subprocess_handler(
2730
stdout=stdout,
2831
stderr=stderr,
2932
local_rank_id=local_rank_id,
33+
numa_options=numa_options,
3034
)

torch/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from subprocess import Popen
1212
from typing import Any, Optional
1313

14+
from torch.numa.binding import maybe_wrap_command_with_numa_bindings, NumaOptions
15+
1416

1517
__all__ = ["SubprocessHandler"]
1618

@@ -39,6 +41,7 @@ def __init__(
3941
stdout: Optional[str],
4042
stderr: Optional[str],
4143
local_rank_id: int,
44+
numa_options: Optional[NumaOptions],
4245
):
4346
self._stdout = open(stdout, "w") if stdout else None
4447
self._stderr = open(stderr, "w") if stderr else None
@@ -47,6 +50,15 @@ def __init__(
4750
env_vars.update(env)
4851

4952
args_str = (entrypoint, *[str(e) for e in args])
53+
args_str = (
54+
maybe_wrap_command_with_numa_bindings(
55+
command_args=args_str,
56+
gpu_index=local_rank_id,
57+
numa_options=numa_options,
58+
)
59+
or args_str
60+
)
61+
5062
self.local_rank_id = local_rank_id
5163
self.proc: Popen = self._popen(args_str, env_vars)
5264

torch/distributed/launcher/api.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from torch.distributed.elastic.rendezvous import RendezvousParameters
2727
from torch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint
2828
from torch.distributed.elastic.utils.logging import get_logger
29-
from torch.distributed.numa.binding import NumaOptions
29+
from torch.numa.binding import NumaOptions
3030

3131

3232
__all__ = ["LaunchConfig", "elastic_launch", "launch_agent"]
@@ -107,7 +107,13 @@ def __post_init__(self):
107107
if self.logs_specs is None:
108108
self.logs_specs = DefaultLogsSpecs()
109109

110-
if self.numa_options is None and torch.cuda.is_available():
110+
if (
111+
self.numa_options is None
112+
# NOTE: This filter isn't relevant for str entrypoints,
113+
# but it's the default anyway.
114+
and self.start_method == "spawn"
115+
and torch.cuda.is_available()
116+
):
111117
self.numa_options = get_default_numa_options()
112118
logger.info("Using default numa options = %r", self.numa_options)
113119

torch/distributed/run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ def main():
382382
from torch.distributed.elastic.utils import macros
383383
from torch.distributed.elastic.utils.logging import get_logger
384384
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
385-
from torch.distributed.numa.binding import (
385+
from torch.numa.binding import (
386386
AffinityMode as _AffinityMode, # Signify as private with _
387387
NumaOptions as _NumaOptions,
388388
)

torch/multiprocessing/spawn.py

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import logging
33
import multiprocessing
44
import multiprocessing.connection
5+
import multiprocessing.spawn as mp_spawn
56
import os
67
import pickle
78
import signal
@@ -12,6 +13,11 @@
1213
from concurrent.futures import as_completed, ThreadPoolExecutor
1314
from typing import Optional
1415

16+
from torch.numa.binding import (
17+
maybe_get_temporary_python_executable_with_numa_bindings,
18+
NumaOptions,
19+
)
20+
1521
from . import _prctl_pr_set_pdeathsig # type: ignore[attr-defined]
1622

1723

@@ -236,6 +242,7 @@ def start_processes(
236242
join=True,
237243
daemon=False,
238244
start_method="spawn",
245+
numa_options: Optional[NumaOptions] = None,
239246
):
240247
# To speed up performance in certain cases (see https://github.com/pytorch/pytorch/issues/133010),
241248
# this func will start processes in parallel if start_method is 'forkserver'.
@@ -251,11 +258,43 @@ def start_processes(
251258
# Set env var TORCH_MP_PARALLEL_START to 0 to disable parallel start
252259
start_parallel = False
253260

261+
if numa_options is not None and start_method != "spawn":
262+
raise ValueError("NUMA binding is only compatible with spawn")
263+
264+
if numa_options is not None and start_parallel:
265+
raise ValueError("NUMA binding is not compatible with parallel start")
266+
254267
mp = multiprocessing.get_context(start_method)
255268
error_files = [None] * nprocs
256269
processes = [None] * nprocs
270+
original_executable = mp_spawn.get_executable()
257271

258272
def start_process(i):
273+
# HACK: We want to force Process.start() to kick off the subprocess
274+
# using a custom numactl command per rank. However, the API exposed
275+
# by multiprocessing only allows us to override the executable for
276+
# the entire context, and only with a single str rather than a tuple.
277+
# Furthermore, there is no API for passing additional options, e.g.
278+
# to make LOCAL_RANK available to the executable.
279+
#
280+
# In order to get around these limitations, we pre-compute
281+
# the appropriate command containing NUMA bindings and store it in a
282+
# temporary executable which passes Python args on to the original
283+
# executable. Then, we call set_executable before and after each
284+
# Process.start() call.
285+
#
286+
# This assumes that, under the hood, Process.start() for rank n
287+
# will not call get_executable after start_process for rank n+1
288+
# calls set_executable again. We guarantee this by
289+
# raising an exception if `start_parallel`, above. (Not clear
290+
# if there would be a race condition otherwise, but we want to be safe.)
291+
temporary_executable_path = (
292+
maybe_get_temporary_python_executable_with_numa_bindings(
293+
python_executable_path=original_executable,
294+
gpu_index=i,
295+
numa_options=numa_options,
296+
)
297+
)
259298
# Each process is assigned a file to write tracebacks to. We
260299
# use the file being non-empty to indicate an exception
261300
# occurred (vs an expected shutdown). Note: this previously
@@ -267,12 +306,19 @@ def start_process(i):
267306
)
268307
tf.close()
269308
os.unlink(tf.name)
270-
process = mp.Process(
271-
target=_wrap,
272-
args=(fn, i, args, tf.name),
273-
daemon=daemon,
274-
)
275-
process.start()
309+
310+
try:
311+
if temporary_executable_path is not None:
312+
mp.set_executable(temporary_executable_path)
313+
process = mp.Process(
314+
target=_wrap,
315+
args=(fn, i, args, tf.name),
316+
daemon=daemon,
317+
)
318+
process.start()
319+
finally:
320+
if temporary_executable_path is not None:
321+
mp.set_executable(original_executable)
276322
return i, process, tf.name
277323

278324
if not start_parallel:

0 commit comments

Comments
 (0)