Skip to content

Commit ada9de8

Browse files
committed
Fixed few comments. Fixed some rebase related errors and restructed the code
Signed-off-by: Meet Patel <[email protected]>
1 parent 8bafc20 commit ada9de8

File tree

5 files changed

+175
-126
lines changed

5 files changed

+175
-126
lines changed

QEfficient/cloud/finetune.py

Lines changed: 130 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import random
99
import warnings
10-
from typing import Optional, Any
10+
from typing import Any, Dict, Optional, Union
1111

1212
import fire
1313
import numpy as np
@@ -17,19 +17,15 @@
1717
import torch.optim as optim
1818
import torch.utils.data
1919
from peft import PeftModel, get_peft_model
20-
from dataclasses import fields
2120
from torch.optim.lr_scheduler import StepLR
22-
from transformers import AutoModelForCausalLM, AutoTokenizer
21+
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
2322

24-
from QEfficient.finetune.configs.peft_config import LoraConfig
2523
from QEfficient.finetune.configs.training import TrainConfig
2624
from QEfficient.finetune.utils.config_utils import (
2725
generate_dataset_config,
2826
generate_peft_config,
2927
get_dataloader_kwargs,
30-
load_config_file,
3128
update_config,
32-
validate_config,
3329
)
3430
from QEfficient.finetune.utils.dataset_utils import (
3531
get_custom_data_collator,
@@ -45,7 +41,7 @@
4541
print(f"Warning: {e}. Proceeding without QAIC modules.")
4642

4743

48-
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
44+
from transformers import AutoModelForSequenceClassification
4945

5046
# Suppress all warnings
5147
warnings.filterwarnings("ignore")
@@ -91,56 +87,103 @@ def setup_seeds(seed: int) -> None:
9187
np.random.seed(seed)
9288

9389

94-
def load_model_and_tokenizer(config: TrainConfig) -> tuple[AutoModelForCausalLM, AutoTokenizer]:
90+
def load_model_and_tokenizer(
91+
train_config: TrainConfig, dataset_config: Any, peft_config_file: str, **kwargs
92+
) -> tuple[AutoModelForCausalLM, AutoTokenizer]:
9593
"""Load the pre-trained model and tokenizer from Hugging Face.
9694
9795
Args:
9896
config (TrainConfig): Training configuration object containing model and tokenizer names.
97+
dataset_config (Any): A dataclass object representing dataset configuration.
98+
peft_config_file (str): Path to PEFT config file used for PEFT finetuning.
99+
kwargs: Additional arguments to override PEFT config.
99100
100101
Returns:
101-
tuple: A tuple containing the loaded model (AutoModelForCausalLM) and tokenizer (AutoTokenizer).
102+
tuple: A tuple of two values.
103+
- Model with pretrained weights loaded.
104+
- Model's tokenizer (AutoTokenizer).
102105
103106
Notes:
104107
- Downloads the model if not already cached using login_and_download_hf_lm.
105108
- Configures the model with FP16 precision and disables caching for training.
106109
- Resizes model embeddings if tokenizer vocab size exceeds model embedding size.
107110
- Sets pad_token_id to eos_token_id if not defined in the tokenizer.
108111
"""
109-
pretrained_model_path = login_and_download_hf_lm(config.model_name)
110-
model = AutoModelForCausalLM.from_pretrained(
111-
pretrained_model_path,
112-
use_cache=False,
113-
attn_implementation="sdpa",
114-
torch_dtype=torch.float16,
115-
)
112+
pretrained_model_path = login_and_download_hf_lm(train_config.model_name)
113+
if train_config.task_type == "seq_classification":
114+
model = AutoModelForSequenceClassification.from_pretrained(
115+
pretrained_model_path,
116+
num_labels=dataset_config.num_labels,
117+
attn_implementation="sdpa",
118+
torch_dtype=torch.float16,
119+
)
120+
121+
if not hasattr(model, "base_model_prefix"):
122+
raise RuntimeError("Given huggingface model does not have 'base_model_prefix' attribute.")
123+
124+
for param in getattr(model, model.base_model_prefix).parameters():
125+
param.requires_grad = False
126+
127+
for param in model.parameters():
128+
if param.requires_grad:
129+
param.data = param.data.to(torch.float32)
130+
else:
131+
model = AutoModelForCausalLM.from_pretrained(
132+
pretrained_model_path,
133+
use_cache=False,
134+
attn_implementation="sdpa",
135+
torch_dtype=torch.float16,
136+
)
116137

117138
tokenizer = AutoTokenizer.from_pretrained(
118-
config.model_name if config.tokenizer_name is None else config.tokenizer_name
139+
train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name
119140
)
120141
if not tokenizer.pad_token_id:
121142
tokenizer.pad_token_id = tokenizer.eos_token_id
122143

144+
# If there is a mismatch between tokenizer vocab size and embedding matrix,
145+
# throw a warning and then expand the embedding matrix
123146
if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
124147
print("WARNING: Resizing embedding matrix to match tokenizer vocab size.")
125148
model.resize_token_embeddings(len(tokenizer))
126149

150+
# FIXME (Meet): Cover below line inside the logger once it is implemented.
151+
print_model_size(model, train_config)
152+
127153
# Note: Need to call this before calling PeftModel.from_pretrained or get_peft_model.
128154
# Because, both makes model.is_gradient_checkpointing = True which is used in peft library to
129155
# apply gradient checkpointing related hooks to the input embeddings. Without this we will get
130156
# "No inf checks were recorded for this optimizer." error.
131157
# Enable gradient checkpointing
132-
if config.gradient_checkpointing:
158+
if train_config.gradient_checkpointing:
133159
# Note: below attribute and method is only available in HuggingFace Transformer models.
134160
if hasattr(model, "supports_gradient_checkpointing") and model.supports_gradient_checkpointing:
135161
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"preserve_rng_state": False})
136162
else:
137163
raise RuntimeError("Given model doesn't support gradient checkpointing. Please disable it and run it.")
138-
164+
165+
model = apply_peft(model, train_config, peft_config_file, **kwargs)
166+
139167
return model, tokenizer
140168

141169

142-
def apply_peft(model: AutoModelForCausalLM, train_config: TrainConfig, lora_config: LoraConfig) -> PeftModel:
143-
"""Apply Parameter-Efficient Fine-Tuning (PEFT) to the model if enabled."""
170+
def apply_peft(
171+
model: AutoModel, train_config: TrainConfig, peft_config_file: Dict, **kwargs
172+
) -> Union[AutoModel, PeftModel]:
173+
"""Apply Parameter-Efficient Fine-Tuning (PEFT) to the model if enabled.
174+
175+
Args:
176+
model (AutoModel): Huggingface model.
177+
train_config (TrainConfig): Training configuration object.
178+
peft_config_file (str, optional): Path to YAML/JSON file containing
179+
PEFT (LoRA) config. Defaults to None.
180+
kwargs: Additional arguments to override PEFT config params.
181+
182+
Returns:
183+
Union[AutoModel, PeftModel]: If the use_peft in train_config is True
184+
then PeftModel object is returned else original model object
185+
(AutoModel) is returned.
186+
"""
144187
if not train_config.use_peft:
145188
return model
146189

@@ -150,27 +193,31 @@ def apply_peft(model: AutoModelForCausalLM, train_config: TrainConfig, lora_conf
150193
peft_config = model.peft_config
151194
# Generate the peft config and start fine-tuning from original model
152195
else:
153-
peft_config = generate_peft_config(train_config, lora_config)
196+
peft_config = generate_peft_config(train_config, peft_config_file, **kwargs)
154197
model = get_peft_model(model, peft_config)
155198
model.print_trainable_parameters()
156199

157200
return model
158201

159202

160203
def setup_dataloaders(
161-
train_config: TrainConfig, dataset_config, tokenizer: AutoTokenizer, dataset_train, dataset_val
162-
) -> tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.DataLoader]]:
204+
train_config: TrainConfig,
205+
dataset_config: Any,
206+
tokenizer: AutoTokenizer,
207+
) -> tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.DataLoader], int]:
163208
"""Set up training and validation DataLoaders.
164209
165210
Args:
166211
train_config (TrainConfig): Training configuration object.
167-
dataset_config: Configuration for the dataset (generated from train_config).
212+
dataset_config (Any): Configuration for the dataset (generated from train_config).
168213
tokenizer (AutoTokenizer): Tokenizer for preprocessing data.
169-
dataset_train: Preprocessed training dataset.
170-
dataset_val: Preprocessed validation dataset.
171214
172215
Returns:
173-
tuple: A tuple of (train_dataloader, eval_dataloader), where eval_dataloader is None if validation is disabled.
216+
tuple: A tuple of three values.
217+
- First value represents train_dataloader
218+
- Second value represents eval_dataloader. It is None if
219+
validation is disabled.
220+
- Length of longest sequence in the dataset.
174221
175222
Raises:
176223
ValueError: If validation is enabled but the validation set is too small.
@@ -179,11 +226,33 @@ def setup_dataloaders(
179226
- Applies a custom data collator if provided by get_custom_data_collator.
180227
- Configures DataLoader kwargs using get_dataloader_kwargs for train and val splits.
181228
"""
182-
custom_data_collator = get_custom_data_collator(tokenizer, dataset_config)
183-
train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train")
229+
# Get the dataset utils
230+
dataset_processer = tokenizer
231+
232+
# Load and preprocess the dataset for training and validation
233+
dataset_train = get_preprocessed_dataset(
234+
dataset_processer, dataset_config, split="train", context_length=train_config.context_length
235+
)
236+
237+
dataset_val = get_preprocessed_dataset(
238+
dataset_processer, dataset_config, split="test", context_length=train_config.context_length
239+
)
240+
241+
# TODO: vbaddi, check if its necessary to do this?
242+
# dataset_train = ConcatDataset(
243+
# dataset_train, chunk_size=train_config.context_length
244+
# )
245+
##
246+
train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, dataset_processer, "train")
247+
print("length of dataset_train", len(dataset_train))
248+
249+
# FIXME (Meet): Add custom data collator registration from the outside by the user.
250+
custom_data_collator = get_custom_data_collator(dataset_processer, dataset_config)
184251
if custom_data_collator:
252+
print("custom_data_collator is used")
185253
train_dl_kwargs["collate_fn"] = custom_data_collator
186254

255+
# Create DataLoaders for the training and validation dataset
187256
train_dataloader = torch.utils.data.DataLoader(
188257
dataset_train,
189258
num_workers=train_config.num_workers_dataloader,
@@ -194,7 +263,12 @@ def setup_dataloaders(
194263

195264
eval_dataloader = None
196265
if train_config.run_validation:
197-
val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, tokenizer, "val")
266+
# if train_config.batching_strategy == "packing":
267+
# dataset_val = ConcatDataset(
268+
# dataset_val, chunk_size=train_config.context_length
269+
# )
270+
271+
val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, dataset_processer, "val")
198272
if custom_data_collator:
199273
val_dl_kwargs["collate_fn"] = custom_data_collator
200274

@@ -204,31 +278,29 @@ def setup_dataloaders(
204278
pin_memory=True,
205279
**val_dl_kwargs,
206280
)
207-
print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
208281
if len(eval_dataloader) == 0:
209-
raise ValueError("Eval set too small to load even one batch.")
282+
raise ValueError(
283+
f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({len(eval_dataloader)=})"
284+
)
285+
else:
286+
print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
210287

211-
return train_dataloader, eval_dataloader
288+
longest_seq_length, _ = get_longest_seq_length(
289+
torch.utils.data.ConcatDataset([train_dataloader.dataset, eval_dataloader.dataset])
290+
)
291+
else:
292+
longest_seq_length, _ = get_longest_seq_length(train_dataloader.dataset)
293+
294+
return train_dataloader, eval_dataloader, longest_seq_length
212295

213296

214-
def main(
215-
model_name: str = None,
216-
tokenizer_name: str = None,
217-
batch_size_training: int = None,
218-
lr: float = None,
219-
peft_config_file: str = None,
220-
**kwargs,
221-
) -> None:
297+
def main(peft_config_file=None, **kwargs) -> None:
222298
"""
223299
Fine-tune a model on QAIC hardware with configurable training and LoRA parameters.
224300
225301
Args:
226-
model_name (str, optional): Override default model name.
227-
tokenizer_name (str, optional): Override default tokenizer name.
228-
batch_size_training (int, optional): Override default training batch size.
229-
lr (float, optional): Override default learning rate.
230-
peft_config_file (str, optional): Path to YAML/JSON file containing PEFT (LoRA) config.
231-
**kwargs: Additional arguments to override TrainConfig.
302+
peft_config_file (str, optional): Path to YAML/JSON file containing PEFT (LoRA) config. Defaults to None.
303+
kwargs: Additional arguments to override TrainConfig.
232304
233305
Example:
234306
.. code-block:: bash
@@ -245,64 +317,36 @@ def main(
245317
--lr 5e-4
246318
"""
247319
train_config = TrainConfig()
248-
# local_args = {k: v for k, v in locals().items() if v is not None and k != "peft_config_file" and k != "kwargs"}
249320
update_config(train_config, **kwargs)
250-
251-
lora_config = LoraConfig()
252-
if peft_config_file:
253-
peft_config_data = load_config_file(peft_config_file)
254-
validate_config(peft_config_data, config_type="lora")
255-
lora_config = LoraConfig(**peft_config_data)
256-
else:
257-
lora_config = LoraConfig()
258-
259-
update_config(lora_config, **kwargs)
321+
dataset_config = generate_dataset_config(train_config.dataset)
322+
update_config(dataset_config, **kwargs)
260323

261324
setup_distributed_training(train_config)
262325
setup_seeds(train_config.seed)
263-
model, tokenizer = load_model_and_tokenizer(train_config)
264-
print_model_size(model, train_config)
265-
model = apply_peft(model, train_config, lora_config)
326+
model, tokenizer = load_model_and_tokenizer(train_config, dataset_config, peft_config_file, **kwargs)
266327

267-
# Pass an empty dict instead of kwargs to avoid irrelevant parameters
268-
dataset_config = generate_dataset_config(train_config, kwargs)
269-
dataset_train = get_preprocessed_dataset(
270-
tokenizer, dataset_config, split="train", context_length=train_config.context_length
271-
)
272-
dataset_val = get_preprocessed_dataset(
273-
tokenizer, dataset_config, split="test", context_length=train_config.context_length
274-
)
275-
train_dataloader, eval_dataloader = setup_dataloaders(
276-
train_config, dataset_config, tokenizer, dataset_train, dataset_val
277-
)
278-
dataset_for_seq_length = (
279-
torch.utils.data.ConcatDataset([train_dataloader.dataset, eval_dataloader.dataset])
280-
if train_config.run_validation
281-
else train_dataloader.dataset
282-
)
283-
longest_seq_length, _ = get_longest_seq_length(dataset_for_seq_length)
328+
# Create DataLoaders for the training and validation dataset
329+
train_dataloader, eval_dataloader, longest_seq_length = setup_dataloaders(train_config, dataset_config, tokenizer)
284330
print(
285-
f"Longest sequence length: {longest_seq_length}, "
286-
f"Context length: {train_config.context_length}, "
287-
f"Model max context: {model.config.max_position_embeddings}"
331+
f"The longest sequence length in the train data is {longest_seq_length}, "
332+
f"passed context length is {train_config.context_length} and overall model's context length is "
333+
f"{model.config.max_position_embeddings}"
288334
)
335+
289336
model.to(train_config.device)
290337
optimizer = optim.AdamW(model.parameters(), lr=train_config.lr, weight_decay=train_config.weight_decay)
291338
scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
292339
if train_config.enable_ddp:
293340
model = nn.parallel.DistributedDataParallel(model, device_ids=[dist.get_rank()])
294341
results = train(
295342
model,
343+
tokenizer,
296344
train_dataloader,
297345
eval_dataloader,
298-
tokenizer,
299346
optimizer,
300347
scheduler,
301-
train_config.gradient_accumulation_steps,
302348
train_config,
303-
train_config.device,
304349
dist.get_rank() if train_config.enable_ddp else None,
305-
None,
306350
)
307351
if train_config.enable_ddp:
308352
dist.destroy_process_group()

QEfficient/finetune/configs/peft_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,6 @@ class LoraConfig:
3434

3535
# CAUTION prefix tuning is currently not supported
3636
@dataclass
37-
class prefix_config:
37+
class PrefixConfig:
3838
num_virtual_tokens: int = 30
3939
task_type: str = "CAUSAL_LM"

0 commit comments

Comments
 (0)