From 84964a5125eb455dda65eba7e75114687eb8e2c6 Mon Sep 17 00:00:00 2001 From: Michael Feliz Date: Tue, 28 Feb 2023 13:09:09 -0800 Subject: [PATCH] Add converter for aten::any --- core/conversion/converters/impl/reduce.cpp | 35 ++++++++++++++- .../conversion/converters/test_reduce.cpp | 44 +++++++++++++++++++ 2 files changed, 78 insertions(+), 1 deletion(-) diff --git a/core/conversion/converters/impl/reduce.cpp b/core/conversion/converters/impl/reduce.cpp index 68d4c41fa5..249ae916ef 100644 --- a/core/conversion/converters/impl/reduce.cpp +++ b/core/conversion/converters/impl/reduce.cpp @@ -203,7 +203,8 @@ auto reduce_registrations TORCHTRT_UNUSED = return true; }}) .pattern( - {"aten::min(Tensor self) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + {"aten::min(Tensor self) -> Tensor", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { auto in_tensor = args[0].ITensorOrFreeze(ctx); auto in_dims = util::toVec(in_tensor->getDimensions()); @@ -216,6 +217,38 @@ auto reduce_registrations TORCHTRT_UNUSED = min_layer->setName(util::node_info(n).c_str()); auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], min_layer->getOutput(0)); + LOG_DEBUG("Output shape: " << out_tensor->getDimensions()); + return true; + }}) + .pattern( + {"aten::any.dim(Tensor self, int dim, bool keepdim=False) -> Tensor", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto in_tensor = args[0].ITensorOrFreeze(ctx); + auto in_dims = in_tensor->getDimensions(); + auto dim = args[1].unwrapToInt(); + LOG_DEBUG("Dim to reduce (original): " << dim); + dim = dim < 0 ? (in_dims.nbDims + dim) : dim; + LOG_DEBUG("Dim to reduce (converted): " << dim); + + uint32_t axis_mask = 1 << dim; + LOG_DEBUG("Axis Mask: " << std::bitset<32>(axis_mask)); + + auto keepdim = args[2].unwrapToBool(); + LOG_DEBUG("Keep dims: " << keepdim); + + // Reduce does not work on bool inputs + if (in_tensor->getType() == nvinfer1::DataType::kBOOL) { + in_tensor = + castITensor(ctx, in_tensor, nvinfer1::DataType::kINT32, (util::node_info(n) + "_in").c_str()); + } + auto sum_layer = ctx->net->addReduce(*in_tensor, nvinfer1::ReduceOperation::kSUM, axis_mask, keepdim); + + TORCHTRT_CHECK(sum_layer, "Unable to create sum layer from node: " << *n); + + sum_layer->setName(util::node_info(n).c_str()); + auto out_tensor = castITensor( + ctx, sum_layer->getOutput(0), nvinfer1::DataType::kBOOL, (util::node_info(n) + "_out").c_str()); + out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], out_tensor); LOG_DEBUG("Output shape: " << out_tensor->getDimensions()); return true; }}); diff --git a/tests/core/conversion/converters/test_reduce.cpp b/tests/core/conversion/converters/test_reduce.cpp index 3bcef3db77..40835a8dea 100644 --- a/tests/core/conversion/converters/test_reduce.cpp +++ b/tests/core/conversion/converters/test_reduce.cpp @@ -300,6 +300,50 @@ TEST(Converters, ATenMeanDimNegIndexKeepDimsConvertsCorrectly) { test_body(graph, in); } +TEST(Converters, ATenAnyDimConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor): + %1 : int = prim::Constant[value=1]() + %3 : bool = prim::Constant[value=0]() + %5 : Tensor = aten::any(%0, %1, %3) + return (%5))IR"; + auto in = at::randint(0, 2, {4, 4, 4}, at::kCUDA); + test_body(graph, in); +} + +TEST(Converters, ATenAnyDimAllFalseConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor): + %1 : int = prim::Constant[value=2]() + %3 : bool = prim::Constant[value=0]() + %5 : Tensor = aten::any(%0, %1, %3) + return (%5))IR"; + auto in = at::zeros({3, 7, 4}, at::kCUDA).to(torch::kBool); + test_body(graph, in); +} + +TEST(Converters, ATenAnyDimKeepDimConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor): + %1 : int = prim::Constant[value=1]() + %3 : bool = prim::Constant[value=1]() + %5 : Tensor = aten::any(%0, %1, %3) + return (%5))IR"; + auto in = at::randint(0, 2, {4, 4, 4}, at::kCUDA).to(torch::kHalf); + test_body(graph, in); +} + +TEST(Converters, ATenAnyDimNegIndexConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor): + %1 : int = prim::Constant[value=-1]() + %3 : bool = prim::Constant[value=1]() + %5 : Tensor = aten::any(%0, %1, %3) + return (%5))IR"; + auto in = at::randint(-2, 2, {2, 32}, at::kCUDA); + test_body(graph, in); +} + TEST(Converters, UnpackVarLowersCorrectly) { const auto graph = R"IR( graph(%x.1 : Tensor):