From bb3263ffae1010b07fa04e028b16bc23d2c1e235 Mon Sep 17 00:00:00 2001 From: Michael Feliz Date: Wed, 19 Apr 2023 14:30:17 -0700 Subject: [PATCH 1/2] Add support for aten::all (#128) * Add support for aten::all * add comment --- core/conversion/converters/impl/reduce.cpp | 76 +++++++++++++------ .../conversion/converters/test_reduce.cpp | 53 ++++++++++++- 2 files changed, 105 insertions(+), 24 deletions(-) diff --git a/core/conversion/converters/impl/reduce.cpp b/core/conversion/converters/impl/reduce.cpp index 249ae916ef..f759f7f00c 100644 --- a/core/conversion/converters/impl/reduce.cpp +++ b/core/conversion/converters/impl/reduce.cpp @@ -9,6 +9,36 @@ namespace converters { namespace impl { namespace { +nvinfer1::ITensor* anyDimImplementation( + ConversionCtx* ctx, + const torch::jit::Node* n, + nvinfer1::ITensor* in_tensor, + int dim, + bool keepdim) { + auto in_dims = in_tensor->getDimensions(); + LOG_DEBUG("Dim to reduce (original): " << dim); + dim = dim < 0 ? (in_dims.nbDims + dim) : dim; + LOG_DEBUG("Dim to reduce (converted): " << dim); + + uint32_t axis_mask = 1 << dim; + LOG_DEBUG("Axis Mask: " << std::bitset<32>(axis_mask)); + LOG_DEBUG("Keep dims: " << keepdim); + + // Reduce does not work on bool inputs + if (in_tensor->getType() == nvinfer1::DataType::kBOOL) { + in_tensor = castITensor(ctx, in_tensor, nvinfer1::DataType::kINT32, (util::node_info(n) + "_in").c_str()); + } + auto sum_layer = ctx->net->addReduce(*in_tensor, nvinfer1::ReduceOperation::kSUM, axis_mask, keepdim); + + TORCHTRT_CHECK(sum_layer, "Unable to create sum layer from node: " << *n); + + sum_layer->setName(util::node_info(n).c_str()); + auto out_tensor = + castITensor(ctx, sum_layer->getOutput(0), nvinfer1::DataType::kBOOL, (util::node_info(n) + "_out").c_str()); + out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], out_tensor); + return out_tensor; +} + auto reduce_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns() .pattern( @@ -224,33 +254,35 @@ auto reduce_registrations TORCHTRT_UNUSED = {"aten::any.dim(Tensor self, int dim, bool keepdim=False) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { auto in_tensor = args[0].ITensorOrFreeze(ctx); - auto in_dims = in_tensor->getDimensions(); auto dim = args[1].unwrapToInt(); - LOG_DEBUG("Dim to reduce (original): " << dim); - dim = dim < 0 ? (in_dims.nbDims + dim) : dim; - LOG_DEBUG("Dim to reduce (converted): " << dim); - - uint32_t axis_mask = 1 << dim; - LOG_DEBUG("Axis Mask: " << std::bitset<32>(axis_mask)); - auto keepdim = args[2].unwrapToBool(); - LOG_DEBUG("Keep dims: " << keepdim); - - // Reduce does not work on bool inputs - if (in_tensor->getType() == nvinfer1::DataType::kBOOL) { - in_tensor = - castITensor(ctx, in_tensor, nvinfer1::DataType::kINT32, (util::node_info(n) + "_in").c_str()); - } - auto sum_layer = ctx->net->addReduce(*in_tensor, nvinfer1::ReduceOperation::kSUM, axis_mask, keepdim); - - TORCHTRT_CHECK(sum_layer, "Unable to create sum layer from node: " << *n); - - sum_layer->setName(util::node_info(n).c_str()); - auto out_tensor = castITensor( - ctx, sum_layer->getOutput(0), nvinfer1::DataType::kBOOL, (util::node_info(n) + "_out").c_str()); + auto out_tensor = anyDimImplementation(ctx, n, in_tensor, dim, keepdim); out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], out_tensor); LOG_DEBUG("Output shape: " << out_tensor->getDimensions()); return true; + }}) + .pattern( + {"aten::all.dim(Tensor self, int dim, bool keepdim=False) -> Tensor", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + //use Not(Any(Not(input))) to calculate all without a direct all reduction + auto in_tensor = args[0].ITensorOrFreeze(ctx); + auto dim = args[1].unwrapToInt(); + auto keepdim = args[2].unwrapToBool(); + if (in_tensor->getType() != nvinfer1::DataType::kBOOL) { + // unary not layer only supports bool inputs + in_tensor = castITensor( + ctx, in_tensor, nvinfer1::DataType::kBOOL, (util::node_info(n) + "_in_to_bool").c_str()); + } + auto not_input_layer = ctx->net->addUnary(*in_tensor, nvinfer1::UnaryOperation::kNOT); + TORCHTRT_CHECK(not_input_layer, "Unable to create logical_not layer from node: " << *n); + not_input_layer->setName((util::node_info(n) + "_not_in").c_str()); + auto not_in = not_input_layer->getOutput(0); + auto any_out = anyDimImplementation(ctx, n, not_in, dim, keepdim); + auto not_output_layer = ctx->net->addUnary(*any_out, nvinfer1::UnaryOperation::kNOT); + TORCHTRT_CHECK(not_output_layer, "Unable to create logical_not layer from node: " << *n); + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], not_output_layer->getOutput(0)); + LOG_DEBUG("Output shape: " << out_tensor->getDimensions()); + return true; }}); } // namespace } // namespace impl diff --git a/tests/core/conversion/converters/test_reduce.cpp b/tests/core/conversion/converters/test_reduce.cpp index 40835a8dea..47e8b8d154 100644 --- a/tests/core/conversion/converters/test_reduce.cpp +++ b/tests/core/conversion/converters/test_reduce.cpp @@ -62,7 +62,7 @@ std::string gen_keepdim_graph(const std::string& op) { return (%5))IR"; } -void test_body(const std::string& graph, at::Tensor& in) { +void test_body(const std::string& graph, at::Tensor& in, bool dynamic = false) { auto g = std::make_shared(); torch::jit::parseIR(graph, g.get()); @@ -71,7 +71,12 @@ void test_body(const std::string& graph, at::Tensor& in) { in = at::clone(in); params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); + std::vector trt_results; + if (dynamic) { + trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}); + } else { + trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); + } ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); } } // namespace @@ -344,6 +349,50 @@ TEST(Converters, ATenAnyDimNegIndexConvertsCorrectly) { test_body(graph, in); } +TEST(Converters, ATenAllDimConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor): + %1 : int = prim::Constant[value=-1]() + %3 : bool = prim::Constant[value=0]() + %5 : Tensor = aten::all(%0, %1, %3) + return (%5))IR"; + auto in = at::randint(0, 2, {64, 2}, at::kCUDA); + test_body(graph, in); +} + +TEST(Converters, ATenAllDimKeepDimConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor): + %1 : int = prim::Constant[value=0]() + %3 : bool = prim::Constant[value=1]() + %5 : Tensor = aten::all(%0, %1, %3) + return (%5))IR"; + auto in = at::randint(-2, 2, {2, 32}, at::kCUDA).to(torch::kBool); + test_body(graph, in); +} + +TEST(Converters, ATenAllDimAllTrueConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor): + %1 : int = prim::Constant[value=1]() + %3 : bool = prim::Constant[value=0]() + %5 : Tensor = aten::all(%0, %1, %3) + return (%5))IR"; + auto in = at::ones({2, 32}, at::kCUDA); + test_body(graph, in); +} + +TEST(Converters, ATenAllDimDynamicConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor): + %1 : int = prim::Constant[value=-1]() + %3 : bool = prim::Constant[value=0]() + %5 : Tensor = aten::all(%0, %1, %3) + return (%5))IR"; + auto in = at::randint(0, 2, {64, 2}, at::kCUDA).to(torch::kHalf); + test_body(graph, in, true); +} + TEST(Converters, UnpackVarLowersCorrectly) { const auto graph = R"IR( graph(%x.1 : Tensor): From 7b0ad9e9c35f06b79fe666a8b0c8f9f58c68f29f Mon Sep 17 00:00:00 2001 From: Michael Feliz Date: Wed, 19 Apr 2023 14:37:16 -0700 Subject: [PATCH 2/2] lint --- core/conversion/converters/impl/reduce.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/conversion/converters/impl/reduce.cpp b/core/conversion/converters/impl/reduce.cpp index f759f7f00c..e3c7498c47 100644 --- a/core/conversion/converters/impl/reduce.cpp +++ b/core/conversion/converters/impl/reduce.cpp @@ -264,7 +264,7 @@ auto reduce_registrations TORCHTRT_UNUSED = .pattern( {"aten::all.dim(Tensor self, int dim, bool keepdim=False) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - //use Not(Any(Not(input))) to calculate all without a direct all reduction + // use Not(Any(Not(input))) to calculate all without a direct all reduction auto in_tensor = args[0].ITensorOrFreeze(ctx); auto dim = args[1].unwrapToInt(); auto keepdim = args[2].unwrapToBool();