Skip to content

Commit 027217b

Browse files
committed
refactor: Apply linting
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent d861f4a commit 027217b

File tree

2 files changed

+29
-16
lines changed

2 files changed

+29
-16
lines changed

core/conversion/converters/impl/batch_norm.cpp

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ void _batch_norm(
3636
// Un-pad bn output if needed
3737
auto out_tensor = addUnpadding(ctx, n, bn->getOutput(0), orig_shape.nbDims);
3838
ctx->AssociateValueAndTensor(n->outputs()[0], out_tensor);
39+
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
3940
}
4041

4142
auto batch_norm_registrations TRTORCH_UNUSED =
@@ -103,23 +104,32 @@ auto batch_norm_registrations TRTORCH_UNUSED =
103104
LOG_DEBUG("Args[4] running_var : " << args[4].isIValue() << " / " << args[4].IValue()->isNone());
104105
LOG_DEBUG("use_input_stats, momemtum, cudnn_enabled disregarded");
105106
LOG_DEBUG("ctx->input_is_dynamic : " << ctx->input_is_dynamic);
106-
107+
107108
// Expand spatial dims from 1D to 2D if needed
108109
bool expandDims = (orig_shape.nbDims < 4);
109110
if (expandDims) {
110111
input = addPadding(ctx, n, input, 4);
111112
}
112113

113114
auto eps = static_cast<float>(args[7].unwrapToDouble(1e-5f));
114-
115+
115116
auto scales = args[1].unwrapToTensor(at::ones(shape[1], options)).cpu().contiguous();
116117
auto bias = args[2].unwrapToTensor(at::zeros(shape[1], options)).cpu().contiguous();
117-
118+
118119
// track_running_stats=True
119120
if (!args[3].IValue()->isNone() || !args[4].IValue()->isNone()) {
120121
auto running_mean = args[3].unwrapToTensor();
121122
auto running_var = args[4].unwrapToTensor();
122-
_batch_norm(ctx, n, input, orig_shape, scales.to(running_mean.options()), bias.to(running_mean.options()), running_mean, running_var, eps);
123+
_batch_norm(
124+
ctx,
125+
n,
126+
input,
127+
orig_shape,
128+
scales.to(running_mean.options()),
129+
bias.to(running_mean.options()),
130+
running_mean,
131+
running_var,
132+
eps);
123133
return true;
124134
}
125135

@@ -132,7 +142,7 @@ auto batch_norm_registrations TRTORCH_UNUSED =
132142
Type Parameter Description
133143
float epsilon A small number to prevent being divided by zero during normalization.
134144
Weights * scale A pointer to weights which contains information about scale factors for
135-
normalization. The definition of Weights can be found in the NvInfer.h header.
145+
normalization. The definition of Weights can be found in the NvInfer.h header.
136146
Weights * bias A pointer to weights which contains information about the bias values for
137147
normalization. The definition of Weights can be found in the NvInfer.h header.
138148
int relu A value used to enable leaky relu activation
@@ -162,6 +172,7 @@ auto batch_norm_registrations TRTORCH_UNUSED =
162172

163173
new_layer->setName(util::node_info(n).c_str());
164174
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0));
175+
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
165176
return true;
166177
}});
167178
} // namespace

tests/core/conversion/converters/test_instance_norm.cpp

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
// const c10::optional<Tensor>& bias_opt /* optional */,
1111
// const c10::optional<Tensor>& running_mean_opt /* optional */,
1212
// const c10::optional<Tensor>& running_var_opt /* optional */,
13-
// bool use_input_stats, double momentum, double eps, bool cudnn_enabled)
13+
// bool use_input_stats, double momentum, double eps, bool cudnn_enabled)
1414
constexpr auto graph = R"IR(
1515
graph(%input.1 : Tensor,
1616
%weight.1 : Tensor?,
@@ -28,23 +28,23 @@ constexpr auto graph = R"IR(
2828
return (%4)
2929
)IR";
3030

31-
3231
TEST(Converters, ATenInstanceNormConvertsCorrectly) {
3332
auto g = std::make_shared<torch::jit::Graph>();
3433
torch::jit::parseIR(graph, g.get());
3534

3635
auto in = at::randint(1, 10, {1, 5, 5, 5}, {at::kCUDA});
3736
torch::jit::IValue weight, bias, mean, var; // NoneType
3837
// https://github.com/pytorch/pytorch/blob/79693bb86a3f601a5c0d3da52d99acec95bb48c1/torch/nn/modules/instancenorm.py#L59
39-
const bool use_input_stats = true;
40-
38+
const bool use_input_stats = true;
39+
4140
auto trt_in = at::clone(in);
4241
torch::jit::IValue trt_weight, trt_bias, trt_mean, trt_var;
4342

4443
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {weight, bias, mean, var, use_input_stats});
4544
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
4645

47-
params = trtorch::core::conversion::get_named_params(g->inputs(), {trt_weight, trt_bias, trt_mean, trt_var, use_input_stats});
46+
params = trtorch::core::conversion::get_named_params(
47+
g->inputs(), {trt_weight, trt_bias, trt_mean, trt_var, use_input_stats});
4848
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
4949

5050
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
@@ -58,8 +58,8 @@ TEST(Converters, ATenInstanceNormAffineConvertsCorrectly) {
5858

5959
auto weight = at::randn({in.size(1)}).to(at::kCUDA);
6060
auto bias = at::randn({in.size(1)}).to(at::kCUDA);
61-
62-
torch::jit::IValue mean, var; // NoneType
61+
62+
torch::jit::IValue mean, var; // NoneType
6363
const bool use_input_stats = true;
6464

6565
auto trt_in = at::clone(in);
@@ -70,7 +70,8 @@ TEST(Converters, ATenInstanceNormAffineConvertsCorrectly) {
7070
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {weight, bias, mean, var, use_input_stats});
7171
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
7272

73-
params = trtorch::core::conversion::get_named_params(g->inputs(), {trt_weight, trt_bias, trt_mean, trt_var, use_input_stats});
73+
params = trtorch::core::conversion::get_named_params(
74+
g->inputs(), {trt_weight, trt_bias, trt_mean, trt_var, use_input_stats});
7475
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
7576

7677
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
@@ -81,12 +82,12 @@ TEST(Converters, ATenInstanceNormRunningStatsConvertsCorrectly) {
8182
torch::jit::parseIR(graph, g.get());
8283

8384
auto in = at::randn({1, 5, 5, 5}, {at::kCUDA});
84-
85+
8586
torch::jit::IValue weight, bias;
8687
auto mean = at::zeros({in.size(1)}, {at::kCUDA});
8788
auto var = at::ones({in.size(1)}, {at::kCUDA});
8889
const bool use_input_stats = false;
89-
90+
9091
auto trt_in = at::clone(in);
9192
torch::jit::IValue trt_weight, trt_bias;
9293
auto trt_mean = at::clone(mean);
@@ -95,7 +96,8 @@ TEST(Converters, ATenInstanceNormRunningStatsConvertsCorrectly) {
9596
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {weight, bias, mean, var, use_input_stats});
9697
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
9798

98-
params = trtorch::core::conversion::get_named_params(g->inputs(), {trt_weight, trt_bias, trt_mean, trt_var, use_input_stats});
99+
params = trtorch::core::conversion::get_named_params(
100+
g->inputs(), {trt_weight, trt_bias, trt_mean, trt_var, use_input_stats});
99101
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
100102
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
101103
}

0 commit comments

Comments
 (0)