@@ -136,28 +136,31 @@ def test_resnet18_torch_exec_ops(ir):
136136 not importlib .util .find_spec ("torchvision" ),
137137 "torchvision is not installed" ,
138138)
139- def test_mobilenet_v2 (ir ):
140- model = models .mobilenet_v2 (pretrained = True ).eval ().to ("cuda" )
141- input = torch .randn ((1 , 3 , 224 , 224 )).to ("cuda" )
139+ @pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 , torch .float32 ])
140+ def test_mobilenet_v2 (ir , dtype ):
141+ model = models .mobilenet_v2 (pretrained = True ).eval ().to ("cuda" ).to (dtype )
142+ input = torch .randn ((1 , 3 , 224 , 224 )).to ("cuda" ).to (dtype )
142143
143144 compile_spec = {
144145 "inputs" : [
145- torchtrt .Input (
146- input .shape , dtype = torch .float , format = torch .contiguous_format
147- )
146+ torchtrt .Input (input .shape , dtype = dtype , format = torch .contiguous_format )
148147 ],
149148 "device" : torchtrt .Device ("cuda:0" ),
150- "enabled_precisions" : {torch .float },
151149 "ir" : ir ,
152150 "pass_through_build_failures" : True ,
153151 "optimization_level" : 1 ,
154152 "min_block_size" : 10 ,
155153 "cache_built_engines" : False ,
156154 "reuse_cached_engines" : False ,
155+ "use_explicit_typing" : True ,
157156 }
158157
159158 trt_mod = torchtrt .compile (model , ** compile_spec )
160- cos_sim = cosine_similarity (model (input ), trt_mod (input ))
159+ pyt_output = model (input )
160+ trt_output = trt_mod (input )
161+ assert pyt_output .dtype == trt_output .dtype
162+ assert pyt_output .dtype == dtype
163+ cos_sim = cosine_similarity (pyt_output , trt_output )
161164 assertions .assertTrue (
162165 cos_sim > COSINE_THRESHOLD ,
163166 msg = f"Mobilenet v2 TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
@@ -172,28 +175,36 @@ def test_mobilenet_v2(ir):
172175 not importlib .util .find_spec ("timm" ) or not importlib .util .find_spec ("torchvision" ),
173176 "timm or torchvision not installed" ,
174177)
175- def test_efficientnet_b0 (ir ):
176- model = timm .create_model ("efficientnet_b0" , pretrained = True ).eval ().to ("cuda" )
177- input = torch .randn ((1 , 3 , 224 , 224 )).to ("cuda" )
178+ @pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 , torch .float32 ])
179+ def test_efficientnet_b0 (ir , dtype ):
180+ model = (
181+ timm .create_model ("efficientnet_b0" , pretrained = True )
182+ .eval ()
183+ .to ("cuda" )
184+ .to (dtype )
185+ )
186+ input = torch .randn ((1 , 3 , 224 , 224 )).to ("cuda" ).to (dtype )
178187
179188 compile_spec = {
180189 "inputs" : [
181- torchtrt .Input (
182- input .shape , dtype = torch .float , format = torch .contiguous_format
183- )
190+ torchtrt .Input (input .shape , dtype = dtype , format = torch .contiguous_format )
184191 ],
185192 "device" : torchtrt .Device ("cuda:0" ),
186- "enabled_precisions" : {torch .float },
187193 "ir" : ir ,
188194 "pass_through_build_failures" : True ,
189195 "optimization_level" : 1 ,
190196 "min_block_size" : 10 ,
191197 "cache_built_engines" : False ,
192198 "reuse_cached_engines" : False ,
199+ "use_explicit_typing" : True ,
193200 }
194201
195202 trt_mod = torchtrt .compile (model , ** compile_spec )
196- cos_sim = cosine_similarity (model (input ), trt_mod (input ))
203+ pyt_output = model (input )
204+ trt_output = trt_mod (input )
205+ assert pyt_output .dtype == trt_output .dtype
206+ assert pyt_output .dtype == dtype
207+ cos_sim = cosine_similarity (pyt_output , trt_output )
197208 assertions .assertTrue (
198209 cos_sim > COSINE_THRESHOLD ,
199210 msg = f"EfficientNet-B0 TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
@@ -208,10 +219,11 @@ def test_efficientnet_b0(ir):
208219 not importlib .util .find_spec ("transformers" ),
209220 "transformers is required to run this test" ,
210221)
211- def test_bert_base_uncased (ir ):
222+ @pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 , torch .float32 ])
223+ def test_bert_base_uncased (ir , dtype ):
212224 from transformers import BertModel
213225
214- model = BertModel .from_pretrained ("bert-base-uncased" ).cuda ().eval ()
226+ model = BertModel .from_pretrained ("bert-base-uncased" ).cuda ().eval (). to ( dtype )
215227 input = torch .randint (0 , 2 , (1 , 14 ), dtype = torch .int32 ).to ("cuda" )
216228 input2 = torch .randint (0 , 2 , (1 , 14 ), dtype = torch .int32 ).to ("cuda" )
217229
@@ -229,21 +241,23 @@ def test_bert_base_uncased(ir):
229241 ),
230242 ],
231243 "device" : torchtrt .Device ("cuda:0" ),
232- "enabled_precisions" : {torch .float },
233244 "truncate_double" : True ,
234245 "ir" : ir ,
235246 "pass_through_build_failures" : True ,
236247 "optimization_level" : 1 ,
237248 "min_block_size" : 15 ,
238249 "cache_built_engines" : False ,
239250 "reuse_cached_engines" : False ,
251+ "use_explicit_typing" : True ,
240252 }
241253 trt_mod = torchtrt .compile (model , ** compile_spec )
242254
243255 model_outputs = model (input , input2 )
244256 trt_model_outputs = trt_mod (input , input2 )
245257 for key in model_outputs .keys ():
246258 out , trt_out = model_outputs [key ], trt_model_outputs [key ]
259+ assert out .dtype == trt_out .dtype
260+ assert out .dtype == dtype
247261 cos_sim = cosine_similarity (out , trt_out )
248262 assertions .assertTrue (
249263 cos_sim > COSINE_THRESHOLD ,
0 commit comments