Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,37 @@ def aten_ops_batch_norm(
)


@dynamo_tensorrt_converter(
torch.ops.aten._native_batch_norm_legit_no_training.default,
capability_validator=one_user_validator,
)
def aten_ops_batch_norm_legit_no_training(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.normalization.batch_norm(
ctx,
target,
SourceIR.ATEN,
name,
input=args[0],
weight=args[1],
bias=args[2],
running_mean=args[3],
running_var=args[4],
training=False,
momentum=args[5],
eps=args[6],
cudnn_enabled=False,
return_mean_rstd=(
target == torch.ops.aten._native_batch_norm_legit_no_training.default
),
)


@dynamo_tensorrt_converter(
torch.ops.aten.native_layer_norm.default, capability_validator=one_user_validator
)
Expand Down
1 change: 0 additions & 1 deletion py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@
aten.native_batch_norm_backward,
aten._native_batch_norm_legit,
aten._native_batch_norm_legit_functional,
aten._native_batch_norm_legit_no_training,
aten.native_dropout_backward,
aten.native_group_norm_backward,
aten.native_layer_norm_backward,
Expand Down
19 changes: 19 additions & 0 deletions tests/py/dynamo/conversion/test_batch_norm_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,25 @@ def forward(self, x):
inputs,
)

def test_batchnorm_legit_no_training(self):
class BatchNorm(torch.nn.Module):
def forward(self, x):
return torch.ops.aten._native_batch_norm_legit_no_training.default(
x,
torch.ones((FEATURE_NUM,)),
torch.zeros((FEATURE_NUM,)),
torch.zeros((FEATURE_NUM,)),
torch.ones((FEATURE_NUM,)),
0.1,
1e-05,
)[0]

inputs = [torch.randn(1, 3, 224, 224)]
self.run_test(
BatchNorm(),
inputs,
)

def test_batchnorm1d_with_dynamic_shape(self):
class BatchNorm(torch.nn.Module):
def forward(self, x):
Expand Down