@@ -199,5 +199,62 @@ def forward(self, x):
199199 )
200200
201201
202+ class TestTRTModuleFloat64Input (TestCase ):
203+ def test_save_and_load_trt_module (self ):
204+ class TestModule (torch .nn .Module ):
205+ def forward (self , x ):
206+ return x + x
207+
208+ inputs = [torch .randn (5 , 5 ).double ()]
209+ mod = TestModule ().eval ()
210+ ref_output = mod (* inputs )
211+
212+ mod = acc_tracer .trace (mod , inputs )
213+ interp = TRTInterpreter (
214+ mod ,
215+ input_specs = InputTensorSpec .from_tensors (inputs ),
216+ )
217+ trt_mod = TRTModule (* interp .run (lower_precision = LowerPrecision .FP32 ))
218+ torch .save (trt_mod , "trt.pt" )
219+ reload_trt_mod = torch .load ("trt.pt" )
220+
221+ torch .testing .assert_close (
222+ reload_trt_mod (inputs [0 ].cuda ()).cpu (),
223+ ref_output ,
224+ rtol = 1e-04 ,
225+ atol = 1e-04 ,
226+ check_dtype = False ,
227+ )
228+ os .remove (f"{ os .getcwd ()} /trt.pt" )
229+
230+ def test_save_and_load_state_dict (self ):
231+ class TestModule (torch .nn .Module ):
232+ def forward (self , x ):
233+ return x + x
234+
235+ inputs = [torch .randn (5 , 5 ).double ()]
236+ mod = TestModule ().eval ()
237+ ref_output = mod (* inputs )
238+
239+ mod = acc_tracer .trace (mod , inputs )
240+ interp = TRTInterpreter (
241+ mod ,
242+ input_specs = InputTensorSpec .from_tensors (inputs ),
243+ )
244+ trt_mod = TRTModule (* interp .run (lower_precision = LowerPrecision .FP32 ))
245+ st = trt_mod .state_dict ()
246+
247+ new_trt_mod = TRTModule ()
248+ new_trt_mod .load_state_dict (st )
249+
250+ torch .testing .assert_close (
251+ new_trt_mod (inputs [0 ].cuda ()).cpu (),
252+ ref_output ,
253+ rtol = 1e-04 ,
254+ atol = 1e-04 ,
255+ check_dtype = False ,
256+ )
257+
258+
202259if __name__ == "__main__" :
203260 run_tests ()
0 commit comments