Skip to content

Commit d93659d

Browse files
authored
Merge pull request #866 from zsef123/fix_batchnorm_affine_false
Fix batchnorm affine false
2 parents 9c5031a + 12942ac commit d93659d

File tree

3 files changed

+49
-16
lines changed

3 files changed

+49
-16
lines changed

core/conversion/converters/impl/batch_norm.cpp

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,27 +50,35 @@ auto batch_norm_registrations TORCHTRT_UNUSED =
5050
auto orig_shape = input->getDimensions();
5151
auto shape = util::toVec(orig_shape);
5252
auto tensor_type = util::TRTDataTypeToScalarType(input->getType());
53-
auto options = torch::TensorOptions().dtype(tensor_type);
53+
auto options =
54+
torch::TensorOptions().dtype(tensor_type).device(torch::kCUDA, ctx->settings.device.gpu_id);
5455

5556
torch::Tensor gamma, beta, mean, var;
57+
LOG_DEBUG("Input :" << orig_shape << "/" << input->getType());
58+
// affine=True
59+
LOG_DEBUG("Args[1] gamma : " << args[1].isIValue() << " / " << args[1].IValue()->isNone());
60+
LOG_DEBUG("Args[2] beta : " << args[2].isIValue() << " / " << args[2].IValue()->isNone());
61+
// track_running_stats=True
62+
LOG_DEBUG("Args[3] mean : " << args[3].isIValue() << " / " << args[3].IValue()->isNone());
63+
LOG_DEBUG("Args[4] var : " << args[4].isIValue() << " / " << args[4].IValue()->isNone());
64+
LOG_DEBUG("use_input_stats, momemtum, cudnn_enabled disregarded");
65+
LOG_DEBUG("ctx->input_is_dynamic : " << ctx->input_is_dynamic);
5666

67+
auto channel_dim = shape[1];
5768
if (ctx->input_is_dynamic) {
58-
gamma = args[1].unwrapToTensor();
59-
beta = args[2].unwrapToTensor();
69+
gamma = args[1].unwrapToTensor(at::full(channel_dim, 1, options));
70+
beta = args[2].unwrapToTensor(at::full(channel_dim, 0, options));
6071
mean = args[3].unwrapToTensor();
6172
var = args[4].unwrapToTensor();
6273
} else {
63-
gamma = args[1].unwrapToTensor(at::full({shape}, 1, {options}));
64-
beta = args[2].unwrapToTensor(at::full({shape}, 1, {options}));
65-
mean = args[3].unwrapToTensor(at::full({shape}, 0, {options}));
66-
var = args[4].unwrapToTensor(at::full({shape}, 0, {options}));
74+
gamma = args[1].unwrapToTensor(at::full(channel_dim, 1, options));
75+
beta = args[2].unwrapToTensor(at::full(channel_dim, 0, options));
76+
mean = args[3].unwrapToTensor(at::full(channel_dim, 0, options));
77+
var = args[4].unwrapToTensor(at::full(channel_dim, 0, options));
6778
}
6879

6980
auto eps = static_cast<float>(args[7].unwrapToDouble(1e-5f));
7081

71-
LOG_DEBUG("momentum disregarded");
72-
LOG_DEBUG("training disregarded");
73-
LOG_DEBUG("cudnn disregarded");
7482
TORCHTRT_CHECK(orig_shape.nbDims >= 2, "Unable to create batch normalization layer from node: " << *n);
7583

7684
// Expand spatial dims from 1D to 2D if needed

