Skip to content

Commit 1645f4a

Browse files
authored
Utils Refactor (#180)
Co-authored-by: Philip Bontrager <[email protected]>
1 parent 0e8b04b commit 1645f4a

33 files changed

+979
-668
lines changed

recipes/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
# LICENSE file in the root directory of this source tree.
66

77

8-
_RECIPE_LIST = ["finetune_llm"]
9-
_CONFIG_LISTS = {"finetune_llm": ["alpaca_llama2_finetune"]}
8+
_RECIPE_LIST = ["finetune_llm", "llama_generate"]
9+
_CONFIG_LISTS = {"finetune_llm": ["alpaca_llama2_finetune"], "llama_generate": []}
1010

1111

1212
def list_recipes():

recipes/configs/alpaca_llama2_finetune.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ shuffle: True
55

66
# Model Arguments
77
model: llama2_7b
8-
model_checkpoint: /tmp/llama2-7b-01112024
8+
model_checkpoint: /tmp/llama2-7b
99
tokenizer: llama2_tokenizer
1010
tokenizer_checkpoint: /tmp/tokenizer.model
1111

@@ -17,5 +17,6 @@ optimizer: SGD
1717
loss: CrossEntropyLoss
1818
output_dir: /tmp/alpaca-llama2-finetune
1919
device: cuda
20+
dtype: fp32
2021
fsdp: False
2122
activation_checkpointing: False

recipes/finetune_llm.py

Lines changed: 95 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -4,143 +4,111 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
78
import os
8-
import sys
99
from functools import partial
10-
from typing import Callable
1110

1211
import torch
13-
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
14-
apply_activation_checkpointing,
15-
)
16-
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
17-
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
18-
from torch.optim.optimizer import Optimizer
12+
from torch.cuda.amp import GradScaler
1913
from torch.utils.data import DataLoader, DistributedSampler
2014

21-
from torchtune.datasets import get_dataset, list_datasets
22-
from torchtune.models import get_model, get_tokenizer, list_models, list_tokenizers
23-
from torchtune.modules import TransformerDecoderLayer
24-
from torchtune.utils import TuneArgumentParser
25-
from torchtune.utils.batch_pad_sequence import batch_pad_to_longest_seq
26-
from torchtune.utils.env import get_world_size_and_rank, init_from_env, seed
15+
from torchtune import datasets, losses, models, modules, optim, utils
2716
from torchtune.utils.generation import generate_from_prompt
28-
from torchtune.utils.precision import (
29-
get_autocast_manager,
30-
get_grad_scaler,
31-
get_supported_dtypes,
32-
)
3317
from tqdm import tqdm
3418

3519

