@@ -29,14 +29,14 @@ The following code illustrates this approach.
2929 import torch_tensorrt
3030
3131 model = MyModel().eval().cuda()
32- inputs = torch.randn((1 , 3 , 224 , 224 )).cuda()
32+ inputs = [ torch.randn((1 , 3 , 224 , 224 )).cuda()]
3333 trt_gm = torch_tensorrt.compile(model, ir = " dynamo" , inputs) # Output is a torch.fx.GraphModule
3434 trt_script_model = torch.jit.trace(trt_gm, inputs)
3535 torch.jit.save(trt_script_model, " trt_model.ts" )
3636
3737 # Later, you can load it and run inference
3838 model = torch.jit.load(" trt_model.ts" ).cuda()
39- model(inputs)
39+ model(* inputs)
4040
4141 b) ExportedProgram
4242^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -50,15 +50,15 @@ b) ExportedProgram
5050 import torch_tensorrt
5151
5252 model = MyModel().eval().cuda()
53- inputs = torch.randn((1 , 3 , 224 , 224 )).cuda()
53+ inputs = [ torch.randn((1 , 3 , 224 , 224 )).cuda()]
5454 trt_gm = torch_tensorrt.compile(model, ir = " dynamo" , inputs) # Output is a torch.fx.GraphModule
5555 # Transform and create an exported program
5656 trt_exp_program = torch_tensorrt.dynamo.export(trt_gm, inputs)
5757 torch.export.save(trt_exp_program, " trt_model.ep" )
5858
5959 # Later, you can load it and run inference
6060 model = torch.export.load(" trt_model.ep" )
61- model(inputs)
61+ model(* inputs)
6262
6363 `torch_tensorrt.dynamo.export ` inlines the submodules within a GraphModule to their corresponding nodes and stiches all the nodes together.
6464This is needed as `torch._export ` serialization cannot handle serializing and deserializing of submodules (`call_module ` nodes).
@@ -78,11 +78,11 @@ Torchscript IR
7878 import torch_tensorrt
7979
8080 model = MyModel().eval().cuda()
81- inputs = torch.randn((1 , 3 , 224 , 224 )).cuda()
81+ inputs = [ torch.randn((1 , 3 , 224 , 224 )).cuda()]
8282 trt_ts = torch_tensorrt.compile(model, ir = " ts" , inputs) # Output is a ScriptModule object
8383 torch.jit.save(trt_ts, " trt_model.ts" )
8484
8585 # Later, you can load it and run inference
8686 model = torch.jit.load(" trt_model.ts" ).cuda()
87- model(inputs)
87+ model(* inputs)
8888
0 commit comments