@@ -11,15 +11,95 @@ namespace impl {
1111namespace {
1212
1313bool add_conv_deconv (ConversionCtx* ctx, const torch::jit::Node* n, args& args) {
14- auto in = args[0 ].ITensor (); // assumes non-static input Tensor
15- auto w = Weights (ctx, args[1 ].unwrapToTensor ());
14+ // Input to conv/deconv
15+ auto in = args[0 ].ITensor ();
16+
17+ // Conv /deconv parameters
1618 auto stride = util::toDims (args[3 ].unwrapToIntList ());
1719 auto padding = util::toDims (args[4 ].unwrapToIntList ());
1820 auto dilation = util::toDims (args[5 ].unwrapToIntList ());
1921 bool transposed = args[6 ].unwrapToBool ();
2022 auto out_padding = util::toDims (args[7 ].unwrapToIntList ());
2123 int64_t groups = args[8 ].unwrapToInt ();
2224
25+ // Reshape the parameters to 2D if needed
26+ if (stride.nbDims == 1 ) {
27+ stride = util::unsqueezeDims (stride, 1 , 1 );
28+ LOG_DEBUG (" Reshaped stride: " << stride);
29+ }
30+ if (dilation.nbDims == 1 ) {
31+ dilation = util::unsqueezeDims (dilation, 1 , 1 );
32+ LOG_DEBUG (" Reshaped dilation: " << dilation);
33+ }
34+ if (padding.nbDims == 1 ) {
35+ padding = util::unsqueezeDims (padding, 1 , 0 );
36+ LOG_DEBUG (" Reshaped padding: " << padding);
37+ }
38+ if (out_padding.nbDims == 1 ) {
39+ out_padding = util::unsqueezeDims (out_padding, 1 , 0 );
40+ LOG_DEBUG (" Reshaped out_padding: " << out_padding);
41+ }
42+
43+ // Get bias tensor or initialize it to zeros.
44+ Weights bias;
45+ if (args[2 ].IValue ()->isTensor ()) {
46+ bias = Weights (ctx, args[2 ].unwrapToTensor ());
47+ } else {
48+ bias = Weights (); // nvinfer1::Weights{nvinfer1::DataType::kFLOAT, nullptr, 0};
49+ }
50+
51+ // Handle case when weights of conv/deconv is an ITensor. This case happens for QAT networks where
52+ // conv_weights -> Quantize -> Dequantize -> new_conv_weights -> conv <- input
53+ // new_conv_weights will be an ITensor because it is an output of Dequantize layer defined in impl/quantization.cpp
54+ if (args[1 ].isITensor ()){
55+ // Get the kernel tensor
56+ auto kernel = args[1 ].ITensor ();
57+ auto kernel_dims = kernel->getDimensions ();
58+
59+ // Make a new Dims with only the spatial dimensions.
60+ nvinfer1::Dims filter_dim;
61+ int64_t nbSpatialDims = in->getDimensions ().nbDims - 2 ;
62+ TRTORCH_CHECK (nbSpatialDims = kernel_dims.nbDims - 2 , " Number of input spatial dimensions should match the kernel spatial dimensions" );
63+ filter_dim.nbDims = nbSpatialDims;
64+ filter_dim.d [0 ] = kernel_dims.d [2 ];
65+ filter_dim.d [1 ] = kernel_dims.d [3 ];
66+
67+ // Initialize a dummy constant kernel to pass it to INetwork->addConvolutionNd/addDeconvolutionNd API.
68+ auto kernel_weights = nvinfer1::Weights{nvinfer1::DataType::kFLOAT , nullptr , 0 };
69+
70+ nvinfer1::ILayer* layer = nullptr ;
71+ if (transposed){
72+ nvinfer1::IDeconvolutionLayer* deconvLayer
73+ = ctx->net ->addDeconvolutionNd (*in, kernel_dims.d [0 ], filter_dim, kernel_weights, bias.data );
74+ deconvLayer->setStrideNd (stride);
75+ deconvLayer->setDilationNd (dilation);
76+ deconvLayer->setNbGroups (groups);
77+ deconvLayer->setPaddingNd (padding);
78+ // Set deconv kernel weights
79+ deconvLayer->setInput (1 , *kernel);
80+ TRTORCH_CHECK (deconvLayer, " Unable to create deconv layer with non-const weights from node: " << *n);
81+ layer = deconvLayer;
82+ } else {
83+ nvinfer1::IConvolutionLayer* convLayer
84+ = ctx->net ->addConvolutionNd (*in, kernel_dims.d [0 ], filter_dim, kernel_weights, bias.data );
85+ convLayer->setStrideNd (stride);
86+ convLayer->setPaddingMode (nvinfer1::PaddingMode::kCAFFE_ROUND_DOWN );
87+ convLayer->setPaddingNd (padding);
88+ convLayer->setPostPadding (out_padding);
89+ convLayer->setDilationNd (dilation);
90+ convLayer->setNbGroups (groups);
91+
92+ // Set conv kernel weights
93+ convLayer->setInput (1 , *kernel);
94+ layer = convLayer;
95+ }
96+
97+ ctx->AssociateValueAndTensor (n->outputs ()[0 ], layer->getOutput (0 ));
98+ LOG_DEBUG (" Output tensor shape: " << layer->getOutput (0 )->getDimensions ());
99+ return true ;
100+ }
101+
102+ auto w = Weights (ctx, args[1 ].unwrapToTensor ());
23103 auto dims = in->getDimensions ();
24104 auto orig_dims = dims;
25105 LOG_DEBUG (" Input dims: " << orig_dims);
@@ -46,32 +126,9 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
46126 w.kernel_shape .d [1 ] = 1 ;
47127 LOG_DEBUG (" Reshaped Weights: " << w);
48128 }
49- if (stride.nbDims == 1 ) {
50- stride = util::unsqueezeDims (stride, 1 , 1 );
51- LOG_DEBUG (" Reshaped stride: " << stride);
52- }
53- if (dilation.nbDims == 1 ) {
54- dilation = util::unsqueezeDims (dilation, 1 , 1 );
55- LOG_DEBUG (" Reshaped dilation: " << dilation);
56- }
57- if (padding.nbDims == 1 ) {
58- padding = util::unsqueezeDims (padding, 1 , 0 );
59- LOG_DEBUG (" Reshaped padding: " << padding);
60- }
61- if (out_padding.nbDims == 1 ) {
62- out_padding = util::unsqueezeDims (out_padding, 1 , 0 );
63- LOG_DEBUG (" Reshaped out_padding: " << out_padding);
64- }
65129
66130 nvinfer1::ILayer* new_layer;
67131 if (transposed) {
68- Weights bias;
69- if (args[2 ].IValue ()->isTensor ()) {
70- bias = Weights (ctx, args[2 ].unwrapToTensor ());
71- } else {
72- bias = Weights (ctx, torch::zeros (w.shape .d [1 ] * groups));
73- }
74-
75132 // shape of deconvolution's weight: [in, out/groups, ...]
76133 auto deconv = ctx->net ->addDeconvolutionNd (*in, w.shape .d [1 ] * groups, w.kernel_shape , w.data , bias.data );
77134 TRTORCH_CHECK (deconv, " Unable to create deconvolution layer from node: " << *n);
@@ -89,12 +146,12 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
89146#endif
90147 new_layer = deconv;
91148 } else {
92- Weights bias;
93- if (args[2 ].IValue ()->isTensor ()) {
94- bias = Weights (ctx, args[2 ].unwrapToTensor ());
95- } else {
96- bias = Weights (ctx, torch::zeros (w.shape .d [0 ]));
97- }
149+ // Weights bias;
150+ // if (args[2].IValue()->isTensor()) {
151+ // bias = Weights(ctx, args[2].unwrapToTensor());
152+ // } else {
153+ // bias = Weights(ctx, torch::zeros(w.shape.d[0]));
154+ // }
98155
99156 // shape of convolution's weight: [out, in/groups, ...]
100157 auto conv = ctx->net ->addConvolutionNd (*in, w.shape .d [0 ], w.kernel_shape , w.data , bias.data );
0 commit comments