Skip to content

Commit 98eb03f

Browse files
committed
chore: fix docs for export
Signed-off-by: Dheeraj Peri <[email protected]> chore: updates Signed-off-by: Dheeraj Peri <[email protected]> chore: updates Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 88f6812 commit 98eb03f

File tree

3 files changed

+13
-16
lines changed

3 files changed

+13
-16
lines changed

docsrc/dynamo/dynamo_export.rst

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
.. _dynamo_export:
22

3-
Compiling ``ExportedPrograms`` with Torch-TensorRT
3+
Compiling Exported Programs with Torch-TensorRT
44
=============================================
55
.. currentmodule:: torch_tensorrt.dynamo
66

@@ -9,8 +9,6 @@ Compiling ``ExportedPrograms`` with Torch-TensorRT
99
:undoc-members:
1010
:show-inheritance:
1111

12-
Using the Torch-TensorRT Frontend for ``torch.export.ExportedPrograms``
13-
--------------------------------------------------------
1412
Pytorch 2.1 introduced ``torch.export`` APIs which
1513
can export graphs from Pytorch programs into ``ExportedProgram`` objects. Torch-TensorRT dynamo
1614
frontend compiles these ``ExportedProgram`` objects and optimizes them using TensorRT. Here's a simple

docsrc/user_guide/saving_models.rst

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,14 @@ The following code illustrates this approach.
2929
import torch_tensorrt
3030
3131
model = MyModel().eval().cuda()
32-
inputs = torch.randn((1, 3, 224, 224)).cuda()
32+
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
3333
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) # Output is a torch.fx.GraphModule
34-
trt_script_model = torch.jit.trace(trt_gm, inputs)
35-
torch.jit.save(trt_script_model, "trt_model.ts")
34+
trt_traced_model = torch.jit.trace(trt_gm, inputs)
35+
torch.jit.save(trt_traced_model, "trt_model.ts")
3636
3737
# Later, you can load it and run inference
3838
model = torch.jit.load("trt_model.ts").cuda()
39-
model(inputs)
39+
model(*inputs)
4040
4141
b) ExportedProgram
4242
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -50,16 +50,15 @@ b) ExportedProgram
5050
import torch_tensorrt
5151
5252
model = MyModel().eval().cuda()
53-
inputs = torch.randn((1, 3, 224, 224)).cuda()
53+
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
5454
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) # Output is a torch.fx.GraphModule
5555
# Transform and create an exported program
56-
trt_gm = torch_tensorrt.dynamo.export(trt_gm, inputs)
57-
trt_exp_program = create_exported_program(trt_gm, call_spec, trt_gm.state_dict())
58-
torch._export.save(trt_exp_program, "trt_model.ep")
56+
trt_exp_program = torch_tensorrt.dynamo.export(trt_gm, inputs)
57+
torch.export.save(trt_exp_program, "trt_model.ep")
5958
6059
# Later, you can load it and run inference
61-
model = torch._export.load("trt_model.ep")
62-
model(inputs)
60+
model = torch.export.load("trt_model.ep")
61+
model(*inputs)
6362
6463
`torch_tensorrt.dynamo.export` inlines the submodules within a GraphModule to their corresponding nodes and stiches all the nodes together.
6564
This is needed as `torch._export` serialization cannot handle serializing and deserializing of submodules (`call_module` nodes).
@@ -79,11 +78,11 @@ Torchscript IR
7978
import torch_tensorrt
8079
8180
model = MyModel().eval().cuda()
82-
inputs = torch.randn((1, 3, 224, 224)).cuda()
81+
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
8382
trt_ts = torch_tensorrt.compile(model, ir="ts", inputs) # Output is a ScriptModule object
8483
torch.jit.save(trt_ts, "trt_model.ts")
8584
8685
# Later, you can load it and run inference
8786
model = torch.jit.load("trt_model.ts").cuda()
88-
model(inputs)
87+
model(*inputs)
8988

py/torch_tensorrt/dynamo/_exporter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def export(
5656
return exp_program
5757
else:
5858
raise ValueError(
59-
"Invalid ir : {ir} provided for serialization. Options include torchscript | exported_program"
59+
f"Invalid ir : {ir} provided for serialization. Options include torchscript | exported_program"
6060
)
6161

6262

0 commit comments

Comments
 (0)