66import torch_tensorrt as torchtrt
77import torchvision .models as models
88from torch ._export .serde .serialize import deserialize , serialize
9- from torch_tensorrt .dynamo .export import create_trt_exp_program , transform
109from torch_tensorrt .dynamo .utils import COSINE_THRESHOLD , cosine_similarity
1110
1211assertions = unittest .TestCase ()
@@ -45,9 +44,8 @@ def forward(self, x):
4544
4645 exp_program = torchtrt .dynamo .trace (model , ** compile_spec )
4746 trt_gm = torchtrt .dynamo .compile (exp_program , ** compile_spec )
48- trt_gm = transform (trt_gm , [input ])
49- trt_exp_program = create_trt_exp_program (
50- trt_gm , exp_program .call_spec , trt_gm .state_dict ()
47+ trt_exp_program = torchtrt .dynamo .serialize (
48+ trt_gm , [input ], call_spec = exp_program .call_spec , ir = "exported_program"
5149 )
5250 serialized_prog = serialize (trt_exp_program )
5351 deserialized_prog = deserialize (* serialized_prog )
@@ -100,11 +98,9 @@ def forward(self, x):
10098
10199 exp_program = torchtrt .dynamo .trace (model , ** compile_spec )
102100 trt_gm = torchtrt .dynamo .compile (exp_program , ** compile_spec )
103- trt_gm = transform (trt_gm , [input ])
104- trt_exp_program = create_trt_exp_program (
105- trt_gm , exp_program .call_spec , trt_gm .state_dict ()
101+ trt_exp_program = torchtrt .dynamo .serialize (
102+ trt_gm , [input ], call_spec = exp_program .call_spec , ir = "exported_program"
106103 )
107-
108104 serialized_prog = serialize (trt_exp_program )
109105 deserialized_prog = deserialize (* serialized_prog )
110106 # Check Pyt and TRT exported program outputs
@@ -161,11 +157,9 @@ def forward(self, x):
161157
162158 exp_program = torchtrt .dynamo .trace (model , ** compile_spec )
163159 trt_gm = torchtrt .dynamo .compile (exp_program , ** compile_spec )
164- trt_gm = transform (trt_gm , [input ])
165- trt_exp_program = create_trt_exp_program (
166- trt_gm , exp_program .call_spec , trt_gm .state_dict ()
160+ trt_exp_program = torchtrt .dynamo .serialize (
161+ trt_gm , [input ], call_spec = exp_program .call_spec , ir = "exported_program"
167162 )
168-
169163 torch ._export .save (trt_exp_program , "/tmp/trt.ep" )
170164 deser_trt_exp_program = torch ._export .load ("/tmp/trt.ep" )
171165
@@ -224,11 +218,9 @@ def forward(self, x):
224218
225219 exp_program = torchtrt .dynamo .trace (model , ** compile_spec )
226220 trt_gm = torchtrt .dynamo .compile (exp_program , ** compile_spec )
227- trt_gm = transform (trt_gm , [input ])
228- trt_exp_program = create_trt_exp_program (
229- trt_gm , exp_program .call_spec , trt_gm .state_dict ()
221+ trt_exp_program = torchtrt .dynamo .serialize (
222+ trt_gm , [input ], call_spec = exp_program .call_spec , ir = "exported_program"
230223 )
231-
232224 torch ._export .save (trt_exp_program , "/tmp/trt.ep" )
233225 deser_trt_exp_program = torch ._export .load ("/tmp/trt.ep" )
234226
@@ -270,9 +262,8 @@ def test_resnet18_save_load(ir):
270262
271263 exp_program = torchtrt .dynamo .trace (model , ** compile_spec )
272264 trt_gm = torchtrt .dynamo .compile (exp_program , ** compile_spec )
273- trt_gm = transform (trt_gm , [input ])
274- trt_exp_program = create_trt_exp_program (
275- trt_gm , exp_program .call_spec , trt_gm .state_dict ()
265+ trt_exp_program = torchtrt .dynamo .serialize (
266+ trt_gm , [input ], call_spec = exp_program .call_spec , ir = "exported_program"
276267 )
277268 torch ._export .save (trt_exp_program , "/tmp/trt.ep" )
278269 deser_trt_exp_program = torch ._export .load ("/tmp/trt.ep" )
@@ -291,59 +282,3 @@ def test_resnet18_save_load(ir):
291282 cos_sim > COSINE_THRESHOLD ,
292283 msg = f"test_resnet18_save_load TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
293284 )
294-
295-
296- # Enable this test once this issue is resolved https://github.com/pytorch/TensorRT/issues/2341
297- # @pytest.mark.unit
298- # def test_hybrid_conv_fallback(ir):
299- # """
300- # This tests export save and load functionality on a hybrid
301- # model where a conv (a weighted layer) has been forced to fallback to Pytorch.
302- # """
303-
304- # class MyModule(torch.nn.Module):
305- # def __init__(self):
306- # super().__init__()
307- # self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True)
308- # self.relu = torch.nn.ReLU()
309-
310- # def forward(self, x):
311- # conv = self.conv(x)
312- # relu = self.relu(conv)
313- # mul = relu * 0.5
314- # return mul
315-
316- # model = MyModule().eval().cuda()
317- # input = torch.randn((1, 3, 224, 224)).to("cuda")
318-
319- # compile_spec = {
320- # "inputs": [
321- # torchtrt.Input(
322- # input.shape, dtype=torch.float, format=torch.contiguous_format
323- # )
324- # ],
325- # "ir": ir,
326- # "min_block_size": 1,
327- # "torch_executed_ops": "torch.ops.aten.convolution.default",
328- # }
329-
330- # trt_exp_program = torchtrt.compile(model, **compile_spec)
331- # torch._export.save(trt_exp_program, "/tmp/trt.ep")
332- # deser_trt_exp_program = torch._export.load("/tmp/trt.ep")
333-
334- # outputs_pyt = model(input)
335- # outputs_trt = trt_exp_program(input)
336- # for idx in range(len(outputs_pyt)):
337- # cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt[idx])
338- # assertions.assertTrue(
339- # cos_sim > COSINE_THRESHOLD,
340- # msg=f"test_base_full_compile_multiple_outputs TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
341- # )
342-
343- # outputs_trt_deser = deser_trt_exp_program(input)
344- # for idx in range(len(outputs_pyt)):
345- # cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
346- # assertions.assertTrue(
347- # cos_sim > COSINE_THRESHOLD,
348- # msg=f"test_base_full_compile_save_load TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
349- # )
0 commit comments