diff --git a/run_train.sh b/run_train.sh index 83319816fe..b24e9047d3 100755 --- a/run_train.sh +++ b/run_train.sh @@ -10,19 +10,21 @@ set -ex # use envs as local overwrites for convenience # e.g. # LOG_RANK=0,1 NGPU=4 ./run_train.sh -# DRY_RUN=1 ./run_train.sh # for config validation without GPU +# COMM_MODE="fake_backend" ./run_train.sh # for config validation without GPU +# COMM_MODE="local_tensor" ./run_train.sh # for local tensor debugging mode NGPU=${NGPU:-"8"} export LOG_RANK=${LOG_RANK:-0} CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/debug_model.toml"} TRAIN_FILE=${TRAIN_FILE:-"torchtitan.train"} -DRY_RUN=${DRY_RUN:-0} +# COMM_MODE options: "fake_backend" (dry run), "local_tensor" (debug mode), or empty for normal training +COMM_MODE=${COMM_MODE:-""} TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE:-"http://localhost:29510"} -if [ "$DRY_RUN" = "1" ]; then - # Dry run mode: validate configuration without GPU/distributed setup - echo "Running in DRY RUN mode - configuration validation only" - python scripts/dry_run.py --job.config_file ${CONFIG_FILE} "$@" +if [ -n "$COMM_MODE" ]; then + # Communication mode specified: validate configuration or run in debug mode + echo "Running with comm_mode=${COMM_MODE}" + NGPU="${NGPU}" LOCAL_RANK=0 python3 -m "${TRAIN_FILE}" --job.config_file "${CONFIG_FILE}" "$@" --comm.comm_mode=${COMM_MODE} --training.steps=1 else # Normal training with torchrun PYTORCH_ALLOC_CONF="expandable_segments:True" \ diff --git a/scripts/dry_run.py b/scripts/dry_run.py deleted file mode 100644 index fa8e1b4c17..0000000000 --- a/scripts/dry_run.py +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -""" -Dry run trainer for fast configuration validation without GPU/distributed setup. - -This module provides a lightweight trainer that validates job configurations, -model architecture, and dataloader setup without requiring GPU resources or -distributed initialization. Useful for rapid iteration on configuration files -and CI/CD validation pipelines. -""" - -import os -import sys - -# Add parent directory to path to import torchtitan -sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -import torch - -import torchtitan.protocols.train_spec as train_spec_module -from torchtitan.config import JobConfig, TORCH_DTYPE_MAP -from torchtitan.tools import utils -from torchtitan.tools.logging import logger -from torchtitan.train import main, Trainer - - -class DryRunTrainer(Trainer): - """ - A lightweight trainer that validates configurations without GPU allocation. - - This trainer performs comprehensive validation of the training configuration - without allocating GPU resources or initializing distributed setup. It validates: - - - Configuration file parsing and structure - - Model architecture (constructed on meta device) - - Tokenizer initialization - - Dataloader configuration - - Parallelism settings - - Model converters (if specified) - - Unlike the regular Trainer, this does not: - - Allocate GPU memory - - Initialize distributed process groups - - Create optimizers or learning rate schedulers - - Set up checkpointing or metrics - - Run any actual training - - Args: - job_config: JobConfig containing all training configuration parameters - - Note: - Validation completes immediately after initialization. No training loop is executed. - All operations use CPU and meta devices for zero-cost validation. - """ - - def __init__(self, job_config: JobConfig): - torch._C._log_api_usage_once("torchtitan.dry_run") - - self.job_config = job_config - - logger.info(f"Starting job: {job_config.job.description}") - logger.info("DRY RUN MODE - Configuration validation only") - - # Use CPU device (no GPU required) - self.device = torch.device("cpu") - - # Log and validate config - job_config.maybe_log() - logger.info("Configuration parsed successfully") - - # Get train spec - self.train_spec = train_spec_module.get_train_spec(job_config.model.name) - logger.info(f"Train spec loaded for model: {job_config.model.name}") - - # Build tokenizer - self.tokenizer = ( - self.train_spec.build_tokenizer_fn(job_config) - if self.train_spec.build_tokenizer_fn is not None - else None - ) - if self.tokenizer: - logger.info("Tokenizer built successfully") - - # Validate model configuration - model_args = self.train_spec.model_args[job_config.model.flavor] - model_args.update_from_config(job_config) - self.model_args = model_args - - logger.info( - f"Model args validated: {job_config.model.name} {job_config.model.flavor}" - ) - - # Build model on meta device (validates architecture without memory allocation) - logger.info("Validating model architecture...") - with ( - torch.device("meta"), - utils.set_default_dtype(TORCH_DTYPE_MAP[job_config.training.dtype]), - ): - model = self.train_spec.model_cls(model_args) - - # Calculate and log model size - model_param_count, _ = model_args.get_nparams_and_flops( - model, job_config.training.seq_len - ) - logger.info( - f"Model architecture validated: {job_config.model.name} " - f"with {model_param_count:,} parameters" - ) - - # Validate dataloader configuration (build with minimal params) - logger.info("Validating dataloader configuration...") - try: - # Use dp_world_size=1 and dp_rank=0 for dry run - dataloader = self.train_spec.build_dataloader_fn( - dp_world_size=1, - dp_rank=0, - tokenizer=self.tokenizer, - job_config=job_config, - ) - logger.info("Dataloader configuration validated successfully") - except Exception as e: - logger.warning(f"Dataloader validation encountered issue: {e}") - logger.info( - "Note: Some dataloader issues may only appear with actual data paths" - ) - - # Validate model converters if specified - if job_config.model.converters: - logger.info(f"Model converters specified: {job_config.model.converters}") - - # Validate parallelism configuration - parallelism_config = job_config.parallelism - logger.info( - f"Parallelism config: " - f"DP-shard={parallelism_config.data_parallel_shard_degree}, " - f"DP-replicate={parallelism_config.data_parallel_replicate_degree}, " - f"TP={parallelism_config.tensor_parallel_degree}, " - f"PP={parallelism_config.pipeline_parallel_degree}, " - f"CP={parallelism_config.context_parallel_degree}" - ) - - # Summary - logger.info("=" * 80) - logger.info("DRY RUN VALIDATION COMPLETE") - logger.info("=" * 80) - logger.info("All configurations validated successfully!") - logger.info("Configuration is ready for training execution.") - logger.info("=" * 80) - - def train(self): - return - - -if __name__ == "__main__": - main(DryRunTrainer) diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 95588d2c3b..b9d600b164 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -791,6 +791,22 @@ class Comm: save_traces_file_prefix: str = "rank_" """Flight recorder trace files prefix""" + comm_mode: Literal["default", "fake_backend", "local_tensor"] = "default" + """ + Communication mode for distributed training. + + Options: + - "default": Normal distributed training with real communication + - "fake_backend": Fake comm backend for dry run mode only (configuration validation without GPU) + - "local_tensor": Local tensor mode for debugging purposes. There will be only one process + regardless of the number of GPUs. LocalTensor will simulate the computation by running one + rank after another. While the performance will be slow, the numerics should be the same. + This enables us to verify numerics with fewer GPUs. For example, we can directly run 5D + parallelisms within a single node to reduce the combinations we need to use in integration tests. + + NOTE: local_tensor is an experimental feature and automatically uses fake_backend internally. + """ + @dataclass class MemoryEstimation: diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index b209ddfd68..664d23aed2 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -258,12 +258,56 @@ def maybe_enable_amp( ) +def init_fake_mode(world_size: int, comm_mode: str = "fake_backend"): + """Initialize fake backend + + Args: + world_size: The number of GPUs to simulate + comm_mode: Communication mode ("fake_backend" or "local_tensor") + + Returns: + The world size + """ + torch.distributed.init_process_group( + "fake", + rank=0, + world_size=world_size, + ) + + # If local_tensor mode is enabled, initialize LocalTensorMode context + if comm_mode == "local_tensor": + from torch.distributed import _local_tensor + + lm = _local_tensor.LocalTensorMode(world_size) + lm.__enter__() + + # TODO: remove this once the root cause is figured out + torch.manual_seed(42) + + return world_size + + def init_distributed( comm_config: CommConfig, enable_cpu_backend: bool = False, base_folder: str = "", ranks: list[int] | None = None, -): +) -> int: + if comm_config.comm_mode in ("fake_backend", "local_tensor"): + ngpu_str = os.environ.get("NGPU") + if ngpu_str is None: + raise ValueError( + f"NGPU environment variable must be set when using comm_mode={comm_config.comm_mode}" + ) + try: + world_size = int(ngpu_str) + except ValueError as e: + raise ValueError( + f"NGPU environment variable must be a valid integer, got: {ngpu_str}" + ) from e + init_fake_mode(world_size, comm_config.comm_mode) + return world_size + def _warn_overwrite_env(env, val): if env in os.environ: logger.warning( @@ -309,6 +353,8 @@ def _get_distributed_backend(enable_cpu_backend): _ranks=ranks if ranks is not None else [], ) + return torch.distributed.get_world_size() + def set_pg_timeouts(timeout, world_mesh): """ diff --git a/torchtitan/train.py b/torchtitan/train.py index 5cfab998b2..1914429398 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -360,15 +360,13 @@ def __init__(self, job_config: JobConfig): def init_distributed(self) -> ParallelDims: job_config = self.job_config - dist_utils.init_distributed( + world_size = dist_utils.init_distributed( job_config.comm, enable_cpu_backend=job_config.training.enable_cpu_offload, base_folder=job_config.job.dump_folder, ) - world_size = int(os.environ["WORLD_SIZE"]) parallelism_config = job_config.parallelism - return ParallelDims( dp_shard=parallelism_config.data_parallel_shard_degree, dp_replicate=parallelism_config.data_parallel_replicate_degree, @@ -718,6 +716,13 @@ def main(trainer_class: type[Trainer]) -> None: try: trainer = trainer_class(config) + # TODO(local_tensor): Remove this special case once LocalTensor supports + # init_weights() and foreach_allgather. In local tensor mode, skip + # training/checkpointing as the # model is not fully initialized + if config.comm.comm_mode == "local_tensor": + logger.info("Local tensor mode enabled - skipping training execution") + return + if config.checkpoint.create_seed_checkpoint: assert ( int(os.environ["WORLD_SIZE"]) == 1