Skip to content

Commit 840ca89

Browse files
committed
fix(//py): Fix some api import issues
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 9982855 commit 840ca89

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

py/torch_tensorrt/_Input.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def __str__(self) -> str:
125125
return "Input(shape={}, dtype={}, format={})".format(self.shape, str(self.dtype), str(self.format))
126126
elif self.shape_mode == Input._ShapeMode.DYNAMIC:
127127
return "Input(min_shape={}, opt_shape={}, max_shape={}, dtype={}, format={})".format(
128-
self.shape["min_shape"], self.shape["min_shape"], self.shape["min_shape"], str(self.dtype),
128+
self.shape["min_shape"], self.shape["opt_shape"], self.shape["max_shape"], str(self.dtype),
129129
str(self.format))
130130
else:
131131
raise RuntimeError("Unknown input shape mode")
@@ -145,14 +145,14 @@ def _to_internal(self) -> _C.Input:
145145
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
146146
+ str(type(self.shape["opt_shape"])) + " for opt_shape")
147147
else:
148-
internal_in.min = self.shape["op_shape"]
148+
internal_in.opt = self.shape["opt_shape"]
149149

150150
if not Input._supported_input_size_type(self.shape["max_shape"]):
151151
raise TypeError(
152152
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: "
153153
+ str(type(self.shape["max_shape"])) + " for max_shape")
154154
else:
155-
internal_in.min = self.shape["opt_shape"]
155+
internal_in.max = self.shape["max_shape"]
156156
internal_in.input_is_dynamic = True
157157
else:
158158
if not Input._supported_input_size_type(self.shape):

py/torch_tensorrt/_compile.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
import torch_tensorrt.ts
44
from torch_tensorrt import logging
55
import torch
6+
from torch import fx
67
from enum import Enum
78

9+
print(torch.fx.GraphModule)
810

911
class _IRType(Enum):
1012
"""Enum to set the minimum required logging level to print a message to stdout
@@ -89,6 +91,7 @@ def compile(module: Any, ir="default", inputs=[], enabled_precisions=set([_enums
8991
ts_mod = module
9092
if isinstance(module, torch.nn.Module):
9193
logging.log(
94+
logging.Level.Info,
9295
"Module was provided as a torch.nn.Module, trying to script the module with torch.jit.script. In the event of a failure please preconvert your module to TorchScript"
9396
)
9497
ts_mod = torch.jit.script(module)

0 commit comments

Comments
 (0)