Skip to content

Commit b2ee39a

Browse files
committed
Fixed FT test case on qaic
Signed-off-by: Meet Patel <[email protected]>
1 parent 2549a61 commit b2ee39a

File tree

2 files changed

+15
-6
lines changed

2 files changed

+15
-6
lines changed

QEfficient/cloud/finetune.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ def main(
288288
scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
289289
if train_config.enable_ddp:
290290
model = nn.parallel.DistributedDataParallel(model, device_ids=[dist.get_rank()])
291-
train(
291+
results = train(
292292
model,
293293
train_dataloader,
294294
eval_dataloader,
@@ -303,6 +303,7 @@ def main(
303303
)
304304
if train_config.enable_ddp:
305305
dist.destroy_process_group()
306+
return results
306307

307308

308309
if __name__ == "__main__":

tests/finetune/test_finetune.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import os
99
import shutil
1010

11+
import numpy as np
1112
import pytest
1213
import torch.optim as optim
1314
from torch.utils.data import DataLoader
@@ -22,12 +23,10 @@ def clean_up(path):
2223
shutil.rmtree(path)
2324

2425

25-
configs = [pytest.param("meta-llama/Llama-3.2-1B", 1, 1, 1, None, True, True, "cpu", id="llama_config")]
26+
configs = [pytest.param("meta-llama/Llama-3.2-1B", 10, 20, 1, None, True, True, "qaic", id="llama_config")]
2627

2728

28-
# TODO:enable this once docker is available
2929
@pytest.mark.on_qaic
30-
@pytest.mark.skip(reason="eager docker not available in sdk")
3130
@pytest.mark.parametrize(
3231
"model_name,max_eval_step,max_train_step,intermediate_step_save,context_length,run_validation,use_peft,device",
3332
configs,
@@ -65,7 +64,13 @@ def test_finetune(
6564
"device": device,
6665
}
6766

68-
finetune(**kwargs)
67+
results = finetune(**kwargs)
68+
69+
assert np.allclose(results["avg_train_prep"], 1.002326), "Train perplexity is not matching."
70+
assert np.allclose(results["avg_train_loss"], 0.00232327), "Train loss is not matching."
71+
assert np.allclose(results["avg_eval_prep"], 1.0193923), "Eval perplexity is not matching."
72+
assert np.allclose(results["avg_eval_loss"], 0.0192067), "Eval loss is not matching."
73+
assert results["avg_epoch_time"] < 30, "Training should complete within 30 seconds."
6974

7075
train_config_spy.assert_called_once()
7176
generate_dataset_config_spy.assert_called_once()
@@ -99,8 +104,11 @@ def test_finetune(
99104

100105
args, kwargs = update_config_spy.call_args
101106
train_config = args[0]
107+
assert max_train_step >= train_config.gradient_accumulation_steps, (
108+
"Total training step should be more than 4 which is gradient accumulation steps."
109+
)
102110

103-
saved_file = os.path.join(train_config.output_dir, "adapter_model.safetensors")
111+
saved_file = os.path.join(train_config.output_dir, "complete_epoch_1/adapter_model.safetensors")
104112
assert os.path.isfile(saved_file)
105113

106114
clean_up(train_config.output_dir)

0 commit comments

Comments
 (0)