Skip to content

🐛 [Bug] Fallback for torch.nn.functional.one_hot fails #814

Closed
@chaoz-dev

Description

@chaoz-dev

Bug Description

Fallback for torch.nn.functional.one_hot, whether automatic or forced, appears to fail with the following message:

WARNING: [Torch-TensorRT] - Input type for doing shape analysis could not be determined, defaulting to F32
Traceback (most recent call last):
  File "/home/chaoz/av/experimental/chaoz/examples/test_trtorch.py", line 32, in <module>
    model_trt = torchtrt.compile(
  File "/home/chaoz/.anaconda3/envs/trt-8/lib/python3.9/site-packages/torch_tensorrt/_compile.py", line 97, in compile
    return torch_tensorrt.ts.compile(ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs)
  File "/home/chaoz/.anaconda3/envs/trt-8/lib/python3.9/site-packages/torch_tensorrt/ts/_compiler.py", line 119, in compile
    compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File "/home/chaoz/av/experimental/chaoz/examples/test_trtorch.py", line 21, in forward
    def forward(self, a):
        return torch.nn.functional.one_hot(a)
               ~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
RuntimeError: one_hot is only applicable to index tensor.

It appears that we attempt to pass floating point values to one_hot during compilation, which will fail as one_hot only takes integer types.

To Reproduce

Run the following:

import torch
import torch_tensorrt as torchtrt

import torch_tensorrt.logging as logging
logging.set_reportable_log_level(logging.Level.Warning)

torch.manual_seed(0)

DEVICE = torch.device("cuda:0")
SHAPE = (10,)

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, a):
        return torch.nn.functional.one_hot(a)


if __name__ == "__main__":
    tensor = torch.ones(SHAPE, dtype=torch.int32, device=DEVICE)

    with torch.no_grad():
        model = Model().eval().to(DEVICE)

        model_trt = torchtrt.compile(
            model,
            inputs=[
                torchtrt.Input(shape=SHAPE, dtype=torch.int32),
            ],
            enabled_precisions={torch.float},
            torch_executed_ops = ['aten::one_hot']
        )
        out_trt = model(tensor)

Expected behavior

Expect the above to compile without issues.

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): 1.0.0
  • PyTorch Version (e.g. 1.0): 1.10
  • CPU Architecture: x86-64
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, libtorch, source): conda
  • Build command you used (if compiling from source): python setup.py install
  • Are you using local sources or building from archives: local
  • Python version: 3.9
  • CUDA version: 11.4
  • GPU models and configuration: T4
  • Any other relevant information:

Additional context

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions