11import torch
22import timm
33import pytest
4+ import unittest
45
56import torch_tensorrt as torchtrt
67import torchvision .models as models
1213 cosine_similarity ,
1314)
1415
16+ assertions = unittest .TestCase ()
17+
1518
1619@pytest .mark .unit
1720def test_resnet18 (ir ):
@@ -32,9 +35,9 @@ def test_resnet18(ir):
3235
3336 trt_mod = torchtrt .compile (model , ** compile_spec )
3437 cos_sim = cosine_similarity (model (input ), trt_mod (input ))
35- assert (
38+ assertions . assertTrue (
3639 cos_sim > COSINE_THRESHOLD ,
37- f"Resnet18 TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
40+ msg = f"Resnet18 TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
3841 )
3942
4043 # Clean up model env
@@ -63,9 +66,9 @@ def test_mobilenet_v2(ir):
6366
6467 trt_mod = torchtrt .compile (model , ** compile_spec )
6568 cos_sim = cosine_similarity (model (input ), trt_mod (input ))
66- assert (
69+ assertions . assertTrue (
6770 cos_sim > COSINE_THRESHOLD ,
68- f"Mobilenet v2 TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
71+ msg = f"Mobilenet v2 TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
6972 )
7073
7174 # Clean up model env
@@ -94,9 +97,9 @@ def test_efficientnet_b0(ir):
9497
9598 trt_mod = torchtrt .compile (model , ** compile_spec )
9699 cos_sim = cosine_similarity (model (input ), trt_mod (input ))
97- assert (
100+ assertions . assertTrue (
98101 cos_sim > COSINE_THRESHOLD ,
99- f"EfficientNet-B0 TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
102+ msg = f"EfficientNet-B0 TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
100103 )
101104
102105 # Clean up model env
@@ -138,9 +141,9 @@ def test_bert_base_uncased(ir):
138141 for key in model_outputs .keys ():
139142 out , trt_out = model_outputs [key ], trt_model_outputs [key ]
140143 cos_sim = cosine_similarity (out , trt_out )
141- assert (
144+ assertions . assertTrue (
142145 cos_sim > COSINE_THRESHOLD ,
143- f"HF BERT base-uncased TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
146+ msg = f"HF BERT base-uncased TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
144147 )
145148
146149 # Clean up model env
@@ -169,9 +172,9 @@ def test_resnet18_half(ir):
169172
170173 trt_mod = torchtrt .compile (model , ** compile_spec )
171174 cos_sim = cosine_similarity (model (input ), trt_mod (input ))
172- assert (
175+ assertions . assertTrue (
173176 cos_sim > COSINE_THRESHOLD ,
174- f"Resnet18 Half TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
177+ msg = f"Resnet18 Half TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
175178 )
176179
177180 # Clean up model env
0 commit comments