Skip to content

Commit dd7a44e

Browse files
committed
feat: Add converter files for torch::max
Signed-off-by: hongwei03 <[email protected]>
1 parent 0b5673b commit dd7a44e

File tree

1 file changed

+18
-18
lines changed
  • core/conversion/converters/impl

1 file changed

+18
-18
lines changed

core/conversion/converters/impl/max.cpp

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,27 +14,27 @@ namespace converters {
1414
namespace impl {
1515
namespace {
1616
auto max_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(
17-
{"aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)",
18-
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
19-
auto self = args[0].ITensorOrFreeze(ctx);
20-
auto dim = args[1].unwrapToInt();
21-
auto selfDim = util::toVec(self->getDimensions());
22-
if (dim < 0) {
23-
dim = selfDim.size() + dim;
24-
}
25-
uint32_t shiftDim = 1 << dim;
26-
auto TopKOperation = nvinfer1::TopKOperation::kMAX;
27-
auto new_layer = ctx->net->addTopK(*self, TopKOperation, 1, shiftDim);
28-
TORCHTRT_CHECK(new_layer, "Unable to create max layer from node: " << *n);
17+
{"aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)",
18+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
19+
auto self = args[0].ITensorOrFreeze(ctx);
20+
auto dim = args[1].unwrapToInt();
21+
auto selfDim = util::toVec(self->getDimensions());
22+
if (dim < 0) {
23+
dim = selfDim.size() + dim;
24+
}
25+
uint32_t shiftDim = 1 << dim;
26+
auto TopKOperation = nvinfer1::TopKOperation::kMAX;
27+
auto new_layer = ctx->net->addTopK(*self, TopKOperation, 1, shiftDim);
28+
TORCHTRT_CHECK(new_layer, "Unable to create max layer from node: " << *n);
2929

30-
auto out0 = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0));
31-
auto out1 = ctx->AssociateValueAndTensor(n->outputs()[1], new_layer->getOutput(1));
30+
auto out0 = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0));
31+
auto out1 = ctx->AssociateValueAndTensor(n->outputs()[1], new_layer->getOutput(1));
3232

33-
LOG_DEBUG("Output tensor(0) shape: " << out0->getDimensions());
34-
LOG_DEBUG("Output tensor(1) shape: " << out1->getDimensions());
33+
LOG_DEBUG("Output tensor(0) shape: " << out0->getDimensions());
34+
LOG_DEBUG("Output tensor(1) shape: " << out1->getDimensions());
3535

36-
return true;
37-
}});
36+
return true;
37+
}});
3838
} // namespace
3939
} // namespace impl
4040
} // namespace converters

0 commit comments

Comments
 (0)