diff --git a/core/conversion/conversion.cpp b/core/conversion/conversion.cpp index 5a965878a5..f75cefb4fb 100644 --- a/core/conversion/conversion.cpp +++ b/core/conversion/conversion.cpp @@ -7,6 +7,9 @@ #include "core/conversion/var/Var.h" #include "core/util/prelude.h" +#include "c10/util/intrusive_ptr.h" +#include "core/conversion/tensorcontainer/TensorContainer.h" + namespace trtorch { namespace core { namespace conversion { @@ -173,18 +176,32 @@ void AddInputs(ConversionCtx* ctx, at::ArrayRef inputs void MarkOutputs(ConversionCtx* ctx, at::ArrayRef outputs) { for (auto out : outputs) { - std::string name = std::string("output_") + std::to_string(ctx->num_outputs); auto it = ctx->value_tensor_map.find(out); - // Leaves the potential for unused outputs to be populated with nullptr - // "safely" - TRTORCH_CHECK( - it != ctx->value_tensor_map.end() && it->second, - "No corresponding output TRT Tensor found for TorchScript output: " << out->debugName()); - auto out_tensor = it->second; - out_tensor->setName(name.c_str()); - ctx->net->markOutput(*out_tensor); - LOG_INFO(ctx->logger, "Marking Output " << out->debugName() << " named " << name << " in engine (ctx.MarkOutput)"); - ctx->num_outputs += 1; + if (it == ctx->value_tensor_map.end()) { + if (ctx->evaluated_value_map.find(out) != ctx->evaluated_value_map.end()) { + auto out_ivalue = ctx->evaluated_value_map[out]; + if (out_ivalue.isCustomClass()) { + std::string name = std::string("output_") + std::to_string(ctx->num_outputs); + auto output_container = out_ivalue.toCustomClass(); + nvinfer1::ITensor* out_tensor = output_container.get()->tensor(); + out_tensor->setName(name.c_str()); + ctx->net->markOutput(*out_tensor); + LOG_INFO( + ctx->logger, "Marking Output " << out->debugName() << " named " << name << " in engine (ctx.MarkOutput)"); + ctx->num_outputs += 1; + } else { + TRTORCH_THROW_ERROR("Unknown output type. Only a single tensor or a TensorList type is supported."); + } + } + } else { + std::string name = std::string("output_") + std::to_string(ctx->num_outputs); + auto out_tensor = it->second; + out_tensor->setName(name.c_str()); + ctx->net->markOutput(*out_tensor); + LOG_INFO( + ctx->logger, "Marking Output " << out->debugName() << " named " << name << " in engine (ctx.MarkOutput)"); + ctx->num_outputs += 1; + } } } @@ -337,12 +354,30 @@ void ConvertBlockToNetDef( } else if (to_eval) { auto eval = EvaluateNode(ctx, n); if (eval) { - if (!eval.value().isTensor()) { + if (n->outputs().size() > 1) { // For ListUnpack scenario + if (eval.value().isTuple()) { + auto eval_list = eval.value().toTuple(); + TRTORCH_CHECK( + eval_list->elements().size() == n->outputs().size(), + "Size of evaluated results: " << eval_list->elements().size() + << " and node outputs size: " << n->outputs().size() << " must match."); + for (int i = 0; i < eval_list->elements().size(); i++) { + auto eval_output = eval_list.get()->elements()[i]; + LOG_DEBUG( + ctx->logger, + "Found the evaluated value(s) to be " << eval_output << " for node: " << util::node_info(n)); + ctx->AssociateValueAndIValue(n->output(i), eval_output); + } + } else { + TRTORCH_THROW_ERROR("Unsupported return type for evaluated node"); + } + } else if (!eval.value().isTensor()) { LOG_DEBUG(ctx->logger, "Found the value to be: " << eval.value()); + ctx->AssociateValueAndIValue(n->output(0), eval.value()); } else { LOG_DEBUG(ctx->logger, "Found the value to be a tensor (shape " << eval.value().toTensor().sizes() << ')'); + ctx->AssociateValueAndIValue(n->output(0), eval.value()); } - ctx->AssociateValueAndIValue(n->output(0), eval.value()); } } else if (!ignored) { // Should error out if something fails diff --git a/core/conversion/converters/impl/select.cpp b/core/conversion/converters/impl/select.cpp index e3e105819d..3c36999a7c 100644 --- a/core/conversion/converters/impl/select.cpp +++ b/core/conversion/converters/impl/select.cpp @@ -1,11 +1,12 @@ +#include +#include #include "NvInfer.h" +#include "c10/util/intrusive_ptr.h" #include "core/conversion/converters/converters.h" +#include "core/conversion/tensorcontainer/TensorContainer.h" #include "core/util/prelude.h" #include "torch/torch.h" -#include -#include - namespace trtorch { namespace core { namespace conversion { @@ -13,6 +14,53 @@ namespace converters { namespace impl { namespace { +bool add_split(ConversionCtx* ctx, const torch::jit::Node* n, args& args, bool split_list) { + auto in = args[0].ITensor(); + auto axis = args[2].unwrapToInt(); + auto inDimSize = in->getDimensions().d[axis]; + auto numOutputs = 1; + std::vector sizes; + + if (split_list) { + sizes = args[1].unwrapToIntList().vec(); + numOutputs = sizes.size(); + } else { + auto split_size = args[1].unwrapToInt(); + numOutputs = inDimSize / split_size; + if (numOutputs == 1) { + sizes.push_back(split_size); + } else { + sizes = std::vector(numOutputs, 1); + } + } + + LOG_DEBUG("Number of split outputs: " << numOutputs); + + c10::ListTypePtr lt = n->output()->type()->expect(); + c10::TypePtr elementType = lt->getElementType(); + auto list = c10::impl::GenericList(elementType); + list.reserve(numOutputs); + + int start_idx = 0; + for (int i = 0; i < numOutputs; i++) { + at::Tensor indices = torch::arange(start_idx, start_idx + sizes[i], 1).to(torch::kI32); + auto indicesTensor = tensor_to_const(ctx, indices); + + auto gather_layer = ctx->net->addGather(*in, *indicesTensor, axis); + auto gather_out = gather_layer->getOutput(0); + + auto tensor_holder = TensorContainer(); + tensor_holder.hold_tensor(gather_out); + auto ival = c10::IValue(std::move(c10::make_intrusive(tensor_holder))); + list.emplace_back(ival); + + start_idx = start_idx + sizes[i]; + } + + auto split_output_ivalue = std::move(torch::jit::IValue(list)); + auto out = ctx->AssociateValueAndIValue(n->outputs()[0], split_output_ivalue); +} + auto select_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns() .pattern({"aten::select.int(Tensor(a) self, int dim, int index) -> (Tensor(a))", @@ -172,11 +220,29 @@ auto select_registrations TRTORCH_UNUSED = LOG_DEBUG("Slice layer output shape: " << out->getDimensions()); return true; - }}); + }}) + .pattern({"aten::split(Tensor self, int[] split_sizes, int dim=0) -> (Tensor[])", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + add_split(ctx, n, args, true); + LOG_DEBUG("Converted split op into a list of IValues"); + return true; + }}) + .pattern({"aten::split.Tensor(Tensor(a) self, int split_size, int dim=0) -> (Tensor[])", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + add_split(ctx, n, args, false); + LOG_DEBUG("Converted split op into a list of IValues"); + return true; + }}) + .pattern({"aten::split_with_sizes(Tensor(a) self, int[] split_sizes, int dim=0) -> (Tensor[])", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + add_split(ctx, n, args, true); + LOG_DEBUG("Converted split op into a list of IValues"); + return true; + }}); } // namespace } // namespace impl } // namespace converters } // namespace conversion } // namespace core -} // namespace trtorch \ No newline at end of file +} // namespace trtorch diff --git a/core/conversion/evaluators/prim.cpp b/core/conversion/evaluators/prim.cpp index 7ef0e070c3..66cb884317 100644 --- a/core/conversion/evaluators/prim.cpp +++ b/core/conversion/evaluators/prim.cpp @@ -32,6 +32,13 @@ auto prim_registrations = [](const torch::jit::Node* n, kwargs& args) -> c10::optional { return at::scalar_to_tensor(args.at(n->output(0)).IValue()->toScalar()); }}) + .evaluator({torch::jit::prim::ListUnpack, + [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + // Outputs is an IValue which has list of tensors which can be found in ctx->evaluated_value_map + const torch::jit::IValue* outputs = args.at(n->input()).IValue(); + auto outputVec = outputs->toList().vec(); + return std::move(c10::ivalue::Tuple::create(outputVec)); + }}) .evaluator({torch::jit::prim::ListConstruct, [](const torch::jit::Node* n, kwargs& args) -> c10::optional { const auto num_inputs = n->inputs().size(); diff --git a/tests/core/conversion/converters/test_select.cpp b/tests/core/conversion/converters/test_select.cpp index e90a35147c..c2c66aa8e6 100644 --- a/tests/core/conversion/converters/test_select.cpp +++ b/tests/core/conversion/converters/test_select.cpp @@ -204,4 +204,87 @@ TEST(Converters, ATenSliceNegEndIndexConvertsCorrectly) { auto trt = trt_results[0].reshape(jit_results[0].sizes()); ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); -} \ No newline at end of file +} + +TEST(Converters, ATenSplitSizesInScriptingConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : int[] = prim::Constant[value=[1, 2]]() + %3 : int = prim::Constant[value=1]() + %4 : Tensor[] = aten::split(%x.1, %2, %3) + %x1.1 : Tensor, %x2.1 : Tensor = prim::ListUnpack(%4) + return (%x1.1, %x2.1))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, &*g); + + auto in = at::randint(1, 10, {1, 3, 4, 4}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in}); + + for (int i = 0; i < jit_results.size(); i++) { + auto trt = trt_results[i].reshape(jit_results[i].sizes()); + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[i], trt, 2e-6)); + } +} + +TEST(Converters, ATenSplitSizesinTracingConvertsCorrectly) { + const auto graph = R"IR( + graph(%argument_1.1 : Tensor): + %2 : int[] = prim::Constant[value=[1, 2]]() + %3 : int = prim::Constant[value=1]() + %4 : Tensor[] = aten::split_with_sizes(%argument_1.1, %2, %3) + %5 : Tensor, %6 : Tensor = prim::ListUnpack(%4) + return (%5, %6))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, &*g); + + auto in = at::randint(1, 10, {1, 3, 4, 4}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in}); + + for (int i = 0; i < jit_results.size(); i++) { + auto trt = trt_results[i].reshape(jit_results[i].sizes()); + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[i], trt, 2e-6)); + } +} + +TEST(Converters, ATenSplitFixedConvertsCorrectly) { + const auto graph = R"IR( + graph(%argument_1.1 : Tensor): + %2 : int = prim::Constant[value=1]() + %3 : Tensor[] = aten::split(%argument_1.1, %2, %2) + %4 : Tensor, %5 : Tensor, %6 : Tensor = prim::ListUnpack(%3) + return (%4, %5, %6))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, &*g); + + auto in = at::randint(1, 10, {1, 3, 4, 4}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in}); + + for (int i = 0; i < jit_results.size(); i++) { + auto trt = trt_results[i].reshape(jit_results[i].sizes()); + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[i], trt, 2e-6)); + } +} diff --git a/tests/core/conversion/evaluators/test_prim_evaluators.cpp b/tests/core/conversion/evaluators/test_prim_evaluators.cpp index f889a171e8..37aa38e281 100644 --- a/tests/core/conversion/evaluators/test_prim_evaluators.cpp +++ b/tests/core/conversion/evaluators/test_prim_evaluators.cpp @@ -17,4 +17,23 @@ TEST(Evaluators, PrimConstantEvaluatesCorrectly) { auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {}); ASSERT_TRUE(jit_results[0] == trt_results[0]); -} \ No newline at end of file +} + +TEST(Evaluators, PrimListUnpackEvaluatesCorrectly) { + const auto graph = R"IR( + graph(): + %1 : int = prim::Constant[value=3]() + %2 : int = prim::Constant[value=4]() + %lc : int[] = prim::ListConstruct(%1, %2) + %lu.1 : int, %lu.2 : int = prim::ListUnpack(%lc) + return (%lu.1, %lu.2))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, &*g); + + auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {}); + auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {}); + + ASSERT_TRUE(jit_results[0] == trt_results[0]); + ASSERT_TRUE(jit_results[1] == trt_results[1]); +} diff --git a/tests/util/evaluate_graph.cpp b/tests/util/evaluate_graph.cpp index 997ae32c4d..65e4449b5d 100644 --- a/tests/util/evaluate_graph.cpp +++ b/tests/util/evaluate_graph.cpp @@ -5,6 +5,7 @@ #include "core/conversion/converters/converters.h" #include "core/conversion/evaluators/evaluators.h" #include "core/conversion/var/Var.h" +#include "core/util/jit_util.h" #include "core/util/prelude.h" namespace trtorch { @@ -20,20 +21,31 @@ std::vector EvaluateGraph(const torch::jit::Block* b, std::v for (size_t i = 0; i < inputs.size(); i++) { ctx->AssociateValueAndIValue(b->inputs()[i], inputs[i]); } - + LOG_DEBUG("Checking nodes"); for (const auto n : b->nodes()) { TRTORCH_CHECK( core::conversion::evaluators::shouldEvalAtConversionTime(n), "Test graph contains non evaluatable nodes: " << *n); auto eval = core::conversion::EvaluateNode(ctx, n); if (eval) { - if (!eval.value().isTensor()) { + if (eval.value().isTuple()) { + auto eval_list = eval.value().toTuple(); + for (int i = 0; i < eval_list->elements().size(); i++) { + auto eval_output = eval_list.get()->elements()[i]; + LOG_DEBUG( + ctx->logger, + "Found the evaluated value(s) to be " << eval_output + << " for node: " << trtorch::core::util::node_info(n)); + ctx->AssociateValueAndIValue(n->output(i), eval_output); + } + } else if (!eval.value().isTensor()) { LOG_DEBUG("Found the value to be: " << eval.value()); + ctx->AssociateValueAndIValue(n->output(0), eval.value()); } else { LOG_DEBUG("Found the value to be a tensor (shape " << eval.value().toTensor().sizes() << ')'); + ctx->AssociateValueAndIValue(n->output(0), eval.value()); } } - ctx->AssociateValueAndIValue(n->output(0), eval.value()); } std::vector outputs;