Skip to content

Commit 4fc7f62

Browse files
committed
Fix batchnorm to support 1d conversion; add unit test case.
1 parent 55c3bab commit 4fc7f62

File tree

2 files changed

+33
-1
lines changed

2 files changed

+33
-1
lines changed

core/conversion/converters/impl/batch_norm.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ auto batch_norm_registrations TORCHTRT_UNUSED =
7171
LOG_DEBUG("momentum disregarded");
7272
LOG_DEBUG("training disregarded");
7373
LOG_DEBUG("cudnn disregarded");
74-
TORCHTRT_CHECK(orig_shape.nbDims > 2, "Unable to create batch normalization layer from node: " << *n);
74+
TORCHTRT_CHECK(orig_shape.nbDims >= 2, "Unable to create batch normalization layer from node: " << *n);
7575

7676
// Expand spatial dims from 1D to 2D if needed
7777
bool expandDims = (orig_shape.nbDims < 4);

tests/core/conversion/converters/test_batch_norm.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,38 @@ TEST(Converters, ATenBatchNormConvertsCorrectly) {
3636
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
3737
}
3838

39+
TEST(Converters, ATenBatchNorm1DConvertsCorrectly) {
40+
const auto graph = R"IR(
41+
graph(%0 : Tensor,
42+
%1: Float(5, strides=[1]),
43+
%2: Float(5, strides=[1]),
44+
%3: Float(5, strides=[1]),
45+
%4: Float(5, strides=[1])):
46+
%5 : bool = prim::Constant[value=0]()
47+
%6 : float = prim::Constant[value=1.0000000000000001e-05]()
48+
%7 : float = prim::Constant[value=0.10000000000000001]()
49+
%8 : Tensor = aten::batch_norm(%0, %1, %2, %3, %4, %5, %6, %7, %5)
50+
return (%8))IR";
51+
52+
auto g = std::make_shared<torch::jit::Graph>();
53+
torch::jit::parseIR(graph, g.get());
54+
55+
auto in = at::randint(1, 10, {1, 5}, {at::kCUDA});
56+
auto gamma = at::randint(1, 10, {5}, {at::kCUDA});
57+
auto beta = at::randint(1, 10, {5}, {at::kCUDA});
58+
auto mean = at::randint(1, 10, {5}, {at::kCUDA});
59+
auto var = at::randint(1, 10, {5}, {at::kCUDA});
60+
61+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {gamma, beta, mean, var});
62+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
63+
64+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {gamma, beta, mean, var});
65+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
66+
67+
ASSERT_TRUE(
68+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
69+
}
70+
3971
TEST(Converters, ATenBatchNormShouldUnpackConvertsCorrectly) {
4072
const auto graph = R"IR(
4173
graph(%0 : Tensor,

0 commit comments

Comments
 (0)