Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions run_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,21 @@ set -ex
# use envs as local overwrites for convenience
# e.g.
# LOG_RANK=0,1 NGPU=4 ./run_train.sh
# DRY_RUN=1 ./run_train.sh # for config validation without GPU
# COMM_MODE="fake_backend" ./run_train.sh # for config validation without GPU
# COMM_MODE="local_tensor" ./run_train.sh # for local tensor debugging mode
NGPU=${NGPU:-"8"}
export LOG_RANK=${LOG_RANK:-0}
CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/debug_model.toml"}
TRAIN_FILE=${TRAIN_FILE:-"torchtitan.train"}
DRY_RUN=${DRY_RUN:-0}
# COMM_MODE options: "fake_backend" (dry run), "local_tensor" (debug mode), or empty for normal training
COMM_MODE=${COMM_MODE:-""}

TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE:-"http://localhost:29510"}

if [ "$DRY_RUN" = "1" ]; then
# Dry run mode: validate configuration without GPU/distributed setup
echo "Running in DRY RUN mode - configuration validation only"
python scripts/dry_run.py --job.config_file ${CONFIG_FILE} "$@"
if [ -n "$COMM_MODE" ]; then
# Communication mode specified: validate configuration or run in debug mode
echo "Running with comm_mode=${COMM_MODE}"
NGPU="${NGPU}" LOCAL_RANK=0 python3 -m "${TRAIN_FILE}" --job.config_file "${CONFIG_FILE}" "$@" --comm.comm_mode=${COMM_MODE} --training.steps=1
else
# Normal training with torchrun
PYTORCH_ALLOC_CONF="expandable_segments:True" \
Expand Down
159 changes: 0 additions & 159 deletions scripts/dry_run.py

This file was deleted.

16 changes: 16 additions & 0 deletions torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,6 +791,22 @@ class Comm:
save_traces_file_prefix: str = "rank_"
"""Flight recorder trace files prefix"""

comm_mode: Literal["default", "fake_backend", "local_tensor"] = "default"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: comm.comm_mode sounds redundant, I would propose

Suggested change
comm_mode: Literal["default", "fake_backend", "local_tensor"] = "default"
backend: Literal["default", "fake", "local"] = "default"

WDYT?

"""
Communication mode for distributed training.

Options:
- "default": Normal distributed training with real communication
- "fake_backend": Fake comm backend for dry run mode only (configuration validation without GPU)
- "local_tensor": Local tensor mode for debugging purposes. There will be only one process
regardless of the number of GPUs. LocalTensor will simulate the computation by running one
rank after another. While the performance will be slow, the numerics should be the same.
This enables us to verify numerics with fewer GPUs. For example, we can directly run 5D
parallelisms within a single node to reduce the combinations we need to use in integration tests.

NOTE: local_tensor is an experimental feature and automatically uses fake_backend internally.
"""


@dataclass
class MemoryEstimation:
Expand Down
48 changes: 47 additions & 1 deletion torchtitan/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,12 +258,56 @@ def maybe_enable_amp(
)


def init_fake_mode(world_size: int, comm_mode: str = "fake_backend"):
"""Initialize fake backend

Args:
world_size: The number of GPUs to simulate
comm_mode: Communication mode ("fake_backend" or "local_tensor")

Returns:
The world size
"""
torch.distributed.init_process_group(
"fake",
rank=0,
world_size=world_size,
)

# If local_tensor mode is enabled, initialize LocalTensorMode context
if comm_mode == "local_tensor":
from torch.distributed import _local_tensor

lm = _local_tensor.LocalTensorMode(world_size)
lm.__enter__()

# TODO: remove this once the root cause is figured out
torch.manual_seed(42)

return world_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to return world_size any more?



def init_distributed(
comm_config: CommConfig,
enable_cpu_backend: bool = False,
base_folder: str = "",
ranks: list[int] | None = None,
):
) -> int:
if comm_config.comm_mode in ("fake_backend", "local_tensor"):
ngpu_str = os.environ.get("NGPU")
if ngpu_str is None:
raise ValueError(
f"NGPU environment variable must be set when using comm_mode={comm_config.comm_mode}"
)
try:
world_size = int(ngpu_str)
except ValueError as e:
raise ValueError(
f"NGPU environment variable must be a valid integer, got: {ngpu_str}"
) from e
init_fake_mode(world_size, comm_config.comm_mode)
return world_size

def _warn_overwrite_env(env, val):
if env in os.environ:
logger.warning(
Expand Down Expand Up @@ -309,6 +353,8 @@ def _get_distributed_backend(enable_cpu_backend):
_ranks=ranks if ranks is not None else [],
)

return torch.distributed.get_world_size()


def set_pg_timeouts(timeout, world_mesh):
"""
Expand Down
11 changes: 8 additions & 3 deletions torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,15 +360,13 @@ def __init__(self, job_config: JobConfig):

def init_distributed(self) -> ParallelDims:
job_config = self.job_config
dist_utils.init_distributed(
world_size = dist_utils.init_distributed(
job_config.comm,
enable_cpu_backend=job_config.training.enable_cpu_offload,
base_folder=job_config.job.dump_folder,
)

world_size = int(os.environ["WORLD_SIZE"])
parallelism_config = job_config.parallelism

return ParallelDims(
dp_shard=parallelism_config.data_parallel_shard_degree,
dp_replicate=parallelism_config.data_parallel_replicate_degree,
Expand Down Expand Up @@ -718,6 +716,13 @@ def main(trainer_class: type[Trainer]) -> None:
try:
trainer = trainer_class(config)

# TODO(local_tensor): Remove this special case once LocalTensor supports
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similarly, can we remove this now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are still some gaps. I updated the comment.

# init_weights() and foreach_allgather. In local tensor mode, skip
# training/checkpointing as the # model is not fully initialized
if config.comm.comm_mode == "local_tensor":
logger.info("Local tensor mode enabled - skipping training execution")
return

if config.checkpoint.create_seed_checkpoint:
assert (
int(os.environ["WORLD_SIZE"]) == 1
Expand Down
Loading