Closed
Description
Bug Description
The batch norm converter (core/conversion/converters/impl/batch_norm.cpp
) does not support input tensors of 2D.
This prevents conversion of the op torch.nn.BatchNorm1d
for which 2D input tensors are valid (inputs may be (N,C) or (N,C,L))
To Reproduce
Steps to reproduce the behavior:
import logging
import torch
import trtorch
logging.basicConfig(level=logging.INFO)
torch.manual_seed(0)
DEVICE = torch.device("cuda:0")
SHAPE = (1, 5)
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.bn = torch.nn.BatchNorm1d(5)
def forward(self, a):
return self.bn(a)
if __name__ == "__main__":
tensor = torch.randn(SHAPE, dtype=torch.float32, device=DEVICE)
model = Model().eval().to(DEVICE)
model = torch.jit.script(model)
model = trtorch.compile(model, {
"inputs": [trtorch.Input(SHAPE)],
"enabled_precisions": {torch.float},
})
out = model(tensor)
print(out)
Output:
Traceback (most recent call last):
File "test_torch2trt.py", line 31, in <module>
"enabled_precisions": {torch.float},
File "/home/chaoz/.local/lib/python3.6/site-packages/trtorch/_compiler.py", line 81, in compile
compiled_cpp_mod = trtorch._C.compile_graph(module._c, _parse_compile_spec(compile_spec))
RuntimeError: [Error thrown at core/conversion/converters/impl/batch_norm.cpp:74] Expected orig_shape.nbDims > 2 to be true but got false
Unable to create batch normalization layer from node: %16 : Tensor = aten::batch_norm(%a.1, %self.bn.running_var, %self.bn.running_mean, %self.bn.running_mean, %self.bn.running_var, %self.bn.training, %3, %2, %7) # /home/chaoz/.local/lib/python3.6/site-packages/torch/nn/functional.py:2149:11
Expected behavior
Expect the conversion to succeed for torch.nn.BatchNorm1d
.
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version (e.g. 1.0.0): 0.4.1
- PyTorch Version (e.g. 1.0): 1.8.1
- CPU Architecture: x86_64
- OS (e.g., Linux): Ubuntu 18.04 LTS
- How you installed PyTorch (
conda
,pip
,libtorch
, source): pip - Build command you used (if compiling from source):
python3 setup.py install --user
- Are you using local sources or building from archives: local sources
- Python version: Python 3.6
- CUDA version: 11.0
Additional context
Error is caused by this line. Setting the >
to >=
should resolve the issue; subsequent output appears to be as expected as well.