Skip to content

🐛 [Bug] BatchNorm converter does not support 2D tensors  #692

Closed
@chaoz-dev

Description

@chaoz-dev

Bug Description

The batch norm converter (core/conversion/converters/impl/batch_norm.cpp) does not support input tensors of 2D.
This prevents conversion of the op torch.nn.BatchNorm1d for which 2D input tensors are valid (inputs may be (N,C) or (N,C,L))

To Reproduce

Steps to reproduce the behavior:

import logging
import torch
import trtorch

logging.basicConfig(level=logging.INFO)
torch.manual_seed(0)

DEVICE = torch.device("cuda:0")
SHAPE = (1, 5)

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.bn = torch.nn.BatchNorm1d(5)

    def forward(self, a):
        return self.bn(a)


if __name__ == "__main__":
    tensor = torch.randn(SHAPE, dtype=torch.float32, device=DEVICE)

    model = Model().eval().to(DEVICE)
    model = torch.jit.script(model)
    model = trtorch.compile(model, {
        "inputs": [trtorch.Input(SHAPE)],
        "enabled_precisions": {torch.float},
    })
    out = model(tensor)
    print(out)

Output:

Traceback (most recent call last):
  File "test_torch2trt.py", line 31, in <module>
    "enabled_precisions": {torch.float},
  File "/home/chaoz/.local/lib/python3.6/site-packages/trtorch/_compiler.py", line 81, in compile
    compiled_cpp_mod = trtorch._C.compile_graph(module._c, _parse_compile_spec(compile_spec))
RuntimeError: [Error thrown at core/conversion/converters/impl/batch_norm.cpp:74] Expected orig_shape.nbDims > 2 to be true but got false
Unable to create batch normalization layer from node: %16 : Tensor = aten::batch_norm(%a.1, %self.bn.running_var, %self.bn.running_mean, %self.bn.running_mean, %self.bn.running_var, %self.bn.training, %3, %2, %7) # /home/chaoz/.local/lib/python3.6/site-packages/torch/nn/functional.py:2149:11

Expected behavior

Expect the conversion to succeed for torch.nn.BatchNorm1d.

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): 0.4.1
  • PyTorch Version (e.g. 1.0): 1.8.1
  • CPU Architecture: x86_64
  • OS (e.g., Linux): Ubuntu 18.04 LTS
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source): python3 setup.py install --user
  • Are you using local sources or building from archives: local sources
  • Python version: Python 3.6
  • CUDA version: 11.0

Additional context

Error is caused by this line. Setting the > to >= should resolve the issue; subsequent output appears to be as expected as well.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions