@@ -36,6 +36,39 @@ TEST(Converters, ATenBatchNormConvertsCorrectly) {
36
36
torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
37
37
}
38
38
39
+ TEST (Converters, ATenBatchNormAffineFalseConvertsCorrectly) {
40
+ // BatchNorm(ch, affine=False)
41
+ const auto graph = R"IR(
42
+ graph(%0 : Tensor,
43
+ %1: NoneType = prim::Constant(),
44
+ %2: NoneType = prim::Constant(),
45
+ %3: Float(5, strides=[1]),
46
+ %4: Float(5, strides=[1])):
47
+ %5 : bool = prim::Constant[value=0]()
48
+ %6 : float = prim::Constant[value=1.0000000000000001e-05]()
49
+ %7 : float = prim::Constant[value=0.10000000000000001]()
50
+ %8 : Tensor = aten::batch_norm(%0, %1, %2, %3, %4, %5, %6, %7, %5)
51
+ return (%8))IR" ;
52
+
53
+ auto g = std::make_shared<torch::jit::Graph>();
54
+ torch::jit::parseIR (graph, g.get ());
55
+
56
+ auto in = at::randint (1 , 10 , {1 , 5 , 5 , 5 }, {at::kCUDA });
57
+
58
+ torch::jit::IValue gamma, beta; // NoneType
59
+ auto mean = at::randint (1 , 10 , {5 }, {at::kCUDA });
60
+ auto var = at::randint (1 , 10 , {5 }, {at::kCUDA });
61
+
62
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {gamma, beta, mean, var});
63
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {in});
64
+
65
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {gamma, beta, mean, var});
66
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {in});
67
+
68
+ ASSERT_TRUE (
69
+ torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
70
+ }
71
+
39
72
TEST (Converters, ATenBatchNorm1DConvertsCorrectly) {
40
73
const auto graph = R"IR(
41
74
graph(%0 : Tensor,
0 commit comments