@@ -37,7 +37,7 @@ bool AdaptivePoolingConverter(
3737 ConversionCtx* ctx,
3838 const torch::jit::Node* n,
3939 args& args,
40- nvinfer1::PoolingType pool_type) {
40+ nvinfer1::PoolingType pool_type, const std::string& mode ) {
4141 auto in = args[0 ].ITensorOrFreeze (ctx);
4242 auto out_size = util::toDims (args[1 ].unwrapToIntList ());
4343
@@ -48,15 +48,7 @@ bool AdaptivePoolingConverter(
4848 }
4949
5050 auto orig_dims = in->getDimensions ();
51- bool expandDims = (orig_dims.nbDims < 4 );
52- TORCHTRT_CHECK (orig_dims.nbDims > 2 , " Unable to create pooling layer from node: " << *n);
53- if (expandDims) {
54- in = addPadding (ctx, n, in, 4 , false , false );
55- }
56-
57- if (out_size.nbDims == 1 ) {
58- out_size = util::unsqueezeDims (out_size, 0 , 1 );
59- }
51+ TORCHTRT_CHECK (orig_dims.nbDims > 1 , " Unable to create pooling layer from node: " << *n);
6052
6153 auto in_shape = util::toVec (in->getDimensions ());
6254 nvinfer1::ILayer* new_layer = nullptr ;
@@ -90,10 +82,6 @@ bool AdaptivePoolingConverter(
9082 int32_t use_scales_casted = 0 ;
9183 f.emplace_back (nvinfer1::PluginField (" use_scales" , &use_scales_casted, nvinfer1::PluginFieldType::kINT32 , 1 ));
9284
93- std::string mode = " adaptive_avg_pool2d" ;
94- if (pool_type == nvinfer1::PoolingType::kMAX ) {
95- mode = " adaptive_max_pool2d" ;
96- }
9785 f.emplace_back (nvinfer1::PluginField (" mode" , &mode, nvinfer1::PluginFieldType::kCHAR , 1 ));
9886
9987 fc.nbFields = f.size ();
@@ -110,7 +98,7 @@ bool AdaptivePoolingConverter(
11098 TORCHTRT_CHECK (new_layer, " Unable to create pooling (interpolation) plugin from node" << *n);
11199
112100 new_layer->setName (util::node_info (n).c_str ());
113- auto layer_output = addUnpadding (ctx, n, new_layer->getOutput (0 ), orig_dims. nbDims , false , false );
101+ auto layer_output = new_layer->getOutput (0 );
114102
115103 ctx->AssociateValueAndTensor (n->outputs ()[0 ], layer_output);
116104 LOG_DEBUG (" Output tensor shape: " << layer_output->getDimensions ());
@@ -238,15 +226,15 @@ auto pooling_registrations TORCHTRT_UNUSED =
238226 }})
239227 .pattern({" aten::adaptive_avg_pool1d(Tensor self, int[1] output_size) -> (Tensor)" ,
240228 [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
241- return AdaptivePoolingConverter (ctx, n, args, nvinfer1::PoolingType::kAVERAGE );
229+ return AdaptivePoolingConverter (ctx, n, args, nvinfer1::PoolingType::kAVERAGE , " adaptive_avg_pool1d " );
242230 }})
243231 .pattern({" aten::adaptive_avg_pool2d(Tensor self, int[2] output_size) -> (Tensor)" ,
244232 [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
245- return AdaptivePoolingConverter (ctx, n, args, nvinfer1::PoolingType::kAVERAGE );
233+ return AdaptivePoolingConverter (ctx, n, args, nvinfer1::PoolingType::kAVERAGE , " adaptive_avg_pool2d " );
246234 }})
247235 .pattern({" aten::adaptive_max_pool2d(Tensor self, int[2] output_size) -> (Tensor, Tensor)" ,
248236 [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
249- return AdaptivePoolingConverter (ctx, n, args, nvinfer1::PoolingType::kMAX );
237+ return AdaptivePoolingConverter (ctx, n, args, nvinfer1::PoolingType::kMAX , " adaptive_max_pool2d " );
250238 }});
251239} // namespace
252240} // namespace impl
0 commit comments