@@ -14,27 +14,27 @@ namespace converters {
14
14
namespace impl {
15
15
namespace {
16
16
auto max_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(
17
- {" aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)" ,
18
- [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
19
- auto self = args[0 ].ITensorOrFreeze (ctx);
20
- auto dim = args[1 ].unwrapToInt ();
21
- auto selfDim = util::toVec (self->getDimensions ());
22
- if (dim < 0 ) {
23
- dim = selfDim.size () + dim;
24
- }
25
- uint32_t shiftDim = 1 << dim;
26
- auto TopKOperation = nvinfer1::TopKOperation::kMAX ;
27
- auto new_layer = ctx->net ->addTopK (*self, TopKOperation, 1 , shiftDim);
28
- TORCHTRT_CHECK (new_layer, " Unable to create max layer from node: " << *n);
17
+ {" aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)" ,
18
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
19
+ auto self = args[0 ].ITensorOrFreeze (ctx);
20
+ auto dim = args[1 ].unwrapToInt ();
21
+ auto selfDim = util::toVec (self->getDimensions ());
22
+ if (dim < 0 ) {
23
+ dim = selfDim.size () + dim;
24
+ }
25
+ uint32_t shiftDim = 1 << dim;
26
+ auto TopKOperation = nvinfer1::TopKOperation::kMAX ;
27
+ auto new_layer = ctx->net ->addTopK (*self, TopKOperation, 1 , shiftDim);
28
+ TORCHTRT_CHECK (new_layer, " Unable to create max layer from node: " << *n);
29
29
30
- auto out0 = ctx->AssociateValueAndTensor (n->outputs ()[0 ], new_layer->getOutput (0 ));
31
- auto out1 = ctx->AssociateValueAndTensor (n->outputs ()[1 ], new_layer->getOutput (1 ));
30
+ auto out0 = ctx->AssociateValueAndTensor (n->outputs ()[0 ], new_layer->getOutput (0 ));
31
+ auto out1 = ctx->AssociateValueAndTensor (n->outputs ()[1 ], new_layer->getOutput (1 ));
32
32
33
- LOG_DEBUG (" Output tensor(0) shape: " << out0->getDimensions ());
34
- LOG_DEBUG (" Output tensor(1) shape: " << out1->getDimensions ());
33
+ LOG_DEBUG (" Output tensor(0) shape: " << out0->getDimensions ());
34
+ LOG_DEBUG (" Output tensor(1) shape: " << out1->getDimensions ());
35
35
36
- return true ;
37
- }});
36
+ return true ;
37
+ }});
38
38
} // namespace
39
39
} // namespace impl
40
40
} // namespace converters
0 commit comments