36-
def get_optimizer(model: torch.nn.Module, optimizer: str, lr: float) -> Optimizer:
37-
return getattr(torch.optim, optimizer)(model.parameters(), lr=lr)
38-
39-
40-
def get_loss(loss_fn: str) -> Callable:
41-
return getattr(torch.nn, loss_fn)()
42-
43-
44-
def get_logger():
45-
import logging
46-
47-
logger = logging.getLogger(__name__)
48-
logger.addHandler(logging.StreamHandler())
49-
logger.setLevel(logging.DEBUG)
50-
return logger.info
51-
52-
53-
def recipe(kwargs):
20+
def recipe(
21+
device,
22+
dtype,
23+
seed,
24+
model,
25+
model_checkpoint,
26+
tokenizer,
27+
tokenizer_checkpoint,
28+
dataset,
29+
shuffle,
30+
batch_size,
31+
fsdp,
32+
epochs,
33+
optimizer,
34+
loss,
35+
lr,
36+
activation_checkpointing,
37+
output_dir,
38+
run_generation,
39+
max_steps_per_epoch,
40+
):
5441
# ---- Initialize components ---- #
55-
logger = get_logger()
56-
57-
# ---- Initialize distributed process group ---- #
58-
device = init_from_env(device_type=kwargs["device"])
59-
# TODO: only supporting devices specified as "cpu", "cuda", or "cuda:n" currently
60-
device_type = (
61-
kwargs["device"]
62-
if kwargs["device"] in ("cpu", "cuda")
63-
else kwargs["device"].split(":")[0]
64-
)
65-
66-
# ---- Initialize seed ---- #
67-
# Fetch world size and rank after distributed process group initialization
68-
world_size, rank = get_world_size_and_rank()
69-
if kwargs["seed"] is not None:
70-
# Ensure that seed is different per rank (and its dataloader workers)
71-
seed(kwargs["seed"] + rank)
72-
73-
tokenizer = get_tokenizer(kwargs["tokenizer"], path=kwargs["tokenizer_checkpoint"])
74-
logger(msg=f"Loaded tokenizer from {kwargs['tokenizer_checkpoint']}")
75-
76-
autocast_precision = kwargs.get("autocast_precision", None)
77-
autocast_mgr = get_autocast_manager(
78-
device_type=device_type, precision=autocast_precision
79-
)
80-
grad_scaler = get_grad_scaler(autocast_precision, fsdp=kwargs["fsdp"])
81-
82-
model = get_model(
83-
kwargs["model"],
84-
device,
85-
)
86-
87-
if kwargs["fsdp"] or kwargs["activation_checkpointing"]:
88-
auto_wrap_policy = ModuleWrapPolicy(
89-
{TransformerDecoderLayer}
90-
) # TODO: remove model specific components
91-
if kwargs["fsdp"]:
92-
model = FSDP(
93-
model,
94-
auto_wrap_policy=auto_wrap_policy,
95-
device_id=device,
96-
param_init_fn=lambda m: m.to_empty(device=device, recurse=False),
42+
utils.init_distributed(fsdp)
43+
44+
# logger = logging.getLogger()
45+
# logger.setLevel(logging.DEBUG) # test
46+
logger = utils.get_logger("DEBUG")
47+
48+
device = utils.get_device(device)
49+
dtype = utils.get_dtype(dtype)
50+
seed = utils.set_seed(seed)
51+
52+
# ---- Setup model and load checkpoint ---- #
53+
tokenizer = models.get_tokenizer(tokenizer, path=tokenizer_checkpoint)
54+
logger.info(msg=f"Loaded tokenizer from {tokenizer_checkpoint}")
55+
56+
model = models.get_model(model, device=device)
57+
if fsdp:
58+
# TODO: initialize models for distributed on meta or cpu device to avoid OOMs
59+
model = utils.get_fsdp(
60+
model=model,
61+
device=device,
62+
dtype=dtype,
63+
strategy="FULL_SHARD",
64+
auto_wrap_policy={modules.TransformerDecoderLayer},
9765
)
98-
if kwargs["activation_checkpointing"]:
99-
apply_activation_checkpointing(
100-
model,
101-
check_fn=lambda mod: isinstance(
102-
mod, TransformerDecoderLayer
103-
), # TODO: remove model specific components
104-
auto_wrap_policy=auto_wrap_policy,
66+
if activation_checkpointing:
67+
utils.set_activation_checkpointing(
68+
model, auto_wrap_policy={modules.TransformerDecoderLayer}
10569
)
10670

107-
loaded_ckpt = torch.load(
108-
kwargs["model_checkpoint"], map_location="cpu", weights_only=True
109-
)
71+
loaded_ckpt = torch.load(model_checkpoint, map_location="cpu", weights_only=True)
11072
model.load_state_dict(loaded_ckpt)
111-
logger(msg=f"Loaded model from {kwargs['model_checkpoint']}")
73+
logger.info(msg=f"Loaded model from {model_checkpoint}")
11274

113-
opt = get_optimizer(model, kwargs["optimizer"], kwargs["lr"])
75+
# ---- Setup optimization functions ---- #
76+
opt = optim.get_optimizer(optimizer, model, lr)
11477
# TODO add lr schedule option
115-
loss_fn = get_loss(kwargs["loss"])
78+
loss_fn = losses.get_loss(loss)
79+
80+
autocast = utils.get_autocast(dtype, device)
81+
if dtype == torch.float16:
82+
grad_scaler = utils.get_gradient_scaler(fsdp=fsdp)
83+
else:
84+
grad_scaler = GradScaler(enabled=False)
11685

11786
# ---- Load dataset, set up sampler, and dataloader ---- #
118-
dataset = get_dataset(kwargs["dataset"], split="train", tokenizer=tokenizer)
87+
world_size, rank = utils.get_world_size_and_rank()
88+
ds = datasets.get_dataset(dataset, split="train", tokenizer=tokenizer)
11989
sampler = DistributedSampler(
120-
dataset,
90+
ds,
12191
num_replicas=world_size,
12292
rank=rank,
123-
shuffle=kwargs["shuffle"],
93+
shuffle=shuffle,
12494
seed=0,
12595
)
12696
dataloader = DataLoader(
127-
dataset=dataset,
128-
batch_size=kwargs["batch_size"],
97+
dataset=ds,
98+
batch_size=batch_size,
12999
sampler=sampler,
130100
collate_fn=partial(
131-
batch_pad_to_longest_seq,
132-
input_padding_idx=tokenizer.pad_id,
133-
label_padding_idx=loss_fn.ignore_index, # TODO support loss without ignore_index
101+
utils.padded_collate,
102+
padding_idx=tokenizer.pad_id,
103+
ignore_idx=loss_fn.ignore_index, # TODO support loss without ignore_index
134104
),
135105
)
136-
logger(msg=f"Loaded dataset {kwargs['dataset']}")
106+
logger.info(msg=f"Loaded dataset {dataset}")
137107

138108
# ---- Train loop ---- #
139-
for epoch in range(kwargs["epochs"]):
140-
# Need to set the epoch for changing sample ordering in each epoch
141-
sampler.set_epoch(epoch)
109+
for epoch in range(epochs):
110+
sampler.set_epoch(epoch) # distributed sampler requires set_epoch
142111
for idx, batch in enumerate(pbar := tqdm(dataloader)):
143-
max_steps_per_epoch = kwargs.get("max_steps_per_epoch", None)
144112
if max_steps_per_epoch is not None and idx == max_steps_per_epoch:
145113
break
146114
opt.zero_grad()
@@ -149,10 +117,7 @@ def recipe(kwargs):
149117
input_ids = input_ids.to(device)
150118
labels = labels.to(device)
151119

152-
# Note: context manager for autocast is only applied in forward pass.
153-
# see https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html#adding-torch-autocast
154-
# for more details.
155-
with autocast_mgr:
120+
with autocast:
156121
logits = model(input_ids)
157122
# Shift so that tokens < n predict n
158123
logits = logits[..., :-1, :].contiguous()
@@ -168,15 +133,11 @@ def recipe(kwargs):
168133
f"{epoch+1}|{idx+1}|Loss: {loss.item()}"
169134
) # TODO: add terminal logger
170135

171-
if grad_scaler:
172-
grad_scaler.scale(loss).backward()
173-
grad_scaler.step(opt)
174-
grad_scaler.update()
175-
else:
176-
loss.backward()
177-
opt.step()
136+
grad_scaler.scale(loss).backward()
137+
grad_scaler.step(opt)
138+
grad_scaler.update()
178139

179-
run_generation = kwargs.get("run_generation", None)
140+
# --- TODO TEMPORARY EVAL Code ---- #
180141
if run_generation and idx % run_generation == 0:
181142
# Log a sample generation for the instruction.
182143
# Just using a hardcoded prompt for now
@@ -189,16 +150,14 @@ def recipe(kwargs):
189150
generation_str, decoded_tokens = generate_from_prompt(
190151
prompt=prompt, tokenizer=tokenizer, decoder=model
191152
)
192-
if (
193-
not torch.distributed.is_initialized()
194-
or torch.distributed.get_rank() == 0
195-
):
196-
logger(f"Generation tokens: {decoded_tokens}")
197-
logger(f"Generation: {generation_str}")
198-
199-
# Save checkpoint at end of each epoch (to be changed later)
200-
os.makedirs(kwargs["output_dir"], exist_ok=True)
201-
output_loc = f"{kwargs['output_dir']}/model_{epoch}.ckpt"
153+
if rank == 0:
154+
logger.info(f"Generation tokens: {decoded_tokens}")
155+
logger.info(f"Generation: {generation_str}")
156+
# --- TODO TEMPORARY EVAL Code Ends ---- #
157+
158+
# ---- Save checkpoint at end of each epoch (to be changed later) ---- #
159+
os.makedirs(output_dir, exist_ok=True)
160+
output_loc = f"{output_dir}/model_{epoch}.ckpt"
202161
torch.save(
203162
{
204163
"epoch": epoch,
@@ -208,19 +167,19 @@ def recipe(kwargs):
208167
},
209168
output_loc,
210169
)
211-
logger(
170+
logger.info(
212171
msg=f"Model checkpoint of size {os.path.getsize(output_loc) >> 20}MB saved to {output_loc}"
213172
)
214173

215174

216175
if __name__ == "__main__":
217-
parser = TuneArgumentParser(description="Fine-tune an LLM.")
176+
parser = utils.TuneArgumentParser(description="Fine-tune an LLM.")
218177

219178
# Dataset and DataLoader arguments
220179
parser.add_argument(
221180
"--dataset",
222181
type=str,
223-
choices=list_datasets(),
182+
choices=datasets.list_datasets(),
224183
help="Dataset name.",
225184
)
226185
parser.add_argument(
@@ -238,7 +197,7 @@ def recipe(kwargs):
238197
parser.add_argument(
239198
"--model",
240199
type=str,
241-
choices=list_models(),
200+
choices=models.list_models(),
242201
help="Model to finetune.",
243202
)
244203
parser.add_argument(
@@ -249,7 +208,7 @@ def recipe(kwargs):
249208
parser.add_argument(
250209
"--tokenizer",
251210
type=str,
252-
choices=list_tokenizers(),
211+
choices=models.list_tokenizers(),
253212
help="Model tokenizer.",
254213
)
255214
parser.add_argument(
@@ -318,14 +277,12 @@ def recipe(kwargs):
318277
help="Max number of steps per epoch for faster dev/testing. Default is to finetune through the full dataset.",
319278
)
320279
parser.add_argument(
321-
"--autocast-precision",
280+
"--dtype",
322281
type=str,
323-
choices=get_supported_dtypes(),
282+
choices=utils.list_dtypes(),
324283
default=None,
325-
help=f"""Low precision used for CUDA automatic mixed precision.
326-
If specified, must be one of {get_supported_dtypes()}.
327-
""",
284+
help="Tensor dtype used for finetuning, lower precision types result in mixed precision training.",
328285
)
329286

330287
kwargs = vars(parser.parse_args())
331-
sys.exit(recipe(kwargs))
288+
recipe(**kwargs)

0 commit comments

Comments
 (0)