10
10
// const c10::optional<Tensor>& bias_opt /* optional */,
11
11
// const c10::optional<Tensor>& running_mean_opt /* optional */,
12
12
// const c10::optional<Tensor>& running_var_opt /* optional */,
13
- // bool use_input_stats, double momentum, double eps, bool cudnn_enabled)
13
+ // bool use_input_stats, double momentum, double eps, bool cudnn_enabled)
14
14
constexpr auto graph = R"IR(
15
15
graph(%input.1 : Tensor,
16
16
%weight.1 : Tensor?,
@@ -28,23 +28,23 @@ constexpr auto graph = R"IR(
28
28
return (%4)
29
29
)IR" ;
30
30
31
-
32
31
TEST (Converters, ATenInstanceNormConvertsCorrectly) {
33
32
auto g = std::make_shared<torch::jit::Graph>();
34
33
torch::jit::parseIR (graph, g.get ());
35
34
36
35
auto in = at::randint (1 , 10 , {1 , 5 , 5 , 5 }, {at::kCUDA });
37
36
torch::jit::IValue weight, bias, mean, var; // NoneType
38
37
// https://github.com/pytorch/pytorch/blob/79693bb86a3f601a5c0d3da52d99acec95bb48c1/torch/nn/modules/instancenorm.py#L59
39
- const bool use_input_stats = true ;
40
-
38
+ const bool use_input_stats = true ;
39
+
41
40
auto trt_in = at::clone (in);
42
41
torch::jit::IValue trt_weight, trt_bias, trt_mean, trt_var;
43
42
44
43
auto params = trtorch::core::conversion::get_named_params (g->inputs (), {weight, bias, mean, var, use_input_stats});
45
44
auto jit_results = trtorch::tests::util::RunGraph (g, params, {in});
46
45
47
- params = trtorch::core::conversion::get_named_params (g->inputs (), {trt_weight, trt_bias, trt_mean, trt_var, use_input_stats});
46
+ params = trtorch::core::conversion::get_named_params (
47
+ g->inputs (), {trt_weight, trt_bias, trt_mean, trt_var, use_input_stats});
48
48
auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_in});
49
49
50
50
ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
@@ -58,8 +58,8 @@ TEST(Converters, ATenInstanceNormAffineConvertsCorrectly) {
58
58
59
59
auto weight = at::randn ({in.size (1 )}).to (at::kCUDA );
60
60
auto bias = at::randn ({in.size (1 )}).to (at::kCUDA );
61
-
62
- torch::jit::IValue mean, var; // NoneType
61
+
62
+ torch::jit::IValue mean, var; // NoneType
63
63
const bool use_input_stats = true ;
64
64
65
65
auto trt_in = at::clone (in);
@@ -70,7 +70,8 @@ TEST(Converters, ATenInstanceNormAffineConvertsCorrectly) {
70
70
auto params = trtorch::core::conversion::get_named_params (g->inputs (), {weight, bias, mean, var, use_input_stats});
71
71
auto jit_results = trtorch::tests::util::RunGraph (g, params, {in});
72
72
73
- params = trtorch::core::conversion::get_named_params (g->inputs (), {trt_weight, trt_bias, trt_mean, trt_var, use_input_stats});
73
+ params = trtorch::core::conversion::get_named_params (
74
+ g->inputs (), {trt_weight, trt_bias, trt_mean, trt_var, use_input_stats});
74
75
auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_in});
75
76
76
77
ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
@@ -81,12 +82,12 @@ TEST(Converters, ATenInstanceNormRunningStatsConvertsCorrectly) {
81
82
torch::jit::parseIR (graph, g.get ());
82
83
83
84
auto in = at::randn ({1 , 5 , 5 , 5 }, {at::kCUDA });
84
-
85
+
85
86
torch::jit::IValue weight, bias;
86
87
auto mean = at::zeros ({in.size (1 )}, {at::kCUDA });
87
88
auto var = at::ones ({in.size (1 )}, {at::kCUDA });
88
89
const bool use_input_stats = false ;
89
-
90
+
90
91
auto trt_in = at::clone (in);
91
92
torch::jit::IValue trt_weight, trt_bias;
92
93
auto trt_mean = at::clone (mean);
@@ -95,7 +96,8 @@ TEST(Converters, ATenInstanceNormRunningStatsConvertsCorrectly) {
95
96
auto params = trtorch::core::conversion::get_named_params (g->inputs (), {weight, bias, mean, var, use_input_stats});
96
97
auto jit_results = trtorch::tests::util::RunGraph (g, params, {in});
97
98
98
- params = trtorch::core::conversion::get_named_params (g->inputs (), {trt_weight, trt_bias, trt_mean, trt_var, use_input_stats});
99
+ params = trtorch::core::conversion::get_named_params (
100
+ g->inputs (), {trt_weight, trt_bias, trt_mean, trt_var, use_input_stats});
99
101
auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_in});
100
102
ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
101
103
}
0 commit comments