88import os
99import shutil
1010
11+ import numpy as np
1112import pytest
1213import torch .optim as optim
1314from torch .utils .data import DataLoader
@@ -22,12 +23,10 @@ def clean_up(path):
2223 shutil .rmtree (path )
2324
2425
25- configs = [pytest .param ("meta-llama/Llama-3.2-1B" , 1 , 1 , 1 , None , True , True , "cpu " , id = "llama_config" )]
26+ configs = [pytest .param ("meta-llama/Llama-3.2-1B" , 10 , 20 , 1 , None , True , True , "qaic " , id = "llama_config" )]
2627
2728
28- # TODO:enable this once docker is available
2929@pytest .mark .on_qaic
30- @pytest .mark .skip (reason = "eager docker not available in sdk" )
3130@pytest .mark .parametrize (
3231 "model_name,max_eval_step,max_train_step,intermediate_step_save,context_length,run_validation,use_peft,device" ,
3332 configs ,
@@ -65,7 +64,13 @@ def test_finetune(
6564 "device" : device ,
6665 }
6766
68- finetune (** kwargs )
67+ results = finetune (** kwargs )
68+
69+ assert np .allclose (results ["avg_train_prep" ], 1.002326 ), "Train perplexity is not matching."
70+ assert np .allclose (results ["avg_train_loss" ], 0.00232327 ), "Train loss is not matching."
71+ assert np .allclose (results ["avg_eval_prep" ], 1.0193923 ), "Eval perplexity is not matching."
72+ assert np .allclose (results ["avg_eval_loss" ], 0.0192067 ), "Eval loss is not matching."
73+ assert results ["avg_epoch_time" ] < 30 , "Training should complete within 30 seconds."
6974
7075 train_config_spy .assert_called_once ()
7176 generate_dataset_config_spy .assert_called_once ()
@@ -99,8 +104,11 @@ def test_finetune(
99104
100105 args , kwargs = update_config_spy .call_args
101106 train_config = args [0 ]
107+ assert max_train_step >= train_config .gradient_accumulation_steps , (
108+ "Total training step should be more than 4 which is gradient accumulation steps."
109+ )
102110
103- saved_file = os .path .join (train_config .output_dir , "adapter_model.safetensors" )
111+ saved_file = os .path .join (train_config .output_dir , "complete_epoch_1/ adapter_model.safetensors" )
104112 assert os .path .isfile (saved_file )
105113
106114 clean_up (train_config .output_dir )
0 commit comments