Skip to content

Commit c412663

Browse files
committed
finished training with hf models
1 parent b291ad6 commit c412663

File tree

15 files changed

+1826
-0
lines changed

15 files changed

+1826
-0
lines changed

torchtitan/experiments/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Training LLAMA with HF weights
2+
3+
This directory contains scripts and configs for training LLAMA with HF weights using TorchTitan.
4+
5+
## Usage
6+
7+
### Install extra dependencies
8+
9+
```bash
10+
pip install -r extra_requirements.txt
11+
```
12+
13+
### Test loading HF weights
14+
15+
```bash
16+
pytest test_loading_hf_weights.py
17+
```
18+
19+
### Run training
20+
21+
```bash
22+
LOG_RANK=7 bash run_train.sh
23+
```
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
#
7+
# Llama 3 is licensed under the LLAMA 3 Community License,
8+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
9+
10+
import torchtitan.experiments.train_llama_hf.model # noqa: F401
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Optional
8+
9+
import torch
10+
11+
from datasets import Dataset
12+
from datasets.distributed import split_dataset_by_node
13+
from torch.distributed.checkpoint.stateful import Stateful
14+
from torch.utils.data import IterableDataset
15+
from transformers import PreTrainedTokenizerBase
16+
17+
from torchtitan.components.dataloader import ParallelAwareDataloader
18+
from torchtitan.config_manager import JobConfig
19+
from torchtitan.datasets.hf_datasets import _validate_dataset
20+
from torchtitan.tools.logging import logger
21+
22+
23+
class HuggingFaceDataset(IterableDataset, Stateful):
24+
def __init__(
25+
self,
26+
dataset_name: str,
27+
dataset_path: Optional[str],
28+
tokenizer: PreTrainedTokenizerBase,
29+
seq_len: int = 2048,
30+
dp_rank: int = 0,
31+
dp_world_size: int = 1,
32+
infinite: bool = False,
33+
) -> None:
34+
# Force lowercase for consistent comparison
35+
dataset_name = dataset_name.lower()
36+
37+
path, dataset_loader, text_processor = _validate_dataset(
38+
dataset_name, dataset_path
39+
)
40+
ds = dataset_loader(path)
41+
42+
self.dataset_name = dataset_name
43+
self._data = split_dataset_by_node(ds, dp_rank, dp_world_size)
44+
self._tokenizer = tokenizer
45+
self.seq_len = seq_len
46+
self.infinite = infinite
47+
self._text_processor = text_processor
48+
49+
# Variables for checkpointing
50+
self._sample_idx = 0
51+
self._all_tokens: list[int] = []
52+
53+
def _get_data_iter(self):
54+
if isinstance(self._data, Dataset) and self._sample_idx == len(self._data):
55+
return iter([])
56+
57+
it = iter(self._data)
58+
for _ in range(self._sample_idx):
59+
next(it)
60+
return it
61+
62+
def __iter__(self):
63+
max_buffer_token_len = 1 + self.seq_len
64+
65+
while True:
66+
for sample in self._get_data_iter():
67+
# Use the dataset-specific text processor
68+
sample_text = self._text_processor(sample)
69+
sample_tokens = self._tokenizer.encode(sample_text)
70+
self._all_tokens.extend(sample_tokens)
71+
self._sample_idx += 1
72+
73+
while len(self._all_tokens) >= max_buffer_token_len:
74+
x = torch.LongTensor(self._all_tokens[:max_buffer_token_len])
75+
# update tokens to the remaining tokens
76+
self._all_tokens = self._all_tokens[max_buffer_token_len:]
77+
input = x[:-1]
78+
label = x[1:]
79+
# Add position IDs (0 to seq_len-1)
80+
position_ids = torch.arange(len(input), dtype=torch.long)
81+
yield input, label, position_ids
82+
83+
if not self.infinite:
84+
logger.warning(f"Dataset {self.dataset_name} has run out of data")
85+
break
86+
else:
87+
# Reset offset for the next iteration
88+
self._sample_idx = 0
89+
logger.warning(f"Dataset {self.dataset_name} is being re-looped")
90+
91+
def load_state_dict(self, state_dict):
92+
self._sample_idx = state_dict["sample_idx"]
93+
self._all_tokens = state_dict["token_buffer"]
94+
95+
def state_dict(self):
96+
return {"token_buffer": self._all_tokens, "sample_idx": self._sample_idx}
97+
98+
99+
def build_hf_dataloader(
100+
dp_world_size: int,
101+
dp_rank: int,
102+
tokenizer,
103+
job_config: JobConfig,
104+
infinite: bool = True,
105+
) -> ParallelAwareDataloader:
106+
"""Build a data loader for HuggingFace datasets."""
107+
dataset_name = job_config.training.dataset
108+
dataset_path = job_config.training.dataset_path
109+
batch_size = job_config.training.batch_size
110+
seq_len = job_config.training.seq_len
111+
112+
hf_ds = HuggingFaceDataset(
113+
dataset_name=dataset_name,
114+
dataset_path=dataset_path,
115+
tokenizer=tokenizer,
116+
seq_len=seq_len,
117+
dp_rank=dp_rank,
118+
dp_world_size=dp_world_size,
119+
infinite=infinite,
120+
)
121+
122+
return ParallelAwareDataloader(
123+
dataset=hf_ds,
124+
dp_rank=dp_rank,
125+
dp_world_size=dp_world_size,
126+
batch_size=batch_size,
127+
)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
transformers >=4.49.0
2+
sentencepiece >=0.2.0
Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import gc
8+
import json
9+
from collections import defaultdict
10+
from pathlib import Path
11+
12+
import torch.nn as nn
13+
from huggingface_hub import repo_exists, snapshot_download
14+
from safetensors import safe_open
15+
from torch.distributed.tensor import distribute_tensor, DTensor
16+
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
17+
18+
from torchtitan.tools.logging import logger
19+
20+
INDEX_NAME_MAPPING = {
21+
"safetensors": SAFE_WEIGHTS_INDEX_NAME,
22+
}
23+
24+
PATTERNS_TO_REMOVE = [
25+
"._orig_mod", # Some optimizers add suffixes
26+
"._fsdp_wrapped_module", # FSDP wrapper
27+
"._checkpoint_wrapped_module", # checkpoint wrapper
28+
".module", # DataParallel/DistributedDataParallel
29+
"_module.", # Some wrappers add prefix
30+
]
31+
32+
33+
def normalize_state_dict_key(
34+
key: str, patterns_to_remove: list[str] = PATTERNS_TO_REMOVE
35+
) -> str:
36+
"""
37+
Normalize the state dict key, remove the prefix or suffix added by various wrappers.
38+
Args:
39+
key: The original state dict key
40+
Returns:
41+
The normalized key
42+
"""
43+
normalized_key = key
44+
for pattern in patterns_to_remove:
45+
normalized_key = normalized_key.replace(pattern, "")
46+
47+
return normalized_key
48+
49+
50+
def get_weight_map(pretrained_model_path: Path) -> dict[str, str]:
51+
"""
52+
Get the weight map from the pretrained model.
53+
Args:
54+
pretrained_model_path: The path to the pretrained model.
55+
Returns:
56+
weight_map: A dictionary mapping from the path to the weight map to the list of state dict keys.
57+
"""
58+
index_file = pretrained_model_path / INDEX_NAME_MAPPING["safetensors"]
59+
if not index_file.exists():
60+
return None
61+
with open(index_file, "r") as f:
62+
metadata = json.load(f)
63+
return metadata["weight_map"]
64+
65+
66+
def group_state_dict_keys_and_st_partition_paths(
67+
pretrained_model_path: Path,
68+
state_dict_keys,
69+
weight_map,
70+
state_dict_map: dict[str, str] = None,
71+
):
72+
"""
73+
Group state dict keys and save them to a file.
74+
Args:
75+
pretrained_model_path: The path to the pretrained model.
76+
state_dict_keys: The state dict keys to group.
77+
weight_map: The weight map.
78+
state_dict_map: A dictionary mapping from the state dict key to the weight path.
79+
Returns:
80+
st_partition_map: A dictionary mapping from the weight path to the list of state dict keys.
81+
"""
82+
st_partition_map = defaultdict(list)
83+
for state_dict_key in state_dict_keys:
84+
ckpt_state_dict_key = (
85+
state_dict_map[state_dict_key]
86+
if state_dict_map is not None
87+
else state_dict_key
88+
)
89+
if weight_map is None:
90+
partition_path = pretrained_model_path / "model.safetensors"
91+
else:
92+
partition_path = pretrained_model_path / weight_map[ckpt_state_dict_key]
93+
st_partition_map[partition_path].append(state_dict_key)
94+
return st_partition_map
95+
96+
97+
def load_sharded_state_dict_for_model_from_path(
98+
pretrained_model_path: Path,
99+
model: nn.Module,
100+
mapping_dict: dict[str, str] = None,
101+
**kwargs,
102+
):
103+
"""
104+
Load the state dict sharded (depends on DTensor) from the pretrained model path. It only load the weights for current rank.
105+
Args:
106+
pretrained_model_path: The path to the pretrained model, it could be a local path or an s3 path.
107+
model: The model to load the state dict into.
108+
**kwargs: other arguments for torch.nn.Module.load_state_dict
109+
"""
110+
# check exceptions
111+
if not pretrained_model_path.exists():
112+
raise ValueError(
113+
f"The pretrained model path {pretrained_model_path} does not exist."
114+
)
115+
if not pretrained_model_path.is_dir():
116+
raise ValueError(
117+
f"The pretrained model path {pretrained_model_path} is not a directory."
118+
)
119+
# get the weight map
120+
weight_map = get_weight_map(pretrained_model_path)
121+
model_state_dict = model.state_dict()
122+
model_state_dict_keys = list(model_state_dict.keys())
123+
124+
# create a mapping_dict between the original state_dict_key and the weight_map_key if not provided
125+
mapping_dict = (
126+
mapping_dict
127+
if mapping_dict is not None
128+
else {key: normalize_state_dict_key(key) for key in model_state_dict_keys}
129+
)
130+
st_partition_map = group_state_dict_keys_and_st_partition_paths(
131+
pretrained_model_path, model_state_dict_keys, weight_map, mapping_dict
132+
)
133+
134+
# get the sharded state dict
135+
state_dict = {}
136+
for safetensor_partition_path, state_dict_keys in st_partition_map.items():
137+
with safe_open(safetensor_partition_path, framework="pt", device="cpu") as f:
138+
for state_dict_key in state_dict_keys:
139+
model_tensor = model_state_dict[state_dict_key]
140+
ckpt_state_dict_key = mapping_dict[state_dict_key]
141+
if isinstance(model_tensor, DTensor):
142+
local_tensor = f.get_tensor(ckpt_state_dict_key)
143+
state_dict[state_dict_key] = distribute_tensor(
144+
local_tensor,
145+
model_tensor.device_mesh,
146+
model_tensor.placements,
147+
)
148+
else:
149+
state_dict[state_dict_key] = f.get_tensor(ckpt_state_dict_key)
150+
model.load_state_dict(state_dict, **kwargs)
151+
del state_dict
152+
gc.collect()
153+
154+
155+
def load_sharded_state_dict_for_model_from_hf(
156+
pretrained_model_id_or_path: str,
157+
model: nn.Module,
158+
**kwargs,
159+
):
160+
"""
161+
Load the state dict sharded (depends on DTensor) from the pretrained model path. It only load the weights for current rank.
162+
Args:
163+
pretrained_model_id_or_path: The id or path to the pretrained model, it could be a repo id in huggingface,
164+
or a local path
165+
model: The model to load the state dict into.
166+
**kwargs: other arguments for torch.nn.Module.load_state_dict
167+
"""
168+
logger.info(f"Loading the state dict from {pretrained_model_id_or_path}")
169+
pretrained_model_id_or_path = Path(pretrained_model_id_or_path)
170+
if not pretrained_model_id_or_path.exists():
171+
if not repo_exists(str(pretrained_model_id_or_path)):
172+
raise ValueError(
173+
f"The pretrained model {pretrained_model_id_or_path} does not exist"
174+
)
175+
logger.info(
176+
f"Try to download the model from huggingface: {pretrained_model_id_or_path}"
177+
)
178+
pretrained_model_path = Path(
179+
snapshot_download(str(pretrained_model_id_or_path))
180+
)
181+
elif not pretrained_model_id_or_path.is_dir():
182+
raise ValueError(
183+
f"The pretrained model path {pretrained_model_id_or_path} is not a directory."
184+
)
185+
else:
186+
pretrained_model_path = pretrained_model_id_or_path
187+
188+
load_sharded_state_dict_for_model_from_path(pretrained_model_path, model, **kwargs)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
9+
10+
def cross_entropy_loss_hf(preds, labels):
11+
loss = torch.nn.functional.cross_entropy(
12+
preds[0].flatten(0, 1).float(), labels.flatten(0, 1)
13+
)
14+
return loss

0 commit comments

Comments
 (0)