-
Notifications
You must be signed in to change notification settings - Fork 59
[QEff Finetune]: Refactor the finetune main __call__ #289
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
2f19722 to
48061ee
Compare
3ff66eb to
c0d2315
Compare
7f2d367 to
b2ee39a
Compare
tests/finetune/test_finetune.py
Outdated
| finetune(**kwargs) | ||
| results = finetune(**kwargs) | ||
|
|
||
| assert np.allclose(results["avg_train_prep"], 1.002326, atol=1e-5), "Train perplexity is not matching." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
avg_train_prep to be changed to avg_train_metric wrt changes in PR 292
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated in latest.
| gradient_accumulation_steps, | ||
| train_config: TRAIN_CONFIG, | ||
| train_config: TrainConfig, | ||
| device, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need of passing all three train_config.gradient_accumulation_steps, train_config and train_config.device, only train_config is enough.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated in latest.
d0fff22 to
e27deeb
Compare
| - Ensures types match expected values (int, float, list, etc.). | ||
| """ | ||
| if config_type.lower() != "lora": | ||
| raise ValueError(f"Unsupported config_type: {config_type}. Only 'lora' is supported.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we are not doing lora finetuning in case of BERT, it will raise error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, this is used only when peft_config_file is provided in main() function of finetune.py. Currently for BERT there wont be any peft_config_file. But if it is provided then PEFT training will happen for BERT.
| Args: | ||
| config_data (Dict[str, Any]): The configuration dictionary loaded from YAML/JSON. | ||
| config_type (str): Type of config to validate ("lora" for LoraConfig, default: "lora"). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to add field in config_type corresponding to BERT as we don't do lora fine tuning in it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lora can work with BERT but there is no point as BERT for sequence classification has random weights for classifier head.
QEfficient/cloud/finetune.py
Outdated
| # local_args = {k: v for k, v in locals().items() if v is not None and k != "peft_config_file" and k != "kwargs"} | ||
| update_config(train_config, **kwargs) | ||
|
|
||
| lora_config = LoraConfig() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this line is not required.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed in latest.
QEfficient/cloud/finetune.py
Outdated
| longest_seq_length, _ = get_longest_seq_length(train_dataloader.dataset) | ||
| lora_config = LoraConfig() | ||
|
|
||
| update_config(lora_config, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do need to update lora_config here with kwargs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved this code inside 'generate_peft_config' function. It was required so update the config params based on cli arguments.
0bb2a51 to
ada9de8
Compare
QEfficient/cloud/finetune.py
Outdated
| def main(**kwargs): | ||
| """ | ||
| Helper function to finetune the model on QAic. | ||
| def setup_distributed_training(config: TrainConfig) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please move these functions setup_distributed_training, setup_seeds, load_model_and_tokenizer, apply_peft and setup_dataloaders to other utils file and import in this file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can take that in next PR. Keeping the refactored code in the same file for better comparison. Lot of refactoring is still required, we will do it in incremental fashion.
1709b85 to
0aead82
Compare
8a5b6b6 to
a4c8c50
Compare
Signed-off-by: vbaddi <[email protected]> Signed-off-by: Meet Patel <[email protected]>
Signed-off-by: Meet Patel <[email protected]>
Signed-off-by: Meet Patel <[email protected]>
…he code Signed-off-by: Meet Patel <[email protected]>
Signed-off-by: Meet Patel <[email protected]>
…e. Addressed comments. Signed-off-by: Meet Patel <[email protected]>
Signed-off-by: Meet Patel <[email protected]>
Signed-off-by: Meet Patel <[email protected]>
Signed-off-by: Meet Patel <[email protected]>
Signed-off-by: Meet Patel <[email protected]>
Signed-off-by: Meet Patel <[email protected]>
Signed-off-by: Meet Patel <[email protected]>
…d issues. Signed-off-by: Meet Patel <[email protected]>
… New PR will be raised to enable tests. Signed-off-by: Meet Patel <[email protected]>
a4c8c50 to
777c7ca
Compare
- Refactor the finetune main api
- Add support to override the PEFT config (yaml/json)
- Add support to validate the correctness of PEFT Config
- Some nit changes
```yaml
r: 16
lora_alpha: 64
target_modules:
- q_proj
- v_proj
- k_proj
bias: none
task_type: CAUSAL_LM
lora_dropout: 0.1
```
Command:
```bash
python -m QEfficient.cloud.finetune \
--model_name "meta-llama/Llama-3.2-1B" \
--lr 5e-4 \
--peft_config_file "lora_config.yaml"
```
#### Using Default LoRA Config:
```bash
python -m QEfficient.cloud.finetune \
--model_name "meta-llama/Llama-3.2-1B" \
--lr 5e-4
```
---------
Signed-off-by: vbaddi <[email protected]>
Signed-off-by: Meet Patel <[email protected]>
Co-authored-by: Meet Patel <[email protected]>
Signed-off-by: Mohit Soni <[email protected]>
- Refactor the finetune main api
- Add support to override the PEFT config (yaml/json)
- Add support to validate the correctness of PEFT Config
- Some nit changes
```yaml
r: 16
lora_alpha: 64
target_modules:
- q_proj
- v_proj
- k_proj
bias: none
task_type: CAUSAL_LM
lora_dropout: 0.1
```
Command:
```bash
python -m QEfficient.cloud.finetune \
--model_name "meta-llama/Llama-3.2-1B" \
--lr 5e-4 \
--peft_config_file "lora_config.yaml"
```
#### Using Default LoRA Config:
```bash
python -m QEfficient.cloud.finetune \
--model_name "meta-llama/Llama-3.2-1B" \
--lr 5e-4
```
---------
Signed-off-by: vbaddi <[email protected]>
Signed-off-by: Meet Patel <[email protected]>
Co-authored-by: Meet Patel <[email protected]>
Signed-off-by: Mohit Soni <[email protected]>
Command:
python -m QEfficient.cloud.finetune \ --model_name "meta-llama/Llama-3.2-1B" \ --lr 5e-4 \ --peft_config_file "lora_config.yaml"Using Default LoRA Config:
python -m QEfficient.cloud.finetune \ --model_name "meta-llama/Llama-3.2-1B" \ --lr 5e-4