3030if USE_HOST_DEPS :
3131 print ("Using dependencies from host python" )
3232
33+ # Set epochs to train VGG model for accuracy tests
34+ EPOCHS = 25
35+
3336SUPPORTED_PYTHON_VERSIONS = ["3.7" , "3.8" , "3.9" , "3.10" ]
3437
3538nox .options .sessions = [
@@ -63,31 +66,6 @@ def install_torch_trt(session):
6366 session .run ("python" , "setup.py" , "develop" )
6467
6568
66- def download_datasets (session ):
67- print (
68- "Downloading dataset to path" ,
69- os .path .join (TOP_DIR , "examples/int8/training/vgg16" ),
70- )
71- session .chdir (os .path .join (TOP_DIR , "examples/int8/training/vgg16" ))
72- session .run_always (
73- "wget" , "https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz" , external = True
74- )
75- session .run_always ("tar" , "-xvzf" , "cifar-10-binary.tar.gz" , external = True )
76- session .run_always (
77- "mkdir" ,
78- "-p" ,
79- os .path .join (TOP_DIR , "tests/accuracy/datasets/data" ),
80- external = True ,
81- )
82- session .run_always (
83- "cp" ,
84- "-rpf" ,
85- os .path .join (TOP_DIR , "examples/int8/training/vgg16/cifar-10-batches-bin" ),
86- os .path .join (TOP_DIR , "tests/accuracy/datasets/data/cidar-10-batches-bin" ),
87- external = True ,
88- )
89-
90-
9169def train_model (session ):
9270 session .chdir (os .path .join (TOP_DIR , "examples/int8/training/vgg16" ))
9371 session .install ("-r" , "requirements.txt" )
@@ -107,14 +85,14 @@ def train_model(session):
10785 "--ckpt-dir" ,
10886 "vgg16_ckpts" ,
10987 "--epochs" ,
110- "25" ,
88+ str ( EPOCHS ) ,
11189 env = {"PYTHONPATH" : PYT_PATH },
11290 )
11391
11492 session .run_always (
11593 "python" ,
11694 "export_ckpt.py" ,
117- "vgg16_ckpts/ckpt_epoch25 .pth" ,
95+ "vgg16_ckpts/ckpt_epoch" + str ( EPOCHS ) + " .pth" ,
11896 env = {"PYTHONPATH" : PYT_PATH },
11997 )
12098 else :
@@ -130,10 +108,12 @@ def train_model(session):
130108 "--ckpt-dir" ,
131109 "vgg16_ckpts" ,
132110 "--epochs" ,
133- "25" ,
111+ str ( EPOCHS ) ,
134112 )
135113
136- session .run_always ("python" , "export_ckpt.py" , "vgg16_ckpts/ckpt_epoch25.pth" )
114+ session .run_always (
115+ "python" , "export_ckpt.py" , "vgg16_ckpts/ckpt_epoch" + str (EPOCHS ) + ".pth"
116+ )
137117
138118
139119def finetune_model (session ):
@@ -156,17 +136,17 @@ def finetune_model(session):
156136 "--ckpt-dir" ,
157137 "vgg16_ckpts" ,
158138 "--start-from" ,
159- "25" ,
139+ str ( EPOCHS ) ,
160140 "--epochs" ,
161- "26" ,
141+ str ( EPOCHS + 1 ) ,
162142 env = {"PYTHONPATH" : PYT_PATH },
163143 )
164144
165145 # Export model
166146 session .run_always (
167147 "python" ,
168148 "export_qat.py" ,
169- "vgg16_ckpts/ckpt_epoch26 .pth" ,
149+ "vgg16_ckpts/ckpt_epoch" + str ( EPOCHS + 1 ) + " .pth" ,
170150 env = {"PYTHONPATH" : PYT_PATH },
171151 )
172152 else :
@@ -182,13 +162,17 @@ def finetune_model(session):
182162 "--ckpt-dir" ,
183163 "vgg16_ckpts" ,
184164 "--start-from" ,
185- "25" ,
165+ str ( EPOCHS ) ,
186166 "--epochs" ,
187- "26" ,
167+ str ( EPOCHS + 1 ) ,
188168 )
189169
190170 # Export model
191- session .run_always ("python" , "export_qat.py" , "vgg16_ckpts/ckpt_epoch26.pth" )
171+ session .run_always (
172+ "python" ,
173+ "export_qat.py" ,
174+ "vgg16_ckpts/ckpt_epoch" + str (EPOCHS + 1 ) + ".pth" ,
175+ )
192176
193177
194178def cleanup (session ):
@@ -219,6 +203,19 @@ def run_base_tests(session):
219203 session .run_always ("pytest" , test )
220204
221205
206+ def run_model_tests (session ):
207+ print ("Running model tests" )
208+ session .chdir (os .path .join (TOP_DIR , "tests/py" ))
209+ tests = [
210+ "models" ,
211+ ]
212+ for test in tests :
213+ if USE_HOST_DEPS :
214+ session .run_always ("pytest" , test , env = {"PYTHONPATH" : PYT_PATH })
215+ else :
216+ session .run_always ("pytest" , test )
217+
218+
222219def run_accuracy_tests (session ):
223220 print ("Running accuracy tests" )
224221 session .chdir (os .path .join (TOP_DIR , "tests/py" ))
@@ -282,7 +279,7 @@ def run_dla_tests(session):
282279 print ("Running DLA tests" )
283280 session .chdir (os .path .join (TOP_DIR , "tests/py" ))
284281 tests = [
285- "test_api_dla.py" ,
282+ "hw/ test_api_dla.py" ,
286283 ]
287284 for test in tests :
288285 if USE_HOST_DEPS :
@@ -322,21 +319,19 @@ def run_l0_dla_tests(session):
322319 cleanup (session )
323320
324321
325- def run_l1_accuracy_tests (session ):
322+ def run_l1_model_tests (session ):
326323 if not USE_HOST_DEPS :
327324 install_deps (session )
328325 install_torch_trt (session )
329- download_datasets (session )
330- train_model (session )
331- run_accuracy_tests (session )
326+ download_models (session )
327+ run_model_tests (session )
332328 cleanup (session )
333329
334330
335331def run_l1_int8_accuracy_tests (session ):
336332 if not USE_HOST_DEPS :
337333 install_deps (session )
338334 install_torch_trt (session )
339- download_datasets (session )
340335 train_model (session )
341336 finetune_model (session )
342337 run_int8_accuracy_tests (session )
@@ -347,9 +342,6 @@ def run_l2_trt_compatibility_tests(session):
347342 if not USE_HOST_DEPS :
348343 install_deps (session )
349344 install_torch_trt (session )
350- download_models (session )
351- download_datasets (session )
352- train_model (session )
353345 run_trt_compatibility_tests (session )
354346 cleanup (session )
355347
@@ -376,9 +368,9 @@ def l0_dla_tests(session):
376368
377369
378370@nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
379- def l1_accuracy_tests (session ):
380- """Checking accuracy performance on various usecases """
381- run_l1_accuracy_tests (session )
371+ def l1_model_tests (session ):
372+ """When a user needs to test the functionality of standard models compilation and results """
373+ run_l1_model_tests (session )
382374
383375
384376@nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
@@ -397,13 +389,3 @@ def l2_trt_compatibility_tests(session):
397389def l2_multi_gpu_tests (session ):
398390 """Makes sure that Torch-TensorRT can operate on multi-gpu systems"""
399391 run_l2_multi_gpu_tests (session )
400-
401-
402- @nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
403- def download_test_models (session ):
404- """Grab all the models needed for testing"""
405- try :
406- import torch
407- except ModuleNotFoundError :
408- install_deps (session )
409- download_models (session )
0 commit comments