77
88import random
99import warnings
10+ from typing import Optional
1011
1112import fire
1213import numpy as np
1718import torch .utils .data
1819from peft import PeftModel , get_peft_model
1920from torch .optim .lr_scheduler import StepLR
21+ from transformers import AutoModelForCausalLM , AutoTokenizer
2022
21- from QEfficient .finetune .configs .training import train_config as TRAIN_CONFIG
23+ from QEfficient .finetune .configs .peft_config import LoraConfig
24+ from QEfficient .finetune .configs .training import TrainConfig
2225from QEfficient .finetune .utils .config_utils import (
2326 generate_dataset_config ,
2427 generate_peft_config ,
2528 get_dataloader_kwargs ,
29+ load_config_file ,
2630 update_config ,
31+ validate_config ,
2732)
2833from QEfficient .finetune .utils .dataset_utils import (
2934 get_custom_data_collator ,
3237from QEfficient .finetune .utils .train_utils import get_longest_seq_length , print_model_size , train
3338from QEfficient .utils ._utils import login_and_download_hf_lm
3439
40+ # Try importing QAIC-specific module, proceed without it if unavailable
3541try :
3642 import torch_qaic # noqa: F401
3743except ImportError as e :
38- print (f"Warning: { e } . Moving ahead without these qaic modules." )
44+ print (f"Warning: { e } . Proceeding without QAIC modules." )
3945
46+ # Suppress all warnings for cleaner output
47+ warnings .filterwarnings ("ignore" )
4048
41- from transformers import AutoModelForCausalLM , AutoTokenizer
4249
43- # Suppress all warnings
44- warnings . filterwarnings ( "ignore" )
50+ def setup_distributed_training ( config : TrainConfig ) -> None :
51+ """Initialize distributed training environment if enabled.
4552
53+ Args:
54+ config (TrainConfig): Training configuration object.
4655
47- def main (** kwargs ):
56+ Notes:
57+ - If distributed data parallel (DDP) is disabled, this function does nothing.
58+ - Ensures the device is not CPU and does not specify an index for DDP compatibility.
59+ - Initializes the process group using the specified distributed backend.
60+
61+ Raises:
62+ AssertionError: If device is CPU or includes an index with DDP enabled.
4863 """
49- Helper function to finetune the model on QAic.
64+ if not config .enable_ddp :
65+ return
66+
67+ torch_device = torch .device (config .device )
68+ assert torch_device .type != "cpu" , "Host doesn't support single-node DDP"
69+ assert torch_device .index is None , f"DDP requires only device type, got: { torch_device } "
70+
71+ dist .init_process_group (backend = config .dist_backend )
72+ # from here onward "qaic/cuda" will automatically map to "qaic:i/cuda:i", where i = process rank
73+ getattr (torch , torch_device .type ).set_device (dist .get_rank ())
5074
51- .. code-block:: bash
5275
53- python -m QEfficient.cloud.finetune OPTIONS
76+ def setup_seeds (seed : int ) -> None :
77+ """Set random seeds across libraries for reproducibility.
5478
79+ Args:
80+ seed (int): Seed value to set for random number generators.
81+
82+ Notes:
83+ - Sets seeds for PyTorch, Python's random module, and NumPy.
5584 """
56- # update the configuration for the training process
57- train_config = TRAIN_CONFIG ()
58- update_config (train_config , ** kwargs )
59- device = train_config .device
85+ torch .manual_seed (seed )
86+ random .seed (seed )
87+ np .random .seed (seed )
6088
61- # dist init
62- if train_config .enable_ddp :
63- # TODO: may have to init qccl backend, next try run with torchrun command
64- torch_device = torch .device (device )
65- assert torch_device .type != "cpu" , "Host doesn't support single-node DDP"
66- assert torch_device .index is None , (
67- f"DDP requires specification of device type only, however provided device index as well: { torch_device } "
68- )
69- dist .init_process_group (backend = train_config .dist_backend )
70- # from here onward "qaic/cuda" will automatically map to "qaic:i/cuda:i", where i = process rank
71- getattr (torch , torch_device .type ).set_device (dist .get_rank ())
72-
73- # Set the seeds for reproducibility
74- torch .manual_seed (train_config .seed )
75- random .seed (train_config .seed )
76- np .random .seed (train_config .seed )
77-
78- # Load the pre-trained model and setup its configuration
79- # config = AutoConfig.from_pretrained(train_config.model_name)
80- pretrained_model_path = login_and_download_hf_lm (train_config .model_name )
89+
90+ def load_model_and_tokenizer (config : TrainConfig ) -> tuple [AutoModelForCausalLM , AutoTokenizer ]:
91+ """Load the pre-trained model and tokenizer from Hugging Face.
92+
93+ Args:
94+ config (TrainConfig): Training configuration object containing model and tokenizer names.
95+
96+ Returns:
97+ tuple: A tuple containing the loaded model (AutoModelForCausalLM) and tokenizer (AutoTokenizer).
98+
99+ Notes:
100+ - Downloads the model if not already cached using login_and_download_hf_lm.
101+ - Configures the model with FP16 precision and disables caching for training.
102+ - Resizes model embeddings if tokenizer vocab size exceeds model embedding size.
103+ - Sets pad_token_id to eos_token_id if not defined in the tokenizer.
104+ """
105+ pretrained_model_path = login_and_download_hf_lm (config .model_name )
81106 model = AutoModelForCausalLM .from_pretrained (
82107 pretrained_model_path ,
83108 use_cache = False ,
84109 attn_implementation = "sdpa" ,
85110 torch_dtype = torch .float16 ,
86111 )
87112
88- # Load the tokenizer and add special tokens
89113 tokenizer = AutoTokenizer .from_pretrained (
90- train_config .model_name if train_config .tokenizer_name is None else train_config .tokenizer_name
114+ config .model_name if config .tokenizer_name is None else config .tokenizer_name
91115 )
92116 if not tokenizer .pad_token_id :
93117 tokenizer .pad_token_id = tokenizer .eos_token_id
94118
95- # If there is a mismatch between tokenizer vocab size and embedding matrix,
96- # throw a warning and then expand the embedding matrix
97119 if len (tokenizer ) > model .get_input_embeddings ().weight .shape [0 ]:
98- print ("WARNING: Resizing the embedding matrix to match the tokenizer vocab size." )
120+ print ("WARNING: Resizing embedding matrix to match tokenizer vocab size." )
99121 model .resize_token_embeddings (len (tokenizer ))
100122
101- print_model_size ( model , train_config )
123+ return model , tokenizer
102124
103- # print the datatype of the model parameters
104- # print(get_parameter_dtypes(model))
105-
106- if train_config .use_peft :
107- # Load the pre-trained peft model checkpoint and setup its configuration
108- if train_config .from_peft_checkpoint :
109- model = PeftModel .from_pretrained (model , train_config .from_peft_checkpoint , is_trainable = True )
110- peft_config = model .peft_config
111- # Generate the peft config and start fine-tuning from original model
112- else :
113- peft_config = generate_peft_config (train_config , kwargs )
114- model = get_peft_model (model , peft_config )
115- model .print_trainable_parameters ()
116-
117- # Get the dataset utils
118- dataset_config = generate_dataset_config (train_config , kwargs )
119- dataset_processer = tokenizer
120125
121- # Load and preprocess the dataset for training and validation
122- dataset_train = get_preprocessed_dataset (
123- dataset_processer , dataset_config , split = "train" , context_length = train_config .context_length
124- )
126+ def apply_peft ( model : AutoModelForCausalLM , train_config : TrainConfig , lora_config : LoraConfig ) -> PeftModel :
127+ """Apply Parameter-Efficient Fine-Tuning (PEFT) to the model if enabled."""
128+ if not train_config .use_peft :
129+ return model
125130
126- dataset_val = get_preprocessed_dataset (
127- dataset_processer , dataset_config , split = "test" , context_length = train_config .context_length
128- )
131+ if train_config .from_peft_checkpoint :
132+ return PeftModel .from_pretrained (model , train_config .from_peft_checkpoint , is_trainable = True )
133+
134+ # Generate PEFT-compatible config from custom LoraConfig
135+ peft_config = generate_peft_config (train_config , lora_config )
136+ model = get_peft_model (model , peft_config )
137+ model .print_trainable_parameters ()
138+ return model
139+
140+
141+ def setup_dataloaders (
142+ train_config : TrainConfig , dataset_config , tokenizer : AutoTokenizer , dataset_train , dataset_val
143+ ) -> tuple [torch .utils .data .DataLoader , Optional [torch .utils .data .DataLoader ]]:
144+ """Set up training and validation DataLoaders.
145+
146+ Args:
147+ train_config (TrainConfig): Training configuration object.
148+ dataset_config: Configuration for the dataset (generated from train_config).
149+ tokenizer (AutoTokenizer): Tokenizer for preprocessing data.
150+ dataset_train: Preprocessed training dataset.
151+ dataset_val: Preprocessed validation dataset.
129152
130- # TODO: vbaddi, check if its necessary to do this?
131- # dataset_train = ConcatDataset(
132- # dataset_train, chunk_size=train_config.context_length
133- # )
134- ##
135- train_dl_kwargs = get_dataloader_kwargs (train_config , dataset_train , dataset_processer , "train" )
136- print ("length of dataset_train" , len (dataset_train ))
137- custom_data_collator = get_custom_data_collator (dataset_processer , dataset_config )
153+ Returns:
154+ tuple: A tuple of (train_dataloader, eval_dataloader), where eval_dataloader is None if validation is disabled.
155+
156+ Raises:
157+ ValueError: If validation is enabled but the validation set is too small.
158+
159+ Notes:
160+ - Applies a custom data collator if provided by get_custom_data_collator.
161+ - Configures DataLoader kwargs using get_dataloader_kwargs for train and val splits.
162+ """
163+ custom_data_collator = get_custom_data_collator (tokenizer , dataset_config )
164+ train_dl_kwargs = get_dataloader_kwargs (train_config , dataset_train , tokenizer , "train" )
138165 if custom_data_collator :
139- print ("custom_data_collator is used" )
140166 train_dl_kwargs ["collate_fn" ] = custom_data_collator
141167
142- # Create DataLoaders for the training and validation dataset
143168 train_dataloader = torch .utils .data .DataLoader (
144169 dataset_train ,
145170 num_workers = train_config .num_workers_dataloader ,
@@ -150,12 +175,7 @@ def main(**kwargs):
150175
151176 eval_dataloader = None
152177 if train_config .run_validation :
153- # if train_config.batching_strategy == "packing":
154- # dataset_val = ConcatDataset(
155- # dataset_val, chunk_size=train_config.context_length
156- # )
157-
158- val_dl_kwargs = get_dataloader_kwargs (train_config , dataset_val , dataset_processer , "val" )
178+ val_dl_kwargs = get_dataloader_kwargs (train_config , dataset_val , tokenizer , "val" )
159179 if custom_data_collator :
160180 val_dl_kwargs ["collate_fn" ] = custom_data_collator
161181
@@ -165,37 +185,90 @@ def main(**kwargs):
165185 pin_memory = True ,
166186 ** val_dl_kwargs ,
167187 )
188+ print (f"--> Num of Validation Set Batches loaded = { len (eval_dataloader )} " )
168189 if len (eval_dataloader ) == 0 :
169- raise ValueError (
170- f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({ len (eval_dataloader )= } )"
171- )
172- else :
173- print (f"--> Num of Validation Set Batches loaded = { len (eval_dataloader )} " )
174-
175- longest_seq_length , _ = get_longest_seq_length (
176- torch .utils .data .ConcatDataset ([train_dataloader .dataset , eval_dataloader .dataset ])
177- )
178- else :
179- longest_seq_length , _ = get_longest_seq_length (train_dataloader .dataset )
190+ raise ValueError ("Eval set too small to load even one batch." )
191+
192+ return train_dataloader , eval_dataloader
193+
180194
195+ def main (
196+ model_name : str = None ,
197+ tokenizer_name : str = None ,
198+ batch_size_training : int = None ,
199+ lr : float = None ,
200+ peft_config_file : str = None ,
201+ ** kwargs ,
202+ ) -> None :
203+ """
204+ Fine-tune a model on QAIC hardware with configurable training and LoRA parameters.
205+
206+ Args:
207+ model_name (str, optional): Override default model name.
208+ tokenizer_name (str, optional): Override default tokenizer name.
209+ batch_size_training (int, optional): Override default training batch size.
210+ lr (float, optional): Override default learning rate.
211+ peft_config_file (str, optional): Path to YAML/JSON file containing PEFT (LoRA) config.
212+ **kwargs: Additional arguments to override TrainConfig.
213+
214+ Example:
215+ .. code-block:: bash
216+
217+ # Using a YAML config file for PEFT
218+ python -m QEfficient.cloud.finetune \\
219+ --model_name "meta-llama/Llama-3.2-1B" \\
220+ --lr 5e-4 \\
221+ --peft_config_file "lora_config.yaml"
222+
223+ # Using default LoRA config
224+ python -m QEfficient.cloud.finetune \\
225+ --model_name "meta-llama/Llama-3.2-1B" \\
226+ --lr 5e-4
227+ """
228+ train_config = TrainConfig ()
229+ # local_args = {k: v for k, v in locals().items() if v is not None and k != "peft_config_file" and k != "kwargs"}
230+ update_config (train_config , ** kwargs )
231+
232+ lora_config = LoraConfig ()
233+ if peft_config_file :
234+ peft_config_data = load_config_file (peft_config_file )
235+ validate_config (peft_config_data , config_type = "lora" )
236+ lora_config = LoraConfig (** peft_config_data )
237+
238+ setup_distributed_training (train_config )
239+ setup_seeds (train_config .seed )
240+ model , tokenizer = load_model_and_tokenizer (train_config )
241+ print_model_size (model , train_config )
242+ model = apply_peft (model , train_config , lora_config )
243+
244+ # Pass an empty dict instead of kwargs to avoid irrelevant parameters
245+ dataset_config = generate_dataset_config (train_config , kwargs )
246+ dataset_train = get_preprocessed_dataset (
247+ tokenizer , dataset_config , split = "train" , context_length = train_config .context_length
248+ )
249+ dataset_val = get_preprocessed_dataset (
250+ tokenizer , dataset_config , split = "test" , context_length = train_config .context_length
251+ )
252+ train_dataloader , eval_dataloader = setup_dataloaders (
253+ train_config , dataset_config , tokenizer , dataset_train , dataset_val
254+ )
255+ dataset_for_seq_length = (
256+ torch .utils .data .ConcatDataset ([train_dataloader .dataset , eval_dataloader .dataset ])
257+ if train_config .run_validation
258+ else train_dataloader .dataset
259+ )
260+ longest_seq_length , _ = get_longest_seq_length (dataset_for_seq_length )
181261 print (
182- f"The longest sequence length in the train data is { longest_seq_length } , "
183- f"passed context length is { train_config .context_length } and overall model's context length is "
184- f"{ model .config .max_position_embeddings } "
262+ f"Longest sequence length: { longest_seq_length } , "
263+ f"Context length: { train_config .context_length } , "
264+ f"Model max context: { model .config .max_position_embeddings } "
185265 )
186266 model .to (train_config .device )
187- optimizer = optim .AdamW (
188- model .parameters (),
189- lr = train_config .lr ,
190- weight_decay = train_config .weight_decay ,
191- )
267+ optimizer = optim .AdamW (model .parameters (), lr = train_config .lr , weight_decay = train_config .weight_decay )
192268 scheduler = StepLR (optimizer , step_size = 1 , gamma = train_config .gamma )
193-
194- # wrap model with DDP
195269 if train_config .enable_ddp :
196270 model = nn .parallel .DistributedDataParallel (model , device_ids = [dist .get_rank ()])
197-
198- _ = train (
271+ train (
199272 model ,
200273 train_dataloader ,
201274 eval_dataloader ,
@@ -208,8 +281,6 @@ def main(**kwargs):
208281 dist .get_rank () if train_config .enable_ddp else None ,
209282 None ,
210283 )
211-
212- # finalize torch distributed
213284 if train_config .enable_ddp :
214285 dist .destroy_process_group ()
215286
0 commit comments