We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 123a486 commit 960b372Copy full SHA for 960b372
py/torch_tensorrt/dynamo/compile.py
@@ -85,8 +85,10 @@ def compile(
85
# Prepare torch_trt inputs
86
inputs = prepare_inputs(inputs)
87
device = to_torch_tensorrt_device(device)
88
-
89
- gm = exported_program.module()
+ if isinstance(exported_program, torch.fx.GraphModule):
+ gm = exported_program
90
+ else:
91
+ gm = exported_program.module()
92
logger.debug("Input graph: " + str(gm.graph))
93
94
# Apply lowering on the graph module
0 commit comments