Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ci/docker/requirements-transformers-backend.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
transformers==4.55.4
transformers==4.57.1
11 changes: 5 additions & 6 deletions torchtitan/experiments/transformers_backend/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,21 @@

## Quick start

- Requirements `transformers==4.55.4`
- Requirements `transformers==4.57.1`

- Config: `torchtitan/torchtitan/experiments/transformers_backend/configs/qwen3_fsdp2_tp2_pp2.toml`
```diff
...
[model]
- name = "llama3"
+ name = "transformers_backend"
+ name = "Qwen/Qwen3-4B-Instruct-2507"
flavor = "debugmodel"
hf_assets_path = "./tests/assets/tokenizer"

+[hf_transformers]
+model = "Qwen/Qwen3-4B-Instruct-2507"
...
```
- Train: `LOG_RANK=7 CONFIG_FILE=<YOUR_PATHQ/torchtitan/experiments/transformers_backend/configs/qwen3_fsdp2_tp2_pp2.toml ./run_train.sh --job.custom_config_module=torchtitan.experiments.transformers_backend.job_config --compile.enable`
**Note:** Any model name containing "/" is automatically recognized as a HuggingFace model ID and will use the `transformers_backend`.

- Train: `LOG_RANK=7 CONFIG_FILE=<YOUR_PATH>/torchtitan/experiments/transformers_backend/configs/qwen3_fsdp2_tp2_pp2.toml ./run_train.sh --compile.enable`
- Make sure you have created the tokenizers beforehand
<img width="1334" height="453" alt="image" src="https://github.com/user-attachments/assets/da459448-027b-4af9-8176-6a3e433a272c" />

Expand Down
102 changes: 99 additions & 3 deletions torchtitan/experiments/transformers_backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,22 @@
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass

import torch.nn as nn

from torchtitan.components.ft import FTManager
from torchtitan.models.moe import MoEArgs
from torchtitan.components.loss import build_cross_entropy_loss
from torchtitan.components.lr_scheduler import build_lr_schedulers
from torchtitan.components.optimizer import build_optimizers
from torchtitan.components.optimizer import (
build_optimizers,
build_optimizers_with_moe_load_balancing,
OptimizersContainer,
)
from torchtitan.components.tokenizer import build_hf_tokenizer
from torchtitan.config import Optimizer as OptimizerConfig
from torchtitan.distributed import ParallelDims
from torchtitan.hf_datasets.text_datasets import build_text_dataloader
from torchtitan.models.moe import MoEArgs
from torchtitan.protocols.train_spec import TrainSpec

from .infra.parallelize import parallelize_hf_transformers
Expand All @@ -24,7 +35,6 @@
"HFTransformerModel",
]


@dataclass
class TitanDenseModelArgs:
"""Arguments for the base TorchTitan model."""
Expand All @@ -43,6 +53,30 @@ class TitanDenseModelArgs:
use_flex_attn: bool = False
attn_mask_type: str = "causal"

@dataclass
class TitanMoeModelArgs:
"""Arguments specific to DeepSeekV3 models."""

moe_args: MoEArgs | None = None
n_group: int | None = None
topk_group: int | None = None
inter_dim: int | None = None
moe_inter_dim: int | None = None
n_dense_layers: int | None = None
n_expert_groups: int | None = None
n_limited_groups: int | None = None
q_lora_rank: int | None = None
kv_lora_rank: int | None = None
qk_nope_head_dim: int | None = None
qk_rope_head_dim: int | None = None
v_head_dim: int | None = None
original_seq_len: int | None = None
rope_factor: float | None = None
beta_fast: int | None = None
beta_slow: int | None = None
mscale: float | None = None
partial_rotary_factor: float | None = None
rope_interleave: bool = True

flavors = {
"debugmodel": HFTransformerModelArgs(
Expand All @@ -53,19 +87,81 @@ class TitanDenseModelArgs:
n_kv_heads=16,
),
),
"debugmodel_moe": HFTransformerModelArgs(
titan_dense_args=TitanDenseModelArgs(
dim=256,
n_layers=3,
n_heads=16,
n_kv_heads=16,
),
titan_moe_args=TitanMoeModelArgs(
partial_rotary_factor=4.0,
inter_dim=1024,
moe_inter_dim=256,
n_dense_layers=1,
n_group=2,
topk_group=1,
kv_lora_rank=512,
q_lora_rank=0,
qk_nope_head_dim=128,
qk_rope_head_dim=64,
v_head_dim=128,
mscale=0.70,
moe_args=MoEArgs(
num_experts=8,
num_shared_experts=2,
top_k=3,
score_func="softmax",
route_norm=True,
score_before_experts=False,
load_balance_coeff=1e-3,
),
),
),
"full": HFTransformerModelArgs(
titan_dense_args=TitanDenseModelArgs(),
),
}

