Closed
Description
Bug Description
Fallback for torch.nn.functional.one_hot
, whether automatic or forced, appears to fail with the following message:
WARNING: [Torch-TensorRT] - Input type for doing shape analysis could not be determined, defaulting to F32
Traceback (most recent call last):
File "/home/chaoz/av/experimental/chaoz/examples/test_trtorch.py", line 32, in <module>
model_trt = torchtrt.compile(
File "/home/chaoz/.anaconda3/envs/trt-8/lib/python3.9/site-packages/torch_tensorrt/_compile.py", line 97, in compile
return torch_tensorrt.ts.compile(ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs)
File "/home/chaoz/.anaconda3/envs/trt-8/lib/python3.9/site-packages/torch_tensorrt/ts/_compiler.py", line 119, in compile
compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
File "/home/chaoz/av/experimental/chaoz/examples/test_trtorch.py", line 21, in forward
def forward(self, a):
return torch.nn.functional.one_hot(a)
~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
RuntimeError: one_hot is only applicable to index tensor.
It appears that we attempt to pass floating point values to one_hot
during compilation, which will fail as one_hot
only takes integer types.
To Reproduce
Run the following:
import torch
import torch_tensorrt as torchtrt
import torch_tensorrt.logging as logging
logging.set_reportable_log_level(logging.Level.Warning)
torch.manual_seed(0)
DEVICE = torch.device("cuda:0")
SHAPE = (10,)
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, a):
return torch.nn.functional.one_hot(a)
if __name__ == "__main__":
tensor = torch.ones(SHAPE, dtype=torch.int32, device=DEVICE)
with torch.no_grad():
model = Model().eval().to(DEVICE)
model_trt = torchtrt.compile(
model,
inputs=[
torchtrt.Input(shape=SHAPE, dtype=torch.int32),
],
enabled_precisions={torch.float},
torch_executed_ops = ['aten::one_hot']
)
out_trt = model(tensor)
Expected behavior
Expect the above to compile without issues.
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version (e.g. 1.0.0): 1.0.0
- PyTorch Version (e.g. 1.0): 1.10
- CPU Architecture: x86-64
- OS (e.g., Linux): Linux
- How you installed PyTorch (
conda
,pip
,libtorch
, source): conda - Build command you used (if compiling from source):
python setup.py install
- Are you using local sources or building from archives: local
- Python version: 3.9
- CUDA version: 11.4
- GPU models and configuration: T4
- Any other relevant information: