44# SPDX-License-Identifier: BSD-3-Clause
55#
66# -----------------------------------------------------------------------------
7-
87import inspect
8+ import json
9+ import os
910from dataclasses import asdict
11+ from typing import Any , Dict
1012
1113import torch .distributed as dist
1214import torch .utils .data as data_utils
15+ import yaml
1316from peft import (
1417 AdaptionPromptConfig ,
15- LoraConfig ,
1618 PrefixTuningConfig ,
1719)
20+ from peft import LoraConfig as PeftLoraConfig
1821from transformers import default_data_collator
1922from transformers .data import DataCollatorForSeq2Seq
2023
2124import QEfficient .finetune .configs .dataset_config as datasets
22- from QEfficient .finetune .configs .peft_config import lora_config , prefix_config
23- from QEfficient .finetune .configs .training import train_config
25+ from QEfficient .finetune .configs .peft_config import LoraConfig
26+ from QEfficient .finetune .configs .training import TrainConfig
2427from QEfficient .finetune .data .sampler import DistributedLengthBasedBatchSampler
2528from QEfficient .finetune .dataset .dataset_config import DATASET_PREPROC
2629
2730
2831def update_config (config , ** kwargs ):
32+ """Update the attributes of a config object based on provided keyword arguments.
33+
34+ Args:
35+ config: The configuration object (e.g., TrainConfig, LoraConfig) or a list/tuple of such objects.
36+ **kwargs: Keyword arguments representing attributes to update.
37+
38+ Raises:
39+ ValueError: If an unknown parameter is provided and the config type doesn't support nested updates.
40+ """
2941 if isinstance (config , (tuple , list )):
3042 for c in config :
3143 update_config (c , ** kwargs )
@@ -34,40 +46,68 @@ def update_config(config, **kwargs):
3446 if hasattr (config , k ):
3547 setattr (config , k , v )
3648 elif "." in k :
37- # allow --some_config.some_param=True
38- config_name , param_name = k .split ("." )
39- if type (config ).__name__ == config_name :
49+ config_name , param_name = k .split ("." , 1 )
50+ if type (config ).__name__ .lower () == config_name .lower ():
4051 if hasattr (config , param_name ):
4152 setattr (config , param_name , v )
4253 else :
43- # In case of specialized config we can warn user
44- assert False , f"Warning: { config_name } does not accept parameter: { k } "
45- elif isinstance (config , train_config ):
46- assert False , f"Warning: unknown parameter { k } "
54+ raise ValueError ( f"Config ' { config_name } ' does not have parameter: ' { param_name } '" )
55+ else :
56+ config_type = type (config ). __name__
57+ print ( f"[WARNING]: Unknown parameter ' { k } ' for config type ' { config_type } '" )
4758
4859
49- def generate_peft_config (train_config , kwargs ):
50- configs = (lora_config , prefix_config )
51- peft_configs = (LoraConfig , AdaptionPromptConfig , PrefixTuningConfig )
52- names = tuple (c .__name__ .rstrip ("_config" ) for c in configs )
60+ def generate_peft_config (train_config : TrainConfig , custom_config : Any ) -> Any :
61+ """Generate a PEFT-compatible configuration from a custom config based on peft_method.
5362
54- if train_config .peft_method not in names :
55- raise RuntimeError (f"Peft config not found: { train_config .peft_method } " )
63+ Args:
64+ train_config (TrainConfig): Training configuration with peft_method.
65+ custom_config: Custom configuration object (e.g., LoraConfig).
5666
57- config = configs [names .index (train_config .peft_method )]()
67+ Returns:
68+ Any: A PEFT-specific configuration object (e.g., PeftLoraConfig).
5869
59- update_config (config , ** kwargs )
70+ Raises:
71+ RuntimeError: If the peft_method is not supported.
72+ """
73+ # Define supported PEFT methods and their corresponding configs
74+ method_to_configs = {
75+ "lora" : (LoraConfig , PeftLoraConfig ),
76+ "adaption_prompt" : (None , AdaptionPromptConfig ), # Placeholder; add custom config if needed
77+ "prefix_tuning" : (None , PrefixTuningConfig ), # Placeholder; add custom config if needed
78+ }
79+
80+ peft_method = train_config .peft_method .lower ()
81+ if peft_method not in method_to_configs :
82+ raise RuntimeError (f"PEFT config not found for method: { train_config .peft_method } " )
83+
84+ custom_config_class , peft_config_class = method_to_configs [peft_method ]
85+
86+ # Use the provided custom_config (e.g., LoraConfig instance)
87+ config = custom_config
6088 params = asdict (config )
61- peft_config = peft_configs [names .index (train_config .peft_method )](** params )
6289
90+ # Create the PEFT-compatible config
91+ peft_config = peft_config_class (** params )
6392 return peft_config
6493
6594
66- def generate_dataset_config (train_config , kwargs ):
95+ def generate_dataset_config (train_config : TrainConfig , kwargs : Dict [str , Any ] = None ) -> Any :
96+ """Generate a dataset configuration based on the specified dataset in train_config.
97+
98+ Args:
99+ train_config (TrainConfig): Training configuration with dataset name.
100+ kwargs (Dict[str, Any], optional): Additional arguments (currently unused).
101+
102+ Returns:
103+ Any: A dataset configuration object.
104+
105+ Raises:
106+ AssertionError: If the dataset name is not recognized.
107+ """
67108 names = tuple (DATASET_PREPROC .keys ())
68109 assert train_config .dataset in names , f"Unknown dataset: { train_config .dataset } "
69110 dataset_config = {k : v for k , v in inspect .getmembers (datasets )}[train_config .dataset ]()
70- update_config (dataset_config , ** kwargs )
71111 return dataset_config
72112
73113
@@ -101,3 +141,84 @@ def get_dataloader_kwargs(train_config, dataset, dataset_processer, mode):
101141 kwargs ["drop_last" ] = True
102142 kwargs ["collate_fn" ] = default_data_collator
103143 return kwargs
144+
145+
146+ def validate_config (config_data : Dict [str , Any ], config_type : str = "lora" ) -> None :
147+ """Validate the provided YAML/JSON configuration for required fields and types.
148+
149+ Args:
150+ config_data (Dict[str, Any]): The configuration dictionary loaded from YAML/JSON.
151+ config_type (str): Type of config to validate ("lora" for LoraConfig, default: "lora").
152+
153+ Raises:
154+ ValueError: If required fields are missing or have incorrect types.
155+ FileNotFoundError: If the config file path is invalid (handled upstream).
156+
157+ Notes:
158+ - Validates required fields for LoraConfig: r, lora_alpha, target_modules.
159+ - Ensures types match expected values (int, float, list, etc.).
160+ """
161+ if config_type .lower () != "lora" :
162+ raise ValueError (f"Unsupported config_type: { config_type } . Only 'lora' is supported." )
163+
164+ required_fields = {
165+ "r" : int ,
166+ "lora_alpha" : int ,
167+ "target_modules" : list ,
168+ }
169+ optional_fields = {
170+ "bias" : str ,
171+ "task_type" : str ,
172+ "lora_dropout" : float ,
173+ "inference_mode" : bool ,
174+ }
175+
176+ # Check for missing required fields
177+ missing_fields = [field for field in required_fields if field not in config_data ]
178+ if missing_fields :
179+ raise ValueError (f"Missing required fields in { config_type } config: { missing_fields } " )
180+
181+ # Validate types of required fields
182+ for field , expected_type in required_fields .items ():
183+ if not isinstance (config_data [field ], expected_type ):
184+ raise ValueError (
185+ f"Field '{ field } ' in { config_type } config must be of type { expected_type .__name__ } , "
186+ f"got { type (config_data [field ]).__name__ } "
187+ )
188+
189+ # Validate target_modules contains strings
190+ if not all (isinstance (mod , str ) for mod in config_data ["target_modules" ]):
191+ raise ValueError ("All elements in 'target_modules' must be strings" )
192+
193+ # Validate types of optional fields if present
194+ for field , expected_type in optional_fields .items ():
195+ if field in config_data and not isinstance (config_data [field ], expected_type ):
196+ raise ValueError (
197+ f"Field '{ field } ' in { config_type } config must be of type { expected_type .__name__ } , "
198+ f"got { type (config_data [field ]).__name__ } "
199+ )
200+
201+
202+ def load_config_file (config_path : str ) -> Dict [str , Any ]:
203+ """Load a configuration from a YAML or JSON file.
204+
205+ Args:
206+ config_path (str): Path to the YAML or JSON file.
207+
208+ Returns:
209+ Dict[str, Any]: The loaded configuration as a dictionary.
210+
211+ Raises:
212+ FileNotFoundError: If the file does not exist.
213+ ValueError: If the file format is unsupported.
214+ """
215+ if not os .path .exists (config_path ):
216+ raise FileNotFoundError (f"Config file not found: { config_path } " )
217+
218+ with open (config_path , "r" ) as f :
219+ if config_path .endswith (".yaml" ) or config_path .endswith (".yml" ):
220+ return yaml .safe_load (f )
221+ elif config_path .endswith (".json" ):
222+ return json .load (f )
223+ else :
224+ raise ValueError ("Unsupported config file format. Use .yaml, .yml, or .json" )
0 commit comments