Skip to content

Fix batchnorm affine false #866

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions core/conversion/converters/impl/batch_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,27 +50,35 @@ auto batch_norm_registrations TORCHTRT_UNUSED =
auto orig_shape = input->getDimensions();
auto shape = util::toVec(orig_shape);
auto tensor_type = util::TRTDataTypeToScalarType(input->getType());
auto options = torch::TensorOptions().dtype(tensor_type);
auto options =
torch::TensorOptions().dtype(tensor_type).device(torch::kCUDA, ctx->settings.device.gpu_id);

torch::Tensor gamma, beta, mean, var;
LOG_DEBUG("Input :" << orig_shape << "/" << input->getType());
// affine=True
LOG_DEBUG("Args[1] gamma : " << args[1].isIValue() << " / " << args[1].IValue()->isNone());
LOG_DEBUG("Args[2] beta : " << args[2].isIValue() << " / " << args[2].IValue()->isNone());
// track_running_stats=True
LOG_DEBUG("Args[3] mean : " << args[3].isIValue() << " / " << args[3].IValue()->isNone());
LOG_DEBUG("Args[4] var : " << args[4].isIValue() << " / " << args[4].IValue()->isNone());
LOG_DEBUG("use_input_stats, momemtum, cudnn_enabled disregarded");
LOG_DEBUG("ctx->input_is_dynamic : " << ctx->input_is_dynamic);

auto channel_dim = shape[1];
if (ctx->input_is_dynamic) {
gamma = args[1].unwrapToTensor();
beta = args[2].unwrapToTensor();
gamma = args[1].unwrapToTensor(at::full(channel_dim, 1, options));
beta = args[2].unwrapToTensor(at::full(channel_dim, 0, options));
mean = args[3].unwrapToTensor();
var = args[4].unwrapToTensor();
} else {
gamma = args[1].unwrapToTensor(at::full({shape}, 1, {options}));
beta = args[2].unwrapToTensor(at::full({shape}, 1, {options}));
mean = args[3].unwrapToTensor(at::full({shape}, 0, {options}));
var = args[4].unwrapToTensor(at::full({shape}, 0, {options}));
gamma = args[1].unwrapToTensor(at::full(channel_dim, 1, options));
beta = args[2].unwrapToTensor(at::full(channel_dim, 0, options));
mean = args[3].unwrapToTensor(at::full(channel_dim, 0, options));
var = args[4].unwrapToTensor(at::full(channel_dim, 0, options));
}

auto eps = static_cast<float>(args[7].unwrapToDouble(1e-5f));

LOG_DEBUG("momentum disregarded");
LOG_DEBUG("training disregarded");
LOG_DEBUG("cudnn disregarded");
TORCHTRT_CHECK(orig_shape.nbDims >= 2, "Unable to create batch normalization layer from node: " << *n);

// Expand spatial dims from 1D to 2D if needed
Expand Down
6 changes: 0 additions & 6 deletions py/torch_tensorrt/ts/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ def compile(module: torch.jit.ScriptModule,
enabled_precisions=set(),
refit=False,
debug=False,
strict_types=False,
capability=_enums.EngineCapability.default,
num_min_timing_iters=2,
num_avg_timing_iters=1,
Expand Down Expand Up @@ -65,7 +64,6 @@ def compile(module: torch.jit.ScriptModule,
enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels
refit (bool): Enable refitting
debug (bool): Enable debuggable engine
strict_types (bool): Kernels should strictly run in a particular operating precision. Enabled precision should only have one type in the set
capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels
num_min_timing_iters (int): Number of minimization timing iterations used to select kernels
num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels
Expand Down Expand Up @@ -98,7 +96,6 @@ def compile(module: torch.jit.ScriptModule,
"enabled_precisions": enabled_precisions, # Enabling FP16 kernels
"refit": refit, # enable refit
"debug": debug, # enable debuggable engine
"strict_types": strict_types, # kernels should strictly run in operating precision
"capability": capability, # Restrict kernel selection to safe gpu kernels or safe dla kernels
"num_min_timing_iters": num_min_timing_iters, # Number of minimization timing iterations used to select kernels
"num_avg_timing_iters": num_avg_timing_iters, # Number of averaging timing iterations used to select kernels
Expand Down Expand Up @@ -127,7 +124,6 @@ def convert_method_to_trt_engine(module: torch.jit.ScriptModule,
enabled_precisions=set(),
refit=False,
debug=False,
strict_types=False,
capability=_enums.EngineCapability.default,
num_min_timing_iters=2,
num_avg_timing_iters=1,
Expand Down Expand Up @@ -169,7 +165,6 @@ def convert_method_to_trt_engine(module: torch.jit.ScriptModule,
enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels
refit (bool): Enable refitting
debug (bool): Enable debuggable engine
strict_types (bool): Kernels should strictly run in a particular operating precision. Enabled precision should only have one type in the set
capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels
num_min_timing_iters (int): Number of minimization timing iterations used to select kernels
num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels
Expand All @@ -193,7 +188,6 @@ def convert_method_to_trt_engine(module: torch.jit.ScriptModule,
"enabled_precisions": enabled_precisions, # Enabling FP16 kernels
"refit": refit, # enable refit
"debug": debug, # enable debuggable engine
"strict_types": strict_types, # kernels should strictly run in operating precision
"capability": capability, # Restrict kernel selection to safe gpu kernels or safe dla kernels
"num_min_timing_iters": num_min_timing_iters, # Number of minimization timing iterations used to select kernels
"num_avg_timing_iters": num_avg_timing_iters, # Number of averaging timing iterations used to select kernels
Expand Down
31 changes: 31 additions & 0 deletions tests/core/conversion/converters/test_batch_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,37 @@ TEST(Converters, ATenBatchNormConvertsCorrectly) {
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
}

TEST(Converters, ATenBatchNormAffineFalseConvertsCorrectly) {
// BatchNorm(ch, affine=False)
const auto graph = R"IR(
graph(%0 : Tensor,
%3: Float(5, strides=[1]),
%4: Float(5, strides=[1])):
%1 : None = prim::Constant()
%5 : bool = prim::Constant[value=0]()
%6 : float = prim::Constant[value=1.0000000000000001e-05]()
%7 : float = prim::Constant[value=0.10000000000000001]()
%8 : Tensor = aten::batch_norm(%0, %1, %1, %3, %4, %5, %6, %7, %5)
return (%8))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto in = at::randint(1, 10, {1, 5, 5, 5}, {at::kCUDA});

auto mean = at::randint(1, 10, {5}, {at::kCUDA});
auto var = at::randint(1, 10, {5}, {at::kCUDA});

auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {mean, var});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});

params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {mean, var});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});

ASSERT_TRUE(
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
}

TEST(Converters, ATenBatchNorm1DConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor,
Expand Down