def build_optimizers_auto_detect_moe(
model_parts: list[nn.Module],
optimizer_config: OptimizerConfig,
parallel_dims: ParallelDims,
ft_manager: FTManager | None = None,
) -> OptimizersContainer:

# Check if any model part has MoE enabled
has_moe = False
for model_part in model_parts:
if hasattr(model_part, "layers"):
for layer in model_part.layers:
if hasattr(layer, "moe_enabled") and layer.moe_enabled:
has_moe = True
break
if has_moe:
break

if has_moe:
# NOTE(3outeille): Monkey-patch temporarily for compatibility. Otherwise, I will need to copy optimizer.py just to loop over layer instead of layer.values().
for model_part in model_parts:
if hasattr(model_part, "layers") and not hasattr(model_part.layers, "values"):
model_part.layers.values = lambda self=model_part.layers: iter(self)

return_val = (build_optimizers_with_moe_load_balancing if has_moe else build_optimizers)(
model_parts=model_parts,
optimizer_config=optimizer_config,
parallel_dims=parallel_dims,
ft_manager=ft_manager,
)
return return_val

def get_train_spec() -> TrainSpec:
return TrainSpec(
model_cls=HFTransformerModel,
model_args=flavors,
parallelize_fn=parallelize_hf_transformers,
pipelining_fn=pipeline_hf_transformers,
build_optimizers_fn=build_optimizers,
build_optimizers_fn=build_optimizers_auto_detect_moe,
build_lr_schedulers_fn=build_lr_schedulers,
build_dataloader_fn=build_text_dataloader,
build_tokenizer_fn=build_hf_tokenizer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,12 @@ save_tb_folder = "tb"
enable_wandb = false

[model]
name = "transformers_backend"
name = "Qwen/Qwen3-4B-Instruct-2507"
flavor = "debugmodel"
# test folder with tokenizer.json, for debug purpose only
hf_assets_path = "./tests/assets/tokenizer"
# converters = ["float8"]

[hf_transformers]
model = "Qwen/Qwen3-4B-Instruct-2507"

[optimizer]
name = "AdamW"
lr = 8e-4
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
SequenceParallel,
)
from torchtitan.config import TORCH_DTYPE_MAP
from torchtitan.config.job_config import JobConfig
from torchtitan.distributed import NoParallel, ParallelDims

from torchtitan.distributed.activation_checkpoint import apply_ac

from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp
from torchtitan.experiments.transformers_backend.job_config import JobConfig
from torchtitan.models.llama3.infra.parallelize import apply_compile, apply_ddp
from torchtitan.tools.logging import logger

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
)

from torchtitan.components.loss import LossFunction
from torchtitan.config.job_config import JobConfig
from torchtitan.distributed import ParallelDims
from torchtitan.distributed.pipeline_parallel import (
build_pipeline_schedule,
generate_llm_fqn_per_model_part,
pipeline_module_split,
)
from torchtitan.experiments.transformers_backend.job_config import JobConfig
from torchtitan.protocols.train_spec import BaseModelArgs, ParallelizeFunction
from torchtitan.tools.logging import logger

Expand Down
18 changes: 0 additions & 18 deletions torchtitan/experiments/transformers_backend/job_config.py

This file was deleted.

82 changes: 77 additions & 5 deletions torchtitan/experiments/transformers_backend/model/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@
from dataclasses import dataclass

from torch import nn
from torchtitan.config.job_config import JobConfig
from torchtitan.models.utils import get_dense_model_nparams_and_flops
from torchtitan.config import JobConfig
from torchtitan.models.utils import (
get_dense_model_nparams_and_flops,
get_moe_model_nparams_and_flops,
)
from torchtitan.protocols import BaseModelArgs
from transformers import AutoConfig
from transformers.configuration_utils import PretrainedConfig
Expand Down Expand Up @@ -36,7 +39,12 @@ class HFTransformerModelArgs(PretrainedConfig, BaseModelArgs):
"norm_eps": "rms_norm_eps",
"max_seq_len": "max_position_embeddings",
"eos_id": "eos_token_id",
}
},
"moe": {
# TorchTitan moe model specific mappings (only when titan_moe_args provided)
"inter_dim": "intermediate_size",
"n_dense_layers": "first_k_dense_replace",
},
}

