@@ -18,17 +18,36 @@ auto max_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().patter
1818 [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
1919 auto self = args[0 ].ITensorOrFreeze (ctx);
2020 auto dim = args[1 ].unwrapToInt ();
21+ auto keep_dims = args[2 ].unwrapToBool ();
2122 auto selfDim = util::toVec (self->getDimensions ());
2223 if (dim < 0 ) {
2324 dim = selfDim.size () + dim;
2425 }
2526 uint32_t shiftDim = 1 << dim;
2627 auto TopKOperation = nvinfer1::TopKOperation::kMAX ;
27- auto new_layer = ctx->net ->addTopK (*self, TopKOperation, 1 , shiftDim);
28- TORCHTRT_CHECK (new_layer, " Unable to create max layer from node: " << *n);
28+ auto topk_layer = ctx->net ->addTopK (*self, TopKOperation, 1 , shiftDim);
29+ TORCHTRT_CHECK (topk_layer, " Unable to create max layer from node: " << *n);
30+ auto topk_dims = util::toVec (topk_layer->getOutput (0 )->getDimensions ());
2931
30- auto out0 = ctx->AssociateValueAndTensor (n->outputs ()[0 ], new_layer->getOutput (0 ));
31- auto out1 = ctx->AssociateValueAndTensor (n->outputs ()[1 ], new_layer->getOutput (1 ));
32+ nvinfer1::ITensor* out0;
33+ nvinfer1::ITensor* out1;
34+ if (!keep_dims) {
35+ if (topk_dims[dim] == 1 ) {
36+ auto squeeze_layer = ctx->net ->addShuffle (*topk_layer->getOutput (0 ));
37+ squeeze_layer->setReshapeDimensions (util::squeezeDims (topk_layer->getOutput (0 )->getDimensions (), dim));
38+ TORCHTRT_CHECK (squeeze_layer, " Unable to create squeeze_layer layer from node: " << *n);
39+ out0 = ctx->AssociateValueAndTensor (n->outputs ()[0 ], squeeze_layer->getOutput (0 ));
40+
41+ auto squeeze_layer_indices = ctx->net ->addShuffle (*topk_layer->getOutput (1 ));
42+ squeeze_layer_indices->setReshapeDimensions (
43+ util::squeezeDims (topk_layer->getOutput (1 )->getDimensions (), dim));
44+ TORCHTRT_CHECK (squeeze_layer_indices, " Unable to create squeeze_layer_indices layer from node: " << *n);
45+ out1 = ctx->AssociateValueAndTensor (n->outputs ()[1 ], squeeze_layer_indices->getOutput (0 ));
46+ }
47+ } else {
48+ out0 = ctx->AssociateValueAndTensor (n->outputs ()[0 ], topk_layer->getOutput (0 ));
49+ out1 = ctx->AssociateValueAndTensor (n->outputs ()[1 ], topk_layer->getOutput (1 ));
50+ }
3251
3352 LOG_DEBUG (" Output tensor(0) shape: " << out0->getDimensions ());
3453 LOG_DEBUG (" Output tensor(1) shape: " << out1->getDimensions ());
0 commit comments