diff --git a/core/conversion/converters/impl/batch_norm.cpp b/core/conversion/converters/impl/batch_norm.cpp index 62c6f929e5..a17e8548e2 100644 --- a/core/conversion/converters/impl/batch_norm.cpp +++ b/core/conversion/converters/impl/batch_norm.cpp @@ -71,7 +71,7 @@ auto batch_norm_registrations TORCHTRT_UNUSED = LOG_DEBUG("momentum disregarded"); LOG_DEBUG("training disregarded"); LOG_DEBUG("cudnn disregarded"); - TORCHTRT_CHECK(orig_shape.nbDims > 2, "Unable to create batch normalization layer from node: " << *n); + TORCHTRT_CHECK(orig_shape.nbDims >= 2, "Unable to create batch normalization layer from node: " << *n); // Expand spatial dims from 1D to 2D if needed bool expandDims = (orig_shape.nbDims < 4); diff --git a/tests/core/conversion/converters/test_batch_norm.cpp b/tests/core/conversion/converters/test_batch_norm.cpp index b0e23e9e54..aa7782552f 100644 --- a/tests/core/conversion/converters/test_batch_norm.cpp +++ b/tests/core/conversion/converters/test_batch_norm.cpp @@ -36,6 +36,39 @@ TEST(Converters, ATenBatchNormConvertsCorrectly) { torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); } +TEST(Converters, ATenBatchNorm1DConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, + %1: Float(5, strides=[1]), + %2: Float(5, strides=[1]), + %3: Float(5, strides=[1]), + %4: Float(5, strides=[1])): + %5 : bool = prim::Constant[value=0]() + %6 : float = prim::Constant[value=1.0000000000000001e-05]() + %7 : float = prim::Constant[value=0.10000000000000001]() + %8 : Tensor = aten::batch_norm(%0, %1, %2, %3, %4, %5, %6, %7, %5) + return (%8))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + // Test 2D tensor, which is valid shape for BatchNorm1D ops. + auto in = at::randint(1, 10, {1, 5}, {at::kCUDA}); + auto gamma = at::randint(1, 10, {5}, {at::kCUDA}); + auto beta = at::randint(1, 10, {5}, {at::kCUDA}); + auto mean = at::randint(1, 10, {5}, {at::kCUDA}); + auto var = at::randint(1, 10, {5}, {at::kCUDA}); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {gamma, beta, mean, var}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {gamma, beta, mean, var}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); +} + TEST(Converters, ATenBatchNormShouldUnpackConvertsCorrectly) { const auto graph = R"IR( graph(%0 : Tensor,