py/torch_tensorrt/ts/_compiler.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ def compile(module: torch.jit.ScriptModule,
1717
enabled_precisions=set(),
1818
refit=False,
1919
debug=False,
20-
strict_types=False,
2120
capability=_enums.EngineCapability.default,
2221
num_min_timing_iters=2,
2322
num_avg_timing_iters=1,
@@ -65,7 +64,6 @@ def compile(module: torch.jit.ScriptModule,
6564
enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels
6665
refit (bool): Enable refitting
6766
debug (bool): Enable debuggable engine
68-
strict_types (bool): Kernels should strictly run in a particular operating precision. Enabled precision should only have one type in the set
6967
capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels
7068
num_min_timing_iters (int): Number of minimization timing iterations used to select kernels
7169
num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels
@@ -98,7 +96,6 @@ def compile(module: torch.jit.ScriptModule,
9896
"enabled_precisions": enabled_precisions, # Enabling FP16 kernels
9997
"refit": refit, # enable refit
10098
"debug": debug, # enable debuggable engine
101-
"strict_types": strict_types, # kernels should strictly run in operating precision
10299
"capability": capability, # Restrict kernel selection to safe gpu kernels or safe dla kernels
103100
"num_min_timing_iters": num_min_timing_iters, # Number of minimization timing iterations used to select kernels
104101
"num_avg_timing_iters": num_avg_timing_iters, # Number of averaging timing iterations used to select kernels
@@ -127,7 +124,6 @@ def convert_method_to_trt_engine(module: torch.jit.ScriptModule,
127124
enabled_precisions=set(),
128125
refit=False,
129126
debug=False,
130-
strict_types=False,
131127
capability=_enums.EngineCapability.default,
132128
num_min_timing_iters=2,
133129
num_avg_timing_iters=1,
@@ -169,7 +165,6 @@ def convert_method_to_trt_engine(module: torch.jit.ScriptModule,
169165
enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels
170166
refit (bool): Enable refitting
171167
debug (bool): Enable debuggable engine
172-
strict_types (bool): Kernels should strictly run in a particular operating precision. Enabled precision should only have one type in the set
173168
capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels
174169
num_min_timing_iters (int): Number of minimization timing iterations used to select kernels
175170
num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels
@@ -193,7 +188,6 @@ def convert_method_to_trt_engine(module: torch.jit.ScriptModule,
193188
"enabled_precisions": enabled_precisions, # Enabling FP16 kernels
194189
"refit": refit, # enable refit
195190
"debug": debug, # enable debuggable engine
196-
"strict_types": strict_types, # kernels should strictly run in operating precision
197191
"capability": capability, # Restrict kernel selection to safe gpu kernels or safe dla kernels
198192
"num_min_timing_iters": num_min_timing_iters, # Number of minimization timing iterations used to select kernels
199193
"num_avg_timing_iters": num_avg_timing_iters, # Number of averaging timing iterations used to select kernels

tests/core/conversion/converters/test_batch_norm.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,37 @@ TEST(Converters, ATenBatchNormConvertsCorrectly) {
3636
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
3737
}
3838

39+
TEST(Converters, ATenBatchNormAffineFalseConvertsCorrectly) {
40+
// BatchNorm(ch, affine=False)
41+
const auto graph = R"IR(
42+
graph(%0 : Tensor,
43+
%3: Float(5, strides=[1]),
44+
%4: Float(5, strides=[1])):
45+
%1 : None = prim::Constant()
46+
%5 : bool = prim::Constant[value=0]()
47+
%6 : float = prim::Constant[value=1.0000000000000001e-05]()
48+
%7 : float = prim::Constant[value=0.10000000000000001]()
49+
%8 : Tensor = aten::batch_norm(%0, %1, %1, %3, %4, %5, %6, %7, %5)
50+
return (%8))IR";
51+
52+
auto g = std::make_shared<torch::jit::Graph>();
53+
torch::jit::parseIR(graph, g.get());
54+
55+
auto in = at::randint(1, 10, {1, 5, 5, 5}, {at::kCUDA});
56+
57+
auto mean = at::randint(1, 10, {5}, {at::kCUDA});
58+
auto var = at::randint(1, 10, {5}, {at::kCUDA});
59+
60+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {mean, var});
61+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
62+
63+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {mean, var});
64+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
65+
66+
ASSERT_TRUE(
67+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
68+
}
69+
3970
TEST(Converters, ATenBatchNorm1DConvertsCorrectly) {
4071
const auto graph = R"IR(
4172
graph(%0 : Tensor,

0 commit comments

Comments
 (0)