# Declarative list of TorchTitan-only attributes (no HF equivalent)
Expand All @@ -48,9 +56,23 @@ class HFTransformerModelArgs(PretrainedConfig, BaseModelArgs):
"attn_mask_type",
]

# MoE attributes that should be copied directly
_MOE_SHARED_ATTRIBUTES = [
"rope_interleave",
"partial_rotary_factor",
"n_group",
"topk_group",
"kv_lora_rank",
"q_lora_rank",
"qk_nope_head_dim",
"qk_rope_head_dim",
"v_head_dim",
]

def __init__(
self,
titan_dense_args,
titan_moe_args=None,
# HuggingFace specific args
attn_implementation: str = "sdpa_torchtitan",
**kwargs,
Expand All @@ -59,14 +81,17 @@ def __init__(
assert titan_dense_args is not None, "titan_dense_args is required"

# Create getter/setter dynamically for TT <-> HF attribute mappings
self._create_getter_setter_dynamically(has_moe=False)
self._create_getter_setter_dynamically(titan_moe_args is not None)

self._titan_injected_model_args = {}
self._titan_injected_model_args.update(kwargs)
self._configure_hf_attention(attn_implementation)

self._initialize_dense_attributes(titan_dense_args)

if titan_moe_args is not None:
self._initialize_moe_attributes(titan_moe_args)

def _initialize_dense_attributes(self, titan_dense_args):
"""Initialize all dense model attributes."""
# Set mapped attributes (TorchTitan <-> HuggingFace)
Expand All @@ -83,6 +108,45 @@ def _initialize_dense_attributes(self, titan_dense_args):
# Update passed_args
self._titan_injected_model_args.update(titan_dense_args.__dict__)

def _initialize_moe_attributes(self, titan_moe_args):
"""Initialize all MoE-specific attributes."""
if titan_moe_args.moe_args is None:
self._titan_injected_model_args.update(titan_moe_args.__dict__)
return

moe_args = titan_moe_args.moe_args

# Convert q_lora_rank (0 -> None for HuggingFace compatibility)
self.q_lora_rank = (
None if titan_moe_args.q_lora_rank == 0 else titan_moe_args.q_lora_rank
)

# Set core MoE attributes
self.moe_args = moe_args
self.num_experts_per_tok = moe_args.top_k
self.n_routed_experts = moe_args.num_experts
self.n_shared_experts = moe_args.num_shared_experts
self.moe_intermediate_size = titan_moe_args.moe_inter_dim

# Set remaining architecture-specific MoE attributes
for attr in self._MOE_SHARED_ATTRIBUTES:
if attr == "q_lora_rank":
continue # Already set above
if hasattr(titan_moe_args, attr):
setattr(self, attr, getattr(titan_moe_args, attr))

# Track all MoE arguments
self._titan_injected_model_args.update(titan_moe_args.__dict__)
self._titan_injected_model_args.update(
{
"num_experts_per_tok": moe_args.top_k,
"n_routed_experts": moe_args.num_experts,
"n_shared_experts": moe_args.num_shared_experts,
"moe_intermediate_size": titan_moe_args.moe_inter_dim,
"q_lora_rank": self.q_lora_rank,
}
)

def _configure_hf_attention(self, attn_implementation: str):
"""Configure HuggingFace attention settings."""
self._titan_injected_model_args["attn_implementation"] = attn_implementation
Expand Down Expand Up @@ -151,6 +215,9 @@ def update_from_config(self, job_config: JobConfig):
if hasattr(self, key) and value is not None:
setattr(self, key, value)

# MoE
if hasattr(self, "qk_nope_head_dim") and hasattr(self, "qk_rope_head_dim"):
self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
self.max_seq_len = job_config.training.seq_len

# Configure HF-specific settings to match TorchTitan settings
Expand All @@ -174,4 +241,9 @@ def update_from_config(self, job_config: JobConfig):
return self

def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]:
return get_dense_model_nparams_and_flops(self, model, head_dims=self.head_dim, seq_len=seq_len)
is_moe = hasattr(self, "n_routed_experts")

if is_moe:
return get_moe_model_nparams_and_flops(self, model, head_dims=self.head_dim, seq_len=seq_len)
else:
return get_dense_model_nparams_and_flops(self, model, head_dims=self.head_dim, seq_len=seq_len)
Loading
Loading