@@ -27,7 +27,7 @@ InterpolatePlugin::InterpolatePlugin(
2727 align_corners_(align_corners),
2828 use_scales_(use_scales) {
2929 if (use_scales) {
30- TRTORCH_ASSERT (mode_ != " adaptive_pool2d " , " use_scales is not valid for adaptive_pool2d " );
30+ TRTORCH_ASSERT (mode_ != " adaptive_avg_pool2d " , " use_scales is not valid for adaptive_avg_pool2d " );
3131 TRTORCH_ASSERT (
3232 scales_.size () != 0 , " Attempted to use interpolate plugin without providing scales while use_scales=true" );
3333 at::Tensor input = at::randint (1 , 10 , in_shape, {at::kCUDA });
@@ -106,7 +106,11 @@ std::vector<int64_t> InterpolatePlugin::getOutputSize() {
106106}
107107
108108int InterpolatePlugin::getNbOutputs () const {
109- return 1 ;
109+ if (mode_ == " adaptive_max_pool2d" ) {
110+ return 2 ;
111+ } else {
112+ return 1 ;
113+ }
110114}
111115
112116const char * InterpolatePlugin::getPluginType () const {
@@ -166,15 +170,6 @@ nvinfer1::DataType InterpolatePlugin::getOutputDataType(int index, const nvinfer
166170}
167171
168172int InterpolatePlugin::initialize () {
169- #if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1)
170- tensor_options_ = tensor_options_.device (c10::kCUDA );
171- #else
172- tensor_options_ = tensor_options_.device (c10::kCPU );
173- #endif
174-
175- // c10::kFloat = FLOAT32
176- tensor_options_ = tensor_options_.dtype (c10::kFloat );
177-
178173 return 0 ;
179174}
180175
@@ -211,9 +206,15 @@ bool InterpolatePlugin::supportsFormatCombination(
211206 const nvinfer1::PluginTensorDesc* inOut,
212207 int nbInputs,
213208 int nbOutputs) {
214- TRTORCH_ASSERT (0 <= pos && pos <= 1 , " There should be exactly 2 connections to the plugin - 1 input, 1 output" );
215209 TRTORCH_ASSERT (nbInputs == 1 , " Expected a single tensor as input to interpolate plugin" );
216- TRTORCH_ASSERT (nbOutputs == 1 , " Expected a single tensor as output to interpolate plugin" );
210+
211+ if (mode_ == " adaptive_max_pool2d" ) {
212+ TRTORCH_ASSERT (nbOutputs == 2 , " Expected 2 tensors as output to interpolate plugin" );
213+ TRTORCH_ASSERT (0 <= pos && pos <= 2 , " There should be exactly 3 connections to the plugin - 1 input, 2 output" );
214+ } else {
215+ TRTORCH_ASSERT (nbOutputs == 1 , " Expected a single tensor as output to interpolate plugin" );
216+ TRTORCH_ASSERT (0 <= pos && pos <= 1 , " There should be exactly 2 connections to the plugin - 1 input, 1 output" );
217+ }
217218
218219 const nvinfer1::PluginTensorDesc& in = inOut[0 ];
219220
@@ -250,10 +251,10 @@ int InterpolatePlugin::enqueue(
250251 void * const * outputs,
251252 void * workspace,
252253 cudaStream_t stream) {
253- # if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1)
254- at::Tensor input = at::from_blob ((void *)inputs[0 ], util::toVec (inputDesc->dims ), [](void *) {}, tensor_options_ );
255- at::Tensor output = at::from_blob (
256- outputs[0 ], util::volume (outputDesc->dims ), [](void *) {}, tensor_options_ );
254+ at::Tensor input =
255+ at::from_blob ((void *)inputs[0 ], util::toVec (inputDesc->dims ), [](void *) {}, {at:: kCUDA }). to (torch:: kFloat );
256+ at::Tensor output =
257+ at::from_blob ( outputs[0 ], util::toVec (outputDesc->dims ), [](void *) {}, {at:: kCUDA }). to (torch:: kFloat );
257258
258259 at::cuda::CUDAStream torch_stream = at::cuda::getStreamFromPool ();
259260 at::cuda::CUDAStreamGuard torch_guard (torch_stream);
@@ -263,27 +264,30 @@ int InterpolatePlugin::enqueue(
263264 cudaEventRecord (event, stream);
264265
265266 cudaStreamWaitEvent (torch_stream.stream (), event, 0 );
266-
267+ at::Tensor out;
267268 if (use_scales_) {
268269 if (mode_ == " linear" ) {
269- at::upsample_linear1d_out (output, input, {} , align_corners_, scales_[0 ]);
270+ out = at::upsample_linear1d ( input, c10:: nullopt , align_corners_, { scales_[0 ]} );
270271 } else if (mode_ == " bilinear" ) {
271- at::upsample_bilinear2d_out (output, input, {} , align_corners_, scales_[ 0 ], scales_[ 1 ] );
272+ out = at::upsample_bilinear2d ( input, c10:: nullopt , align_corners_, scales_);
272273 } else if (mode_ == " trilinear" ) {
273- at::upsample_trilinear3d_out (output, input, {} , align_corners_, scales_[ 0 ], scales_[ 1 ], scales_[ 2 ] );
274+ out = at::upsample_trilinear3d ( input, c10:: nullopt , align_corners_, scales_);
274275 }
275276 } else {
276277 if (mode_ == " linear" ) {
277- at::upsample_linear1d_out (output, input, {size_[0 ]}, align_corners_);
278+ out = at::upsample_linear1d ( input, {size_[0 ]}, align_corners_);
278279 } else if (mode_ == " bilinear" ) {
279- at::upsample_bilinear2d_out (output, input, {size_[0 ], size_[1 ]}, align_corners_);
280+ out = at::upsample_bilinear2d ( input, {size_[0 ], size_[1 ]}, align_corners_);
280281 } else if (mode_ == " trilinear" ) {
281- at::upsample_trilinear3d_out (output, input, {size_[0 ], size_[1 ], size_[2 ]}, align_corners_);
282- } else if (mode_ == " adaptive_pool2d" ) {
283- at::adaptive_avg_pool2d_out (output, input, {size_[0 ], size_[1 ]});
282+ out = at::upsample_trilinear3d (input, {size_[0 ], size_[1 ], size_[2 ]}, align_corners_);
283+ } else if (mode_ == " adaptive_avg_pool2d" ) {
284+ out = at::adaptive_avg_pool2d (input, {size_[0 ], size_[1 ]});
285+ } else if (mode_ == " adaptive_max_pool2d" ) {
286+ out = std::get<0 >(at::adaptive_max_pool2d (input, {size_[0 ], size_[1 ]}));
284287 }
285288 }
286289
290+ output.copy_ (out);
287291 cudaEvent_t torch_event;
288292 cudaEventCreate (&torch_event);
289293 cudaEventRecord (torch_event, torch_stream.stream ());
@@ -294,49 +298,6 @@ int InterpolatePlugin::enqueue(
294298 cudaEventDestroy (torch_event);
295299
296300 return 0 ;
297- #else
298- // TODO: When PyTorch updates to cuDNN 8 try moving back to CUDA based ATen
299- // kernels HACK: WAR because there is a segfault if you try to create a CUDA
300- // Tensor in the context of TensorRT execution
301- float * input_blob = (float *)malloc (util::volume (inputDesc->dims ) * sizeof (float ));
302- cudaMemcpyAsync (
303- input_blob,
304- static_cast <const void *>(inputs[0 ]),
305- util::volume (inputDesc->dims ) * sizeof (float ),
306- cudaMemcpyDeviceToHost,
307- stream);
308- cudaStreamSynchronize (stream);
309-
310- at::Tensor input = at::from_blob ((void *)input_blob, util::toVec (inputDesc->dims ), tensor_options_);
311- at::Tensor output;
312- if (use_scales_) {
313- if (mode_ == " linear" ) {
314- output = at::upsample_linear1d (input, c10::nullopt , align_corners_, {scales_[0 ]});
315- } else if (mode_ == " bilinear" ) {
316- output = at::upsample_bilinear2d (input, c10::nullopt , align_corners_, scales_);
317- } else if (mode_ == " trilinear" ) {
318- output = at::upsample_trilinear3d (input, c10::nullopt , align_corners_, scales_);
319- }
320- } else {
321- if (mode_ == " linear" ) {
322- output = at::upsample_linear1d (input, {size_[0 ]}, align_corners_);
323- } else if (mode_ == " bilinear" ) {
324- output = at::upsample_bilinear2d (input, {size_[0 ], size_[1 ]}, align_corners_);
325- } else if (mode_ == " trilinear" ) {
326- output = at::upsample_trilinear3d (input, {size_[0 ], size_[1 ], size_[2 ]}, align_corners_);
327- } else if (mode_ == " adaptive_pool2d" ) {
328- output = at::adaptive_avg_pool2d (input, {size_[0 ], size_[1 ]});
329- }
330- }
331-
332- cudaMemcpyAsync (
333- outputs[0 ], output.data_ptr (), util::volume (outputDesc->dims ) * sizeof (float ), cudaMemcpyHostToDevice, stream);
334- cudaStreamSynchronize (stream);
335-
336- free (input_blob);
337-
338- return 0 ;
339- #endif
340301}
341302
342303/*
0 commit comments