Skip to content

Commit 2549a61

Browse files
vbaddiquic-meetkuma
authored andcommitted
refactor the finetune main __call__
Signed-off-by: vbaddi <[email protected]>
1 parent 4014e70 commit 2549a61

File tree

8 files changed

+392
-138
lines changed

8 files changed

+392
-138
lines changed

QEfficient/cloud/finetune.py

Lines changed: 182 additions & 103 deletions
Large diffs are not rendered by default.

QEfficient/finetune/configs/peft_config.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,24 @@
99
from typing import List
1010

1111

12-
# Currently, the support is for Lora Configs only
13-
# In future, we can expand to llama_adapters and prefix tuning
14-
# TODO: vbaddi: Check back once FSDP is enabled
1512
@dataclass
16-
class lora_config:
13+
class LoraConfig:
14+
"""LoRA-specific configuration for parameter-efficient fine-tuning.
15+
16+
Attributes:
17+
r (int): LoRA rank (default: 8).
18+
lora_alpha (int): LoRA scaling factor (default: 32).
19+
target_modules (List[str]): Modules to apply LoRA to (default: ["q_proj", "v_proj"]).
20+
bias (str): Bias handling in LoRA (default: "none").
21+
task_type (str): Task type for LoRA (default: "CAUSAL_LM").
22+
lora_dropout (float): Dropout rate for LoRA (default: 0.0).
23+
inference_mode (bool): Whether model is in inference mode (default: False).
24+
"""
25+
1726
r: int = 8
1827
lora_alpha: int = 32
1928
target_modules: List[str] = field(default_factory=lambda: ["q_proj", "v_proj"])
20-
bias = "none"
29+
bias: str = "none"
2130
task_type: str = "CAUSAL_LM"
2231
lora_dropout: float = 0.05
2332
inference_mode: bool = False # should be False for finetuning

QEfficient/finetune/configs/training.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,52 @@
77
from dataclasses import dataclass
88

99

10+
# Configuration Classes
1011
@dataclass
11-
class train_config:
12+
class TrainConfig:
13+
"""Training configuration for model fine-tuning.
14+
15+
Attributes:
16+
model_name (str): Name of the pre-trained model to fine-tune (default: "meta-llama/Llama-3.2-1B").
17+
tokenizer_name (str): Name of the tokenizer (defaults to model_name if None).
18+
run_validation (bool): Whether to run validation during training (default: True).
19+
batch_size_training (int): Batch size for training (default: 1).
20+
context_length (Optional[int]): Maximum sequence length for inputs (default: None).
21+
gradient_accumulation_steps (int): Steps for gradient accumulation (default: 4).
22+
num_epochs (int): Number of training epochs (default: 1).
23+
max_train_step (int): Maximum training steps (default: 0, unlimited if 0).
24+
max_eval_step (int): Maximum evaluation steps (default: 0, unlimited if 0).
25+
device (str): Device to train on (default: "qaic").
26+
num_workers_dataloader (int): Number of workers for data loading (default: 1).
27+
lr (float): Learning rate (default: 3e-4).
28+
weight_decay (float): Weight decay for optimizer (default: 0.0).
29+
gamma (float): Learning rate decay factor (default: 0.85).
30+
seed (int): Random seed for reproducibility (default: 42).
31+
use_fp16 (bool): Use mixed precision training (default: True).
32+
use_autocast (bool): Use autocast for mixed precision (default: True).
33+
val_batch_size (int): Batch size for validation (default: 1).
34+
dataset (str): Dataset name for training (default: "samsum_dataset").
35+
peft_method (str): Parameter-efficient fine-tuning method (default: "lora").
36+
use_peft (bool): Whether to use PEFT (default: True).
37+
from_peft_checkpoint (str): Path to PEFT checkpoint (default: "").
38+
output_dir (str): Directory to save outputs (default: "meta-llama-samsum").
39+
num_freeze_layers (int): Number of layers to freeze (default: 1).
40+
one_qaic (bool): Use single QAIC device (default: False).
41+
save_model (bool): Save the trained model (default: True).
42+
save_metrics (bool): Save training metrics (default: True).
43+
intermediate_step_save (int): Steps between intermediate saves (default: 1000).
44+
batching_strategy (str): Batching strategy (default: "packing").
45+
enable_sorting_for_ddp (bool): Sort data for DDP (default: True).
46+
convergence_counter (int): Steps to check convergence (default: 5).
47+
convergence_loss (float): Loss threshold for convergence (default: 1e-4).
48+
use_profiler (bool): Enable profiling (default: False).
49+
enable_ddp (bool): Enable distributed data parallel (default: False).
50+
dist_backend (str): Backend for distributed training (default: "cpu:gloo,qaic:qccl,cuda:gloo").
51+
grad_scaler (bool): Use gradient scaler (default: True).
52+
dump_root_dir (str): Directory for mismatch dumps (default: "meta-llama-samsum-mismatches/step_").
53+
opByOpVerifier (bool): Enable operation-by-operation verification (default: False).
54+
"""
55+
1256
model_name: str = "meta-llama/Llama-3.2-1B"
1357
tokenizer_name: str = None # if not passed as an argument, it uses the value of model_name
1458
run_validation: bool = True

QEfficient/finetune/eval.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import fire
1212
import numpy as np
1313
import torch
14-
from configs.training import train_config as TRAIN_CONFIG
1514
from peft import AutoPeftModelForCausalLM
1615
from transformers import AutoModelForCausalLM, AutoTokenizer
1716
from utils.config_utils import (
@@ -25,6 +24,8 @@
2524
)
2625
from utils.train_utils import evaluation, print_model_size
2726

27+
from QEfficient.finetune.configs.training import TrainConfig
28+
2829
try:
2930
import torch_qaic # noqa: F401
3031

@@ -39,7 +40,7 @@
3940

4041
def main(**kwargs):
4142
# update the configuration for the training process
42-
train_config = TRAIN_CONFIG()
43+
train_config = TrainConfig()
4344
update_config(train_config, **kwargs)
4445

4546
# Set the seeds for reproducibility

QEfficient/finetune/utils/config_utils.py

Lines changed: 143 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,40 @@
44
# SPDX-License-Identifier: BSD-3-Clause
55
#
66
# -----------------------------------------------------------------------------
7-
87
import inspect
8+
import json
9+
import os
910
from dataclasses import asdict
11+
from typing import Any, Dict
1012

1113
import torch.distributed as dist
1214
import torch.utils.data as data_utils
15+
import yaml
1316
from peft import (
1417
AdaptionPromptConfig,
15-
LoraConfig,
1618
PrefixTuningConfig,
1719
)
20+
from peft import LoraConfig as PeftLoraConfig
1821
from transformers import default_data_collator
1922
from transformers.data import DataCollatorForSeq2Seq
2023

2124
import QEfficient.finetune.configs.dataset_config as datasets
22-
from QEfficient.finetune.configs.peft_config import lora_config, prefix_config
23-
from QEfficient.finetune.configs.training import train_config
25+
from QEfficient.finetune.configs.peft_config import LoraConfig
26+
from QEfficient.finetune.configs.training import TrainConfig
2427
from QEfficient.finetune.data.sampler import DistributedLengthBasedBatchSampler
2528
from QEfficient.finetune.dataset.dataset_config import DATASET_PREPROC
2629

2730

2831
def update_config(config, **kwargs):
32+
"""Update the attributes of a config object based on provided keyword arguments.
33+
34+
Args:
35+
config: The configuration object (e.g., TrainConfig, LoraConfig) or a list/tuple of such objects.
36+
**kwargs: Keyword arguments representing attributes to update.
37+
38+
Raises:
39+
ValueError: If an unknown parameter is provided and the config type doesn't support nested updates.
40+
"""
2941
if isinstance(config, (tuple, list)):
3042
for c in config:
3143
update_config(c, **kwargs)
@@ -34,40 +46,68 @@ def update_config(config, **kwargs):
3446
if hasattr(config, k):
3547
setattr(config, k, v)
3648
elif "." in k:
37-
# allow --some_config.some_param=True
38-
config_name, param_name = k.split(".")
39-
if type(config).__name__ == config_name:
49+
config_name, param_name = k.split(".", 1)
50+
if type(config).__name__.lower() == config_name.lower():
4051
if hasattr(config, param_name):
4152
setattr(config, param_name, v)
4253
else:
43-
# In case of specialized config we can warn user
44-
assert False, f"Warning: {config_name} does not accept parameter: {k}"
45-
elif isinstance(config, train_config):
46-
assert False, f"Warning: unknown parameter {k}"
54+
raise ValueError(f"Config '{config_name}' does not have parameter: '{param_name}'")
55+
else:
56+
config_type = type(config).__name__
57+
print(f"[WARNING]: Unknown parameter '{k}' for config type '{config_type}'")
4758

4859

49-
def generate_peft_config(train_config, kwargs):
50-
configs = (lora_config, prefix_config)
51-
peft_configs = (LoraConfig, AdaptionPromptConfig, PrefixTuningConfig)
52-
names = tuple(c.__name__.rstrip("_config") for c in configs)
60+
def generate_peft_config(train_config: TrainConfig, custom_config: Any) -> Any:
61+
"""Generate a PEFT-compatible configuration from a custom config based on peft_method.
5362
54-
if train_config.peft_method not in names:
55-
raise RuntimeError(f"Peft config not found: {train_config.peft_method}")
63+
Args:
64+
train_config (TrainConfig): Training configuration with peft_method.
65+
custom_config: Custom configuration object (e.g., LoraConfig).
5666
57-
config = configs[names.index(train_config.peft_method)]()
67+
Returns:
68+
Any: A PEFT-specific configuration object (e.g., PeftLoraConfig).
5869
59-
update_config(config, **kwargs)
70+
Raises:
71+
RuntimeError: If the peft_method is not supported.
72+
"""
73+
# Define supported PEFT methods and their corresponding configs
74+
method_to_configs = {
75+
"lora": (LoraConfig, PeftLoraConfig),
76+
"adaption_prompt": (None, AdaptionPromptConfig), # Placeholder; add custom config if needed
77+
"prefix_tuning": (None, PrefixTuningConfig), # Placeholder; add custom config if needed
78+
}
79+
80+
peft_method = train_config.peft_method.lower()
81+
if peft_method not in method_to_configs:
82+
raise RuntimeError(f"PEFT config not found for method: {train_config.peft_method}")
83+
84+
custom_config_class, peft_config_class = method_to_configs[peft_method]
85+
86+
# Use the provided custom_config (e.g., LoraConfig instance)
87+
config = custom_config
6088
params = asdict(config)
61-
peft_config = peft_configs[names.index(train_config.peft_method)](**params)
6289

90+
# Create the PEFT-compatible config
91+
peft_config = peft_config_class(**params)
6392
return peft_config
6493

6594

66-
def generate_dataset_config(train_config, kwargs):
95+
def generate_dataset_config(train_config: TrainConfig, kwargs: Dict[str, Any] = None) -> Any:
96+
"""Generate a dataset configuration based on the specified dataset in train_config.
97+
98+
Args:
99+
train_config (TrainConfig): Training configuration with dataset name.
100+
kwargs (Dict[str, Any], optional): Additional arguments (currently unused).
101+
102+
Returns:
103+
Any: A dataset configuration object.
104+
105+
Raises:
106+
AssertionError: If the dataset name is not recognized.
107+
"""
67108
names = tuple(DATASET_PREPROC.keys())
68109
assert train_config.dataset in names, f"Unknown dataset: {train_config.dataset}"
69110
dataset_config = {k: v for k, v in inspect.getmembers(datasets)}[train_config.dataset]()
70-
update_config(dataset_config, **kwargs)
71111
return dataset_config
72112

73113

@@ -101,3 +141,84 @@ def get_dataloader_kwargs(train_config, dataset, dataset_processer, mode):
101141
kwargs["drop_last"] = True
102142
kwargs["collate_fn"] = default_data_collator
103143
return kwargs
144+
145+
146+
def validate_config(config_data: Dict[str, Any], config_type: str = "lora") -> None:
147+
"""Validate the provided YAML/JSON configuration for required fields and types.
148+
149+
Args:
150+
config_data (Dict[str, Any]): The configuration dictionary loaded from YAML/JSON.
151+
config_type (str): Type of config to validate ("lora" for LoraConfig, default: "lora").
152+
153+
Raises:
154+
ValueError: If required fields are missing or have incorrect types.
155+
FileNotFoundError: If the config file path is invalid (handled upstream).
156+
157+
Notes:
158+
- Validates required fields for LoraConfig: r, lora_alpha, target_modules.
159+
- Ensures types match expected values (int, float, list, etc.).
160+
"""
161+
if config_type.lower() != "lora":
162+
raise ValueError(f"Unsupported config_type: {config_type}. Only 'lora' is supported.")
163+
164+
required_fields = {
165+
"r": int,
166+
"lora_alpha": int,
167+
"target_modules": list,
168+
}
169+
optional_fields = {
170+
"bias": str,
171+
"task_type": str,
172+
"lora_dropout": float,
173+
"inference_mode": bool,
174+
}
175+
176+
# Check for missing required fields
177+
missing_fields = [field for field in required_fields if field not in config_data]
178+
if missing_fields:
179+
raise ValueError(f"Missing required fields in {config_type} config: {missing_fields}")
180+
181+
# Validate types of required fields
182+
for field, expected_type in required_fields.items():
183+
if not isinstance(config_data[field], expected_type):
184+
raise ValueError(
185+
f"Field '{field}' in {config_type} config must be of type {expected_type.__name__}, "
186+
f"got {type(config_data[field]).__name__}"
187+
)
188+
189+
# Validate target_modules contains strings
190+
if not all(isinstance(mod, str) for mod in config_data["target_modules"]):
191+
raise ValueError("All elements in 'target_modules' must be strings")
192+
193+
# Validate types of optional fields if present
194+
for field, expected_type in optional_fields.items():
195+
if field in config_data and not isinstance(config_data[field], expected_type):
196+
raise ValueError(
197+
f"Field '{field}' in {config_type} config must be of type {expected_type.__name__}, "
198+
f"got {type(config_data[field]).__name__}"
199+
)
200+
201+
202+
def load_config_file(config_path: str) -> Dict[str, Any]:
203+
"""Load a configuration from a YAML or JSON file.
204+
205+
Args:
206+
config_path (str): Path to the YAML or JSON file.
207+
208+
Returns:
209+
Dict[str, Any]: The loaded configuration as a dictionary.
210+
211+
Raises:
212+
FileNotFoundError: If the file does not exist.
213+
ValueError: If the file format is unsupported.
214+
"""
215+
if not os.path.exists(config_path):
216+
raise FileNotFoundError(f"Config file not found: {config_path}")
217+
218+
with open(config_path, "r") as f:
219+
if config_path.endswith(".yaml") or config_path.endswith(".yml"):
220+
return yaml.safe_load(f)
221+
elif config_path.endswith(".json"):
222+
return json.load(f)
223+
else:
224+
raise ValueError("Unsupported config file format. Use .yaml, .yml, or .json")

QEfficient/finetune/utils/train_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from torch.utils.tensorboard import SummaryWriter
1818
from tqdm import tqdm
1919

20-
from QEfficient.finetune.configs.training import train_config as TRAIN_CONFIG
20+
from QEfficient.finetune.configs.training import TrainConfig
2121

2222
try:
2323
import torch_qaic # noqa: F401
@@ -39,7 +39,7 @@ def train(
3939
optimizer,
4040
lr_scheduler,
4141
gradient_accumulation_steps,
42-
train_config: TRAIN_CONFIG,
42+
train_config: TrainConfig,
4343
device,
4444
local_rank=None,
4545
rank=None,

scripts/finetune/run_ft_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from peft import AutoPeftModelForCausalLM
1313
from transformers import AutoModelForCausalLM, AutoTokenizer
1414

15-
from QEfficient.finetune.configs.training import train_config as TRAIN_CONFIG
15+
from QEfficient.finetune.configs.training import TrainConfig
1616

1717
# Suppress all warnings
1818
warnings.filterwarnings("ignore")
@@ -25,7 +25,7 @@
2525
print(f"Warning: {e}. Moving ahead without these qaic modules.")
2626
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
2727

28-
train_config = TRAIN_CONFIG()
28+
train_config = TrainConfig()
2929
model = AutoModelForCausalLM.from_pretrained(
3030
train_config.model_name,
3131
use_cache=False,

tests/finetune/test_finetune.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def test_finetune(
4343
device,
4444
mocker,
4545
):
46-
train_config_spy = mocker.spy(QEfficient.cloud.finetune, "TRAIN_CONFIG")
46+
train_config_spy = mocker.spy(QEfficient.cloud.finetune, "TrainConfig")
4747
generate_dataset_config_spy = mocker.spy(QEfficient.cloud.finetune, "generate_dataset_config")
4848
generate_peft_config_spy = mocker.spy(QEfficient.cloud.finetune, "generate_peft_config")
4949
get_dataloader_kwargs_spy = mocker.spy(QEfficient.cloud.finetune, "get_dataloader_kwargs")

0 commit comments

Comments
 (0)