@@ -36,6 +36,38 @@ TEST(Converters, ATenBatchNormConvertsCorrectly) {
36
36
torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
37
37
}
38
38
39
+ TEST (Converters, ATenBatchNorm1DConvertsCorrectly) {
40
+ const auto graph = R"IR(
41
+ graph(%0 : Tensor,
42
+ %1: Float(5, strides=[1]),
43
+ %2: Float(5, strides=[1]),
44
+ %3: Float(5, strides=[1]),
45
+ %4: Float(5, strides=[1])):
46
+ %5 : bool = prim::Constant[value=0]()
47
+ %6 : float = prim::Constant[value=1.0000000000000001e-05]()
48
+ %7 : float = prim::Constant[value=0.10000000000000001]()
49
+ %8 : Tensor = aten::batch_norm(%0, %1, %2, %3, %4, %5, %6, %7, %5)
50
+ return (%8))IR" ;
51
+
52
+ auto g = std::make_shared<torch::jit::Graph>();
53
+ torch::jit::parseIR (graph, g.get ());
54
+
55
+ auto in = at::randint (1 , 10 , {1 , 5 }, {at::kCUDA });
56
+ auto gamma = at::randint (1 , 10 , {5 }, {at::kCUDA });
57
+ auto beta = at::randint (1 , 10 , {5 }, {at::kCUDA });
58
+ auto mean = at::randint (1 , 10 , {5 }, {at::kCUDA });
59
+ auto var = at::randint (1 , 10 , {5 }, {at::kCUDA });
60
+
61
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {gamma, beta, mean, var});
62
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {in});
63
+
64
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {gamma, beta, mean, var});
65
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {in});
66
+
67
+ ASSERT_TRUE (
68
+ torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
69
+ }
70
+
39
71
TEST (Converters, ATenBatchNormShouldUnpackConvertsCorrectly) {
40
72
const auto graph = R"IR(
41
73
graph(%0 : Tensor,
0 commit comments