Commit 58fa181
authored
3outeille/transformers backend (Dense model only) (#2048)
# Context
Reference PR: huggingface#1
This PR enables:
- Llama-like HF models to work with 4D parallelism: FSDP, CP, TP, PP
(and the combinations between them). The following models were tested:
- `meta-llama/Llama-3.2-1B`
- `microsoft/phi-2`
- `Qwen/Qwen2.5-7B`
- `mistralai/Mistral-7B-v0.1`
- `ByteDance-Seed/Seed-Coder-8B-Instruct`
- `Qwen/Qwen3-4B-Instruct-2507`
- `arcee-ai/AFM-4.5B`
- `ibm-granite/granite-3b-code-base-2k`
- `baidu/ERNIE-4.5-0.3B-Base-PT`
- `kyutai/helium-1-preview-2b`
- `allenai/OLMo-7B-hf`
- `mistralai/Ministral-8B-Instruct-2410`
- Patching HF models weights initialisation. Without this, the the
`loss` and `grad_norm` starts very high
# Usage
- Requirements `transformers==4.57.1`
- Config:
`torchtitan/torchtitan/experiments/transformers_backend/configs/qwen3.toml`
```diff
...
[model]
- name = "llama3"
+ name = "transformers_backend"
flavor = "debugmodel"
hf_assets_path = "./tests/assets/tokenizer"
+[hf_transformers]
+model = "Qwen/Qwen3-4B-Instruct-2507"
...
```
- Train: `LOG_RANK=7
CONFIG_FILE=<YOUR_PATH>/torchtitan/experiments/transformers_backend/configs/qwen3.toml
./run_train.sh
--job.custom_config_module=torchtitan.experiments.transformers_backend.job_config
--compile.enable`
<img width="1334" height="453" alt="image"
src="https://github.com/user-attachments/assets/da459448-027b-4af9-8176-6a3e433a272c"
/>
# Testing methodology
<img width="2672" height="2018" alt="image"
src="https://github.com/user-attachments/assets/66d8689d-7ede-47e3-b389-d4fc1bdd70f7"
/>
- Following the
[converging.md](https://github.com/pytorch/torchtitan/blob/main/docs/converging.md)
guidelines, I am comparing the baseline `FSDP=2` vs `FSDP=2 & <other
//-ism>`
- More precisely, the `test_hf_integration.py`is going to do:
```bash
results/
|_ meta-llama
|_ Llama-3.2-1B
|_ debugmodel/
|_ seed_checkpoint/
|_ config.toml
|_ seed.slurm
|_ step-0/
|_ ....
|_ fsdp2_tp1_cp1_pp1/
|_ config.toml
|_ nd_parallelism.slurm
|_ nd_parallelism.log
|_ fsdp2_tp2_cp1_pp1/
|_ config.toml
|_ nd_parallelism.slurm
|_ nd_parallelism.log
|_ diff_baseline_vs_nd_parallelism.log
|_ fsdp2_tp1_cp1_pp2/
|_ config.toml
|_ nd_parallelism.slurm
|_ nd_parallelism.log
|_ diff_baseline_vs_nd_parallelism.log
|_ fsdp2_tp1_cp2_pp1/
|_ config.toml
|_ nd_parallelism.slurm
|_ nd_parallelism.log
|_ diff_baseline_vs_nd_parallelism.log
|_ fsdp2_tp1_cp2_pp2/
|_ config.toml
|_ nd_parallelism.slurm
|_ nd_parallelism.log
|_ diff_baseline_vs_nd_parallelism.log`
|_ full/
...
```
- Here is the grid search to test the HF modelling
```shell
#!/usr/bin/bash
model_names=(
"meta-llama/Llama-3.2-1B"
"microsoft/phi-2"
"Qwen/Qwen2.5-7B"
"mistralai/Mistral-7B-v0.1"
"ByteDance-Seed/Seed-Coder-8B-Instruct"
"Qwen/Qwen3-4B-Instruct-2507"
"arcee-ai/AFM-4.5B"
"ibm-granite/granite-3b-code-base-2k"
"baidu/ERNIE-4.5-0.3B-Base-PT"
"kyutai/helium-1-preview-2b"
"allenai/OLMo-7B-hf"
"mistralai/Ministral-8B-Instruct-2410"
)
for model_name in "${model_names[@]}"; do
rm -rf slurm_results/${model_name}
python test_hf_integration.py create_configs --model_name "$model_name" --out_dir slurm_results --flavor debugmodel
python test_hf_integration.py submit_jobs --inp_dir slurm_results/${model_name}/debugmodel/seed_checkpoint --qos high
while [ ! -f slurm_results/${model_name}/debugmodel/seed_checkpoint/status.txt ] || [ "$(cat slurm_results/${model_name}/debugmodel/seed_checkpoint/status.txt)" != "completed" ]; do
echo "Waiting for seed checkpoint from ${model_name} to complete ..."
sleep 1
done
python test_hf_integration.py submit_jobs --inp_dir slurm_results/${model_name}/debugmodel --qos high
echo "================"
done
```
# Further tasks
- Moe (handle in PR huggingface#3)
- Missing `build_optimizers_with_moe_load_balancing` support for MoE
- Missing TP/PP/EP supports for MoE
- When using HF modeling, the test `FSDP=2 vs FSDP=2 + PP=2`, the `loss`
and `grad_norm` not bitwise matching (but converging) while it is the
case with Torchtitan modeling. (issue is tracked in
huggingface#4)
- Add convergence tests to CI by doing tiny model + gloo backend (once
PP is bitwise matching)
- the HF modeling has lower MFU than Torchtitan MFU
- NOTE: `import torch._dynamo.config;
torch._dynamo.config.cache_size_limit = 128` to avoid recomputation for
graph when using `torch.compile` and `activation checkpointing`1 parent d167a20 commit 58fa181
File tree
17 files changed
+1936
-2
lines changed- .ci/docker
- common
- ubuntu
- .github/workflows
- torchtitan
- experiments
- transformers_backend
- configs
- infra
- model
- tests
17 files changed
+1936
-2
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
43 | 43 | | |
44 | 44 | | |
45 | 45 | | |
| 46 | + | |
46 | 47 | | |
47 | 48 | | |
48 | 49 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
33 | 33 | | |
34 | 34 | | |
35 | 35 | | |
| 36 | + | |
36 | 37 | | |
37 | 38 | | |
38 | 39 | | |
| |||
Lines changed: 53 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
31 | 31 | | |
32 | 32 | | |
33 | 33 | | |
| 34 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
12 | 12 | | |
13 | 13 | | |
14 | 14 | | |
| 15 | + | |
15 | 16 | | |
16 | 17 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
Lines changed: 88 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
Lines changed: 87 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
0 commit comments