@@ -29,14 +29,14 @@ The following code illustrates this approach.
2929 import torch_tensorrt
3030
3131 model = MyModel().eval().cuda()
32- inputs = torch.randn((1 , 3 , 224 , 224 )).cuda()
32+ inputs = [ torch.randn((1 , 3 , 224 , 224 )).cuda()]
3333 trt_gm = torch_tensorrt.compile(model, ir = " dynamo" , inputs) # Output is a torch.fx.GraphModule
34- trt_script_model = torch.jit.trace(trt_gm, inputs)
35- torch.jit.save(trt_script_model , " trt_model.ts" )
34+ trt_traced_model = torch.jit.trace(trt_gm, inputs)
35+ torch.jit.save(trt_traced_model , " trt_model.ts" )
3636
3737 # Later, you can load it and run inference
3838 model = torch.jit.load(" trt_model.ts" ).cuda()
39- model(inputs)
39+ model(* inputs)
4040
4141 b) ExportedProgram
4242^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -50,16 +50,15 @@ b) ExportedProgram
5050 import torch_tensorrt
5151
5252 model = MyModel().eval().cuda()
53- inputs = torch.randn((1 , 3 , 224 , 224 )).cuda()
53+ inputs = [ torch.randn((1 , 3 , 224 , 224 )).cuda()]
5454 trt_gm = torch_tensorrt.compile(model, ir = " dynamo" , inputs) # Output is a torch.fx.GraphModule
5555 # Transform and create an exported program
56- trt_gm = torch_tensorrt.dynamo.export(trt_gm, inputs)
57- trt_exp_program = create_exported_program(trt_gm, call_spec, trt_gm.state_dict())
58- torch._export.save(trt_exp_program, " trt_model.ep" )
56+ trt_exp_program = torch_tensorrt.dynamo.export(trt_gm, inputs)
57+ torch.export.save(trt_exp_program, " trt_model.ep" )
5958
6059 # Later, you can load it and run inference
61- model = torch._export .load(" trt_model.ep" )
62- model(inputs)
60+ model = torch.export .load(" trt_model.ep" )
61+ model(* inputs)
6362
6463 `torch_tensorrt.dynamo.export ` inlines the submodules within a GraphModule to their corresponding nodes and stiches all the nodes together.
6564This is needed as `torch._export ` serialization cannot handle serializing and deserializing of submodules (`call_module ` nodes).
@@ -79,11 +78,11 @@ Torchscript IR
7978 import torch_tensorrt
8079
8180 model = MyModel().eval().cuda()
82- inputs = torch.randn((1 , 3 , 224 , 224 )).cuda()
81+ inputs = [ torch.randn((1 , 3 , 224 , 224 )).cuda()]
8382 trt_ts = torch_tensorrt.compile(model, ir = " ts" , inputs) # Output is a ScriptModule object
8483 torch.jit.save(trt_ts, " trt_model.ts" )
8584
8685 # Later, you can load it and run inference
8786 model = torch.jit.load(" trt_model.ts" ).cuda()
88- model(inputs)
87+ model(* inputs)
8988
0 commit comments