@@ -50,27 +50,34 @@ auto batch_norm_registrations TORCHTRT_UNUSED =
50
50
auto orig_shape = input->getDimensions ();
51
51
auto shape = util::toVec (orig_shape);
52
52
auto tensor_type = util::TRTDataTypeToScalarType (input->getType ());
53
- auto options = torch::TensorOptions ().dtype (tensor_type);
54
-
53
+ auto options = torch::TensorOptions ().dtype (tensor_type). device (torch:: kCUDA , ctx-> settings . device . gpu_id ) ;
54
+
55
55
torch::Tensor gamma, beta, mean, var;
56
+ LOG_DEBUG (" Input :" << orig_shape << " /" << input->getType ());
57
+ // affine=True
58
+ LOG_DEBUG (" Args[1] gamma : " << args[1 ].isIValue () << " / " << args[1 ].IValue ()->isNone ());
59
+ LOG_DEBUG (" Args[2] beta : " << args[2 ].isIValue () << " / " << args[2 ].IValue ()->isNone ());
60
+ // track_running_stats=True
61
+ LOG_DEBUG (" Args[3] mean : " << args[3 ].isIValue () << " / " << args[3 ].IValue ()->isNone ());
62
+ LOG_DEBUG (" Args[4] var : " << args[4 ].isIValue () << " / " << args[4 ].IValue ()->isNone ());
63
+ LOG_DEBUG (" use_input_stats, momemtum, cudnn_enabled disregarded" );
64
+ LOG_DEBUG (" ctx->input_is_dynamic : " << ctx->input_is_dynamic );
56
65
66
+ auto channel_dim = shape[1 ];
57
67
if (ctx->input_is_dynamic ) {
58
- gamma = args[1 ].unwrapToTensor ();
59
- beta = args[2 ].unwrapToTensor ();
68
+ gamma = args[1 ].unwrapToTensor (at::full (channel_dim, 1 , options) );
69
+ beta = args[2 ].unwrapToTensor (at::full (channel_dim, 0 , options) );
60
70
mean = args[3 ].unwrapToTensor ();
61
71
var = args[4 ].unwrapToTensor ();
62
72
} else {
63
- gamma = args[1 ].unwrapToTensor (at::full ({shape} , 1 , { options} ));
64
- beta = args[2 ].unwrapToTensor (at::full ({shape}, 1 , { options} ));
65
- mean = args[3 ].unwrapToTensor (at::full ({shape} , 0 , { options} ));
66
- var = args[4 ].unwrapToTensor (at::full ({shape} , 0 , { options} ));
73
+ gamma = args[1 ].unwrapToTensor (at::full (channel_dim , 1 , options));
74
+ beta = args[2 ].unwrapToTensor (at::full (channel_dim, 0 , options));
75
+ mean = args[3 ].unwrapToTensor (at::full (channel_dim , 0 , options));
76
+ var = args[4 ].unwrapToTensor (at::full (channel_dim , 0 , options));
67
77
}
68
78
69
79
auto eps = static_cast <float >(args[7 ].unwrapToDouble (1e-5f ));
70
80
71
- LOG_DEBUG (" momentum disregarded" );
72
- LOG_DEBUG (" training disregarded" );
73
- LOG_DEBUG (" cudnn disregarded" );
74
81
TORCHTRT_CHECK (orig_shape.nbDims >= 2 , " Unable to create batch normalization layer from node: " << *n);
75
82
76
83
// Expand spatial dims from 1D to 2D if needed
0 commit comments