Skip to content

Commit 621a04a

Browse files
committed
Change default tensor
Signed-off-by: root <[email protected]>
1 parent 7dda059 commit 621a04a

File tree

1 file changed

+18
-11
lines changed

1 file changed

+18
-11
lines changed

core/conversion/converters/impl/batch_norm.cpp

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -50,27 +50,34 @@ auto batch_norm_registrations TORCHTRT_UNUSED =
5050
auto orig_shape = input->getDimensions();
5151
auto shape = util::toVec(orig_shape);
5252
auto tensor_type = util::TRTDataTypeToScalarType(input->getType());
53-
auto options = torch::TensorOptions().dtype(tensor_type);
54-
53+
auto options = torch::TensorOptions().dtype(tensor_type).device(torch::kCUDA, ctx->settings.device.gpu_id);
54+
5555
torch::Tensor gamma, beta, mean, var;
56+
LOG_DEBUG("Input :" << orig_shape << "/" << input->getType());
57+
// affine=True
58+
LOG_DEBUG("Args[1] gamma : " << args[1].isIValue() << " / " << args[1].IValue()->isNone());
59+
LOG_DEBUG("Args[2] beta : " << args[2].isIValue() << " / " << args[2].IValue()->isNone());
60+
// track_running_stats=True
61+
LOG_DEBUG("Args[3] mean : " << args[3].isIValue() << " / " << args[3].IValue()->isNone());
62+
LOG_DEBUG("Args[4] var : " << args[4].isIValue() << " / " << args[4].IValue()->isNone());
63+
LOG_DEBUG("use_input_stats, momemtum, cudnn_enabled disregarded");
64+
LOG_DEBUG("ctx->input_is_dynamic : " << ctx->input_is_dynamic);
5665

66+
auto channel_dim = shape[1];
5767
if (ctx->input_is_dynamic) {
58-
gamma = args[1].unwrapToTensor();
59-
beta = args[2].unwrapToTensor();
68+
gamma = args[1].unwrapToTensor(at::full(channel_dim, 1, options));
69+
beta = args[2].unwrapToTensor(at::full(channel_dim, 0, options));
6070
mean = args[3].unwrapToTensor();
6171
var = args[4].unwrapToTensor();
6272
} else {
63-
gamma = args[1].unwrapToTensor(at::full({shape}, 1, {options}));
64-
beta = args[2].unwrapToTensor(at::full({shape}, 1, {options}));
65-
mean = args[3].unwrapToTensor(at::full({shape}, 0, {options}));
66-
var = args[4].unwrapToTensor(at::full({shape}, 0, {options}));
73+
gamma = args[1].unwrapToTensor(at::full(channel_dim, 1, options));
74+
beta = args[2].unwrapToTensor(at::full(channel_dim, 0, options));
75+
mean = args[3].unwrapToTensor(at::full(channel_dim, 0, options));
76+
var = args[4].unwrapToTensor(at::full(channel_dim, 0, options));
6777
}
6878

6979
auto eps = static_cast<float>(args[7].unwrapToDouble(1e-5f));
7080

71-
LOG_DEBUG("momentum disregarded");
72-
LOG_DEBUG("training disregarded");
73-
LOG_DEBUG("cudnn disregarded");
7481
TORCHTRT_CHECK(orig_shape.nbDims >= 2, "Unable to create batch normalization layer from node: " << *n);
7582

7683
// Expand spatial dims from 1D to 2D if needed

0 commit comments

Comments
 (0)