Skip to content

Commit c348842

Browse files
committed
Add test
Signed-off-by: root <[email protected]>
1 parent 26ed9f1 commit c348842

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed

tests/core/conversion/converters/test_batch_norm.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,39 @@ TEST(Converters, ATenBatchNormConvertsCorrectly) {
3636
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
3737
}
3838

39+
TEST(Converters, ATenBatchNormAffineFalseConvertsCorrectly) {
40+
// BatchNorm(ch, affine=False)
41+
const auto graph = R"IR(
42+
graph(%0 : Tensor,
43+
%1: NoneType = prim::Constant(),
44+
%2: NoneType = prim::Constant(),
45+
%3: Float(5, strides=[1]),
46+
%4: Float(5, strides=[1])):
47+
%5 : bool = prim::Constant[value=0]()
48+
%6 : float = prim::Constant[value=1.0000000000000001e-05]()
49+
%7 : float = prim::Constant[value=0.10000000000000001]()
50+
%8 : Tensor = aten::batch_norm(%0, %1, %2, %3, %4, %5, %6, %7, %5)
51+
return (%8))IR";
52+
53+
auto g = std::make_shared<torch::jit::Graph>();
54+
torch::jit::parseIR(graph, g.get());
55+
56+
auto in = at::randint(1, 10, {1, 5, 5, 5}, {at::kCUDA});
57+
58+
torch::jit::IValue gamma, beta; // NoneType
59+
auto mean = at::randint(1, 10, {5}, {at::kCUDA});
60+
auto var = at::randint(1, 10, {5}, {at::kCUDA});
61+
62+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {gamma, beta, mean, var});
63+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
64+
65+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {gamma, beta, mean, var});
66+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
67+
68+
ASSERT_TRUE(
69+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
70+
}
71+
3972
TEST(Converters, ATenBatchNorm1DConvertsCorrectly) {
4073
const auto graph = R"IR(
4174
graph(%0 : Tensor,

0 commit comments

Comments
 (0)