Skip to content

Fix batchnorm affine false #866

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 5, 2022

Conversation

zsef123
Copy link
Contributor

@zsef123 zsef123 commented Feb 12, 2022

Description

  • Missing removes about strict_types and max_batch_size
  • Fix batchnorm when affine=False

Fixes

  1. 🐛 [Bug] RuntimeError in affine=False with BatchNorm2d #860
  2. Build error from refactor: removing the strict_types and max_batch_size apis #782

Type of change

  • Bug fix (non-breaking change which fixes an issue)
    Before, Builderror causing missing variables, strict_types and max_batch_size on python parts
    So I removed that vars

  • New feature (non-breaking change which adds functionality)
    In nn.BatchNorm2d(C, affine=False), gamma and beta set to None.
    And if affine=True gamma and beta are C shape tensors.

>>> bn = nn.BatchNorm2d(4) 
>>> bn.weight.shape
torch.Size([4])

But in converter, gamma = args[1].unwrapToTensor(at::full({shape}, 1, {options})); follows input tensor shapes, not channel dim. And that wrong shapes occurs error.

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes

@zsef123 zsef123 force-pushed the fix_batchnorm_affine_false branch from 88441a2 to 621a04a Compare February 12, 2022 17:05
@narendasan
Copy link
Collaborator

Hi, Thanks for filing this! Would it be possible for you to add a test case verifying your fix? You can add it to the converters tests. If you need help we are happy to assist

@zsef123 zsef123 force-pushed the fix_batchnorm_affine_false branch from f0e2180 to 26ed9f1 Compare February 13, 2022 06:41
@zsef123
Copy link
Contributor Author

zsef123 commented Feb 13, 2022

@narendasan

>>> x = torch.jit.trace(bn, torch.randn((2, 2, 8, 8)))
>>> x
BatchNorm2d(original_name=BatchNorm2d)
>>> x.graph
graph(%self : __torch__.torch.nn.modules.batchnorm.___torch_mangle_2.BatchNorm2d,
      %input : Float(2, 2, 8, 8, strides=[128, 64, 8, 1], requires_grad=0, device=cpu)):
  %running_var : Tensor = prim::GetAttr[name="running_var"](%self)
  %running_mean : Tensor = prim::GetAttr[name="running_mean"](%self)
  %7 : NoneType = prim::Constant()
  %8 : NoneType = prim::Constant()
  %9 : bool = prim::Constant[value=0]() # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:2282:0
  %10 : float = prim::Constant[value=0.10000000000000001]() # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:2282:0
  %11 : float = prim::Constant[value=1.0000000000000001e-05]() # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:2282:0
  %12 : bool = prim::Constant[value=1]() # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:2282:0
  %13 : Float(2, 2, 8, 8, strides=[128, 64, 8, 1], requires_grad=0, device=cpu) = aten::batch_norm(%input, %7, %8, %running_mean, %running_var, %9, %10, %11, %12) # /opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:2282:0
  return (%13)

In this case, How can I pass Nonetype into torch_tensorrt::core::ir::get_static_params ?

@narendasan
Copy link
Collaborator

You should be able to just set it as a inline constant in the graph like this

  %7 : NoneType = prim::Constant()
  %8 : NoneType = prim::Constant()

instead of passing it to get_static_params

@narendasan
Copy link
Collaborator

So likely your test graph will look like:

graph(%input : Tensor, %running_var: Tensor, %running_mean: Tensor):
  %7 : NoneType = prim::Constant()
  %8 : NoneType = prim::Constant()
  %9 : bool = prim::Constant[value=0]()
  %10 : float = prim::Constant[value=0.10000000000000001]()
  %11 : float = prim::Constant[value=1.0000000000000001e-05]() # 
  %12 : bool = prim::Constant[value=1]() 
  %13 : Tensor = aten::batch_norm(%input, %7, %8, %running_mean, %running_var, %9, %10, %11, %12)
  return (%13)

Where you provide tensors for input, running_var and running_mean like we do in other tests.

@zsef123
Copy link
Contributor Author

zsef123 commented Feb 15, 2022

@narendasan Test Done!
And now conflicts, this branch seems better

Copy link
Collaborator

@narendasan narendasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, changes wise this looks good. Can you rebase, then we will test and merge.

// BatchNorm(ch, affine=False)
const auto graph = R"IR(
graph(%0 : Tensor,
%1: NoneType = prim::Constant(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just fyi, the nones in the graph inputs can be consolidated as a single %1: NoneType = prim::Constant() in graph so that you don't need to pass these as arguments but this should be fine

@zsef123 zsef123 force-pushed the fix_batchnorm_affine_false branch from c348842 to 56a2043 Compare February 16, 2022 16:21
@zsef123
Copy link
Contributor Author

zsef123 commented Feb 16, 2022

@narendasan Rebased 👍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants