77
88import random
99import warnings
10- from typing import Optional , Any
10+ from typing import Any , Dict , Optional , Union
1111
1212import fire
1313import numpy as np
1717import torch .optim as optim
1818import torch .utils .data
1919from peft import PeftModel , get_peft_model
20- from dataclasses import fields
2120from torch .optim .lr_scheduler import StepLR
22- from transformers import AutoModelForCausalLM , AutoTokenizer
21+ from transformers import AutoModel , AutoModelForCausalLM , AutoTokenizer
2322
24- from QEfficient .finetune .configs .peft_config import LoraConfig
2523from QEfficient .finetune .configs .training import TrainConfig
2624from QEfficient .finetune .utils .config_utils import (
2725 generate_dataset_config ,
2826 generate_peft_config ,
2927 get_dataloader_kwargs ,
30- load_config_file ,
3128 update_config ,
32- validate_config ,
3329)
3430from QEfficient .finetune .utils .dataset_utils import (
3531 get_custom_data_collator ,
4541 print (f"Warning: { e } . Proceeding without QAIC modules." )
4642
4743
48- from transformers import AutoModelForCausalLM , AutoModelForSequenceClassification , AutoTokenizer
44+ from transformers import AutoModelForSequenceClassification
4945
5046# Suppress all warnings
5147warnings .filterwarnings ("ignore" )
@@ -91,56 +87,103 @@ def setup_seeds(seed: int) -> None:
9187 np .random .seed (seed )
9288
9389
94- def load_model_and_tokenizer (config : TrainConfig ) -> tuple [AutoModelForCausalLM , AutoTokenizer ]:
90+ def load_model_and_tokenizer (
91+ train_config : TrainConfig , dataset_config : Any , peft_config_file : str , ** kwargs
92+ ) -> tuple [AutoModelForCausalLM , AutoTokenizer ]:
9593 """Load the pre-trained model and tokenizer from Hugging Face.
9694
9795 Args:
9896 config (TrainConfig): Training configuration object containing model and tokenizer names.
97+ dataset_config (Any): A dataclass object representing dataset configuration.
98+ peft_config_file (str): Path to PEFT config file used for PEFT finetuning.
99+ kwargs: Additional arguments to override PEFT config.
99100
100101 Returns:
101- tuple: A tuple containing the loaded model (AutoModelForCausalLM) and tokenizer (AutoTokenizer).
102+ tuple: A tuple of two values.
103+ - Model with pretrained weights loaded.
104+ - Model's tokenizer (AutoTokenizer).
102105
103106 Notes:
104107 - Downloads the model if not already cached using login_and_download_hf_lm.
105108 - Configures the model with FP16 precision and disables caching for training.
106109 - Resizes model embeddings if tokenizer vocab size exceeds model embedding size.
107110 - Sets pad_token_id to eos_token_id if not defined in the tokenizer.
108111 """
109- pretrained_model_path = login_and_download_hf_lm (config .model_name )
110- model = AutoModelForCausalLM .from_pretrained (
111- pretrained_model_path ,
112- use_cache = False ,
113- attn_implementation = "sdpa" ,
114- torch_dtype = torch .float16 ,
115- )
112+ pretrained_model_path = login_and_download_hf_lm (train_config .model_name )
113+ if train_config .task_type == "seq_classification" :
114+ model = AutoModelForSequenceClassification .from_pretrained (
115+ pretrained_model_path ,
116+ num_labels = dataset_config .num_labels ,
117+ attn_implementation = "sdpa" ,
118+ torch_dtype = torch .float16 ,
119+ )
120+
121+ if not hasattr (model , "base_model_prefix" ):
122+ raise RuntimeError ("Given huggingface model does not have 'base_model_prefix' attribute." )
123+
124+ for param in getattr (model , model .base_model_prefix ).parameters ():
125+ param .requires_grad = False
126+
127+ for param in model .parameters ():
128+ if param .requires_grad :
129+ param .data = param .data .to (torch .float32 )
130+ else :
131+ model = AutoModelForCausalLM .from_pretrained (
132+ pretrained_model_path ,
133+ use_cache = False ,
134+ attn_implementation = "sdpa" ,
135+ torch_dtype = torch .float16 ,
136+ )
116137
117138 tokenizer = AutoTokenizer .from_pretrained (
118- config .model_name if config .tokenizer_name is None else config .tokenizer_name
139+ train_config .model_name if train_config .tokenizer_name is None else train_config .tokenizer_name
119140 )
120141 if not tokenizer .pad_token_id :
121142 tokenizer .pad_token_id = tokenizer .eos_token_id
122143
144+ # If there is a mismatch between tokenizer vocab size and embedding matrix,
145+ # throw a warning and then expand the embedding matrix
123146 if len (tokenizer ) > model .get_input_embeddings ().weight .shape [0 ]:
124147 print ("WARNING: Resizing embedding matrix to match tokenizer vocab size." )
125148 model .resize_token_embeddings (len (tokenizer ))
126149
150+ # FIXME (Meet): Cover below line inside the logger once it is implemented.
151+ print_model_size (model , train_config )
152+
127153 # Note: Need to call this before calling PeftModel.from_pretrained or get_peft_model.
128154 # Because, both makes model.is_gradient_checkpointing = True which is used in peft library to
129155 # apply gradient checkpointing related hooks to the input embeddings. Without this we will get
130156 # "No inf checks were recorded for this optimizer." error.
131157 # Enable gradient checkpointing
132- if config .gradient_checkpointing :
158+ if train_config .gradient_checkpointing :
133159 # Note: below attribute and method is only available in HuggingFace Transformer models.
134160 if hasattr (model , "supports_gradient_checkpointing" ) and model .supports_gradient_checkpointing :
135161 model .gradient_checkpointing_enable (gradient_checkpointing_kwargs = {"preserve_rng_state" : False })
136162 else :
137163 raise RuntimeError ("Given model doesn't support gradient checkpointing. Please disable it and run it." )
138-
164+
165+ model = apply_peft (model , train_config , peft_config_file , ** kwargs )
166+
139167 return model , tokenizer
140168
141169
142- def apply_peft (model : AutoModelForCausalLM , train_config : TrainConfig , lora_config : LoraConfig ) -> PeftModel :
143- """Apply Parameter-Efficient Fine-Tuning (PEFT) to the model if enabled."""
170+ def apply_peft (
171+ model : AutoModel , train_config : TrainConfig , peft_config_file : Dict , ** kwargs
172+ ) -> Union [AutoModel , PeftModel ]:
173+ """Apply Parameter-Efficient Fine-Tuning (PEFT) to the model if enabled.
174+
175+ Args:
176+ model (AutoModel): Huggingface model.
177+ train_config (TrainConfig): Training configuration object.
178+ peft_config_file (str, optional): Path to YAML/JSON file containing
179+ PEFT (LoRA) config. Defaults to None.
180+ kwargs: Additional arguments to override PEFT config params.
181+
182+ Returns:
183+ Union[AutoModel, PeftModel]: If the use_peft in train_config is True
184+ then PeftModel object is returned else original model object
185+ (AutoModel) is returned.
186+ """
144187 if not train_config .use_peft :
145188 return model
146189
@@ -150,27 +193,31 @@ def apply_peft(model: AutoModelForCausalLM, train_config: TrainConfig, lora_conf
150193 peft_config = model .peft_config
151194 # Generate the peft config and start fine-tuning from original model
152195 else :
153- peft_config = generate_peft_config (train_config , lora_config )
196+ peft_config = generate_peft_config (train_config , peft_config_file , ** kwargs )
154197 model = get_peft_model (model , peft_config )
155198 model .print_trainable_parameters ()
156199
157200 return model
158201
159202
160203def setup_dataloaders (
161- train_config : TrainConfig , dataset_config , tokenizer : AutoTokenizer , dataset_train , dataset_val
162- ) -> tuple [torch .utils .data .DataLoader , Optional [torch .utils .data .DataLoader ]]:
204+ train_config : TrainConfig ,
205+ dataset_config : Any ,
206+ tokenizer : AutoTokenizer ,
207+ ) -> tuple [torch .utils .data .DataLoader , Optional [torch .utils .data .DataLoader ], int ]:
163208 """Set up training and validation DataLoaders.
164209
165210 Args:
166211 train_config (TrainConfig): Training configuration object.
167- dataset_config: Configuration for the dataset (generated from train_config).
212+ dataset_config (Any) : Configuration for the dataset (generated from train_config).
168213 tokenizer (AutoTokenizer): Tokenizer for preprocessing data.
169- dataset_train: Preprocessed training dataset.
170- dataset_val: Preprocessed validation dataset.
171214
172215 Returns:
173- tuple: A tuple of (train_dataloader, eval_dataloader), where eval_dataloader is None if validation is disabled.
216+ tuple: A tuple of three values.
217+ - First value represents train_dataloader
218+ - Second value represents eval_dataloader. It is None if
219+ validation is disabled.
220+ - Length of longest sequence in the dataset.
174221
175222 Raises:
176223 ValueError: If validation is enabled but the validation set is too small.
@@ -179,11 +226,33 @@ def setup_dataloaders(
179226 - Applies a custom data collator if provided by get_custom_data_collator.
180227 - Configures DataLoader kwargs using get_dataloader_kwargs for train and val splits.
181228 """
182- custom_data_collator = get_custom_data_collator (tokenizer , dataset_config )
183- train_dl_kwargs = get_dataloader_kwargs (train_config , dataset_train , tokenizer , "train" )
229+ # Get the dataset utils
230+ dataset_processer = tokenizer
231+
232+ # Load and preprocess the dataset for training and validation
233+ dataset_train = get_preprocessed_dataset (
234+ dataset_processer , dataset_config , split = "train" , context_length = train_config .context_length
235+ )
236+
237+ dataset_val = get_preprocessed_dataset (
238+ dataset_processer , dataset_config , split = "test" , context_length = train_config .context_length
239+ )
240+
241+ # TODO: vbaddi, check if its necessary to do this?
242+ # dataset_train = ConcatDataset(
243+ # dataset_train, chunk_size=train_config.context_length
244+ # )
245+ ##
246+ train_dl_kwargs = get_dataloader_kwargs (train_config , dataset_train , dataset_processer , "train" )
247+ print ("length of dataset_train" , len (dataset_train ))
248+
249+ # FIXME (Meet): Add custom data collator registration from the outside by the user.
250+ custom_data_collator = get_custom_data_collator (dataset_processer , dataset_config )
184251 if custom_data_collator :
252+ print ("custom_data_collator is used" )
185253 train_dl_kwargs ["collate_fn" ] = custom_data_collator
186254
255+ # Create DataLoaders for the training and validation dataset
187256 train_dataloader = torch .utils .data .DataLoader (
188257 dataset_train ,
189258 num_workers = train_config .num_workers_dataloader ,
@@ -194,7 +263,12 @@ def setup_dataloaders(
194263
195264 eval_dataloader = None
196265 if train_config .run_validation :
197- val_dl_kwargs = get_dataloader_kwargs (train_config , dataset_val , tokenizer , "val" )
266+ # if train_config.batching_strategy == "packing":
267+ # dataset_val = ConcatDataset(
268+ # dataset_val, chunk_size=train_config.context_length
269+ # )
270+
271+ val_dl_kwargs = get_dataloader_kwargs (train_config , dataset_val , dataset_processer , "val" )
198272 if custom_data_collator :
199273 val_dl_kwargs ["collate_fn" ] = custom_data_collator
200274
@@ -204,31 +278,29 @@ def setup_dataloaders(
204278 pin_memory = True ,
205279 ** val_dl_kwargs ,
206280 )
207- print (f"--> Num of Validation Set Batches loaded = { len (eval_dataloader )} " )
208281 if len (eval_dataloader ) == 0 :
209- raise ValueError ("Eval set too small to load even one batch." )
282+ raise ValueError (
283+ f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({ len (eval_dataloader )= } )"
284+ )
285+ else :
286+ print (f"--> Num of Validation Set Batches loaded = { len (eval_dataloader )} " )
210287
211- return train_dataloader , eval_dataloader
288+ longest_seq_length , _ = get_longest_seq_length (
289+ torch .utils .data .ConcatDataset ([train_dataloader .dataset , eval_dataloader .dataset ])
290+ )
291+ else :
292+ longest_seq_length , _ = get_longest_seq_length (train_dataloader .dataset )
293+
294+ return train_dataloader , eval_dataloader , longest_seq_length
212295
213296
214- def main (
215- model_name : str = None ,
216- tokenizer_name : str = None ,
217- batch_size_training : int = None ,
218- lr : float = None ,
219- peft_config_file : str = None ,
220- ** kwargs ,
221- ) -> None :
297+ def main (peft_config_file = None , ** kwargs ) -> None :
222298 """
223299 Fine-tune a model on QAIC hardware with configurable training and LoRA parameters.
224300
225301 Args:
226- model_name (str, optional): Override default model name.
227- tokenizer_name (str, optional): Override default tokenizer name.
228- batch_size_training (int, optional): Override default training batch size.
229- lr (float, optional): Override default learning rate.
230- peft_config_file (str, optional): Path to YAML/JSON file containing PEFT (LoRA) config.
231- **kwargs: Additional arguments to override TrainConfig.
302+ peft_config_file (str, optional): Path to YAML/JSON file containing PEFT (LoRA) config. Defaults to None.
303+ kwargs: Additional arguments to override TrainConfig.
232304
233305 Example:
234306 .. code-block:: bash
@@ -245,64 +317,36 @@ def main(
245317 --lr 5e-4
246318 """
247319 train_config = TrainConfig ()
248- # local_args = {k: v for k, v in locals().items() if v is not None and k != "peft_config_file" and k != "kwargs"}
249320 update_config (train_config , ** kwargs )
250-
251- lora_config = LoraConfig ()
252- if peft_config_file :
253- peft_config_data = load_config_file (peft_config_file )
254- validate_config (peft_config_data , config_type = "lora" )
255- lora_config = LoraConfig (** peft_config_data )
256- else :
257- lora_config = LoraConfig ()
258-
259- update_config (lora_config , ** kwargs )
321+ dataset_config = generate_dataset_config (train_config .dataset )
322+ update_config (dataset_config , ** kwargs )
260323
261324 setup_distributed_training (train_config )
262325 setup_seeds (train_config .seed )
263- model , tokenizer = load_model_and_tokenizer (train_config )
264- print_model_size (model , train_config )
265- model = apply_peft (model , train_config , lora_config )
326+ model , tokenizer = load_model_and_tokenizer (train_config , dataset_config , peft_config_file , ** kwargs )
266327
267- # Pass an empty dict instead of kwargs to avoid irrelevant parameters
268- dataset_config = generate_dataset_config (train_config , kwargs )
269- dataset_train = get_preprocessed_dataset (
270- tokenizer , dataset_config , split = "train" , context_length = train_config .context_length
271- )
272- dataset_val = get_preprocessed_dataset (
273- tokenizer , dataset_config , split = "test" , context_length = train_config .context_length
274- )
275- train_dataloader , eval_dataloader = setup_dataloaders (
276- train_config , dataset_config , tokenizer , dataset_train , dataset_val
277- )
278- dataset_for_seq_length = (
279- torch .utils .data .ConcatDataset ([train_dataloader .dataset , eval_dataloader .dataset ])
280- if train_config .run_validation
281- else train_dataloader .dataset
282- )
283- longest_seq_length , _ = get_longest_seq_length (dataset_for_seq_length )
328+ # Create DataLoaders for the training and validation dataset
329+ train_dataloader , eval_dataloader , longest_seq_length = setup_dataloaders (train_config , dataset_config , tokenizer )
284330 print (
285- f"Longest sequence length: { longest_seq_length } , "
286- f"Context length: { train_config .context_length } , "
287- f"Model max context: { model .config .max_position_embeddings } "
331+ f"The longest sequence length in the train data is { longest_seq_length } , "
332+ f"passed context length is { train_config .context_length } and overall model's context length is "
333+ f"{ model .config .max_position_embeddings } "
288334 )
335+
289336 model .to (train_config .device )
290337 optimizer = optim .AdamW (model .parameters (), lr = train_config .lr , weight_decay = train_config .weight_decay )
291338 scheduler = StepLR (optimizer , step_size = 1 , gamma = train_config .gamma )
292339 if train_config .enable_ddp :
293340 model = nn .parallel .DistributedDataParallel (model , device_ids = [dist .get_rank ()])
294341 results = train (
295342 model ,
343+ tokenizer ,
296344 train_dataloader ,
297345 eval_dataloader ,
298- tokenizer ,
299346 optimizer ,
300347 scheduler ,
301- train_config .gradient_accumulation_steps ,
302348 train_config ,
303- train_config .device ,
304349 dist .get_rank () if train_config .enable_ddp else None ,
305- None ,
306350 )
307351 if train_config .enable_ddp :
308352 dist .destroy_process_group ()
0 commit comments