Skip to content

Commit fbd389b

Browse files
committed
Fixed FT test case on qaic
Signed-off-by: Meet Patel <[email protected]>
1 parent a802eac commit fbd389b

File tree

2 files changed

+15
-6
lines changed

2 files changed

+15
-6
lines changed

QEfficient/cloud/finetune.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ def main(
291291
scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
292292
if train_config.enable_ddp:
293293
model = nn.parallel.DistributedDataParallel(model, device_ids=[dist.get_rank()])
294-
train(
294+
results = train(
295295
model,
296296
train_dataloader,
297297
eval_dataloader,
@@ -306,6 +306,7 @@ def main(
306306
)
307307
if train_config.enable_ddp:
308308
dist.destroy_process_group()
309+
return results
309310

310311

311312
if __name__ == "__main__":

tests/finetune/test_finetune.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import os
99
import shutil
1010

11+
import numpy as np
1112
import pytest
1213
import torch.optim as optim
1314
from torch.utils.data import DataLoader
@@ -22,12 +23,10 @@ def clean_up(path):
2223
shutil.rmtree(path)
2324

2425

25-
configs = [pytest.param("meta-llama/Llama-3.2-1B", 1, 1, 1, None, True, True, "cpu", id="llama_config")]
26+
configs = [pytest.param("meta-llama/Llama-3.2-1B", 10, 20, 1, None, True, True, "qaic", id="llama_config")]
2627

2728

28-
# TODO:enable this once docker is available
2929
@pytest.mark.on_qaic
30-
@pytest.mark.skip(reason="eager docker not available in sdk")
3130
@pytest.mark.parametrize(
3231
"model_name,max_eval_step,max_train_step,intermediate_step_save,context_length,run_validation,use_peft,device",
3332
configs,
@@ -65,7 +64,13 @@ def test_finetune(
6564
"device": device,
6665
}
6766

68-
finetune(**kwargs)
67+
results = finetune(**kwargs)
68+
69+
assert np.allclose(results["avg_train_prep"], 1.002326), "Train perplexity is not matching."
70+
assert np.allclose(results["avg_train_loss"], 0.00232327), "Train loss is not matching."
71+
assert np.allclose(results["avg_eval_prep"], 1.0193923), "Eval perplexity is not matching."
72+
assert np.allclose(results["avg_eval_loss"], 0.0192067), "Eval loss is not matching."
73+
assert results["avg_epoch_time"] < 30, "Training should complete within 30 seconds."
6974

7075
train_config_spy.assert_called_once()
7176
generate_dataset_config_spy.assert_called_once()
@@ -99,8 +104,11 @@ def test_finetune(
99104

100105
args, kwargs = update_config_spy.call_args
101106
train_config = args[0]
107+
assert max_train_step >= train_config.gradient_accumulation_steps, (
108+
"Total training step should be more than 4 which is gradient accumulation steps."
109+
)
102110

103-
saved_file = os.path.join(train_config.output_dir, "adapter_model.safetensors")
111+
saved_file = os.path.join(train_config.output_dir, "complete_epoch_1/adapter_model.safetensors")
104112
assert os.path.isfile(saved_file)
105113

106114
clean_up(train_config.output_dir)

0 commit comments

Comments
 (0)