@@ -72,6 +72,7 @@ def forward(self, x):
7272 interp = TRTInterpreter (
7373 mod ,
7474 input_specs = InputTensorSpec .from_tensors (inputs ),
75+ truncate_long_and_double = True ,
7576 )
7677 trt_mod = TRTModule (* interp .run (lower_precision = LowerPrecision .FP32 ))
7778 torch .save (trt_mod , "trt.pt" )
@@ -99,6 +100,66 @@ def forward(self, x):
99100 interp = TRTInterpreter (
100101 mod ,
101102 input_specs = InputTensorSpec .from_tensors (inputs ),
103+ truncate_long_and_double = True ,
104+ )
105+ trt_mod = TRTModule (* interp .run (lower_precision = LowerPrecision .FP32 ))
106+ st = trt_mod .state_dict ()
107+
108+ new_trt_mod = TRTModule ()
109+ new_trt_mod .load_state_dict (st )
110+
111+ torch .testing .assert_close (
112+ new_trt_mod (inputs [0 ].cuda ()).cpu (),
113+ ref_output ,
114+ rtol = 1e-04 ,
115+ atol = 1e-04 ,
116+ check_dtype = False ,
117+ )
118+
119+
120+ class TestTRTModuleFloat64Input (TestCase ):
121+ def test_save_and_load_trt_module (self ):
122+ class TestModule (torch .nn .Module ):
123+ def forward (self , x ):
124+ return x + x
125+
126+ inputs = [torch .randn (5 , 5 ).double ()]
127+ mod = TestModule ().eval ()
128+ ref_output = mod (* inputs )
129+
130+ mod = acc_tracer .trace (mod , inputs )
131+ interp = TRTInterpreter (
132+ mod ,
133+ input_specs = InputTensorSpec .from_tensors (inputs ),
134+ truncate_long_and_double = True ,
135+ )
136+ trt_mod = TRTModule (* interp .run (lower_precision = LowerPrecision .FP32 ))
137+ torch .save (trt_mod , "trt.pt" )
138+ reload_trt_mod = torch .load ("trt.pt" )
139+
140+ torch .testing .assert_close (
141+ reload_trt_mod (inputs [0 ].cuda ()).cpu (),
142+ ref_output ,
143+ rtol = 1e-04 ,
144+ atol = 1e-04 ,
145+ check_dtype = False ,
146+ )
147+ os .remove (f"{ os .getcwd ()} /trt.pt" )
148+
149+ def test_save_and_load_state_dict (self ):
150+ class TestModule (torch .nn .Module ):
151+ def forward (self , x ):
152+ return x + x
153+
154+ inputs = [torch .randn (5 , 5 ).double ()]
155+ mod = TestModule ().eval ()
156+ ref_output = mod (* inputs )
157+
158+ mod = acc_tracer .trace (mod , inputs )
159+ interp = TRTInterpreter (
160+ mod ,
161+ input_specs = InputTensorSpec .from_tensors (inputs ),
162+ truncate_long_and_double = True ,
102163 )
103164 trt_mod = TRTModule (* interp .run (lower_precision = LowerPrecision .FP32 ))
104165 st = trt_mod .state_dict ()
0 commit comments