Skip to content

🐛 [Bug] RuntimeError in affine=False with BatchNorm2d #860

@zsef123

Description

@zsef123

Bug Description

Traceback (most recent call last):
  File "tt.py", line 20, in <module>
    convert(
  File "tt.py", line 13, in convert
    trt = torch_tensorrt.compile(
  File "/opt/conda/lib/python3.8/site-packages/torch_tensorrt/_compile.py", line 97, in compile
    return torch_tensorrt.ts.compile(ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch_tensorrt/ts/_compiler.py", line 119, in compile
    compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
RuntimeError: The size of tensor a (8) must match the size of tensor b (4) at non-singleton dimension 3

To Reproduce

  • V100 32GB
  • Both NGC Pytorch 21.02 and NGC Pytorch 22.01
import torch
import torch.nn as nn
import torch_tensorrt

device = "cuda"

def convert(x, net, dtype=torch.float32):
    net = net.eval().to(device)
    net(x)
    scripted = torch.jit.script(net)
    traced = torch.jit.trace(net, [x])
    
    trt = torch_tensorrt.compile(
        traced, 
        inputs=[torch_tensorrt.Input(list(x.shape), dtype=dtype)],
        enabled_precisions={dtype}
    )
    trt(x)

convert(
    torch.randn((1, 4, 8, 8), device=device),
    nn.BatchNorm2d(4, momentum=0.5, affine=False).eval()
)

Expected behavior

if affine=True working well

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • V100 32GB
  • Driver Version: 470.103.01
  • Both NGC Pytorch 21.02 and NGC Pytorch 22.01

Additional context

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions