4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+
7
8
import os
8
- import sys
9
9
from functools import partial
10
- from typing import Callable
11
10
12
11
import torch
13
- from torch .distributed .algorithms ._checkpoint .checkpoint_wrapper import (
14
- apply_activation_checkpointing ,
15
- )
16
- from torch .distributed .fsdp import FullyShardedDataParallel as FSDP
17
- from torch .distributed .fsdp .wrap import ModuleWrapPolicy
18
- from torch .optim .optimizer import Optimizer
12
+ from torch .cuda .amp import GradScaler
19
13
from torch .utils .data import DataLoader , DistributedSampler
20
14
21
- from torchtune .datasets import get_dataset , list_datasets
22
- from torchtune .models import get_model , get_tokenizer , list_models , list_tokenizers
23
- from torchtune .modules import TransformerDecoderLayer
24
- from torchtune .utils import TuneArgumentParser
25
- from torchtune .utils .batch_pad_sequence import batch_pad_to_longest_seq
26
- from torchtune .utils .env import get_world_size_and_rank , init_from_env , seed
15
+ from torchtune import datasets , losses , models , modules , optim , utils
27
16
from torchtune .utils .generation import generate_from_prompt
28
- from torchtune .utils .precision import (
29
- get_autocast_manager ,
30
- get_grad_scaler ,
31
- get_supported_dtypes ,
32
- )
33
17
from tqdm import tqdm
34
18
35
19
36
- def get_optimizer (model : torch .nn .Module , optimizer : str , lr : float ) -> Optimizer :
37
- return getattr (torch .optim , optimizer )(model .parameters (), lr = lr )
38
-
39
-
40
- def get_loss (loss_fn : str ) -> Callable :
41
- return getattr (torch .nn , loss_fn )()
42
-
43
-
44
- def get_logger ():
45
- import logging
46
-
47
- logger = logging .getLogger (__name__ )
48
- logger .addHandler (logging .StreamHandler ())
49
- logger .setLevel (logging .DEBUG )
50
- return logger .info
51
-
52
-
53
- def recipe (kwargs ):
20
+ def recipe (
21
+ device ,
22
+ dtype ,
23
+ seed ,
24
+ model ,
25
+ model_checkpoint ,
26
+ tokenizer ,
27
+ tokenizer_checkpoint ,
28
+ dataset ,
29
+ shuffle ,
30
+ batch_size ,
31
+ fsdp ,
32
+ epochs ,
33
+ optimizer ,
34
+ loss ,
35
+ lr ,
36
+ activation_checkpointing ,
37
+ output_dir ,
38
+ run_generation ,
39
+ max_steps_per_epoch ,
40
+ ):
54
41
# ---- Initialize components ---- #
55
- logger = get_logger ()
56
-
57
- # ---- Initialize distributed process group ---- #
58
- device = init_from_env (device_type = kwargs ["device" ])
59
- # TODO: only supporting devices specified as "cpu", "cuda", or "cuda:n" currently
60
- device_type = (
61
- kwargs ["device" ]
62
- if kwargs ["device" ] in ("cpu" , "cuda" )
63
- else kwargs ["device" ].split (":" )[0 ]
64
- )
65
-
66
- # ---- Initialize seed ---- #
67
- # Fetch world size and rank after distributed process group initialization
68
- world_size , rank = get_world_size_and_rank ()
69
- if kwargs ["seed" ] is not None :
70
- # Ensure that seed is different per rank (and its dataloader workers)
71
- seed (kwargs ["seed" ] + rank )
72
-
73
- tokenizer = get_tokenizer (kwargs ["tokenizer" ], path = kwargs ["tokenizer_checkpoint" ])
74
- logger (msg = f"Loaded tokenizer from { kwargs ['tokenizer_checkpoint' ]} " )
75
-
76
- autocast_precision = kwargs .get ("autocast_precision" , None )
77
- autocast_mgr = get_autocast_manager (
78
- device_type = device_type , precision = autocast_precision
79
- )
80
- grad_scaler = get_grad_scaler (autocast_precision , fsdp = kwargs ["fsdp" ])
81
-
82
- model = get_model (
83
- kwargs ["model" ],
84
- device ,
85
- )
86
-
87
- if kwargs ["fsdp" ] or kwargs ["activation_checkpointing" ]:
88
- auto_wrap_policy = ModuleWrapPolicy (
89
- {TransformerDecoderLayer }
90
- ) # TODO: remove model specific components
91
- if kwargs ["fsdp" ]:
92
- model = FSDP (
93
- model ,
94
- auto_wrap_policy = auto_wrap_policy ,
95
- device_id = device ,
96
- param_init_fn = lambda m : m .to_empty (device = device , recurse = False ),
42
+ utils .init_distributed (fsdp )
43
+
44
+ # logger = logging.getLogger()
45
+ # logger.setLevel(logging.DEBUG) # test
46
+ logger = utils .get_logger ("DEBUG" )
47
+
48
+ device = utils .get_device (device )
49
+ dtype = utils .get_dtype (dtype )
50
+ seed = utils .set_seed (seed )
51
+
52
+ # ---- Setup model and load checkpoint ---- #
53
+ tokenizer = models .get_tokenizer (tokenizer , path = tokenizer_checkpoint )
54
+ logger .info (msg = f"Loaded tokenizer from { tokenizer_checkpoint } " )
55
+
56
+ model = models .get_model (model , device = device )
57
+ if fsdp :
58
+ # TODO: initialize models for distributed on meta or cpu device to avoid OOMs
59
+ model = utils .get_fsdp (
60
+ model = model ,
61
+ device = device ,
62
+ dtype = dtype ,
63
+ strategy = "FULL_SHARD" ,
64
+ auto_wrap_policy = {modules .TransformerDecoderLayer },
97
65
)
98
- if kwargs ["activation_checkpointing" ]:
99
- apply_activation_checkpointing (
100
- model ,
101
- check_fn = lambda mod : isinstance (
102
- mod , TransformerDecoderLayer
103
- ), # TODO: remove model specific components
104
- auto_wrap_policy = auto_wrap_policy ,
66
+ if activation_checkpointing :
67
+ utils .set_activation_checkpointing (
68
+ model , auto_wrap_policy = {modules .TransformerDecoderLayer }
105
69
)
106
70
107
- loaded_ckpt = torch .load (
108
- kwargs ["model_checkpoint" ], map_location = "cpu" , weights_only = True
109
- )
71
+ loaded_ckpt = torch .load (model_checkpoint , map_location = "cpu" , weights_only = True )
110
72
model .load_state_dict (loaded_ckpt )
111
- logger (msg = f"Loaded model from { kwargs [ ' model_checkpoint' ] } " )
73
+ logger . info (msg = f"Loaded model from { model_checkpoint } " )
112
74
113
- opt = get_optimizer (model , kwargs ["optimizer" ], kwargs ["lr" ])
75
+ # ---- Setup optimization functions ---- #
76
+ opt = optim .get_optimizer (optimizer , model , lr )
114
77
# TODO add lr schedule option
115
- loss_fn = get_loss (kwargs ["loss" ])
78
+ loss_fn = losses .get_loss (loss )
79
+
80
+ autocast = utils .get_autocast (dtype , device )
81
+ if dtype == torch .float16 :
82
+ grad_scaler = utils .get_gradient_scaler (fsdp = fsdp )
83
+ else :
84
+ grad_scaler = GradScaler (enabled = False )
116
85
117
86
# ---- Load dataset, set up sampler, and dataloader ---- #
118
- dataset = get_dataset (kwargs ["dataset" ], split = "train" , tokenizer = tokenizer )
87
+ world_size , rank = utils .get_world_size_and_rank ()
88
+ ds = datasets .get_dataset (dataset , split = "train" , tokenizer = tokenizer )
119
89
sampler = DistributedSampler (
120
- dataset ,
90
+ ds ,
121
91
num_replicas = world_size ,
122
92
rank = rank ,
123
- shuffle = kwargs [ " shuffle" ] ,
93
+ shuffle = shuffle ,
124
94
seed = 0 ,
125
95
)
126
96
dataloader = DataLoader (
127
- dataset = dataset ,
128
- batch_size = kwargs [ " batch_size" ] ,
97
+ dataset = ds ,
98
+ batch_size = batch_size ,
129
99
sampler = sampler ,
130
100
collate_fn = partial (
131
- batch_pad_to_longest_seq ,
132
- input_padding_idx = tokenizer .pad_id ,
133
- label_padding_idx = loss_fn .ignore_index , # TODO support loss without ignore_index
101
+ utils . padded_collate ,
102
+ padding_idx = tokenizer .pad_id ,
103
+ ignore_idx = loss_fn .ignore_index , # TODO support loss without ignore_index
134
104
),
135
105
)
136
- logger (msg = f"Loaded dataset { kwargs [ ' dataset' ] } " )
106
+ logger . info (msg = f"Loaded dataset { dataset } " )
137
107
138
108
# ---- Train loop ---- #
139
- for epoch in range (kwargs ["epochs" ]):
140
- # Need to set the epoch for changing sample ordering in each epoch
141
- sampler .set_epoch (epoch )
109
+ for epoch in range (epochs ):
110
+ sampler .set_epoch (epoch ) # distributed sampler requires set_epoch
142
111
for idx , batch in enumerate (pbar := tqdm (dataloader )):
143
- max_steps_per_epoch = kwargs .get ("max_steps_per_epoch" , None )
144
112
if max_steps_per_epoch is not None and idx == max_steps_per_epoch :
145
113
break
146
114
opt .zero_grad ()
@@ -149,10 +117,7 @@ def recipe(kwargs):
149
117
input_ids = input_ids .to (device )
150
118
labels = labels .to (device )
151
119
152
- # Note: context manager for autocast is only applied in forward pass.
153
- # see https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html#adding-torch-autocast
154
- # for more details.
155
- with autocast_mgr :
120
+ with autocast :
156
121
logits = model (input_ids )
157
122
# Shift so that tokens < n predict n
158
123
logits = logits [..., :- 1 , :].contiguous ()
@@ -168,15 +133,11 @@ def recipe(kwargs):
168
133
f"{ epoch + 1 } |{ idx + 1 } |Loss: { loss .item ()} "
169
134
) # TODO: add terminal logger
170
135
171
- if grad_scaler :
172
- grad_scaler .scale (loss ).backward ()
173
- grad_scaler .step (opt )
174
- grad_scaler .update ()
175
- else :
176
- loss .backward ()
177
- opt .step ()
136
+ grad_scaler .scale (loss ).backward ()
137
+ grad_scaler .step (opt )
138
+ grad_scaler .update ()
178
139
179
- run_generation = kwargs . get ( "run_generation" , None )
140
+ # --- TODO TEMPORARY EVAL Code ---- #
180
141
if run_generation and idx % run_generation == 0 :
181
142
# Log a sample generation for the instruction.
182
143
# Just using a hardcoded prompt for now
@@ -189,16 +150,14 @@ def recipe(kwargs):
189
150
generation_str , decoded_tokens = generate_from_prompt (
190
151
prompt = prompt , tokenizer = tokenizer , decoder = model
191
152
)
192
- if (
193
- not torch .distributed .is_initialized ()
194
- or torch .distributed .get_rank () == 0
195
- ):
196
- logger (f"Generation tokens: { decoded_tokens } " )
197
- logger (f"Generation: { generation_str } " )
198
-
199
- # Save checkpoint at end of each epoch (to be changed later)
200
- os .makedirs (kwargs ["output_dir" ], exist_ok = True )
201
- output_loc = f"{ kwargs ['output_dir' ]} /model_{ epoch } .ckpt"
153
+ if rank == 0 :
154
+ logger .info (f"Generation tokens: { decoded_tokens } " )
155
+ logger .info (f"Generation: { generation_str } " )
156
+ # --- TODO TEMPORARY EVAL Code Ends ---- #
157
+
158
+ # ---- Save checkpoint at end of each epoch (to be changed later) ---- #
159
+ os .makedirs (output_dir , exist_ok = True )
160
+ output_loc = f"{ output_dir } /model_{ epoch } .ckpt"
202
161
torch .save (
203
162
{
204
163
"epoch" : epoch ,
@@ -208,19 +167,19 @@ def recipe(kwargs):
208
167
},
209
168
output_loc ,
210
169
)
211
- logger (
170
+ logger . info (
212
171
msg = f"Model checkpoint of size { os .path .getsize (output_loc ) >> 20 } MB saved to { output_loc } "
213
172
)
214
173
215
174
216
175
if __name__ == "__main__" :
217
- parser = TuneArgumentParser (description = "Fine-tune an LLM." )
176
+ parser = utils . TuneArgumentParser (description = "Fine-tune an LLM." )
218
177
219
178
# Dataset and DataLoader arguments
220
179
parser .add_argument (
221
180
"--dataset" ,
222
181
type = str ,
223
- choices = list_datasets (),
182
+ choices = datasets . list_datasets (),
224
183
help = "Dataset name." ,
225
184
)
226
185
parser .add_argument (
@@ -238,7 +197,7 @@ def recipe(kwargs):
238
197
parser .add_argument (
239
198
"--model" ,
240
199
type = str ,
241
- choices = list_models (),
200
+ choices = models . list_models (),
242
201
help = "Model to finetune." ,
243
202
)
244
203
parser .add_argument (
@@ -249,7 +208,7 @@ def recipe(kwargs):
249
208
parser .add_argument (
250
209
"--tokenizer" ,
251
210
type = str ,
252
- choices = list_tokenizers (),
211
+ choices = models . list_tokenizers (),
253
212
help = "Model tokenizer." ,
254
213
)
255
214
parser .add_argument (
@@ -318,14 +277,12 @@ def recipe(kwargs):
318
277
help = "Max number of steps per epoch for faster dev/testing. Default is to finetune through the full dataset." ,
319
278
)
320
279
parser .add_argument (
321
- "--autocast-precision " ,
280
+ "--dtype " ,
322
281
type = str ,
323
- choices = get_supported_dtypes (),
282
+ choices = utils . list_dtypes (),
324
283
default = None ,
325
- help = f"""Low precision used for CUDA automatic mixed precision.
326
- If specified, must be one of { get_supported_dtypes ()} .
327
- """ ,
284
+ help = "Tensor dtype used for finetuning, lower precision types result in mixed precision training." ,
328
285
)
329
286
330
287
kwargs = vars (parser .parse_args ())
331
- sys . exit ( recipe (kwargs ) )
288
+ recipe (** kwargs )
0 commit comments