Skip to content

Support type checking #86

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ Requires: Linux (+ SSH & shared filesystem if using multiple machines)
Dummy distributed training function:

```python
from __future__ import annotations
import os
import torch
import torch.nn as nn
Expand Down Expand Up @@ -59,26 +60,24 @@ Launching training with `torchrunx`:
```python
import torchrunx

results = torchrunx.launch(
func = train,
kwargs = dict(
model = nn.Linear(10, 10),
num_steps = 10
),
#
results = torchrunx.Launcher(
hostnames = ["localhost", "second_machine"],
workers_per_host = 2
).run(
train,
model = nn.Linear(10, 10),
num_steps = 10
)

trained_model: nn.Module = results.rank(0)
torch.save(trained_model.state_dict(), "output/model.pth")
```

**See examples where we fine-tune LLMs (e.g. GPT-2 on WikiText) using:**
- [Accelerate](https://torchrun.xyz/examples/accelerate.html)
- [HF Transformers](https://torchrun.xyz/examples/transformers.html)
- [Transformers](https://torchrun.xyz/examples/transformers.html)
- [DeepSpeed](https://torchrun.xyz/examples/deepspeed.html)
- [PyTorch Lightning](https://torchrun.xyz/examples/lightning.html)
- [Accelerate](https://torchrun.xyz/examples/accelerate.html)

**Refer to our [API](https://torchrun.xyz/api.html) and [Advanced Usage Guide](https://torchrun.xyz/advanced.html) for many more capabilities!**

Expand Down Expand Up @@ -118,4 +117,4 @@ torch.save(trained_model.state_dict(), "output/model.pth")
> - Automatic detection of SLURM environments.
> - Start multi-node training from Python notebooks!

**On our [roadmap](https://github.com/apoorvkh/torchrunx/issues?q=is%3Aopen+is%3Aissue+label%3Aenhancement): higher-order parallelism, support for debuggers, fuller typing, and more!**
**On our [roadmap](https://github.com/apoorvkh/torchrunx/issues?q=is%3Aopen+is%3Aissue+label%3Aenhancement): higher-order parallelism, support for debuggers, and more!**
7 changes: 4 additions & 3 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@
"sphinx_toolbox.github",
]

maximum_signature_line_length = 90
autodoc_member_order = "bysource"
autodoc_typehints = "description"
autodoc_typehints_description_target = "documented"

intersphinx_mapping = {'python': ('https://docs.python.org/3', None)}
intersphinx_mapping = {
'python': ('https://docs.python.org/3.9', None),
}

from docs.linkcode_github import generate_linkcode_resolve_fn
linkcode_resolve = generate_linkcode_resolve_fn(project, github_username, github_repository)
25 changes: 1 addition & 24 deletions docs/source/api.md
Original file line number Diff line number Diff line change
@@ -1,29 +1,6 @@
# API

```{eval-rst}
.. autofunction:: torchrunx.launch(func, args, kwargs, ...)
```

We provide the {obj}`torchrunx.Launcher` class as an alias to {obj}`torchrunx.launch`.

```{eval-rst}
.. autoclass:: torchrunx.Launcher
:members:
```

## Results

```{eval-rst}
.. autoclass:: torchrunx.LaunchResult
.. automodule:: torchrunx
:members:
```

## Exceptions

```{eval-rst}
.. autoexception:: torchrunx.AgentFailedError
```

```{eval-rst}
.. autoexception:: torchrunx.WorkerFailedError
```
8 changes: 4 additions & 4 deletions docs/source/features/cli.md
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
# CLI Integration

We can use {mod}`torchrunx.Launcher` to populate arguments from the CLI (e.g. with [tyro](https://brentyi.github.io/tyro/)):
We can automatically populate {mod}`torchrunx.Launcher` arguments using most CLI tools (those that generate interfaces from Data Classes, e.g. [tyro](https://brentyi.github.io/tyro/)):

```python
import torchrunx as trx
import torchrunx
import tyro

def distributed_function():
pass
...

if __name__ == "__main__":
launcher = tyro.cli(trx.Launcher)
launcher = tyro.cli(torchrunx.Launcher)
launcher.run(distributed_function)
```

Expand Down
13 changes: 10 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ authors = [
]
description = "Automatically initialize distributed PyTorch environments"
readme = "README.md"
license = {file = "LICENSE"}
license = { file = "LICENSE" }
urls = { Repository = "https://github.com/apoorvkh/torchrunx.git", Documentation = "https://torchrun.xyz" }
requires-python = ">=3.9"
dependencies = [
Expand All @@ -21,12 +21,17 @@ dependencies = [
# torch.distributed depends on numpy
# torch<=2.2 needs numpy<2
"numpy>=1.20",
"typing-extensions>=4.9.0",
]
[dependency-groups]
dev = ["ruff==0.9.5", "pyright[nodejs]==1.1.393", "pytest==8.3.4"]
test-extras = ["submitit", "transformers"]
docs = ["sphinx==7.4.7", "furo==2024.8.6", "myst-parser==3.0.1", "sphinx-toolbox==3.8.2"]

docs = [
"sphinx==7.4.7",
"furo==2024.8.6",
"myst-parser==3.0.1",
"sphinx-toolbox==3.8.2",
]

[tool.ruff]
include = ["pyproject.toml", "src/**/*.py", "tests/**/*.py"]
Expand All @@ -36,6 +41,8 @@ src = ["src", "tests"]
[tool.ruff.lint]
select = ["ALL"]
ignore = [
"TC003", # no type checking blocks for stdlib
"D104", # package docstrings
"ANN401", # self / cls / Any annotations
"BLE001", # blind exceptions
"TD", # todo syntax
Expand Down
12 changes: 5 additions & 7 deletions src/torchrunx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
"""API for our torchrunx library."""

from .launcher import Launcher, LaunchResult, launch
from .launcher import DEFAULT_ENV_VARS_FOR_COPY, Launcher, LaunchResult
from .utils.errors import AgentFailedError, WorkerFailedError

__all__ = [
"AgentFailedError",
"LaunchResult",
__all__ = [ # noqa: RUF022
"DEFAULT_ENV_VARS_FOR_COPY",
"Launcher",
"LaunchResult",
"AgentFailedError",
"WorkerFailedError",
"launch",
]
1 change: 0 additions & 1 deletion src/torchrunx/integrations/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
"""Utilities for integrations with other libraries."""
Loading