Skip to content

🐛 [Bug] inception_v3 pretrained compilation - Unsupported ATen data type Double #1096

Closed
@apivovarov

Description

@apivovarov

Bug Description

To Reproduce

Steps to reproduce the behavior:

import torch
import torchvision.models as models
import torch_tensorrt

# get inception_v3 pretrained model
# It also implicitly sets transform_input=True which causes Double type issue during the compilation
model = models.inception_v3(pretrained=True).eval()
x = torch.rand(1, 3, 299, 299)
y=model(x)
tmodel = torch.jit.trace(model, x)
trt_model = torch_tensorrt.compile(tmodel,
    inputs= [torch_tensorrt.Input((1, 3, 299, 299))]
)

Error:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/opt/conda/lib/python3.8/site-packages/torch_tensorrt/_compile.py", line 115, in compile
    return torch_tensorrt.ts.compile(ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch_tensorrt/ts/_compiler.py", line 116, in compile
    compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
RuntimeError: [Error thrown at core/util/trt_util.cpp:283] Expected type to be true but got false
Unsupported ATen data type Double

Environment

Build information about Torch-TensorRT can be found by turning on debug messages
I use the latest Nvidia PyTorch Docker Image nvcr.io/nvidia/pytorch:22.04-py3.

docker run -ti --gpus all \
--ipc=host --ulimit memlock=-1 --ulimit stack=67108864 \
nvcr.io/nvidia/pytorch:22.04-py3
>>> torch.__version__
'1.12.0a0+bd13bc6'
>>> torch_tensorrt.__version__
'1.1.0a0'

Additional context

To solve the issue I need to replace existing inception model method _transform_input() with a fixed one
Existing _transform_input() method

    def _transform_input(self, x: Tensor) -> Tensor:
        if self.transform_input:
            x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
            x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
            x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
            x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
        return x

Fixed _transform_input() where I explicitly wrap constants with float32 tensors. It works, but PyTorch developers usually do not write models code this way.

    def _transform_input(self, x: Tensor) -> Tensor:
        if self.transform_input:
            a0=torch.tensor(0.229 / 0.5, dtype=torch.float32)
            a1=torch.tensor(0.224 / 0.5, dtype=torch.float32)
            a2=torch.tensor(0.225 / 0.5, dtype=torch.float32)
            b0=torch.tensor((0.485 - 0.5) / 0.5, dtype=torch.float32)
            b1=torch.tensor((0.456 - 0.5) / 0.5, dtype=torch.float32)
            b2=torch.tensor((0.406 - 0.5) / 0.5, dtype=torch.float32)
            x_ch0 = torch.unsqueeze(x[:, 0], 1) * a0 + b0
            x_ch1 = torch.unsqueeze(x[:, 1], 1) * a1 + b1
            x_ch2 = torch.unsqueeze(x[:, 2], 1) * a2 + b2
            x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
        return x

Is it possible to ask torch.jit.trace to automatically wrap constants with float32 tensors?

tmodel.graph shows Double for the original model

  %1512 : Double(requires_grad=0, device=cpu) = prim::Constant[value={0.458}]()

for the fixed model (with fixed _transform_input() method) it shows Float - which works for TRT compiler

 %1502 : Float(requires_grad=0, device=cpu) = prim::Constant[value={0.458}]()

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions