Skip to content

Commit 0c7493a

Browse files
authored
Merge pull request #795 from chaoz-dev/chaoz/batch-norm1d
Update batchnorm to support 2D tensors (ie. BatchNorm1D) and add unit test case.
2 parents 7191959 + acddb41 commit 0c7493a

File tree

2 files changed

+34
-1
lines changed

2 files changed

+34
-1
lines changed

core/conversion/converters/impl/batch_norm.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ auto batch_norm_registrations TORCHTRT_UNUSED =
7171
LOG_DEBUG("momentum disregarded");
7272
LOG_DEBUG("training disregarded");
7373
LOG_DEBUG("cudnn disregarded");
74-
TORCHTRT_CHECK(orig_shape.nbDims > 2, "Unable to create batch normalization layer from node: " << *n);
74+
TORCHTRT_CHECK(orig_shape.nbDims >= 2, "Unable to create batch normalization layer from node: " << *n);
7575

7676
// Expand spatial dims from 1D to 2D if needed
7777
bool expandDims = (orig_shape.nbDims < 4);

tests/core/conversion/converters/test_batch_norm.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,39 @@ TEST(Converters, ATenBatchNormConvertsCorrectly) {
3636
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
3737
}
3838

39+
TEST(Converters, ATenBatchNorm1DConvertsCorrectly) {
40+
const auto graph = R"IR(
41+
graph(%0 : Tensor,
42+
%1: Float(5, strides=[1]),
43+
%2: Float(5, strides=[1]),
44+
%3: Float(5, strides=[1]),
45+
%4: Float(5, strides=[1])):
46+
%5 : bool = prim::Constant[value=0]()
47+
%6 : float = prim::Constant[value=1.0000000000000001e-05]()
48+
%7 : float = prim::Constant[value=0.10000000000000001]()
49+
%8 : Tensor = aten::batch_norm(%0, %1, %2, %3, %4, %5, %6, %7, %5)
50+
return (%8))IR";
51+
52+
auto g = std::make_shared<torch::jit::Graph>();
53+
torch::jit::parseIR(graph, g.get());
54+
55+
// Test 2D tensor, which is valid shape for BatchNorm1D ops.
56+
auto in = at::randint(1, 10, {1, 5}, {at::kCUDA});
57+
auto gamma = at::randint(1, 10, {5}, {at::kCUDA});
58+
auto beta = at::randint(1, 10, {5}, {at::kCUDA});
59+
auto mean = at::randint(1, 10, {5}, {at::kCUDA});
60+
auto var = at::randint(1, 10, {5}, {at::kCUDA});
61+
62+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {gamma, beta, mean, var});
63+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
64+
65+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {gamma, beta, mean, var});
66+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
67+
68+
ASSERT_TRUE(
69+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
70+
}
71+
3972
TEST(Converters, ATenBatchNormShouldUnpackConvertsCorrectly) {
4073
const auto graph = R"IR(
4174
graph(%0 : Tensor,

0 commit comments

Comments
 (0)