@@ -11,15 +11,95 @@ namespace impl {
11
11
namespace {
12
12
13
13
bool add_conv_deconv (ConversionCtx* ctx, const torch::jit::Node* n, args& args) {
14
- auto in = args[0 ].ITensor (); // assumes non-static input Tensor
15
- auto w = Weights (ctx, args[1 ].unwrapToTensor ());
14
+ // Input to conv/deconv
15
+ auto in = args[0 ].ITensor ();
16
+
17
+ // Conv /deconv parameters
16
18
auto stride = util::toDims (args[3 ].unwrapToIntList ());
17
19
auto padding = util::toDims (args[4 ].unwrapToIntList ());
18
20
auto dilation = util::toDims (args[5 ].unwrapToIntList ());
19
21
bool transposed = args[6 ].unwrapToBool ();
20
22
auto out_padding = util::toDims (args[7 ].unwrapToIntList ());
21
23
int64_t groups = args[8 ].unwrapToInt ();
22
24
25
+ // Reshape the parameters to 2D if needed
26
+ if (stride.nbDims == 1 ) {
27
+ stride = util::unsqueezeDims (stride, 1 , 1 );
28
+ LOG_DEBUG (" Reshaped stride: " << stride);
29
+ }
30
+ if (dilation.nbDims == 1 ) {
31
+ dilation = util::unsqueezeDims (dilation, 1 , 1 );
32
+ LOG_DEBUG (" Reshaped dilation: " << dilation);
33
+ }
34
+ if (padding.nbDims == 1 ) {
35
+ padding = util::unsqueezeDims (padding, 1 , 0 );
36
+ LOG_DEBUG (" Reshaped padding: " << padding);
37
+ }
38
+ if (out_padding.nbDims == 1 ) {
39
+ out_padding = util::unsqueezeDims (out_padding, 1 , 0 );
40
+ LOG_DEBUG (" Reshaped out_padding: " << out_padding);
41
+ }
42
+
43
+ // Get bias tensor or initialize it to zeros.
44
+ Weights bias;
45
+ if (args[2 ].IValue ()->isTensor ()) {
46
+ bias = Weights (ctx, args[2 ].unwrapToTensor ());
47
+ } else {
48
+ bias = Weights (); // nvinfer1::Weights{nvinfer1::DataType::kFLOAT, nullptr, 0};
49
+ }
50
+
51
+ // Handle case when weights of conv/deconv is an ITensor. This case happens for QAT networks where
52
+ // conv_weights -> Quantize -> Dequantize -> new_conv_weights -> conv <- input
53
+ // new_conv_weights will be an ITensor because it is an output of Dequantize layer defined in impl/quantization.cpp
54
+ if (args[1 ].isITensor ()){
55
+ // Get the kernel tensor
56
+ auto kernel = args[1 ].ITensor ();
57
+ auto kernel_dims = kernel->getDimensions ();
58
+
59
+ // Make a new Dims with only the spatial dimensions.
60
+ nvinfer1::Dims filter_dim;
61
+ int64_t nbSpatialDims = in->getDimensions ().nbDims - 2 ;
62
+ TRTORCH_CHECK (nbSpatialDims = kernel_dims.nbDims - 2 , " Number of input spatial dimensions should match the kernel spatial dimensions" );
63
+ filter_dim.nbDims = nbSpatialDims;
64
+ filter_dim.d [0 ] = kernel_dims.d [2 ];
65
+ filter_dim.d [1 ] = kernel_dims.d [3 ];
66
+
67
+ // Initialize a dummy constant kernel to pass it to INetwork->addConvolutionNd/addDeconvolutionNd API.
68
+ auto kernel_weights = nvinfer1::Weights{nvinfer1::DataType::kFLOAT , nullptr , 0 };
69
+
70
+ nvinfer1::ILayer* layer = nullptr ;
71
+ if (transposed){
72
+ nvinfer1::IDeconvolutionLayer* deconvLayer
73
+ = ctx->net ->addDeconvolutionNd (*in, kernel_dims.d [0 ], filter_dim, kernel_weights, bias.data );
74
+ deconvLayer->setStrideNd (stride);
75
+ deconvLayer->setDilationNd (dilation);
76
+ deconvLayer->setNbGroups (groups);
77
+ deconvLayer->setPaddingNd (padding);
78
+ // Set deconv kernel weights
79
+ deconvLayer->setInput (1 , *kernel);
80
+ TRTORCH_CHECK (deconvLayer, " Unable to create deconv layer with non-const weights from node: " << *n);
81
+ layer = deconvLayer;
82
+ } else {
83
+ nvinfer1::IConvolutionLayer* convLayer
84
+ = ctx->net ->addConvolutionNd (*in, kernel_dims.d [0 ], filter_dim, kernel_weights, bias.data );
85
+ convLayer->setStrideNd (stride);
86
+ convLayer->setPaddingMode (nvinfer1::PaddingMode::kCAFFE_ROUND_DOWN );
87
+ convLayer->setPaddingNd (padding);
88
+ convLayer->setPostPadding (out_padding);
89
+ convLayer->setDilationNd (dilation);
90
+ convLayer->setNbGroups (groups);
91
+
92
+ // Set conv kernel weights
93
+ convLayer->setInput (1 , *kernel);
94
+ layer = convLayer;
95
+ }
96
+
97
+ ctx->AssociateValueAndTensor (n->outputs ()[0 ], layer->getOutput (0 ));
98
+ LOG_DEBUG (" Output tensor shape: " << layer->getOutput (0 )->getDimensions ());
99
+ return true ;
100
+ }
101
+
102
+ auto w = Weights (ctx, args[1 ].unwrapToTensor ());
23
103
auto dims = in->getDimensions ();
24
104
auto orig_dims = dims;
25
105
LOG_DEBUG (" Input dims: " << orig_dims);
@@ -46,32 +126,9 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
46
126
w.kernel_shape .d [1 ] = 1 ;
47
127
LOG_DEBUG (" Reshaped Weights: " << w);
48
128
}
49
- if (stride.nbDims == 1 ) {
50
- stride = util::unsqueezeDims (stride, 1 , 1 );
51
- LOG_DEBUG (" Reshaped stride: " << stride);
52
- }
53
- if (dilation.nbDims == 1 ) {
54
- dilation = util::unsqueezeDims (dilation, 1 , 1 );
55
- LOG_DEBUG (" Reshaped dilation: " << dilation);
56
- }
57
- if (padding.nbDims == 1 ) {
58
- padding = util::unsqueezeDims (padding, 1 , 0 );
59
- LOG_DEBUG (" Reshaped padding: " << padding);
60
- }
61
- if (out_padding.nbDims == 1 ) {
62
- out_padding = util::unsqueezeDims (out_padding, 1 , 0 );
63
- LOG_DEBUG (" Reshaped out_padding: " << out_padding);
64
- }
65
129
66
130
nvinfer1::ILayer* new_layer;
67
131
if (transposed) {
68
- Weights bias;
69
- if (args[2 ].IValue ()->isTensor ()) {
70
- bias = Weights (ctx, args[2 ].unwrapToTensor ());
71
- } else {
72
- bias = Weights (ctx, torch::zeros (w.shape .d [1 ] * groups));
73
- }
74
-
75
132
// shape of deconvolution's weight: [in, out/groups, ...]
76
133
auto deconv = ctx->net ->addDeconvolutionNd (*in, w.shape .d [1 ] * groups, w.kernel_shape , w.data , bias.data );
77
134
TRTORCH_CHECK (deconv, " Unable to create deconvolution layer from node: " << *n);
@@ -89,12 +146,12 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
89
146
#endif
90
147
new_layer = deconv;
91
148
} else {
92
- Weights bias;
93
- if (args[2 ].IValue ()->isTensor ()) {
94
- bias = Weights (ctx, args[2 ].unwrapToTensor ());
95
- } else {
96
- bias = Weights (ctx, torch::zeros (w.shape .d [0 ]));
97
- }
149
+ // Weights bias;
150
+ // if (args[2].IValue()->isTensor()) {
151
+ // bias = Weights(ctx, args[2].unwrapToTensor());
152
+ // } else {
153
+ // bias = Weights(ctx, torch::zeros(w.shape.d[0]));
154
+ // }
98
155
99
156
// shape of convolution's weight: [out, in/groups, ...]
100
157
auto conv = ctx->net ->addConvolutionNd (*in, w.shape .d [0 ], w.kernel_shape , w.data , bias.data );
0 commit comments