diff --git a/core/conversion/converters/BUILD b/core/conversion/converters/BUILD index 8b5338c08b..f9948eea7d 100755 --- a/core/conversion/converters/BUILD +++ b/core/conversion/converters/BUILD @@ -79,6 +79,7 @@ cc_library( "impl/squeeze.cpp", "impl/stack.cpp", "impl/topk.cpp", + "impl/max.cpp", "impl/unary.cpp", "impl/unsqueeze.cpp", ], diff --git a/core/conversion/converters/impl/max.cpp b/core/conversion/converters/impl/max.cpp new file mode 100644 index 0000000000..adc8d06ed0 --- /dev/null +++ b/core/conversion/converters/impl/max.cpp @@ -0,0 +1,43 @@ +#include "NvInfer.h" +#include "core/conversion/converters/converters.h" +#include "core/conversion/tensorcontainer/TensorContainer.h" +#include "core/util/prelude.h" +#include "torch/torch.h" + +#include +#include + +namespace torch_tensorrt { +namespace core { +namespace conversion { +namespace converters { +namespace impl { +namespace { +auto max_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern( + {"aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto self = args[0].ITensorOrFreeze(ctx); + auto dim = args[1].unwrapToInt(); + auto selfDim = util::toVec(self->getDimensions()); + if (dim < 0) { + dim = selfDim.size() + dim; + } + uint32_t shiftDim = 1 << dim; + auto TopKOperation = nvinfer1::TopKOperation::kMAX; + auto new_layer = ctx->net->addTopK(*self, TopKOperation, 1, shiftDim); + TORCHTRT_CHECK(new_layer, "Unable to create max layer from node: " << *n); + + auto out0 = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0)); + auto out1 = ctx->AssociateValueAndTensor(n->outputs()[1], new_layer->getOutput(1)); + + LOG_DEBUG("Output tensor(0) shape: " << out0->getDimensions()); + LOG_DEBUG("Output tensor(1) shape: " << out1->getDimensions()); + + return true; + }}); +} // namespace +} // namespace impl +} // namespace converters +} // namespace conversion +} // namespace core +} // namespace torch_tensorrt