diff --git a/.ci/docker/common/install_conda.sh b/.ci/docker/common/install_conda.sh index c2f316b04b..d3cb20e7a3 100755 --- a/.ci/docker/common/install_conda.sh +++ b/.ci/docker/common/install_conda.sh @@ -43,6 +43,7 @@ install_pip_dependencies() { pip_install -r /opt/conda/requirements.txt pip_install -r /opt/conda/requirements-flux.txt pip_install -r /opt/conda/requirements-vlm.txt + pip_install -r /opt/conda/requirements-transformers-backend.txt popd } diff --git a/.ci/docker/requirements-transformers-backend.txt b/.ci/docker/requirements-transformers-backend.txt new file mode 100644 index 0000000000..76e8886ed0 --- /dev/null +++ b/.ci/docker/requirements-transformers-backend.txt @@ -0,0 +1 @@ +transformers==4.57.1 diff --git a/.ci/docker/ubuntu/Dockerfile b/.ci/docker/ubuntu/Dockerfile index baaca85824..b8123099b9 100644 --- a/.ci/docker/ubuntu/Dockerfile +++ b/.ci/docker/ubuntu/Dockerfile @@ -33,6 +33,7 @@ COPY requirements-dev.txt /opt/conda/ COPY requirements.txt /opt/conda/ COPY requirements-flux.txt /opt/conda/ COPY requirements-vlm.txt /opt/conda/ +COPY requirements-transformers-backend.txt /opt/conda/ COPY conda-env-ci.txt /opt/conda/ COPY ./common/install_conda.sh install_conda.sh COPY ./common/utils.sh utils.sh diff --git a/.github/workflows/integration_test_8gpu_transformers_backend.yaml b/.github/workflows/integration_test_8gpu_transformers_backend.yaml new file mode 100644 index 0000000000..aea5189d81 --- /dev/null +++ b/.github/workflows/integration_test_8gpu_transformers_backend.yaml @@ -0,0 +1,53 @@ +name: Transformers Backend 8 GPU Integration Tests + +on: + push: + branches: [ main ] + paths: + - 'torchtitan/experiments/transformers_backend/**' + pull_request: + paths: + - 'torchtitan/experiments/transformers_backend/**' + schedule: + # Runs every 12 hours + - cron: '0 */12 * * *' + +concurrency: + group: unit-test${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }} + cancel-in-progress: true + +defaults: + run: + shell: bash -l -eo pipefail {0} + +jobs: + build-test: + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + with: + runner: linux.g5.48xlarge.nvidia.gpu + gpu-arch-type: cuda + gpu-arch-version: "12.6" + # This image is faster to clone than the default, but it lacks CC needed by triton + # (1m25s vs 2m37s). + docker-image: torchtitan-ubuntu-20.04-clang12 + repository: pytorch/torchtitan + upload-artifact: outputs + script: | + set -eux + + # The generic Linux job chooses to use base env, not the one setup by the image + CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") + conda activate "${CONDA_ENV}" + + # Log CUDA driver version for debugging. + DRIVER_VERSION=$(nvidia-smi --query-gpu=driver_version --format=csv,noheader | head -n 1 || true) + echo "CUDA driver version: ${DRIVER_VERSION}" + + pip config --user set global.progress_bar off + + python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 + + USE_CPP=0 python -m pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu126 + + mkdir artifacts-to-be-uploaded + python -m torchtitan.experiments.transformers_backend.tests.integration_tests artifacts-to-be-uploaded --ngpu 8 diff --git a/torchtitan/experiments/README.md b/torchtitan/experiments/README.md index 14f8ba6544..08dc692bf9 100644 --- a/torchtitan/experiments/README.md +++ b/torchtitan/experiments/README.md @@ -31,3 +31,4 @@ We provide this `experiments/` folder to host experiments that add significant v | [moe_symm_mem_kernels](./moe_symm_mem_kernels/) | TBA | [@kwen2501](https://github.com/kwen2501) | | [gpt_oss](./gpt_oss/) | TBA | [@jianiw](https://github.com/jianiw) | | [compiler_toolkit](./compiler_toolkit/) | [![Compiler Toolkit 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_compiler_toolkit.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_compiler_toolkit.yaml?query=branch%3Amain) | [@SherlockNoMad](https://github.com/SherlockNoMad) [@yiming0416](https://github.com/yiming0416) | +| [transformers_backend](./transformers_backend/) | [![Transformers backend 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_transformers_backend.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_transformers_backend.yaml?query=branch%3Amain) | [@3outeille](https://github.com/3outeille) | diff --git a/torchtitan/experiments/__init__.py b/torchtitan/experiments/__init__.py index f6f813bfae..db3a44a824 100644 --- a/torchtitan/experiments/__init__.py +++ b/torchtitan/experiments/__init__.py @@ -12,5 +12,6 @@ "vlm", "compiler_toolkit.deepseek_v3", "compiler_toolkit.llama3", + "transformers_backend", ] ) diff --git a/torchtitan/experiments/transformers_backend/README.md b/torchtitan/experiments/transformers_backend/README.md new file mode 100644 index 0000000000..805afb9ab9 --- /dev/null +++ b/torchtitan/experiments/transformers_backend/README.md @@ -0,0 +1,52 @@ +# Huggingface Transformers backend + +## Quick start + +- Requirements `transformers==4.57.1` + +- Config: `torchtitan/torchtitan/experiments/transformers_backend/configs/qwen3.toml` +```diff +... +[model] +- name = "llama3" ++ name = "transformers_backend" +flavor = "debugmodel" +hf_assets_path = "./tests/assets/tokenizer" + ++[hf_transformers] ++model = "Qwen/Qwen3-4B-Instruct-2507" +... +``` +- Train: `LOG_RANK=7 CONFIG_FILE=/torchtitan/experiments/transformers_backend/configs/qwen3.toml ./run_train.sh --job.custom_config_module=torchtitan.experiments.transformers_backend.job_config --compile.enable` + - Make sure you have created the tokenizers beforehand +image + +## Supported Features + +- The following models were tested: + - Dense (FSDP/CP/TP/PP/`torch.compile`) + - `meta-llama/Llama-3.2-1B` + - `microsoft/phi-2` + - `Qwen/Qwen2.5-7B` + - `mistralai/Mistral-7B-v0.1` + - `ByteDance-Seed/Seed-Coder-8B-Instruct` + - `Qwen/Qwen3-4B-Instruct-2507` + - `arcee-ai/AFM-4.5B` + - `ibm-granite/granite-3b-code-base-2k` + - `baidu/ERNIE-4.5-0.3B-Base-PT` + - `kyutai/helium-1-preview-2b` + - `allenai/OLMo-7B-hf` + - `mistralai/Ministral-8B-Instruct-2410` + - MoE (upcoming) + +## Known issues to address later + +- When using HF modeling, the test `FSDP=2 vs FSDP=2 + PP=2`, the `loss` and `grad_norm` not bitwise matching (but converging) while it is the case with Torchtitan modeling. This will be addressed in another PR but the culprit is probably `register_buffer` when loading `seed_checkpoint` +- the HF modeling has lower MFU than Torchtitan MFU + +## Further work + +- Missing `build_optimizers_with_moe_load_balancing` support for MoE +- Missing TP/PP/EP supports for MoE +- Load HF weights +- Add LORA support diff --git a/torchtitan/experiments/transformers_backend/__init__.py b/torchtitan/experiments/transformers_backend/__init__.py new file mode 100644 index 0000000000..aec28a0bdd --- /dev/null +++ b/torchtitan/experiments/transformers_backend/__init__.py @@ -0,0 +1,51 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers +from torchtitan.components.tokenizer import build_hf_tokenizer +from torchtitan.hf_datasets.text_datasets import build_text_dataloader +from torchtitan.protocols.train_spec import TrainSpec + +from .infra.parallelize import parallelize_hf_transformers + +from .infra.pipeline import pipeline_hf_transformers +from .model.args import HFTransformerModelArgs, TitanDenseModelArgs +from .model.model import HFTransformerModel + +__all__ = [ + "HFTransformerModelArgs", + "HFTransformerModel", +] + + +flavors = { + "debugmodel": HFTransformerModelArgs( + titan_dense_args=TitanDenseModelArgs( + dim=256, + n_layers=2, + n_heads=16, + n_kv_heads=16, + ), + ), + "full": HFTransformerModelArgs( + titan_dense_args=TitanDenseModelArgs(), + ), +} + + +def get_train_spec() -> TrainSpec: + return TrainSpec( + model_cls=HFTransformerModel, + model_args=flavors, + parallelize_fn=parallelize_hf_transformers, + pipelining_fn=pipeline_hf_transformers, + build_optimizers_fn=build_optimizers, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_text_dataloader, + build_tokenizer_fn=build_hf_tokenizer, + build_loss_fn=build_cross_entropy_loss, + ) diff --git a/torchtitan/experiments/transformers_backend/configs/debug_model.toml b/torchtitan/experiments/transformers_backend/configs/debug_model.toml new file mode 100644 index 0000000000..7b3de04b87 --- /dev/null +++ b/torchtitan/experiments/transformers_backend/configs/debug_model.toml @@ -0,0 +1,88 @@ +# torchtitan Config.toml + +[job] +dump_folder = "./outputs" +description = "Qwen 3 debug training" +print_config = true + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 5 +enable_memory_snapshot = false +save_memory_snapshot_folder = "memory_snapshot" + +[metrics] +log_freq = 1 +disable_color_printing = false +enable_tensorboard = false +save_tb_folder = "tb" +enable_wandb = false + +[model] +name = "transformers_backend" +flavor = "debugmodel" +# test folder with tokenizer.json, for debug purpose only +hf_assets_path = "./tests/assets/tokenizer" +# converters = ["float8"] + +[hf_transformers] +model = "Qwen/Qwen3-4B-Instruct-2507" + +[optimizer] +name = "AdamW" +lr = 8e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps +decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps +decay_type = "linear" +min_lr_factor = 0.0 + +[training] +local_batch_size = 2 +seq_len = 2048 +max_norm = 1.0 # grad norm clipping +steps = 10 +dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) +dataset_path = "./tests/assets/c4_test" + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" # default / never / always +tensor_parallel_degree = 1 +enable_async_tensor_parallel = false +pipeline_parallel_degree = 1 +pipeline_parallel_schedule = "1F1B" +context_parallel_degree = 1 +expert_parallel_degree = 1 +expert_tensor_parallel_degree = 1 + +[checkpoint] +enable = false +folder = "checkpoint" +interval = 10 +last_save_model_only = false +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = "selective" # ["none", "selective", "full"] +selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[compile] +enable=false +components = ["model", "loss"] + +[quantize.linear.float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = ["output"] + +[validation] +enable = false +dataset = "c4_validation" +freq = 5 +steps = 10 diff --git a/torchtitan/experiments/transformers_backend/configs/full.toml b/torchtitan/experiments/transformers_backend/configs/full.toml new file mode 100644 index 0000000000..45eaa785de --- /dev/null +++ b/torchtitan/experiments/transformers_backend/configs/full.toml @@ -0,0 +1,87 @@ +# torchtitan Config.toml + +[job] +dump_folder = "./outputs" +description = "Qwen 3 full training" +print_config = true + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 5 +enable_memory_snapshot = false +save_memory_snapshot_folder = "memory_snapshot" + +[metrics] +log_freq = 1 +disable_color_printing = false +enable_tensorboard = false +save_tb_folder = "tb" +enable_wandb = false + +[model] +name = "transformers_backend" +flavor = "full" +# test folder with tokenizer.json, for debug purpose only +hf_assets_path = "./tests/assets/tokenizer" +# converters = ["float8"] + +[hf_transformers] +model = "Qwen/Qwen3-4B-Instruct-2507" + +[optimizer] +name = "AdamW" +lr = 8e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps +decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps +decay_type = "linear" +min_lr_factor = 0.0 + +[training] +local_batch_size = 2 +seq_len = 2048 +max_norm = 1.0 # grad norm clipping +steps = 10 +dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" # default / never / always +tensor_parallel_degree = 1 +enable_async_tensor_parallel = false +pipeline_parallel_degree = 1 +pipeline_parallel_schedule = "1F1B" +context_parallel_degree = 1 +expert_parallel_degree = 1 +expert_tensor_parallel_degree = 1 + +[checkpoint] +enable = false +folder = "checkpoint" +interval = 10 +last_save_model_only = false +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = "selective" # ["none", "selective", "full"] +selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[compile] +enable=false +components = ["model", "loss"] + +[quantize.linear.float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = ["output"] + +[validation] +enable = false +dataset = "c4_validation" +freq = 5 +steps = 10 diff --git a/torchtitan/experiments/transformers_backend/infra/parallelize.py b/torchtitan/experiments/transformers_backend/infra/parallelize.py new file mode 100644 index 0000000000..b2ae3f02a1 --- /dev/null +++ b/torchtitan/experiments/transformers_backend/infra/parallelize.py @@ -0,0 +1,435 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy +from torch.distributed.tensor import Replicate, Shard +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + parallelize_module, + PrepareModuleInput, + RowwiseParallel, + SequenceParallel, +) +from torchtitan.config import TORCH_DTYPE_MAP +from torchtitan.distributed import NoParallel, ParallelDims + +from torchtitan.distributed.activation_checkpoint import apply_ac + +from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp +from torchtitan.experiments.transformers_backend.job_config import JobConfig +from torchtitan.models.llama3.infra.parallelize import apply_compile, apply_ddp +from torchtitan.tools.logging import logger + + +def parallelize_hf_transformers( + model: nn.Module, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + """ + Apply tensor parallelism, activation checkpointing, torch.compile, and data + parallelism to the model. + + NOTE: The passed-in model preferably should be on meta device. Otherwise, + the model must fit on GPU or CPU memory. + """ + world_mesh = parallel_dims.world_mesh + # TODO: TP currently cannot handle uneven seq_len because we set + # `use_local_output=True` to use plain Tensors for legacy reasons. + # Need to revisit this. + assert ( + job_config.training.seq_len % parallel_dims.seq_len_divisor == 0 + ), f""" + Sequence length {job_config.training.seq_len} must be divisible by the product of TP degree + ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). + """ + + if parallel_dims.tp_enabled: + enable_float8_linear = "float8" in job_config.model.converters + float8_is_rowwise = job_config.quantize.linear.float8.recipe_name in ( + "rowwise", + "rowwise_with_gw_hp", + ) + + # For now, float8 all-gather with TP is only supported for tensorwise + # float8 scaling recipes. For rowwise recipes, we use regular TP and + # all-gather happens in high precision. + enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise + + apply_non_moe_tp( + model, + world_mesh["tp"], + loss_parallel=not job_config.parallelism.disable_loss_parallel, + enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, + ) + maybe_enable_async_tp(job_config, world_mesh["tp"]) + + model_compile_enabled = ( + job_config.compile.enable and "model" in job_config.compile.components + ) + + if job_config.activation_checkpoint.mode != "none": + apply_ac(model, job_config.activation_checkpoint) + + # turn on per-TransformerBlock compile after AC wrapping and before FSDP + if model_compile_enabled: + apply_compile(model, job_config.compile) + + if parallel_dims.fsdp_enabled: + # apply FSDP or HSDP, potentially with Context Parallel + if parallel_dims.dp_replicate_enabled: + dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") + else: + dp_mesh_dim_names = ("dp_shard_cp",) + + apply_fsdp( + model, + world_mesh[tuple(dp_mesh_dim_names)], + param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], + reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], + pp_enabled=parallel_dims.pp_enabled, + cpu_offload=job_config.training.enable_cpu_offload, + reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, + ) + + if parallel_dims.dp_replicate_enabled: + logger.info("Applied HSDP to the model") + else: + logger.info("Applied FSDP to the model") + + if parallel_dims.cp_enabled: + model.set_cp_mesh(world_mesh["cp"]) + logger.info("Applied Context Parallel to the model") + + if job_config.training.enable_cpu_offload: + logger.info("Applied CPU Offloading to the model") + elif parallel_dims.dp_replicate_enabled: + if world_mesh.ndim > 1: + raise RuntimeError("DDP has not supported > 1D parallelism") + apply_ddp( + model, + world_mesh, + enable_compile=model_compile_enabled, + ) + + return model + + +def apply_non_moe_tp( + model: nn.Module, + tp_mesh: DeviceMesh, + loss_parallel: bool, + enable_float8_tensorwise_tp: bool, +): + """Apply tensor parallelism.""" + # 1. Parallelize the embedding and shard its outputs (which are the first + # transformer block's inputs) + # 2. Parallelize the root norm layer over the sequence dim + # 3. Parallelize the final linear output layer + + # skipping nn.Identity modules (which are added by pipeline parallelism for unused modules) + root_plan = {} + + if hasattr(model, "tok_embeddings"): + if isinstance(model.tok_embeddings, nn.Identity): + root_plan["tok_embeddings"] = NoParallel() + else: + root_plan["tok_embeddings"] = RowwiseParallel( + input_layouts=Replicate(), + output_layouts=Shard(1), + ) + + if hasattr(model, "norm"): + if isinstance(model.norm, nn.Identity): + root_plan["norm"] = NoParallel() + else: + root_plan["norm"] = SequenceParallel() + + if hasattr(model, "output"): + if isinstance(model.output, nn.Identity): + root_plan["output"] = NoParallel() + else: + root_plan["output"] = ColwiseParallel( + input_layouts=Shard(1), + output_layouts=Shard(-1) if loss_parallel else Replicate(), + use_local_output=not loss_parallel, + ) + if root_plan: # Only call if there's something to parallelize + parallelize_module(model, tp_mesh, root_plan) + + # Parallel styles used for transformer block linear weights and their + # inputs may be different for float8 linears with tensorwise scaling. + if enable_float8_tensorwise_tp: + # TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there + from torchao.float8.float8_tensor_parallel import ( + Float8ColwiseParallel, + Float8RowwiseParallel, + PrepareFloat8ModuleInput, + ) + + rowwise_parallel, colwise_parallel, prepare_module_input = ( + Float8RowwiseParallel, + Float8ColwiseParallel, + PrepareFloat8ModuleInput, + ) + else: + rowwise_parallel, colwise_parallel, prepare_module_input = ( + RowwiseParallel, + ColwiseParallel, + PrepareModuleInput, + ) + + # Apply tensor + sequence parallelism to every transformer block + for transformer_block in model.layers: + layer_plan = { + "input_layernorm": SequenceParallel(), + "self_attn": prepare_module_input( + input_kwarg_layouts={"hidden_states": Shard(1)}, + desired_input_kwarg_layouts={"hidden_states": Replicate()}, + ), + "post_attention_layernorm": SequenceParallel(), + } + + if getattr(transformer_block.self_attn, "q_lora_rank", None) is None: + layer_plan.update( + { + "self_attn.q_proj": colwise_parallel(), + "self_attn.k_proj": colwise_parallel(), + "self_attn.v_proj": colwise_parallel(), + } + ) + else: + layer_plan.update( + { + "self_attn.q_a_proj": NoParallel(), + "self_attn.q_a_layernorm": NoParallel(), + "self_attn.q_b_proj": colwise_parallel(), + "self_attn.kv_a_proj_with_mqa": NoParallel(), + "self_attn.kv_a_layernorm": NoParallel(), + "self_attn.kv_b_proj": colwise_parallel(), + } + ) + + # Handle different names for the output projection layer, e.g. o_proj vs dense + o_proj_name = ( + "o_proj" if hasattr(transformer_block.self_attn, "o_proj") else "dense" + ) + layer_plan[f"self_attn.{o_proj_name}"] = rowwise_parallel( + output_layouts=Shard(1) + ) + # For model that uses RMSNorm on Q and K (i.e. Qwen3) + if hasattr(transformer_block.self_attn, "q_norm") and hasattr( + transformer_block.self_attn, "k_norm" + ): + layer_plan["self_attn.q_norm"] = SequenceParallel( + sequence_dim=2, use_local_output=True + ) + layer_plan["self_attn.k_norm"] = SequenceParallel( + sequence_dim=2, use_local_output=True + ) + + if not transformer_block.moe_enabled: + mlp_plan = { + "mlp": prepare_module_input( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + ), + } + # Handle different names for MLP layers, e.g. gate_proj vs fc1 + gate_proj_name = ( + "gate_proj" if hasattr(transformer_block.mlp, "gate_proj") else "fc1" + ) + mlp_plan[f"mlp.{gate_proj_name}"] = colwise_parallel() + + if hasattr(transformer_block.mlp, "up_proj"): + mlp_plan["mlp.up_proj"] = colwise_parallel() + + down_proj_name = ( + "down_proj" if hasattr(transformer_block.mlp, "down_proj") else "fc2" + ) + mlp_plan[f"mlp.{down_proj_name}"] = rowwise_parallel( + output_layouts=Shard(1) + ) + layer_plan.update(mlp_plan) + + # Some models like Phi-2 don't have post_attention_layernorm + if not hasattr(transformer_block, "post_attention_layernorm"): + layer_plan.pop("post_attention_layernorm") + + parallelize_module( + module=transformer_block, + device_mesh=tp_mesh, + parallelize_plan=layer_plan, + ) + + logger.info( + f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}" + "Tensor Parallelism to the model" + ) + + +def apply_fsdp( + model: nn.Module, + dp_mesh: DeviceMesh, + param_dtype: torch.dtype, + reduce_dtype: torch.dtype, + pp_enabled: bool, + cpu_offload: bool = False, + reshard_after_forward_policy: str = "default", + ep_degree: int = 1, + dp_mod_ep_mesh: DeviceMesh | None = None, + gradient_divide_factor: int | None = None, +): + """ + Apply data parallelism (via FSDP2) to the model. + + Args: + model (nn.Module): The model to apply data parallelism to. + dp_mesh (DeviceMesh): The device mesh to use for data parallelism. + param_dtype (torch.dtype): The data type to use for model parameters. + reduce_dtype (torch.dtype): The data type to use for reduction operations. + pp_enabled (bool): Whether pipeline parallelism is enabled. + cpu_offload (bool, optional): Whether to offload model parameters to CPU. Defaults to False. + reshard_after_forward_policy (str, optional): The policy to use for resharding after forward pass. Defaults to "default". + Other options: "never", "always". + - "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios. + - "always" will enable `reshard_after_forward` for all forward passes. + - "never" will disable `reshard_after_forward` for all forward passes. + + """ + mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) + fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} + if cpu_offload: + fsdp_config["offload_policy"] = CPUOffloadPolicy() + + match reshard_after_forward_policy: + case "always": + reshard_after_forward = True + case "never": + reshard_after_forward = False + case "default": + # For PP, by default do not reshard after forward to avoid per-microbatch + # all-gathers, which can be expensive and non-overlapped + reshard_after_forward = not pp_enabled + case _: + raise ValueError( + f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}." + ) + + if model.tok_embeddings is not None: + fully_shard( + model.tok_embeddings, + **fsdp_config, + reshard_after_forward=reshard_after_forward, + ) + + for transformer_block in model.layers: + # NOTE: When EP is enabled, In an MoE layer, we use the following FSDP wrapping + # - the router and the shared experts are sharded together with the TransformerBlock + # - the routed experts are sharded with the remaining dp_mod_ep_mesh + if ( + hasattr(transformer_block, "moe_enabled") + and transformer_block.moe_enabled + and ep_degree > 1 + ): + fsdp_mod_ep_config = fsdp_config.copy() + fsdp_mod_ep_config["mesh"] = dp_mod_ep_mesh + moe_block = transformer_block.mlp + # NOTE: EP alreadys shards the routed experts on dim 0 (num_experts). + # When dp_mod_ep * ep > num_experts, FSDP default dim-0 sharding + # causes inefficiency, so we choose to do FSDP sharding on dim-1. + # Even when EP is not used, we may still want to shard the experts + # on non-0 dim. For now it may not be worth the complexity to support + # shard_placement_fn on the outer TransformerBlock-level FSDP. + _experts_shard_placement_fn = None + assert dp_mod_ep_mesh is not None + if dp_mod_ep_mesh.size() * ep_degree > moe_block.experts.num_experts: + _experts_shard_placement_fn = lambda param: Shard(1) + + fully_shard( + moe_block.experts, + **fsdp_mod_ep_config, + reshard_after_forward=reshard_after_forward, + shard_placement_fn=_experts_shard_placement_fn, + ) + + # NOTE: # Although the FSDP sharding of experts is done on a mesh of + # a different size than other parameters, the gradient division + # factor should be consistent with data. + moe_block.experts.set_gradient_divide_factor( + gradient_divide_factor, + ) + + fully_shard( + transformer_block, + **fsdp_config, + reshard_after_forward=reshard_after_forward, + ) + + # As an optimization, do not reshard_after_forward the last layers by default + # since FSDP would prefetch them immediately after the forward pass + if model.norm is not None and model.output is not None: + fully_shard( + [model.norm, model.output], + **fsdp_config, + reshard_after_forward=reshard_after_forward_policy == "always", + ) + + fully_shard(model, **fsdp_config) + + # NOTE: set up explicit prefetching when EP is enabled, as D2H syncs + # in EP could interfere with implicit prefetching in FSDP + if ep_degree == 1: + return + + # forward + transformer_blocks = list(model.layers.values()) + next_transformer_blocks = transformer_blocks[1:] + [None] + + if model.tok_embeddings is not None and model.layers is not None: + model.tok_embeddings.set_modules_to_forward_prefetch([transformer_blocks[0]]) + + for transformer_block, next_transformer_block in zip( + transformer_blocks, next_transformer_blocks + ): + if next_transformer_block is not None: + if next_transformer_block.moe_enabled: + transformer_block.set_modules_to_forward_prefetch( + [next_transformer_block, next_transformer_block.mlp.experts] + ) + else: + transformer_block.set_modules_to_forward_prefetch( + [next_transformer_block] + ) + elif model.norm is not None and model.output is not None: + transformer_block.set_modules_to_forward_prefetch( + [model.norm, model.output] + ) + + # backward + reversed_transformer_blocks = list(reversed(model.layers.values())) + prev_transformer_blocks = reversed_transformer_blocks[1:] + [None] + + if model.norm is not None and model.output is not None and model.layers is not None: + model.output.set_modules_to_backward_prefetch([reversed_transformer_blocks[0]]) + + for transformer_block, prev_transformer_block in zip( + reversed_transformer_blocks, prev_transformer_blocks + ): + if prev_transformer_block is not None: + if prev_transformer_block.moe_enabled: + transformer_block.set_modules_to_backward_prefetch( + [prev_transformer_block, prev_transformer_block.mlp.experts] + ) + else: + transformer_block.set_modules_to_backward_prefetch( + [prev_transformer_block] + ) + elif model.tok_embeddings is not None: + transformer_block.set_modules_to_backward_prefetch([model.tok_embeddings]) diff --git a/torchtitan/experiments/transformers_backend/infra/pipeline.py b/torchtitan/experiments/transformers_backend/infra/pipeline.py new file mode 100644 index 0000000000..04452c5ede --- /dev/null +++ b/torchtitan/experiments/transformers_backend/infra/pipeline.py @@ -0,0 +1,391 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import copy +import math + +import torch +import torch.nn as nn +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.pipelining import PipelineStage +from torch.distributed.pipelining.schedules import ( + _PipelineSchedule, + get_schedule_class, + PipelineScheduleSingle, + ScheduleDualPipeV, + ScheduleZBVZeroBubble, +) + +from torchtitan.components.loss import LossFunction +from torchtitan.distributed import ParallelDims +from torchtitan.distributed.pipeline_parallel import build_pipeline_schedule +from torchtitan.experiments.transformers_backend.job_config import JobConfig +from torchtitan.protocols.train_spec import BaseModelArgs, ParallelizeFunction +from torchtitan.tools.logging import logger + +# NOTE(3outeille): the only modifications comes from replacing None to nn.Identity and adding rotary_emb per model_part + + +def generate_llm_fqn_per_model_part( + num_stages: int, + num_layers: int, + input_weight: int = 1, + output_weight: int = 1, +) -> list[list[str]]: + """ + Programmatically generates module names model part, focused on LLMs models. + Args: + num_stages: Number of pipeline stages + num_layers: Total number of transformer layers in the model + input_weight: Weight for input modules (embed_tokens) in layer calculation + output_weight: Weight for output modules (norm + output) in layer calculation + Returns: + List of lists containing module names for each model part + Example: + generate_llm_fqn_per_model_part(2, 3, input_weight=2, output_weight=2) + treats embeddings as 2 layers and norm+output as 2 layers for distribution + """ + if num_stages < 1: + raise ValueError("Number of stages must be at least 1") + + if num_stages == 1: + # Single stage gets everything + layer_names = [f"layers.{i}" for i in range(num_layers)] + return [["tok_embeddings"] + layer_names + ["norm", "output", "rotary_emb"]] + + # Calculate effective layers including weights + num_effective_layers = num_layers + input_weight + output_weight + + if num_stages > num_effective_layers: + raise ValueError( + f"Number of stages ({num_stages}) cannot be greater than effective layers ({num_effective_layers})" + ) + + # Calculate layers per stage (distribute evenly) + layers_per_stage = num_effective_layers // num_stages + extra_layers = num_effective_layers % num_stages + + # Feasibility check: Ensure at least 1 layer in each PP stage + if layers_per_stage == 0: + raise ValueError( + f"Configuration would result in empty stages. " + f"With {num_stages} stages and {num_effective_layers} effective layers " + f"(num_layers={num_layers} + input_weight={input_weight} + output_weight={output_weight}), " + f"each stage would get {layers_per_stage} layers on average. " + f"Reduce num_stages or increase num_layers/weights." + ) + + # Balance check: Ensure weights don't exceed minimum layers per stage + if input_weight > layers_per_stage: + raise ValueError( + f"input_weight ({input_weight}) exceeds minimum layers per stage ({layers_per_stage})." + ) + if output_weight > layers_per_stage: + raise ValueError( + f"output_weight ({output_weight}) exceeds minimum layers per stage ({layers_per_stage})." + ) + + module_names_per_stage = [] + current_layer = 0 + + for stage_idx in range(num_stages): + stage_modules = [] + + # Calculate effective layers for this stage + effective_layers_for_stage = layers_per_stage + if stage_idx < extra_layers: + effective_layers_for_stage += 1 + + # First stage: handle input modules with weighting + if stage_idx == 0: + stage_modules.append("tok_embeddings") + # Account for input weight in layer distribution + remaining_layers_for_stage = effective_layers_for_stage - input_weight + + # Add transformer layers + for _ in range(remaining_layers_for_stage): + if current_layer < num_layers: + stage_modules.append(f"layers.{current_layer}") + current_layer += 1 + + # Last stage: handle output modules with weighting + elif stage_idx == num_stages - 1: + # Account for output weight in layer distribution + remaining_layers_for_stage = effective_layers_for_stage - output_weight + + # Add transformer layers + for _ in range(remaining_layers_for_stage): + if current_layer < num_layers: + stage_modules.append(f"layers.{current_layer}") + current_layer += 1 + + # Add output modules + stage_modules.extend(["norm", "output"]) + + # Middle stages: only transformer layers + else: + for _ in range(effective_layers_for_stage): + if current_layer < num_layers: + stage_modules.append(f"layers.{current_layer}") + current_layer += 1 + + stage_modules.append("rotary_emb") + module_names_per_stage.append(stage_modules) + + return module_names_per_stage + + +def pipeline_module_split( + whole_model: nn.Module, + pp_mesh: DeviceMesh, + pp_schedule: str, + device: torch.device, + module_names_per_stage: list[list[str]], +) -> tuple[list[PipelineStage], list[nn.Module]]: + """ + This API creates pipeline stages based on specified module names for each stage. + + Some model restrictions include: + - forward() method should tolerate deleted layers + - weight initialization methods should tolerate deleted layers + - Does not support nested moduledict and modulelist structures + + Args: + whole_model: The complete model to be split + pp_mesh: Pipeline parallel device mesh + pp_schedule: Name of pipeline parallelism schedule + device: Device + module_names_per_stage: List of lists, where each inner list contains the module names + that should be included in that stage. Module names should be + dot-separated paths. Examples: + - "tok_embeddings" for token embeddings + - "layers.0", "layers.1" for specific transformer layers + - "norm" for the final normalization layer + - "output" for the output projection layer + + Returns: + Tuple of (stages, models) where stages are PipelineStage objects and models are the + corresponding model chunks + + Example usage: + module_names_per_stage = [ + ["tok_embeddings", "layers.0"], # Stage 0: embeddings + first layer + ["layers.1", "layers.2"], # Stage 1: middle layers + ["norm", "output"] # Stage 2: final norm + output + ] + """ + pp_rank = pp_mesh.get_local_rank() + pp_degree = pp_mesh.size() + + def _build_stage_from_modules( + stage_idx: int, module_names: list[str], num_stages: int + ) -> tuple[PipelineStage, nn.Module]: + model = copy.deepcopy(whole_model) + + # Create a set of modules to keep for faster lookup + modules_to_keep = set(module_names) + for module_name, module_value in model.named_children(): + # Handle layer-like structures (e.g., "layers.0", "layers.1") + if isinstance(module_value, (nn.ModuleDict, nn.ModuleList)): + layers_to_keep = { + name.split(".", 1)[1] + for name in modules_to_keep + if name.startswith(f"{module_name}.") + } + if layers_to_keep: + # Keep only specified layers + if isinstance(module_value, nn.ModuleDict): + for layer_name in list(module_value.keys()): + if layer_name not in layers_to_keep: + del module_value[layer_name] + elif isinstance(module_value, nn.ModuleList): + indices_to_keep = { + int(idx) for idx in layers_to_keep if idx.isdigit() + } + new_layers = nn.ModuleList( + [ + layer + for i, layer in enumerate(module_value) + if i in indices_to_keep + ] + ) + setattr(model, module_name, new_layers) + else: + # No layers from this structure needed, set to empty structure + if isinstance(module_value, nn.ModuleDict): + setattr(model, module_name, nn.ModuleDict()) + elif isinstance(module_value, nn.ModuleList): + setattr(model, module_name, nn.ModuleList()) + # Handle simple module attributes (e.g., "linear", "norm") + elif module_name not in modules_to_keep: + # Replace with Identity + setattr(model, module_name, nn.Identity()) + + stage = PipelineStage( + model, + stage_idx, + num_stages, + device, + group=pp_mesh.get_group("pp"), + ) + return stage, model + + num_stages = len(module_names_per_stage) + stages = [] + models = [] + + schedule_class = get_schedule_class(pp_schedule) + style = ( + "v" if schedule_class in (ScheduleZBVZeroBubble, ScheduleDualPipeV) else "loop" + ) + + def _get_stage_indices() -> tuple[int]: + """ + Compute the stage ids for the stages that will run on this pp rank + for either a looped or V style schedule + """ + assert ( + num_stages % pp_degree == 0 + ), f"num_stages {num_stages} must be evenly divisible by pp_degree {pp_degree}" + stages_per_rank = num_stages // pp_degree + if style == "loop": + return tuple(pp_rank + s * pp_degree for s in range(stages_per_rank)) + elif style == "v": + assert ( + stages_per_rank == 2 + ), f"v schedules assume 2 stages per rank, got {stages_per_rank}" + stage_v_pairs = list( + zip(range(pp_degree), range(num_stages - 1, pp_degree - 1, -1)) + ) + return stage_v_pairs[pp_rank] + + for stage_idx in _get_stage_indices(): + module_names = module_names_per_stage[stage_idx] + stage, model_chunk = _build_stage_from_modules( + stage_idx, + module_names, + num_stages, + ) + logger.info( + f"PP rank {pp_rank} is building stage_idx {stage_idx} " + f"with modules {module_names}" + ) + stages.append(stage) + models.append(model_chunk) + + return stages, models + + +def pipeline_hf_transformers( + model: nn.Module, + parallel_dims: ParallelDims, + job_config: JobConfig, + device: torch.device, + model_args: BaseModelArgs, + parallelize_fn: ParallelizeFunction, + loss_fn: LossFunction, +) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]: + pp_mesh = parallel_dims.world_mesh["pp"] + + # Determine the number of virtual stages based on schedule type + schedule_class = get_schedule_class( + job_config.parallelism.pipeline_parallel_schedule + ) + is_single_stage_schedule = issubclass(schedule_class, PipelineScheduleSingle) + layers_per_stage = job_config.parallelism.pipeline_parallel_layers_per_stage + if hasattr(model_args, "n_layers"): + num_layers = model_args.n_layers + else: + raise ValueError("Model does not have n_layers attribute.") + + # You can adjust these weights based on the computational cost of embeddings and output layers + # Higher weights mean these modules are treated as "heavier" in the distribution + input_weight = job_config.parallelism.pipeline_parallel_first_stage_less_layers + output_weight = job_config.parallelism.pipeline_parallel_last_stage_less_layers + + # Calculate number of virtual stages + if layers_per_stage is not None: + + # Calculate number of virtual stages needed (using ceiling division) + # This allows for unequal distribution where stages can differ by at most 1 layer + num_virtual_stages = math.ceil( + (num_layers + input_weight + output_weight) / layers_per_stage + ) + + # Validation: check stages per rank based on schedule type + model_config_info = f"Model has {num_layers} layers with pipeline_parallel_layers_per_stage={layers_per_stage}" + stage_distribution_info = ( + f"resulting in {num_virtual_stages=} across {parallel_dims.pp} PP ranks" + ) + + if num_virtual_stages % parallel_dims.pp != 0: + raise ValueError( + f"Number of virtual stages ({num_virtual_stages}) must be divisible by " + f"pipeline parallel size ({parallel_dims.pp}). " + f"{model_config_info}. " + f"Please adjust pipeline_parallel_layers_per_stage to a value that results in a number of stages " + f"divisible by {parallel_dims.pp}." + ) + + stages_per_rank = num_virtual_stages // parallel_dims.pp + + if is_single_stage_schedule and stages_per_rank != 1: + raise ValueError( + f"Single stage schedule requires exactly 1 stage per rank, but got {stages_per_rank} stages per rank. " + f"{model_config_info}, {stage_distribution_info}. " + f"Please increase pipeline_parallel_layers_per_stage to {num_layers // parallel_dims.pp} or higher " + f"to achieve 1 stage per rank." + ) + + if not is_single_stage_schedule and stages_per_rank < 2: + raise ValueError( + f"Multi-stage schedule requires at least 2 stages per rank, but got {stages_per_rank} stages per rank. " + f"{model_config_info}, {stage_distribution_info}. " + f"Please decrease pipeline_parallel_layers_per_stage to achieve at least 2 stages per rank." + ) + else: + # Fallback to default behavior when layers_per_stage is not provided + # For multi-stage schedules, default is 2 virtual stages per rank + # For single-stage schedules, default is 1 virtual stage per rank + stages_per_rank = 1 if is_single_stage_schedule else 2 + num_virtual_stages = parallel_dims.pp * stages_per_rank + + module_names_per_stage = job_config.parallelism.module_fqns_per_model_part + if module_names_per_stage is None: + module_names_per_stage = generate_llm_fqn_per_model_part( + num_virtual_stages, num_layers, input_weight, output_weight + ) + + stages, model_parts = pipeline_module_split( + model, + pp_mesh, + job_config.parallelism.pipeline_parallel_schedule, + device, + module_names_per_stage, + ) + + # For PP with looped schedules, each item in model_parts is one stage-model-chunk. + # We need to iterate through model_parts to apply SPMD parallelisms, compilation, + # optimizer, and checkpointing + for i, m in enumerate(model_parts): + # apply SPMD-style PT-D techniques + m = parallelize_fn(m, parallel_dims, job_config) + model_parts[i] = m + # NOTE: this is to update the model in the stage + # in case the model is modified e.g. by torch.compile + stages[i].submod = m + + pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn) + + # This is used in the train loop to determine whether to pass in the input_ids and labels + has_first_stage = False + has_last_stage = False + for stage in stages: + if stage.is_first: + has_first_stage = True + if stage.is_last: + has_last_stage = True + + return pp_schedule, model_parts, has_first_stage, has_last_stage diff --git a/torchtitan/experiments/transformers_backend/job_config.py b/torchtitan/experiments/transformers_backend/job_config.py new file mode 100644 index 0000000000..f3b1667798 --- /dev/null +++ b/torchtitan/experiments/transformers_backend/job_config.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field + + +@dataclass +class HFTransformers: + model: str = "" + """HuggingFace model ID (e.g., 'Qwen/Qwen3-4B-Instruct-2507')""" + + +@dataclass +class JobConfig: + hf_transformers: HFTransformers = field(default_factory=HFTransformers) diff --git a/torchtitan/experiments/transformers_backend/model/args.py b/torchtitan/experiments/transformers_backend/model/args.py new file mode 100644 index 0000000000..25ab328f15 --- /dev/null +++ b/torchtitan/experiments/transformers_backend/model/args.py @@ -0,0 +1,199 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass + +from torch import nn +from torchtitan.config.job_config import JobConfig +from torchtitan.models.utils import get_dense_model_nparams_and_flops +from torchtitan.protocols import BaseModelArgs +from transformers import AutoConfig +from transformers.configuration_utils import PretrainedConfig +from transformers.integrations.sdpa_attention import sdpa_attention_forward +from transformers.modeling_utils import AttentionInterface + + +@dataclass +class TitanDenseModelArgs: + """Arguments for the base TorchTitan model.""" + + dim: int = 4096 + n_layers: int = 32 + n_heads: int = 32 + n_kv_heads: int | None = None + vocab_size: int | None = None + multiple_of: int = 256 + ffn_dim_multiplier: float | None = None + norm_eps: float = 1e-5 + rope_theta: float = 10000 + max_seq_len: int = 2048 + depth_init: bool = True + use_flex_attn: bool = False + attn_mask_type: str = "causal" + + +@dataclass +class HFTransformerModelArgs(PretrainedConfig, BaseModelArgs): + """ + Configuration class that bridges TorchTitan and HuggingFace Transformers naming conventions. + + Uses properties to provide TorchTitan-style access while maintaining HuggingFace compatibility. + Properties are created dynamically based on which arguments are provided. + """ + + # Define all possible mappings organized by argument type + _TT_TO_HF_MAPPINGS = { + "dense": { + # TorchTitan dense model mappings (always available) + "dim": "hidden_size", + "n_layers": "num_hidden_layers", + "n_heads": "num_attention_heads", + "n_kv_heads": "num_key_value_heads", + "norm_eps": "rms_norm_eps", + "max_seq_len": "max_position_embeddings", + "eos_id": "eos_token_id", + } + } + + # Declarative list of TorchTitan-only attributes (no HF equivalent) + _TT_SPECIFIC_ATTRIBUTES = [ + "multiple_of", + "ffn_dim_multiplier", + "depth_init", + "use_flex_attn", + "attn_mask_type", + ] + + def __init__( + self, + titan_dense_args, + # HuggingFace specific args + attn_implementation: str = "sdpa_torchtitan", + **kwargs, + ): + super().__init__(attn_implementation=attn_implementation, **kwargs) + assert titan_dense_args is not None, "titan_dense_args is required" + + # Create getter/setter dynamically for TT <-> HF attribute mappings + self._create_getter_setter_dynamically(has_moe=False) + + self._titan_injected_model_args = {} + self._configure_hf_attention(attn_implementation) + + self._initialize_dense_attributes(titan_dense_args) + + def _initialize_dense_attributes(self, titan_dense_args): + """Initialize all dense model attributes.""" + # Set mapped attributes (TorchTitan <-> HuggingFace) + for titan_name, hf_name in self._tt_to_hf_attribute_map.items(): + if hasattr(titan_dense_args, titan_name): + value = getattr(titan_dense_args, titan_name) + setattr(self, hf_name, value) + + # Set TorchTitan-only attributes + for attr_name in self._TT_SPECIFIC_ATTRIBUTES: + if hasattr(titan_dense_args, attr_name): + setattr(self, attr_name, getattr(titan_dense_args, attr_name)) + + # Update passed_args + self._titan_injected_model_args.update(titan_dense_args.__dict__) + + def _configure_hf_attention(self, attn_implementation: str): + """Configure HuggingFace attention settings.""" + self._titan_injected_model_args["attn_implementation"] = attn_implementation + self.attn_implementation = attn_implementation + # NOTE:(3outeille):This will force create_causal_mask to return None + AttentionInterface._global_mapping[attn_implementation] = sdpa_attention_forward + + def _create_getter_setter_dynamically(self, has_moe: bool): + """ + Create properties dynamically based on tt and hf attribute mappings. + For example, creates a property 'dim' that reads/writes to 'hidden_size'. + """ + + def _create_property(hf_name: str) -> property: + def getter(self): + return getattr(self, hf_name) + + def setter(self, value): + setattr(self, hf_name, value) + + return property(getter, setter) + + # Setup attribute mappings + self._tt_to_hf_attribute_map = dict(self._TT_TO_HF_MAPPINGS["dense"]) + if has_moe: + self._tt_to_hf_attribute_map.update(self._TT_TO_HF_MAPPINGS["moe"]) + + for titan_name, hf_name in self._tt_to_hf_attribute_map.items(): + # Create getter/setter for attribute that don't already exist + if not hasattr(self.__class__, titan_name): + setattr(self.__class__, titan_name, _create_property(hf_name)) + + def __repr__(self) -> str: + # HFTransformerModelArgs is a dataclass that also inherits from PretrainedConfig. + # PretrainedConfig has a __repr__ that serializes the object to JSON, but it + # doesn't work well with how HFTransformerModelArgs is initialized. + # This custom __repr__ provides a dataclass-like representation that correctly + # displays the arguments passed during initialization. + args_lines = [ + f"{k}={getattr(self, k)!r}" + for k in sorted(self._titan_injected_model_args.keys()) + if hasattr(self, k) + ] + args_str = "\n".join(args_lines) + return f"{self.__class__.__name__}(\n{args_str}\n)" + + def update_from_config(self, job_config: JobConfig): + # Load HF config (overwrites our HF attributes) + hf_model_config = AutoConfig.from_pretrained( + job_config.hf_transformers.model, + attn_implementation=self.attn_implementation, + trust_remote_code=True, + ) + + # Explicitly update attributes based on mappings + for titan_name, hf_name in self._tt_to_hf_attribute_map.items(): + if hasattr(hf_model_config, hf_name): + setattr(self, titan_name, getattr(hf_model_config, hf_name)) + + # Copy any other attributes that might not be in the mapping + for key, value in hf_model_config.to_dict().items(): + setattr(self, key, value) + + # Update our attributes with the passed args from flavors + for key, value in self._titan_injected_model_args.items(): + if hasattr(self, key) and value is not None: + setattr(self, key, value) + + self.max_seq_len = job_config.training.seq_len + + self.deterministic = job_config.debug.deterministic + + # Configure HF-specific settings to match TorchTitan settings + # TODO: false ? + self.attention_bias = False + self.mlp_bias = False + self.use_cache = False + self.initializer_range = 1.0 # use as std for normal init in embedding + + if not hasattr(self, "inter_dim"): # Only for llama model + ffn_hidden_size = 4 * self.dim + ffn_hidden_size = int(2 * ffn_hidden_size / 3) + if self.ffn_dim_multiplier is not None: + ffn_hidden_size = int(self.ffn_dim_multiplier * ffn_hidden_size) + self.intermediate_size = self.multiple_of * ( + (ffn_hidden_size + self.multiple_of - 1) // self.multiple_of + ) + + self.head_dim = self.dim // self.num_attention_heads + + return self + + def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: + return get_dense_model_nparams_and_flops( + self, model, head_dims=self.head_dim, seq_len=seq_len + ) diff --git a/torchtitan/experiments/transformers_backend/model/model.py b/torchtitan/experiments/transformers_backend/model/model.py new file mode 100644 index 0000000000..b88fffc54b --- /dev/null +++ b/torchtitan/experiments/transformers_backend/model/model.py @@ -0,0 +1,477 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import importlib +import math + +import torch +from torch import nn +from torch.nn import init +from torchtitan.tools.logging import logger +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_utils import PreTrainedModel + +from .args import HFTransformerModelArgs + + +class SliceableModuleDict(nn.ModuleDict): + """ + A ModuleDict that supports slicing like ModuleList. + Keys are expected to be string representations of integers (e.g., "0", "1", "2"). + """ + + def __getitem__(self, key): + if isinstance(key, slice): + # Handle slicing: convert slice to list of keys + keys = sorted( + self.keys(), key=lambda x: int(x) if x.isdigit() else float("inf") + ) + sliced_keys = keys[key] + # Return a new SliceableModuleDict with the sliced items + return SliceableModuleDict({k: self[k] for k in sliced_keys}) + return super().__getitem__(key) + + def __iter__(self): + # Iterate over values in sorted order by key (as integers) + keys = sorted( + self.keys(), key=lambda x: int(x) if x.isdigit() else float("inf") + ) + for key in keys: + yield self[key] + + def __len__(self): + return len(self._modules) + + +class HFTransformerModel(nn.Module): + def __init__(self, model_args: HFTransformerModelArgs): + super().__init__() + + # NOTE(3outeille): This prevents Hugging Face modeling from initializing ROPE (inv_freq) buffers to NaN. + # Needed when loading from seed checkpoint. + if hasattr(model_args, "deterministic") and model_args.deterministic: + torch.utils.deterministic.fill_uninitialized_memory = False + + # Try to import the model class dynamically from the transformers library if not found in globals + model_class_name = model_args.architectures[0] + model_cls = globals().get(model_class_name, None) + if model_cls is None: + try: + transformers_mod = importlib.import_module("transformers") + model_cls = getattr(transformers_mod, model_class_name) + except (ImportError, AttributeError) as e: + raise ImportError( + f"Could not find model class '{model_class_name}' in globals or transformers. " + f"Make sure the class is available. Original error: {e}" + ) from e + + # Attempt to patch model weight initialization based on architecture type + try: + model_name_prefix = model_class_name.replace("ForCausalLM", "") + model_module = importlib.import_module(model_cls.__module__) + + attention_cls = getattr(model_module, f"{model_name_prefix}Attention", None) + mlp_cls = getattr(model_module, f"{model_name_prefix}MLP", None) + decoder_layer_cls = getattr( + model_module, f"{model_name_prefix}DecoderLayer", None + ) + + required_classes = { + "Attention": attention_cls, + "DecoderLayer": decoder_layer_cls, + } + + if all(required_classes.values()): + logger.info(f"Applying Llama-like patch for {model_name_prefix}") + self._patch_hf_llama_like( + decoder_layer_cls=decoder_layer_cls, + attention_cls=attention_cls, + mlp_cls=mlp_cls, # mlp_cls can be None + ) + else: + missing = [name for name, cls in required_classes.items() if not cls] + logger.warning( + f"Could not find required classes ({', '.join(missing)}) for {model_name_prefix}. " + "Skipping Llama-like patch." + ) + + except Exception as e: + logger.warning( + f"Failed to apply agnostic patch for {model_class_name} due to: {e}. " + "Weight initialization might not match TorchTitan." + ) + + self.model = model_cls(config=model_args) + self.max_seq_len = model_args.max_seq_len + self.cp_mesh = None + + # Convert ModuleList to ModuleDict to preserve original indices + # This ensures state dict keys match checkpoint keys + if isinstance(self.model.model.layers, nn.ModuleList): + self.model.model.layers = SliceableModuleDict( + {str(i): layer for i, layer in enumerate(self.model.model.layers)} + ) + + for layer in self.model.model.layers.values(): + layer.moe_enabled = False + + def set_cp_mesh(self, mesh): + self.cp_mesh = mesh + + def _patch_hf_llama_like(self, decoder_layer_cls, attention_cls, mlp_cls=None): + """ + This patch modifies a Hugging Face Llama-like model's weight initialization to match + the initialization scheme used in TorchTitan. This is crucial for ensuring + bit-for-bit reproducibility when converting checkpoints between the native + TorchTitan format and the Hugging Face format. + + The patch targets the following aspects of the model: + - `PreTrainedModel._initialize_weights`: Handles meta device initialization correctly. + - `PreTrainedModel._init_weights`: Implements TorchTitan's specific initialization + for attention, MLP, embedding, and layer norm layers. This includes depth-dependent + initialization for attention and MLP layers. + - `DecoderLayer.__init__`: Adds `layer_idx` to attention and MLP modules within + each decoder layer, which is required for the depth-dependent initialization. + """ + + _original_decoder_layer_init = decoder_layer_cls.__init__ + + def _decoder_layer_init_patched(self, config: PretrainedConfig, layer_idx: int): + _original_decoder_layer_init(self, config, layer_idx) + self.layer_idx = layer_idx + # Ensure both attention and mlp modules have layer_idx for depth-based init + if hasattr(self, "self_attn"): + self.self_attn.layer_idx = layer_idx + # some models might not have mlp in each layer + if hasattr(self, "mlp") and self.mlp is not None: + self.mlp.layer_idx = layer_idx + + def _initialize_weights_patched(self, module): + # NOTE(3outeille): monkey-patch PreTrainedModel to handle meta device initialization correctly + # The default _initialize_weights sets _is_hf_initialized = True even on a meta device, + # which prevents subsequent proper initialization. + if getattr(module, "_is_hf_initialized", False): + return + + for param in module.parameters(recurse=True): + if param.device.type == "meta": + return + + # If not on a meta device, call the original weight initialization + self._init_weights(module) + module._is_hf_initialized = True + + def _init_weights_patched(self, module): + """ + Patched version of _init_weights to match TorchTitan's initialization for Llama-like models. + `self` is a PreTrainedModel instance. + """ + config = self.config + # Build tuple of classes to check for layer_idx-based init_std calculation + layer_idx_classes = [attention_cls] + if mlp_cls: + layer_idx_classes.append(mlp_cls) + layer_idx_classes = tuple(layer_idx_classes) + + if isinstance(module, layer_idx_classes): + if not hasattr(module, "layer_idx"): + raise ValueError( + f"Module {module} does not have a layer_idx attribute" + ) + + layer_idx = module.layer_idx + + if hasattr(config, "depth_init") and config.depth_init: + init_std = 0.02 / (2 * (layer_idx + 1)) ** 0.5 + else: + init_std = 0.02 / (2 * config.num_hidden_layers) ** 0.5 + + if isinstance(module, attention_cls): + # Initialize weights and biases for q, k, v projections + for proj_name in ["q_proj", "k_proj", "v_proj"]: + proj = getattr(module, proj_name) + nn.init.trunc_normal_(proj.weight, mean=0.0, std=0.02) + if proj.bias is not None: + fan_in, _ = init._calculate_fan_in_and_fan_out(proj.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + init.uniform_(proj.bias, -bound, bound) + + # Handle different names for the output projection layer + o_proj = getattr(module, "o_proj", getattr(module, "dense", None)) + if o_proj is not None: + nn.init.trunc_normal_(o_proj.weight, mean=0.0, std=init_std) + if o_proj.bias is not None: + fan_in, _ = init._calculate_fan_in_and_fan_out(o_proj.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + init.uniform_(o_proj.bias, -bound, bound) + + elif mlp_cls and isinstance(module, mlp_cls): + # Handle different names for MLP layers + gate_proj = getattr(module, "gate_proj", getattr(module, "fc1", None)) + up_proj = getattr(module, "up_proj", None) + down_proj = getattr(module, "down_proj", getattr(module, "fc2", None)) + + # gate_proj (or fc1) should always use std=0.02 for numerical stability. + if gate_proj is not None: + nn.init.trunc_normal_(gate_proj.weight, mean=0.0, std=0.02) + if gate_proj.bias is not None: + fan_in, _ = init._calculate_fan_in_and_fan_out(gate_proj.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + init.uniform_(gate_proj.bias, -bound, bound) + # up_proj and down_proj (or fc2) use the depth-dependent init_std. + if up_proj is not None: + nn.init.trunc_normal_(up_proj.weight, mean=0.0, std=init_std) + if up_proj.bias is not None: + fan_in, _ = init._calculate_fan_in_and_fan_out(up_proj.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + init.uniform_(up_proj.bias, -bound, bound) + if down_proj is not None: + nn.init.trunc_normal_(down_proj.weight, mean=0.0, std=init_std) + if down_proj.bias is not None: + fan_in, _ = init._calculate_fan_in_and_fan_out(down_proj.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + init.uniform_(down_proj.bias, -bound, bound) + + elif module is getattr( + self, "lm_head", None + ): # TODO(3outeille): find a better way to detect lm_head + final_out_std = config.hidden_size**-0.5 + cutoff_factor = 3 + nn.init.trunc_normal_( + module.weight, + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + if module.bias is not None: + module.bias.data.zero_() + + elif isinstance(module, nn.Embedding): + # When tie_word_embeddings is True, use lm_head initialization + if ( + hasattr(config, "tie_word_embeddings") + and config.tie_word_embeddings + ): + final_out_std = config.hidden_size**-0.5 + cutoff_factor = 3 + nn.init.trunc_normal_( + module.weight, + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + else: + std = config.initializer_range + module.weight.data.normal_(mean=0.0, std=std) + + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + elif ( + isinstance( + module, + (nn.GroupNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d), + ) + or "LayerNorm" in module.__class__.__name__ + or "RMSNorm" in module.__class__.__name__ + ): + # Norms can exist without weights (in which case they are None from torch primitives) + if hasattr(module, "weight") and module.weight is not None: + module.weight.data.fill_(1.0) + if hasattr(module, "bias") and module.bias is not None: + module.bias.data.zero_() + + decoder_layer_cls.__init__ = _decoder_layer_init_patched + PreTrainedModel._init_weights = _init_weights_patched + PreTrainedModel._initialize_weights = _initialize_weights_patched + + @property + def tok_embeddings(self): + """Returns the model's embed_tokens, handling different Hugging Face model structures.""" + if hasattr(self.model, "model") and hasattr( + self.model.model, "embed_tokens" + ): # Llama-like + return self.model.model.embed_tokens + else: + raise AttributeError( + "Could not find embed_tokens in the model. Please check the model structure." + ) + + @tok_embeddings.setter + def tok_embeddings(self, value): + if hasattr(self.model, "model") and hasattr( + self.model.model, "embed_tokens" + ): # Llama-like + self.model.model.embed_tokens = value + else: + raise AttributeError( + "Could not find embed_tokens in the model. Please check the model structure." + ) + + @property + def layers(self): + """Returns the model's layers, handling different Hugging Face model structures.""" + if hasattr(self.model, "model") and hasattr( + self.model.model, "layers" + ): # Llama-like + return self.model.model.layers + else: + # Add more cases here if needed for other model architectures + raise AttributeError( + "Could not find layers in the model. Please check the model structure." + ) + + @layers.setter + def layers(self, value): + if hasattr(self.model, "model") and hasattr( + self.model.model, "layers" + ): # Llama-like + self.model.model.layers = value + else: + raise AttributeError( + "Could not find layers in the model. Please check the model structure." + ) + + @property + def norm(self): + """Returns the model's norm, handling different Hugging Face model structures.""" + if hasattr(self.model, "model") and hasattr( + self.model.model, "norm" + ): # Llama-like + return self.model.model.norm + elif hasattr(self.model, "model") and hasattr( + self.model.model, "final_layernorm" + ): # Phi-like + return self.model.model.final_layernorm + else: + raise AttributeError( + "Could not find norm in the model. Please check the model structure." + ) + + @norm.setter + def norm(self, value): + if hasattr(self.model, "model") and hasattr( + self.model.model, "norm" + ): # Llama-like + self.model.model.norm = value + elif hasattr(self.model, "model") and hasattr( + self.model.model, "final_layernorm" + ): # Phi-like + self.model.model.final_layernorm = value + else: + raise AttributeError( + "Could not find norm in the model. Please check the model structure." + ) + + @property + def output(self): + """Returns the model's output layer, handling different Hugging Face model structures.""" + if hasattr(self.model, "lm_head"): # For models like LlamaForCausalLM + return self.model.lm_head + else: + # Add more cases here if needed for other model architectures + raise AttributeError( + "Could not find output (lm_head) in the model. Please check the model structure." + ) + + @output.setter + def output(self, value): + if hasattr(self.model, "lm_head"): # For models like LlamaForCausalLM + self.model.lm_head = value + else: + raise AttributeError( + "Could not find output (lm_head) in the model. Please check the model structure." + ) + + @property + def rotary_emb(self): + """Returns the model's rotary_emb, handling different Hugging Face model structures.""" + if hasattr(self.model, "model") and hasattr( + self.model.model, "rotary_emb" + ): # Llama-like + return self.model.model.rotary_emb + else: + raise AttributeError( + "Could not find rotary_emb in the model. Please check the model structure." + ) + + @rotary_emb.setter + def rotary_emb(self, value): + if hasattr(self.model, "model") and hasattr( + self.model.model, "rotary_emb" + ): # Llama-like + self.model.model.rotary_emb = value + else: + raise AttributeError( + "Could not find rotary_emb in the model. Please check the model structure." + ) + + def forward(self, *args, **kwargs): + local_seq_len = self.max_seq_len + local_seq_len //= ( + self.cp_mesh.size() + if self.cp_mesh is not None and self.cp_mesh.size() > 1 + else 1 + ) + kwargs["position_ids"] = torch.arange( + local_seq_len, device=args[0].device + ).unsqueeze(0) + output = self.model.model(*args, **kwargs) + output = self.model.lm_head(output.last_hidden_state) + return output + + def init_weights(self, *args, **kwargs): + # This method replicates the behavior of the original PreTrainedModel.init_weights, + # but with a custom weight initialization function that skips nn.Identity modules (when PP is enabled) + + if self.model.config.pruned_heads: + logger.info("Pruning heads as per model configuration.") + self.model.prune_heads(self.model.config.pruned_heads) + + original_init_weights_fn = self.model._init_weights + + def selective_init(module): + # For pipeline parallel, we need to skip nn.Identity modules + if not isinstance(module, nn.Identity): + original_init_weights_fn(module) + else: + logger.info("Skipping nn.Identity module during weight initialization.") + + self.model.apply(selective_init) + + # TODO(3outeille): For pipeline parallel, only tie weights if both input and output embeddings are on the same device + # Maybe better way of handling this? + if not isinstance(self.tok_embeddings, nn.Identity) and not isinstance( + self.output, nn.Identity + ): + self.model.tie_weights() + + def named_children(self): + """ + Provides a flattened view of the model's main components, + making it compatible with TorchTitan's expectations. + """ + yield "tok_embeddings", self.tok_embeddings + yield "layers", self.layers + yield "norm", self.norm + yield "output", self.output + yield "rotary_emb", self.rotary_emb + + def __setattr__(self, name, value): + # If a property with a setter exists for this name, use it. + # This is to bypass the nn.Module.__setattr__ logic that + # directly registers modules and skips property setters. + cls = self.__class__ + if hasattr(cls, name): + prop = getattr(cls, name) + if isinstance(prop, property) and prop.fset is not None: + prop.fset(self, value) + return + + # Otherwise, fall back to the default nn.Module behavior. + super().__setattr__(name, value) diff --git a/torchtitan/experiments/transformers_backend/tests/integration_tests.py b/torchtitan/experiments/transformers_backend/tests/integration_tests.py new file mode 100644 index 0000000000..35d09d6a94 --- /dev/null +++ b/torchtitan/experiments/transformers_backend/tests/integration_tests.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os + +from tests.integration_tests import OverrideDefinitions +from tests.integration_tests.run_tests import run_tests + + +def build_transformers_backend_test_list() -> list[OverrideDefinitions]: + """ + key is the config file name and value is a list of OverrideDefinitions + that is used to generate variations of integration tests based on the + same root config file. + """ + integration_tests_flavors = [ + OverrideDefinitions( + [ + [ + "--model.name transformers_backend", + "--job.custom_config_module=torchtitan.experiments.transformers_backend.job_config", + "--hf_transformers.model Qwen/Qwen2.5-7B", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.tensor_parallel_degree 2", + "--parallelism.pipeline_parallel_degree 2", + "--parallelism.pipeline_parallel_schedule 1F1B", + ], + ], + "Transformers Backend FSDP+TP+PP", + "transformers_backend_fsdp+tp+pp", + ngpu=8, + ), + ] + return integration_tests_flavors + + +_TEST_SUITES_FUNCTION = { + "transformers_backend": build_transformers_backend_test_list, +} + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("output_dir") + parser.add_argument( + "--config_path", + default="./tests/integration_tests/base_config.toml", + help="Base config path for integration tests. This is the config that will be used as a base for all tests.", + ) + parser.add_argument( + "--test_name", + default="all", + help="test to run, acceptable values: `test_name` in `build_test_list` (default: all)", + ) + parser.add_argument("--ngpu", default=8, type=int) + args = parser.parse_args() + + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + if os.listdir(args.output_dir): + raise RuntimeError("Please provide an empty output directory.") + + test_list = _TEST_SUITES_FUNCTION["transformers_backend"]() + run_tests(args, test_list) + + +if __name__ == "__main__": + main() diff --git a/torchtitan/train.py b/torchtitan/train.py index 5cfab998b2..d157a3a307 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -474,11 +474,17 @@ def forward_backward_step( ) # apply context parallelism if cp is enabled # ensure CP handles the separate freqs_cis buffer for each pp stage + cp_buffers = [inputs, labels] + cp_seq_dims = [1, 1] + if hasattr(model_parts[0], "freqs_cis"): + cp_buffers += [m.freqs_cis for m in model_parts] + cp_seq_dims += [0 for _ in model_parts] + optional_context_parallel_ctx = ( dist_utils.create_context_parallel_ctx( cp_mesh=parallel_dims.world_mesh["cp"], - cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts], - cp_seq_dims=[1, 1] + [0 for _ in model_parts], + cp_buffers=cp_buffers, + cp_seq_dims=cp_seq_dims, cp_no_restore_buffers={inputs, labels}, cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method, )