diff --git a/README.md b/README.md index 722f0ed..a847f7a 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,7 @@ Requires: Linux (+ SSH & shared filesystem if using multiple machines) Dummy distributed training function: ```python +from __future__ import annotations import os import torch import torch.nn as nn @@ -59,15 +60,13 @@ Launching training with `torchrunx`: ```python import torchrunx -results = torchrunx.launch( - func = train, - kwargs = dict( - model = nn.Linear(10, 10), - num_steps = 10 - ), - # +results = torchrunx.Launcher( hostnames = ["localhost", "second_machine"], workers_per_host = 2 +).run( + train, + model = nn.Linear(10, 10), + num_steps = 10 ) trained_model: nn.Module = results.rank(0) @@ -75,10 +74,10 @@ torch.save(trained_model.state_dict(), "output/model.pth") ``` **See examples where we fine-tune LLMs (e.g. GPT-2 on WikiText) using:** - - [Accelerate](https://torchrun.xyz/examples/accelerate.html) - - [HF Transformers](https://torchrun.xyz/examples/transformers.html) + - [Transformers](https://torchrun.xyz/examples/transformers.html) - [DeepSpeed](https://torchrun.xyz/examples/deepspeed.html) - [PyTorch Lightning](https://torchrun.xyz/examples/lightning.html) + - [Accelerate](https://torchrun.xyz/examples/accelerate.html) **Refer to our [API](https://torchrun.xyz/api.html) and [Advanced Usage Guide](https://torchrun.xyz/advanced.html) for many more capabilities!** @@ -118,4 +117,4 @@ torch.save(trained_model.state_dict(), "output/model.pth") > - Automatic detection of SLURM environments. > - Start multi-node training from Python notebooks! -**On our [roadmap](https://github.com/apoorvkh/torchrunx/issues?q=is%3Aopen+is%3Aissue+label%3Aenhancement): higher-order parallelism, support for debuggers, fuller typing, and more!** +**On our [roadmap](https://github.com/apoorvkh/torchrunx/issues?q=is%3Aopen+is%3Aissue+label%3Aenhancement): higher-order parallelism, support for debuggers, and more!** diff --git a/docs/conf.py b/docs/conf.py index 8de5c00..242298a 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -24,11 +24,12 @@ "sphinx_toolbox.github", ] +maximum_signature_line_length = 90 autodoc_member_order = "bysource" -autodoc_typehints = "description" -autodoc_typehints_description_target = "documented" -intersphinx_mapping = {'python': ('https://docs.python.org/3', None)} +intersphinx_mapping = { + 'python': ('https://docs.python.org/3.9', None), +} from docs.linkcode_github import generate_linkcode_resolve_fn linkcode_resolve = generate_linkcode_resolve_fn(project, github_username, github_repository) diff --git a/docs/source/api.md b/docs/source/api.md index 95be4c9..6ee6d48 100644 --- a/docs/source/api.md +++ b/docs/source/api.md @@ -1,29 +1,6 @@ # API ```{eval-rst} -.. autofunction:: torchrunx.launch(func, args, kwargs, ...) -``` - -We provide the {obj}`torchrunx.Launcher` class as an alias to {obj}`torchrunx.launch`. - -```{eval-rst} -.. autoclass:: torchrunx.Launcher - :members: -``` - -## Results - -```{eval-rst} -.. autoclass:: torchrunx.LaunchResult +.. automodule:: torchrunx :members: ``` - -## Exceptions - -```{eval-rst} -.. autoexception:: torchrunx.AgentFailedError -``` - -```{eval-rst} -.. autoexception:: torchrunx.WorkerFailedError -``` diff --git a/docs/source/features/cli.md b/docs/source/features/cli.md index bce898f..d8e33e7 100644 --- a/docs/source/features/cli.md +++ b/docs/source/features/cli.md @@ -1,16 +1,16 @@ # CLI Integration -We can use {mod}`torchrunx.Launcher` to populate arguments from the CLI (e.g. with [tyro](https://brentyi.github.io/tyro/)): +We can automatically populate {mod}`torchrunx.Launcher` arguments using most CLI tools (those that generate interfaces from Data Classes, e.g. [tyro](https://brentyi.github.io/tyro/)): ```python -import torchrunx as trx +import torchrunx import tyro def distributed_function(): - pass + ... if __name__ == "__main__": - launcher = tyro.cli(trx.Launcher) + launcher = tyro.cli(torchrunx.Launcher) launcher.run(distributed_function) ``` diff --git a/pyproject.toml b/pyproject.toml index 8c781a7..7925c7b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ authors = [ ] description = "Automatically initialize distributed PyTorch environments" readme = "README.md" -license = {file = "LICENSE"} +license = { file = "LICENSE" } urls = { Repository = "https://github.com/apoorvkh/torchrunx.git", Documentation = "https://torchrun.xyz" } requires-python = ">=3.9" dependencies = [ @@ -21,12 +21,17 @@ dependencies = [ # torch.distributed depends on numpy # torch<=2.2 needs numpy<2 "numpy>=1.20", + "typing-extensions>=4.9.0", ] [dependency-groups] dev = ["ruff==0.9.5", "pyright[nodejs]==1.1.393", "pytest==8.3.4"] test-extras = ["submitit", "transformers"] -docs = ["sphinx==7.4.7", "furo==2024.8.6", "myst-parser==3.0.1", "sphinx-toolbox==3.8.2"] - +docs = [ + "sphinx==7.4.7", + "furo==2024.8.6", + "myst-parser==3.0.1", + "sphinx-toolbox==3.8.2", +] [tool.ruff] include = ["pyproject.toml", "src/**/*.py", "tests/**/*.py"] @@ -36,6 +41,8 @@ src = ["src", "tests"] [tool.ruff.lint] select = ["ALL"] ignore = [ + "TC003", # no type checking blocks for stdlib + "D104", # package docstrings "ANN401", # self / cls / Any annotations "BLE001", # blind exceptions "TD", # todo syntax diff --git a/src/torchrunx/__init__.py b/src/torchrunx/__init__.py index 3856f58..0342f4b 100644 --- a/src/torchrunx/__init__.py +++ b/src/torchrunx/__init__.py @@ -1,12 +1,10 @@ -"""API for our torchrunx library.""" - -from .launcher import Launcher, LaunchResult, launch +from .launcher import DEFAULT_ENV_VARS_FOR_COPY, Launcher, LaunchResult from .utils.errors import AgentFailedError, WorkerFailedError -__all__ = [ - "AgentFailedError", - "LaunchResult", +__all__ = [ # noqa: RUF022 + "DEFAULT_ENV_VARS_FOR_COPY", "Launcher", + "LaunchResult", + "AgentFailedError", "WorkerFailedError", - "launch", ] diff --git a/src/torchrunx/integrations/__init__.py b/src/torchrunx/integrations/__init__.py index 58cebc9..e69de29 100644 --- a/src/torchrunx/integrations/__init__.py +++ b/src/torchrunx/integrations/__init__.py @@ -1 +0,0 @@ -"""Utilities for integrations with other libraries.""" diff --git a/src/torchrunx/launcher.py b/src/torchrunx/launcher.py index 622ca0e..3dba4c7 100644 --- a/src/torchrunx/launcher.py +++ b/src/torchrunx/launcher.py @@ -2,187 +2,133 @@ from __future__ import annotations -__all__ = ["LaunchResult", "Launcher", "launch"] +__all__ = ["DEFAULT_ENV_VARS_FOR_COPY", "LaunchResult", "Launcher"] import fnmatch -import ipaddress import itertools import logging import os -import shlex import socket -import subprocess -import sys -from dataclasses import dataclass +import typing +from dataclasses import dataclass, field from functools import partial -from logging import Handler from multiprocessing import Event, Process -from pathlib import Path -from typing import Any, Callable, Literal +from typing import Generic, TypeVar -import fabric import torch.distributed as dist -from typing_extensions import Self +from typing_extensions import ParamSpec, Self from .utils.comm import ( LauncherAgentGroup, LauncherPayload, get_open_port, ) -from .utils.environment import auto_hosts, slurm_hosts -from .utils.errors import ( - ExceptionFromWorker, - WorkerFailedError, +from .utils.environment import ( + build_launch_command, + execute_command, + resolve_environment, ) +from .utils.errors import ExceptionFromWorker, WorkerFailedError from .utils.logging import LoggingServerArgs, start_logging_server +DEFAULT_ENV_VARS_FOR_COPY = ( + "PATH", + "LD_LIBRARY", + "LIBRARY_PATH", + "PYTHON*", + "CUDA*", + "TORCH*", + "PYTORCH*", + "NCCL*", +) -def launch( - func: Callable, - args: tuple | None = None, - kwargs: dict[str, Any] | None = None, - *, - hostnames: list[str] | Literal["auto", "slurm"] = "auto", - workers_per_host: int | list[int] | Literal["auto"] = "auto", - ssh_config_file: str | os.PathLike | None = None, - backend: Literal["nccl", "gloo", "mpi", "ucc", "auto"] | None = "auto", - timeout: int = 600, - default_env_vars: tuple[str, ...] = ( - "PATH", - "LD_LIBRARY", - "LIBRARY_PATH", - "PYTHON*", - "CUDA*", - "TORCH*", - "PYTORCH*", - "NCCL*", - ), - extra_env_vars: tuple[str, ...] = (), - env_file: str | os.PathLike | None = None, - propagate_exceptions: bool = True, - handler_factory: Callable[[], list[Handler]] | Literal["auto"] | None = "auto", -) -> LaunchResult: - """Distribute and parallelize a function onto specified nodes and workers. - - Arguments: - func: Function to replicate on each node/worker. - args: Positional arguments for ``func``. Default: :py:obj:`None`. - kwargs: Keyword arguments for ``func``. Default: :py:obj:`None`. - hostnames: Nodes on which to launch the function. - Default: ``"auto"`` (infer from localhost or SLURM). - workers_per_host: Number of processes to run (e.g. # of GPUs) per node. - Default: ``"auto"`` (number of GPUs per host). - ssh_config_file: Path to an SSH configuration file for connecting to nodes. - Default: ``"~/.ssh/config"`` or ``"/etc/ssh/ssh_config"``. - backend: `Backend `_ - for worker process group. Set `None` to disable. - Default: ``"auto"`` (NCCL if GPU or GLOO if CPU). - timeout: Worker process group timeout (seconds). - Default: ``600``. - default_env_vars: Environment variables to copy from the launcher process to workers. - Supports bash pattern matching syntax. - Default: ``("PATH", "LD_LIBRARY", "LIBRARY_PATH", "PYTHON*", "CUDA*", "TORCH*", - "PYTORCH*", "NCCL*")``. - extra_env_vars: Additional user-specified environment variables to copy. - Default: ``()``. - env_file: Path to a file (e.g., ``.env``) with additional environment variables to copy. - Default: :py:obj:`None`. - propagate_exceptions: Raise exceptions from worker processes in the launcher. - If false, raises :exc:`WorkerFailedError` instead. - Default: :py:obj:`True`. - handler_factory: Function to customize processing of agent and worker logs with handlers. - Default: ``"auto"`` (see `custom logging `_). - - Raises: - RuntimeError: If there are configuration issues. - Exception: Any exception raised in a worker process is propagated. - WorkerFailedError: If a worker fails (e.g. from a segmentation fault) - or raises an exception and ``propagate_exceptions=False``. - AgentFailedError: If an agent fails, e.g. from an OS signal. - """ - return ( - Launcher( - hostnames=hostnames, - workers_per_host=workers_per_host, - ssh_config_file=ssh_config_file, - backend=backend, - timeout=timeout, - default_env_vars=default_env_vars, - extra_env_vars=extra_env_vars, - env_file=env_file, - propagate_exceptions=propagate_exceptions, - ) - .set_handler_factory(handler_factory) - .run( - func, - args, - kwargs, - ) - ) +FunctionP = ParamSpec("FunctionP") +FunctionR = TypeVar("FunctionR") @dataclass class Launcher: - """Alias class for :func:`launch`. Refer to that function for documentation.""" - - hostnames: list[str] | Literal["auto", "slurm"] = "auto" - """Node hostnames to use in distributed execution. "auto" and "slurm" attempt to detect this - for you based on your environmental variables.""" - workers_per_host: int | list[int] | Literal["auto"] = "auto" - """Number of worker processes per node. You can specify a constant number of workers for all - nodes (int), a different number of workers for each node (list[int]), or automatically determine - it per-node ("auto").""" + """For configuring the function launch environment.""" + + hostnames: list[str] | typing.Literal["auto", "slurm"] = "auto" + """Nodes on which to launch the function. By default, infer from localhost or SLURM.""" + workers_per_host: int | list[int] | typing.Literal["auto"] = "auto" + """Number of processes to run per node. By default, number of GPUs per host.""" ssh_config_file: str | os.PathLike | None = None - """Path to custom SSH Config for passwordless SSH into each node.""" - backend: Literal["nccl", "gloo", "mpi", "ucc", "auto"] | None = "auto" - """A torch.distributed backend to use for inter-process communication. "auto" will use NCCL if - GPUs are detected, otherwise GLOO.""" + """For connecting to nodes. By default, ``"~/.ssh/config"`` or ``"/etc/ssh/ssh_config"``.""" + backend: typing.Literal["nccl", "gloo", "mpi", "ucc", "auto"] | None = "auto" + """`Backend `_ + for worker process group or ``None``. By default, NCCL if GPUs detected, else GLOO.""" timeout: int = 600 - """The torch.distributed communication timeout of the worker process group, in seconds.""" - default_env_vars: tuple[str, ...] = ( - "PATH", - "LD_LIBRARY", - "LIBRARY_PATH", - "PYTHON*", - "CUDA*", - "TORCH*", - "PYTORCH*", - "NCCL*", - ) - """Environmental variables to clone from the launcher process to worker processes, - supporting unix pattern matching.""" - extra_env_vars: tuple[str, ...] = () - """Additional environmental variables to set in the worker process environments, - formatted identically to the defaul_env_vars field.""" + """Worker process group timeout (seconds).""" + copy_env_vars: tuple[str, ...] = DEFAULT_ENV_VARS_FOR_COPY + """Environment variables to copy from the launcher process to workers. + Supports Unix pattern matching syntax.""" + extra_env_vars: dict[str, str] | None = None + """Additional environment variables to load onto workers.""" env_file: str | os.PathLike | None = None - """A bash style .env file that will be sourced by worker processes.""" + """Path to a ``.env`` file, containing environment variables to load onto workers.""" propagate_exceptions: bool = True - """Whether worker exceptions should be raised by the launcher.""" + """Whether to raise specific worker exceptions or :exc:`torchrunx.WorkerFailedError`.""" - def __post_init__(self) -> None: - """Initializing ``handler_factory``. Inclusion in ``__init__`` inhibits CLI generation.""" - self.handler_factory: Callable[[], list[Handler]] | Literal["auto"] | None = "auto" + handler_factory: typing.Callable[[], list[logging.Handler]] | typing.Literal["auto"] | None = ( + field(default="auto", init=False) + ) def set_handler_factory( - self, factory: Callable[[], list[Handler]] | Literal["auto"] | None + self, factory: typing.Callable[[], list[logging.Handler]] | typing.Literal["auto"] | None ) -> Self: - """Setter for log handler factory.""" + """Provide a ``factory`` to set custom handling of agent and worker logs. + + Parameters: + factory: Factory function to generate :obj:`logging.Handler` objects. + + See `custom logging `_. + """ self.handler_factory = factory return self - def run( # noqa: C901, PLR0912 + def run( # noqa: C901, PLR0912, PLR0915 self, - func: Callable, - args: tuple | None = None, - kwargs: dict[str, Any] | None = None, - ) -> LaunchResult: - """Launch a function using class configuration.""" + func: typing.Callable[FunctionP, FunctionR], + *args: FunctionP.args, + **kwargs: FunctionP.kwargs, + ) -> LaunchResult[FunctionR]: + """Distribute a function onto specified nodes and parallelize across workers. + + Raises: + RuntimeError: Configuration issues. + Exception: Exceptions raised in worker processes are propagated + (if ``propagate_exceptions=True``). + WorkerFailedError: If a worker fails (e.g. from a segmentation fault) + or raises an exception with ``propagate_exceptions=False``. + AgentFailedError: If an agent fails, e.g. from an OS signal. + """ if not dist.is_available(): msg = "The torch.distributed package is not available." raise RuntimeError(msg) - hostnames: list[str] = _resolve_hostnames(self.hostnames) - workers_per_host: list[int] = _resolve_workers_per_host(hostnames, self.workers_per_host) + ### + + hostnames, workers_per_host, backend = resolve_environment( + self.hostnames, self.workers_per_host, self.backend, self.ssh_config_file + ) + ssh_config_file = self.ssh_config_file + timeout = self.timeout + + env_vars = { + k: v + for k, v in os.environ.items() + if any(fnmatch.fnmatch(k, e) for e in self.copy_env_vars) + } + if self.extra_env_vars is not None: + env_vars.update(self.extra_env_vars) + env_file = self.env_file + + propagate_exceptions = self.propagate_exceptions + handler_factory = self.handler_factory + + ### launcher_hostname = socket.getfqdn() launcher_port = get_open_port() @@ -192,19 +138,31 @@ def run( # noqa: C901, PLR0912 stop_logging_event = None log_process = None launcher_agent_group = None + + _cumulative_workers = [0, *itertools.accumulate(workers_per_host)] + worker_global_ranks = [ + list(range(_cumulative_workers[n], _cumulative_workers[n + 1])) + for n in range(len(hostnames)) + ] + payload = LauncherPayload( + fn=partial(func, *args, **kwargs), + hostnames=hostnames, + worker_global_ranks=worker_global_ranks, + worker_world_size=sum(workers_per_host), + backend=backend, + timeout=timeout, + ) agent_payloads = None try: # Start logging server (recieves LogRecords from agents/workers) logging_server_args = LoggingServerArgs( - handler_factory=self.handler_factory, + handler_factory=handler_factory, logging_hostname=launcher_hostname, logging_port=logging_port, hostnames=hostnames, workers_per_host=workers_per_host, - log_dir=Path(os.environ.get("TORCHRUNX_LOG_DIR", "torchrunx_logs")), - log_level=logging._nameToLevel[os.environ.get("TORCHRUNX_LOG_LEVEL", "INFO")], # noqa: SLF001 ) stop_logging_event = Event() @@ -220,24 +178,24 @@ def run( # noqa: C901, PLR0912 # Start agents on each node for i, hostname in enumerate(hostnames): - _execute_command( - command=_build_launch_command( + execute_command( + command=build_launch_command( launcher_hostname=launcher_hostname, launcher_port=launcher_port, logger_port=logging_port, world_size=world_size, rank=i + 1, - env_vars=(self.default_env_vars + self.extra_env_vars), - env_file=self.env_file, + env_vars=env_vars, + env_file=env_file, ), hostname=hostname, - ssh_config_file=self.ssh_config_file, + ssh_config_file=ssh_config_file, ) # Initialize launcher-agent process group # ranks = (launcher, agent_{hostnames[0]}, ..., agent[-1]) - launcher_agent_group = LauncherAgentGroup( + launcher_agent_group = LauncherAgentGroup[FunctionR]( launcher_hostname=launcher_hostname, launcher_port=launcher_port, world_size=world_size, @@ -246,22 +204,6 @@ def run( # noqa: C901, PLR0912 # Sync initial payloads between launcher and agents - _cumulative_workers = [0, *itertools.accumulate(workers_per_host)] - - worker_global_ranks = [ - list(range(_cumulative_workers[n], _cumulative_workers[n + 1])) - for n in range(len(hostnames)) - ] - - payload = LauncherPayload( - fn=partial(func, *(args or ()), **(kwargs or {})), - hostnames=hostnames, - worker_global_ranks=worker_global_ranks, - worker_world_size=sum(workers_per_host), - backend=self.backend, - timeout=self.timeout, - ) - launcher_payload, agent_payloads = launcher_agent_group.sync_payloads(payload=payload) # Monitor agent statuses (until failed or done) @@ -272,16 +214,17 @@ def run( # noqa: C901, PLR0912 # raises specific exception if any agent fails for s in agent_statuses: - for value in s.return_values: - if isinstance(value, ExceptionFromWorker): - if self.propagate_exceptions: - raise value.exception - raise WorkerFailedError from value.exception - if isinstance(value, WorkerFailedError): - raise value + for v in s.return_values: + if isinstance(v, ExceptionFromWorker): + if propagate_exceptions: + raise v.exception + raise WorkerFailedError from v.exception + if isinstance(v, WorkerFailedError): + raise v if all(s.state == "done" for s in agent_statuses): - break + return_values: list[list[FunctionR]] = [s.return_values for s in agent_statuses] # pyright: ignore [reportAssignmentType] + return LaunchResult.from_returns(hostnames, return_values) finally: if stop_logging_event is not None: stop_logging_event.set() @@ -294,152 +237,31 @@ def run( # noqa: C901, PLR0912 # cleanup: SIGTERM all agents if agent_payloads is not None: for agent_payload, agent_hostname in zip(agent_payloads, hostnames): - _execute_command( + execute_command( command=f"kill {agent_payload.process_id}", hostname=agent_hostname, - ssh_config_file=self.ssh_config_file, + ssh_config_file=ssh_config_file, ) - # if launch is successful: return objects from workers - return_values = [s.return_values for s in agent_statuses] - return LaunchResult(hostnames=hostnames, return_values=return_values) - @dataclass -class LaunchResult: +class LaunchResult(Generic[FunctionR]): """Container for objects returned from workers after successful launches.""" - def __init__(self, hostnames: list[str], return_values: list[list[Any]]) -> None: - """Initialize from corresponding lists of hostnames and worker return values.""" - self.results: dict[str, list[Any]] = dict(zip(hostnames, return_values)) + results: dict[str, list[FunctionR]] # [hostname][local_rank] -> FunctionR + + @classmethod + def from_returns(cls, hostnames: list[str], return_values: list[list[FunctionR]]) -> Self: # noqa: D102 + return cls(results=dict(zip(hostnames, return_values))) - def index(self, hostname: str, locak_rank: int) -> Any: + def index(self, hostname: str, locak_rank: int) -> FunctionR: """Get return value from worker by host and local rank.""" return self.results[hostname][locak_rank] - def rank(self, i: int) -> Any: + def rank(self, i: int) -> FunctionR: """Get return value from worker by global rank.""" for results_per_host in self.results.values(): if i < len(results_per_host): return results_per_host[i] i -= len(results_per_host) raise IndexError - - -def _resolve_hostnames(hostnames: list[str] | Literal["auto", "slurm"]) -> list[str]: - if hostnames == "auto": - return auto_hosts() - if hostnames == "slurm": - return slurm_hosts() - return hostnames - - -def _resolve_workers_per_host( - hostnames: list[str], - workers_per_host: int | list[int] | Literal["auto"], -) -> list[int]: - if isinstance(workers_per_host, int): - return [workers_per_host] * len(hostnames) - - if workers_per_host == "auto": - python = shlex.quote(sys.executable) - command = f"{python} -c \"import torch; print(torch.cuda.device_count(), end='')\"" - gpus_per_host = [ - int(_execute_command(command, hostname, return_stdout_stderr=True)[0]) - for hostname in hostnames - ] - if any(g == 0 for g in gpus_per_host): - msg = 'workers_per_host="auto", but no GPUs detected on at least one host.' - raise RuntimeError(msg) - return gpus_per_host - - return workers_per_host - - -def _build_launch_command( - launcher_hostname: str, - launcher_port: int, - logger_port: int, - world_size: int, - rank: int, - env_vars: tuple[str, ...], - env_file: str | os.PathLike | None, -) -> str: - # shlex.quote prevents shell injection here (resolves S602 in execute_command) - - commands = [] - - current_dir = shlex.quote(str(Path.cwd())) - commands.append("cd " + current_dir) - - env_exports = [] - for k, v in os.environ.items(): - if any(fnmatch.fnmatch(k, e) for e in env_vars): - env_exports.append(shlex.quote(f"{k}={v}")) - - if len(env_exports) > 0: - commands.append("export " + " ".join(env_exports)) - - if env_file is not None: - commands.append("source " + shlex.quote(str(env_file))) - - python = shlex.quote(sys.executable) - launcher_hostname = shlex.quote(launcher_hostname) - - commands.append( - f"{python} -u -m torchrunx " - f"--launcher-hostname {launcher_hostname} " - f"--launcher-port {launcher_port} " - f"--logger-port {logger_port} " - f"--world-size {world_size} " - f"--rank {rank}", - ) - - return " && ".join(commands) - - -def _execute_command( - command: str, - hostname: str, - *, - ssh_config_file: str | os.PathLike | None = None, - return_stdout_stderr: bool = False, -) -> tuple[str, str]: - is_localhost = True - _hostname_or_ip = hostname - try: - _ip = ipaddress.ip_address(_hostname_or_ip) - except ValueError: - _ip = ipaddress.ip_address(socket.gethostbyname(_hostname_or_ip)) - if not _ip.is_loopback: - # compare local interface addresses between host and localhost - _host_addrs = [addr[4][0] for addr in socket.getaddrinfo(str(_ip), None)] - _localhost_addrs = [addr[4][0] for addr in socket.getaddrinfo(socket.gethostname(), None)] - is_localhost = len(set(_host_addrs) & set(_localhost_addrs)) > 0 - - if is_localhost: - # S602: subprocess.Popen is called with shell=True (https://docs.python.org/3.9/library/subprocess.html#security-considerations) - # Made sure to shlex.quote arguments in build_command to prevent shell injection - process = subprocess.Popen( # noqa: S602 - command, shell=True, text=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE - ) - - if return_stdout_stderr: - stdout, stderr = process.communicate() - return stdout, stderr - else: - runtime_ssh_path = ssh_config_file - if isinstance(ssh_config_file, os.PathLike): - runtime_ssh_path = str(ssh_config_file) - - with fabric.Connection( - host=hostname, - config=fabric.Config(runtime_ssh_path=runtime_ssh_path), - ) as conn: - promise = conn.run(command, asynchronous=True, hide=True) - - if return_stdout_stderr: - results = promise.join() - return results.stdout, results.stderr - - return ("", "") diff --git a/src/torchrunx/utils/__init__.py b/src/torchrunx/utils/__init__.py index dc4af98..d6b94d1 100644 --- a/src/torchrunx/utils/__init__.py +++ b/src/torchrunx/utils/__init__.py @@ -1,5 +1,3 @@ -"""Utility classes and functions.""" - from .logging import add_filter_to_handler, file_handler, stream_handler __all__ = ["add_filter_to_handler", "file_handler", "stream_handler"] diff --git a/src/torchrunx/utils/comm.py b/src/torchrunx/utils/comm.py index 7634c9c..da68563 100644 --- a/src/torchrunx/utils/comm.py +++ b/src/torchrunx/utils/comm.py @@ -15,7 +15,7 @@ import socket from contextlib import closing from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Callable, Literal +from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, TypeVar import cloudpickle import torch.distributed as dist @@ -34,8 +34,12 @@ def get_open_port() -> int: return s.getsockname()[1] +ObjectT = TypeVar("ObjectT", bound=Any) +FunctionR = TypeVar("FunctionR") + + @dataclass -class LauncherAgentGroup: +class LauncherAgentGroup(Generic[FunctionR]): """Initializes a GLOO distributed process group between launcher and all agents.""" launcher_hostname: str @@ -62,25 +66,24 @@ def __post_init__(self) -> None: timeout=datetime.timedelta(seconds=30), ) - def _serialize(self, obj: Any) -> bytes: - return cloudpickle.dumps(obj) - - def _deserialize(self, serialized: bytes) -> Any: - return cloudpickle.loads(serialized) - - def _all_gather(self, obj: Any) -> list: + def _all_gather(self, obj: ObjectT) -> list[ObjectT]: """Gather object from each rank to list (in rank-order). Raises: AgentFailedError: if any agent fails (observed by this communication). """ try: - rank_obj = self._serialize((self.rank, obj)) - rank_obj_list = [b""] * self.world_size - # raises RuntimeError if timeout - dist.all_gather_object(object_list=rank_obj_list, obj=rank_obj, group=self.group) - rank_obj_list = sorted([self._deserialize(o) for o in rank_obj_list]) - return [obj for _, obj in sorted(rank_obj_list)] + rank_obj = cloudpickle.dumps((self.rank, obj)) + all_gather_list = [b""] * self.world_size + + dist.all_gather_object( + object_list=all_gather_list, obj=rank_obj, group=self.group + ) # raises RuntimeError if timeout + + rank_obj_list: list[tuple[int, ObjectT]] = sorted( + [cloudpickle.loads(o) for o in all_gather_list] + ) + return [obj for _, obj in rank_obj_list] except RuntimeError as e: # occurs if launcher or any agent dies and communication times out raise AgentFailedError from e @@ -91,13 +94,17 @@ def sync_payloads( ) -> tuple[LauncherPayload, list[AgentPayload]]: """All-gather payloads across launcher and all agents.""" payloads = self._all_gather(payload) - launcher_payload = payloads[0] - agent_payloads = payloads[1:] + launcher_payload: LauncherPayload = payloads[0] # pyright: ignore [reportAssignmentType] + agent_payloads: list[AgentPayload] = payloads[1:] # pyright: ignore [reportAssignmentType] return launcher_payload, agent_payloads - def sync_agent_statuses(self, status: AgentStatus | None) -> list[AgentStatus]: + def sync_agent_statuses( + self, status: AgentStatus[FunctionR] | None + ) -> list[AgentStatus[FunctionR]]: """All-gather agent statuses across launcher and all agents.""" - return self._all_gather(status)[1:] # [0] is launcher (status=None) + # only launcher has status = None + agent_statuses: list[AgentStatus[FunctionR]] = self._all_gather(status)[1:] # pyright: ignore [reportAssignmentType] + return agent_statuses def shutdown(self) -> None: """Terminate process group.""" @@ -126,7 +133,7 @@ class AgentPayload: @dataclass -class AgentStatus: +class AgentStatus(Generic[FunctionR]): """Status of each agent (to be synchronized in LauncherAgentGroup). Attributes: @@ -135,7 +142,7 @@ class AgentStatus: """ state: Literal["running", "failed", "done"] - return_values: list[Any | WorkerFailedError | ExceptionFromWorker] = field( + return_values: list[FunctionR | WorkerFailedError | ExceptionFromWorker] = field( default_factory=list ) # indexed by local rank diff --git a/src/torchrunx/utils/environment.py b/src/torchrunx/utils/environment.py index ca5f0e7..b5fd5fd 100644 --- a/src/torchrunx/utils/environment.py +++ b/src/torchrunx/utils/environment.py @@ -2,10 +2,62 @@ from __future__ import annotations -__all__ = ["auto_hosts", "in_slurm_job", "slurm_hosts"] +from typing import Literal, Union +from typing_extensions import TypeAlias + +__all__ = [ + "auto_hosts", + "build_launch_command", + "execute_command", + "get_gpus_per_host", + "in_slurm_job", + "slurm_hosts", +] + +import ipaddress import os +import shlex +import socket import subprocess +import sys +from pathlib import Path + +import fabric + +Hostnames: TypeAlias = list[str] +WorkersPerHost: TypeAlias = list[int] +Backend: TypeAlias = Union[Literal["nccl", "gloo", "mpi", "ucc"], None] + + +def resolve_environment( + hostnames: list[str] | Literal["auto", "slurm"], + workers_per_host: int | list[int] | Literal["auto"], + backend: Literal["nccl", "gloo", "mpi", "ucc", "auto"] | None, + ssh_config_file: str | os.PathLike | None, +) -> tuple[Hostnames, WorkersPerHost, Backend]: + if hostnames == "auto": + hostnames = auto_hosts() + elif hostnames == "slurm": + hostnames = slurm_hosts() + + if isinstance(workers_per_host, int): + workers_per_host = [workers_per_host] * len(hostnames) + + if workers_per_host == "auto" or backend == "auto": + gpus_per_host: list[int] = get_gpus_per_host(hostnames, ssh_config_file) + gpus_on_every_host: bool = all(g > 0 for g in gpus_per_host) + + if workers_per_host == "auto": + if not gpus_on_every_host: + msg = 'workers_per_host="auto", but no GPUs detected on at least one host.' + raise RuntimeError(msg) + workers_per_host = gpus_per_host + + if backend == "auto": + backend = "nccl" if gpus_per_host else "gloo" + + return hostnames, workers_per_host, backend def auto_hosts() -> list[str]: @@ -27,3 +79,103 @@ def slurm_hosts() -> list[str]: raise RuntimeError(msg) return subprocess.check_output(["scontrol", "show", "hostnames"]).decode().strip().split("\n") + + +def get_gpus_per_host(hostnames: list[str], ssh_config_file: str | os.PathLike | None) -> list[int]: + """Count the number of GPUs on each host.""" + python = shlex.quote(sys.executable) + command = f"{python} -c \"import torch; print(torch.cuda.device_count(), end='')\"" + return [ + int( + execute_command( + command, hostname, ssh_config_file=ssh_config_file, return_stdout_stderr=True + )[0] + ) + for hostname in hostnames + ] + + +def build_launch_command( + launcher_hostname: str, + launcher_port: int, + logger_port: int, + world_size: int, + rank: int, + env_vars: dict[str, str], + env_file: str | os.PathLike | None, +) -> str: + """Generator for command to launch torchrunx on an agent.""" + # shlex.quote prevents shell injection here (resolves S602 in execute_command) + + commands = [] + + commands.append(f"cd {shlex.quote(str(Path.cwd()))}") + + env_exports = [shlex.quote(f"{k}={v}") for k, v in env_vars.items()] + if len(env_exports) > 0: + commands.append("export " + " ".join(env_exports)) + + if env_file is not None: + commands.append("source " + shlex.quote(str(env_file))) + + python = shlex.quote(sys.executable) + launcher_hostname = shlex.quote(launcher_hostname) + + commands.append( + f"{python} -u -m torchrunx " + f"--launcher-hostname {launcher_hostname} " + f"--launcher-port {launcher_port} " + f"--logger-port {logger_port} " + f"--world-size {world_size} " + f"--rank {rank}", + ) + + return " && ".join(commands) + + +def execute_command( + command: str, + hostname: str, + *, + ssh_config_file: str | os.PathLike | None = None, + return_stdout_stderr: bool = False, +) -> tuple[str, str]: + """Run a command on local or remote host (using SSH).""" + is_localhost = True + _hostname_or_ip = hostname + try: + _ip = ipaddress.ip_address(_hostname_or_ip) + except ValueError: + _ip = ipaddress.ip_address(socket.gethostbyname(_hostname_or_ip)) + if not _ip.is_loopback: + # compare local interface addresses between host and localhost + _host_addrs = [addr[4][0] for addr in socket.getaddrinfo(str(_ip), None)] + _localhost_addrs = [addr[4][0] for addr in socket.getaddrinfo(socket.gethostname(), None)] + is_localhost = len(set(_host_addrs) & set(_localhost_addrs)) > 0 + + if is_localhost: + # S602: subprocess.Popen is called with shell=True (https://docs.python.org/3.9/library/subprocess.html#security-considerations) + # Made sure to shlex.quote arguments in build_command to prevent shell injection + process = subprocess.Popen( # noqa: S602 + command, shell=True, text=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + + if return_stdout_stderr: + stdout, stderr = process.communicate() + return stdout, stderr + else: + runtime_ssh_path = ssh_config_file + if isinstance(ssh_config_file, os.PathLike): + runtime_ssh_path = str(ssh_config_file) + + with fabric.Connection( + host=hostname, + config=fabric.Config(runtime_ssh_path=runtime_ssh_path), + ) as conn: + promise = conn.run(command, asynchronous=True, hide=True) + + if return_stdout_stderr: + results = promise.join() + return results.stdout, results.stderr + + return ("", "") diff --git a/src/torchrunx/utils/logging.py b/src/torchrunx/utils/logging.py index 643c968..f139941 100644 --- a/src/torchrunx/utils/logging.py +++ b/src/torchrunx/utils/logging.py @@ -16,6 +16,7 @@ import datetime import logging +import os import pickle import signal import struct @@ -25,22 +26,19 @@ from io import StringIO from logging import Handler, Logger from logging.handlers import SocketHandler +from multiprocessing.synchronize import Event as EventClass from pathlib import Path from socketserver import StreamRequestHandler, ThreadingTCPServer -from typing import TYPE_CHECKING, Callable, Literal +from typing import Callable, Literal import cloudpickle from typing_extensions import Self -if TYPE_CHECKING: - import os - from multiprocessing.synchronize import Event as EventClass - ## Handler utilities def add_filter_to_handler( - handler: Handler, + handler: logging.Handler, hostname: str, local_rank: int | None, # None indicates agent log_level: int = logging.NOTSET, @@ -64,9 +62,29 @@ def _filter(record: WorkerLogRecord) -> bool: handler.addFilter(_filter) # pyright: ignore [reportArgumentType] +def default_handlers( + hostnames: list[str], + workers_per_host: list[int], + log_level: int = logging.INFO, +) -> list[logging.Handler]: + """Default :mod:`logging.Handler`s for ``log_handlers="auto"`` in :mod:`torchrunx.launch`. + + Logs for ``host[0]`` and its ``local_rank[0]`` worker are written to launcher process stdout. + Logs for all agents/workers are written to files in ``log_dir`` (named by timestamp, hostname, + local_rank). + """ + log_dir = Path(os.environ.get("TORCHRUNX_LOG_DIR", "torchrunx_logs")) + log_level = logging._nameToLevel[os.environ.get("TORCHRUNX_LOG_LEVEL", "INFO")] # noqa: SLF001 + return [ + stream_handler(hostname=hostnames[0], local_rank=None, log_level=log_level), + stream_handler(hostname=hostnames[0], local_rank=0, log_level=log_level), + *file_handlers(hostnames, workers_per_host, log_dir=log_dir, log_level=log_level), + ] + + def stream_handler( hostname: str, local_rank: int | None, log_level: int = logging.NOTSET -) -> Handler: +) -> logging.Handler: """Handler builder function for writing logs from specified hostname/rank to stdout.""" handler = logging.StreamHandler(stream=sys.stdout) add_filter_to_handler(handler, hostname, local_rank, log_level=log_level) @@ -86,7 +104,7 @@ def file_handler( local_rank: int | None, file_path: str | os.PathLike, log_level: int = logging.NOTSET, -) -> Handler: +) -> logging.Handler: """Handler builder function for writing logs from specified hostname/rank to a file.""" handler = logging.FileHandler(file_path) add_filter_to_handler(handler, hostname, local_rank, log_level=log_level) @@ -101,7 +119,7 @@ def file_handlers( workers_per_host: list[int], log_dir: str | os.PathLike = Path("torchrunx_logs"), log_level: int = logging.NOTSET, -) -> list[Handler]: +) -> list[logging.Handler]: """Handler builder function for writing logs for all workers/agents to a directory. Files are named with hostname and the local_rank (for workers). @@ -121,25 +139,6 @@ def file_handlers( return handlers -def default_handlers( - hostnames: list[str], - workers_per_host: list[int], - log_dir: str | os.PathLike = Path("torchrunx_logs"), - log_level: int = logging.INFO, -) -> list[Handler]: - """Default :mod:`logging.Handler`s for ``log_handlers="auto"`` in :mod:`torchrunx.launch`. - - Logs for ``host[0]`` and its ``local_rank[0]`` worker are written to launcher process stdout. - Logs for all agents/workers are written to files in ``log_dir`` (named by timestamp, hostname, - local_rank). - """ - return [ - stream_handler(hostname=hostnames[0], local_rank=None, log_level=log_level), - stream_handler(hostname=hostnames[0], local_rank=0, log_level=log_level), - *file_handlers(hostnames, workers_per_host, log_dir=log_dir, log_level=log_level), - ] - - ## Launcher utilities @@ -193,15 +192,13 @@ class LoggingServerArgs: logging_port: int hostnames: list[str] workers_per_host: list[int] - log_dir: str | os.PathLike - log_level: int def serialize(self) -> bytes: """Serialize :class:`LoggingServerArgs` for passing to a new process.""" return cloudpickle.dumps(self) - @staticmethod - def deserialize(serialized: bytes) -> LoggingServerArgs: + @classmethod + def from_bytes(cls, serialized: bytes) -> Self: """Deserialize bytes to :class:`LoggingServerArgs`.""" return cloudpickle.loads(serialized) @@ -211,7 +208,7 @@ def start_logging_server( stop_event: EventClass, ) -> None: """Serve :class:`_LogRecordSocketReceiver` until stop event triggered.""" - args = LoggingServerArgs.deserialize(serialized_args) + args = LoggingServerArgs.from_bytes(serialized_args) log_handlers = [] if args.handler_factory is None: @@ -220,8 +217,6 @@ def start_logging_server( log_handlers = default_handlers( hostnames=args.hostnames, workers_per_host=args.workers_per_host, - log_dir=args.log_dir, - log_level=args.log_level, ) elif isinstance(args.handler_factory, Callable): log_handlers = args.handler_factory() diff --git a/src/torchrunx/worker.py b/src/torchrunx/worker.py index e730752..422f9cc 100644 --- a/src/torchrunx/worker.py +++ b/src/torchrunx/worker.py @@ -7,12 +7,13 @@ import os import sys import traceback -from dataclasses import dataclass +from dataclasses import asdict, dataclass from typing import Any, Callable, Literal import cloudpickle import torch import torch.distributed as dist +from typing_extensions import Self from .utils.errors import ExceptionFromWorker from .utils.logging import log_records_to_socket, redirect_stdio_to_logger @@ -38,29 +39,24 @@ class WorkerArgs: hostname: str timeout: int - def serialize(self) -> SerializedWorkerArgs: + def serialize(self) -> bytes: """Arguments must be serialized (to bytes) before passed to spawned workers.""" - return SerializedWorkerArgs(worker_args=self) + return cloudpickle.dumps(asdict(self)) + @classmethod + def from_bytes(cls, b: bytes) -> Self: + """Deserialize the bytes back into a WorkerArgs object.""" + return cls(**cloudpickle.loads(b)) -class SerializedWorkerArgs: - """We use cloudpickle as a serialization backend (as it supports nearly all Python types).""" - def __init__(self, worker_args: WorkerArgs) -> None: - self.bytes = cloudpickle.dumps(worker_args) - - def deserialize(self) -> WorkerArgs: - return cloudpickle.loads(self.bytes) - - -def worker_entrypoint(serialized_worker_args: SerializedWorkerArgs) -> Any | ExceptionFromWorker: +def worker_entrypoint(serialized_worker_args: bytes) -> Any | ExceptionFromWorker: """Function called by spawned worker processes. Workers first prepare a process group (for communicating with all other workers). They then invoke the user-provided function. Logs are transmitted to the launcher process. """ - worker_args: WorkerArgs = serialized_worker_args.deserialize() + worker_args = WorkerArgs.from_bytes(serialized_worker_args) # Start logging to the logging server (i.e. the launcher) diff --git a/tests/test_ci.py b/tests/test_ci.py index 98cce6f..27e4223 100644 --- a/tests/test_ci.py +++ b/tests/test_ci.py @@ -32,11 +32,10 @@ def dist_func() -> torch.Tensor: tmp = tempfile.mkdtemp() os.environ["TORCHRUNX_DIR"] = tmp - r = trx.launch( - dist_func, + r = trx.Launcher( workers_per_host=2, - backend="gloo", # log_dir="./test_logs" - ) + backend="gloo", + ).run(dist_func) assert torch.all(r.rank(0) == r.rank(1)) @@ -55,10 +54,11 @@ def dist_func() -> None: time.sleep(1) - trx.launch( - dist_func, + trx.Launcher( workers_per_host=num_workers, backend="gloo", + ).run( + dist_func, ) after_timestamp = datetime.datetime.now() @@ -95,10 +95,11 @@ def error_func() -> NoReturn: os.environ["TORCHRUNX_DIR"] = tmp with pytest.raises(ValueError) as excinfo: # noqa: PT011 - trx.launch( - error_func, + trx.Launcher( workers_per_host=1, backend="gloo", + ).run( + error_func, ) assert "abcdefg" in str(excinfo.value) diff --git a/tests/test_func.py b/tests/test_func.py index f0474b9..1a0fc5c 100644 --- a/tests/test_func.py +++ b/tests/test_func.py @@ -9,10 +9,9 @@ def test_launch() -> None: - result = trx.launch( - func=simple_matmul, + result = trx.Launcher( hostnames="slurm", - ) + ).run(simple_matmul) result_values = reduce(add, result.results.values()) diff --git a/tests/test_submitit.py b/tests/test_submitit.py index 1f639df..75a91f1 100644 --- a/tests/test_submitit.py +++ b/tests/test_submitit.py @@ -53,7 +53,7 @@ def main() -> None: def launch() -> None: - trx.launch(main, hostnames="slurm") + trx.Launcher(hostnames="slurm").run(main) def test_submitit() -> None: diff --git a/tests/test_train_gpu.py b/tests/test_train_gpu.py index e1cbad3..6be44c0 100644 --- a/tests/test_train_gpu.py +++ b/tests/test_train_gpu.py @@ -32,10 +32,9 @@ def worker() -> None: def test_distributed_train() -> None: - trx.launch( - worker, + trx.Launcher( backend="nccl", - ) + ).run(worker) if __name__ == "__main__": diff --git a/uv.lock b/uv.lock index e244261..13fe003 100644 --- a/uv.lock +++ b/uv.lock @@ -1,4 +1,5 @@ version = 1 +revision = 1 requires-python = ">=3.9" resolution-markers = [ "python_full_version >= '3.13'", @@ -58,11 +59,11 @@ wheels = [ [[package]] name = "babel" -version = "2.17.0" +version = "2.16.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/7d/6b/d52e42361e1aa00709585ecc30b3f9684b3ab62530771402248b1b1d6240/babel-2.17.0.tar.gz", hash = "sha256:0c54cffb19f690cdcc52a3b50bcbf71e07a808d1c80d549f2459b9d2cf0afb9d", size = 9951852 } +sdist = { url = "https://files.pythonhosted.org/packages/2a/74/f1bc80f23eeba13393b7222b11d95ca3af2c1e28edca18af487137eefed9/babel-2.16.0.tar.gz", hash = "sha256:d1f3554ca26605fe173f3de0c65f750f5a42f924499bf134de6423582298e316", size = 9348104 } wheels = [ - { url = "https://files.pythonhosted.org/packages/b7/b8/3fe70c75fe32afc4bb507f75563d39bc5642255d1d94f1f23604725780bf/babel-2.17.0-py3-none-any.whl", hash = "sha256:4d0b53093fdfb4b21c92b5213dba5a1b23885afa8383709427046b21c366e5f2", size = 10182537 }, + { url = "https://files.pythonhosted.org/packages/ed/20/bc79bc575ba2e2a7f70e8a1155618bb1301eaa5132a8271373a6903f73f8/babel-2.16.0-py3-none-any.whl", hash = "sha256:368b5b98b37c06b7daf6696391c3240c938b37767d4584413e8438c5c435fa8b", size = 9587599 }, ] [[package]] @@ -1810,6 +1811,7 @@ dependencies = [ { name = "numpy", version = "2.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, { name = "numpy", version = "2.2.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, { name = "torch" }, + { name = "typing-extensions" }, ] [package.dev-dependencies] @@ -1835,6 +1837,7 @@ requires-dist = [ { name = "fabric", specifier = ">=3.2" }, { name = "numpy", specifier = ">=1.20" }, { name = "torch", specifier = ">=2.0" }, + { name = "typing-extensions", specifier = ">=4.9.0" }, ] [package.metadata.requires-dev]