diff --git a/core/conversion/converters/BUILD b/core/conversion/converters/BUILD index 8b5338c08b..f9948eea7d 100755 --- a/core/conversion/converters/BUILD +++ b/core/conversion/converters/BUILD @@ -79,6 +79,7 @@ cc_library( "impl/squeeze.cpp", "impl/stack.cpp", "impl/topk.cpp", + "impl/max.cpp", "impl/unary.cpp", "impl/unsqueeze.cpp", ], diff --git a/core/conversion/converters/impl/max.cpp b/core/conversion/converters/impl/max.cpp new file mode 100644 index 0000000000..adc8d06ed0 --- /dev/null +++ b/core/conversion/converters/impl/max.cpp @@ -0,0 +1,43 @@ +#include "NvInfer.h" +#include "core/conversion/converters/converters.h" +#include "core/conversion/tensorcontainer/TensorContainer.h" +#include "core/util/prelude.h" +#include "torch/torch.h" + +#include +#include + +namespace torch_tensorrt { +namespace core { +namespace conversion { +namespace converters { +namespace impl { +namespace { +auto max_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern( + {"aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto self = args[0].ITensorOrFreeze(ctx); + auto dim = args[1].unwrapToInt(); + auto selfDim = util::toVec(self->getDimensions()); + if (dim < 0) { + dim = selfDim.size() + dim; + } + uint32_t shiftDim = 1 << dim; + auto TopKOperation = nvinfer1::TopKOperation::kMAX; + auto new_layer = ctx->net->addTopK(*self, TopKOperation, 1, shiftDim); + TORCHTRT_CHECK(new_layer, "Unable to create max layer from node: " << *n); + + auto out0 = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0)); + auto out1 = ctx->AssociateValueAndTensor(n->outputs()[1], new_layer->getOutput(1)); + + LOG_DEBUG("Output tensor(0) shape: " << out0->getDimensions()); + LOG_DEBUG("Output tensor(1) shape: " << out1->getDimensions()); + + return true; + }}); +} // namespace +} // namespace impl +} // namespace converters +} // namespace conversion +} // namespace core +} // namespace torch_tensorrt diff --git a/tests/core/conversion/converters/test_topk.cpp b/tests/core/conversion/converters/test_topk.cpp index dddddc9d2c..608e397a1f 100644 --- a/tests/core/conversion/converters/test_topk.cpp +++ b/tests/core/conversion/converters/test_topk.cpp @@ -30,3 +30,28 @@ TEST(Converters, ATenTopKConvertsCorrectly) { ASSERT_TRUE( torch_tensorrt::tests::util::almostEqual(jit_results[1], trt_results[1].reshape_as(jit_results[1]), 2e-6)); } + +TEST(Converters, ATenMaxDimConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : int = prim::Constant[value=0]() + %3 : bool = prim::Constant[value=0]() + %4 : Tensor, %5 : Tensor = aten::max(%x.1, %2, %3) + return (%4, %5))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in = at::rand({2, 3, 5, 5}, {at::kCUDA}); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_results[1], trt_results[1].reshape_as(jit_results[1]), 2e-6)); +}