Closed
Description
Bug Description
Compiling the below graph causes the following error:
ValueError: stoi
It's not clear to me what the root cause is yet; I've seen this appear in several of our models internally, but this is the first time I've been able to isolate the error outside of our repo.
To Reproduce
import torch
import torch_tensorrt as torchtrt
import torch_tensorrt.logging as logging
logging.set_reportable_log_level(logging.Level.Warning)
torch.manual_seed(0)
DEVICE = torch.device("cuda:0")
SHAPE = (1, 2)
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, a):
y = torch.squeeze(a, -1)
return y - y
if __name__ == "__main__":
tensor = torch.randn(SHAPE, dtype=torch.float32, device=DEVICE)
model = Model().eval().to(DEVICE)
out = model(tensor)
print(f"Model: {out}")
model_trt = torchtrt.compile(
model,
inputs=[
torchtrt.Input(shape=SHAPE),
],
enabled_precisions={torch.float},
)
out_trt = model(tensor)
print(f"Model TRT: {out_trt}")
assert torch.max(torch.abs(out - out_trt)) < 1e-6
Outputs the following:
root@1ccd9cb0f739:/workspace# python /scripts/squeeze.py
Model: tensor([[0., 0.]], device='cuda:0')
WARNING: [Torch-TensorRT] - Cannot infer input type from calcuations in graph for input a.1. Assuming it is Float32. If not, specify input type explicity
Traceback (most recent call last):
File "/scripts/squeeze.py", line 31, in <module>
model_trt = torchtrt.compile(
File "/usr/local/lib/python3.8/dist-packages/torch_tensorrt/_compile.py", line 115, in compile
return torch_tensorrt.ts.compile(ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch_tensorrt/ts/_compiler.py", line 116, in compile
compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
ValueError: stoi
Expected behavior
Compilation should not error out, and should output the same values as in the non-TRT version.
Environment
Ubuntu 18.04 x86-64
Built from source:
- Master (
11bcb98d3cd680c3c34e6cc4c4efdc7512c144cc
) using TRT NGC22.02-py3
- v1.0 using TRT NGC
21.11-py3
Error appears in both.
Additional context
In this particular case, it looks like we can fix the error by adding a third dimension to the tensor so that it is squeezable, ie.
SHAPE = (1, 2, 1)