@@ -36,6 +36,39 @@ TEST(Converters, ATenBatchNormConvertsCorrectly) {
36
36
torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
37
37
}
38
38
39
+ TEST (Converters, ATenBatchNorm1DConvertsCorrectly) {
40
+ const auto graph = R"IR(
41
+ graph(%0 : Tensor,
42
+ %1: Float(5, strides=[1]),
43
+ %2: Float(5, strides=[1]),
44
+ %3: Float(5, strides=[1]),
45
+ %4: Float(5, strides=[1])):
46
+ %5 : bool = prim::Constant[value=0]()
47
+ %6 : float = prim::Constant[value=1.0000000000000001e-05]()
48
+ %7 : float = prim::Constant[value=0.10000000000000001]()
49
+ %8 : Tensor = aten::batch_norm(%0, %1, %2, %3, %4, %5, %6, %7, %5)
50
+ return (%8))IR" ;
51
+
52
+ auto g = std::make_shared<torch::jit::Graph>();
53
+ torch::jit::parseIR (graph, g.get ());
54
+
55
+ // Test 2D tensor, which is valid shape for BatchNorm1D ops.
56
+ auto in = at::randint (1 , 10 , {1 , 5 }, {at::kCUDA });
57
+ auto gamma = at::randint (1 , 10 , {5 }, {at::kCUDA });
58
+ auto beta = at::randint (1 , 10 , {5 }, {at::kCUDA });
59
+ auto mean = at::randint (1 , 10 , {5 }, {at::kCUDA });
60
+ auto var = at::randint (1 , 10 , {5 }, {at::kCUDA });
61
+
62
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {gamma, beta, mean, var});
63
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {in});
64
+
65
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {gamma, beta, mean, var});
66
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {in});
67
+
68
+ ASSERT_TRUE (
69
+ torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
70
+ }
71
+
39
72
TEST (Converters, ATenBatchNormShouldUnpackConvertsCorrectly) {
40
73
const auto graph = R"IR(
41
74
graph(%0 : Tensor,
0 commit comments