Skip to content

Conversation

@3outeille
Copy link
Member

  • MoE-like HF models to work with 2D parallelism: FSDP, CP (and the combinations between them) + WITHOUT build_optimizers_with_moe_load_balancing. The following models were tested:
    • deepseek-ai/DeepSeek-V3
    • moonshotai/Moonlight-16B-A3B
    • moonshotai/Kimi-K2-Instruct
    • zai-org/GLM-4.5

@3outeille 3outeille changed the title begining moe load balancing Add transformer backend (MoE) clean Nov 3, 2025
@3outeille 3outeille force-pushed the 3outeille/transformers_backend branch from a241ae7 to a70c4c4 Compare November 4, 2025 10:09
@3outeille 3outeille force-pushed the 3outeille/transformers_backend branch from eceb1c3 to 5f1695f Compare November 14, 2025 14:28
tianyu-l pushed a commit to pytorch/torchtitan that referenced this pull request Nov 20, 2025
# Context
Reference PR: huggingface#1

This PR enables:
- Llama-like HF models to work with 4D parallelism: FSDP, CP, TP, PP
(and the combinations between them). The following models were tested:
  - `meta-llama/Llama-3.2-1B`
  - `microsoft/phi-2`
  - `Qwen/Qwen2.5-7B`
  - `mistralai/Mistral-7B-v0.1`
  - `ByteDance-Seed/Seed-Coder-8B-Instruct`
  - `Qwen/Qwen3-4B-Instruct-2507`
  - `arcee-ai/AFM-4.5B`
  - `ibm-granite/granite-3b-code-base-2k`
  - `baidu/ERNIE-4.5-0.3B-Base-PT`
  - `kyutai/helium-1-preview-2b`
  - `allenai/OLMo-7B-hf`
  - `mistralai/Ministral-8B-Instruct-2410`
- Patching HF models weights initialisation. Without this, the the
`loss` and `grad_norm` starts very high

# Usage

- Requirements `transformers==4.57.1`
- Config:
`torchtitan/torchtitan/experiments/transformers_backend/configs/qwen3.toml`
```diff
...
[model]
- name = "llama3"
+ name = "transformers_backend"
flavor = "debugmodel"
hf_assets_path = "./tests/assets/tokenizer"

+[hf_transformers]
+model = "Qwen/Qwen3-4B-Instruct-2507"
...
```
- Train: `LOG_RANK=7
CONFIG_FILE=<YOUR_PATH>/torchtitan/experiments/transformers_backend/configs/qwen3.toml
./run_train.sh
--job.custom_config_module=torchtitan.experiments.transformers_backend.job_config
--compile.enable`

<img width="1334" height="453" alt="image"
src="https://github.com/user-attachments/assets/da459448-027b-4af9-8176-6a3e433a272c"
/>

# Testing methodology

<img width="2672" height="2018" alt="image"
src="https://github.com/user-attachments/assets/66d8689d-7ede-47e3-b389-d4fc1bdd70f7"
/>

- Following the
[converging.md](https://github.com/pytorch/torchtitan/blob/main/docs/converging.md)
guidelines, I am comparing the baseline `FSDP=2` vs `FSDP=2 & <other
//-ism>`
- More precisely, the `test_hf_integration.py`is going to do:

```bash
    results/
        |_ meta-llama
            |_ Llama-3.2-1B
                |_ debugmodel/
                    |_ seed_checkpoint/
                        |_ config.toml
                        |_ seed.slurm
                        |_ step-0/
                           |_ ....
                    |_ fsdp2_tp1_cp1_pp1/
                        |_ config.toml
                        |_ nd_parallelism.slurm
                        |_ nd_parallelism.log
                    |_ fsdp2_tp2_cp1_pp1/
                        |_ config.toml
                        |_ nd_parallelism.slurm
                        |_ nd_parallelism.log
                        |_ diff_baseline_vs_nd_parallelism.log
                    |_ fsdp2_tp1_cp1_pp2/
                        |_ config.toml
                        |_ nd_parallelism.slurm
                        |_ nd_parallelism.log
                        |_ diff_baseline_vs_nd_parallelism.log
                    |_ fsdp2_tp1_cp2_pp1/
                        |_ config.toml
                        |_ nd_parallelism.slurm
                        |_ nd_parallelism.log
                        |_ diff_baseline_vs_nd_parallelism.log
                    |_ fsdp2_tp1_cp2_pp2/
                        |_ config.toml
                        |_ nd_parallelism.slurm
                        |_ nd_parallelism.log
                        |_ diff_baseline_vs_nd_parallelism.log`
                |_ full/
                ...
```
- Here is the grid search to test the HF modelling
```shell
#!/usr/bin/bash
model_names=(
     "meta-llama/Llama-3.2-1B"
     "microsoft/phi-2" 
     "Qwen/Qwen2.5-7B"
     "mistralai/Mistral-7B-v0.1"
     "ByteDance-Seed/Seed-Coder-8B-Instruct"
     "Qwen/Qwen3-4B-Instruct-2507" 
     "arcee-ai/AFM-4.5B" 
     "ibm-granite/granite-3b-code-base-2k" 
     "baidu/ERNIE-4.5-0.3B-Base-PT" 
     "kyutai/helium-1-preview-2b" 
     "allenai/OLMo-7B-hf"
     "mistralai/Ministral-8B-Instruct-2410" 
)

for model_name in "${model_names[@]}"; do
    rm -rf slurm_results/${model_name}

    python test_hf_integration.py create_configs --model_name "$model_name" --out_dir slurm_results --flavor debugmodel
    python test_hf_integration.py submit_jobs --inp_dir slurm_results/${model_name}/debugmodel/seed_checkpoint --qos high
    while [ ! -f slurm_results/${model_name}/debugmodel/seed_checkpoint/status.txt ] || [ "$(cat slurm_results/${model_name}/debugmodel/seed_checkpoint/status.txt)" != "completed" ]; do
        echo "Waiting for seed checkpoint from ${model_name} to complete ..."
        sleep 1
    done
    python test_hf_integration.py submit_jobs --inp_dir slurm_results/${model_name}/debugmodel --qos high
    echo "================"
done
```

# Further tasks

- Moe (handle in PR huggingface#3)
	- Missing `build_optimizers_with_moe_load_balancing` support for MoE
	- Missing TP/PP/EP supports for MoE 
- When using HF modeling, the test `FSDP=2 vs FSDP=2 + PP=2`, the `loss`
and `grad_norm` not bitwise matching (but converging) while it is the
case with Torchtitan modeling. (issue is tracked in
huggingface#4)
- Add convergence tests to CI by doing tiny model + gloo backend (once
PP is bitwise matching)
- the HF modeling has lower MFU than Torchtitan MFU
- NOTE: `import torch._dynamo.config;
torch._dynamo.config.cache_size_limit = 128` to avoid recomputation for
graph when using `torch.compile` and `activation checkpointing`
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants