diff --git a/core/conversion/conversion.cpp b/core/conversion/conversion.cpp index fb42a41dbf..9ad77aa056 100644 --- a/core/conversion/conversion.cpp +++ b/core/conversion/conversion.cpp @@ -174,23 +174,34 @@ void AddInputs(ConversionCtx* ctx, at::ArrayRef inputs #endif } +void MarkIValueOutputs(ConversionCtx* ctx, c10::IValue out_ivalue, const torch::jit::Value* out) { + if (out_ivalue.isCustomClass()) { + std::string name = std::string("output_") + std::to_string(ctx->num_outputs); + auto output_container = out_ivalue.toCustomClass(); + nvinfer1::ITensor* out_tensor = output_container.get()->tensor(); + out_tensor->setName(name.c_str()); + ctx->net->markOutput(*out_tensor); + LOG_INFO( + ctx->logger, "Marking Output " << out->debugName() << " named " << name << " in engine (ctx.MarkOutput)"); + ctx->num_outputs += 1; + } else { + TRTORCH_THROW_ERROR("Unsupported output type, only Tensors or unwrapped collections of Tensors can be marked as engine outputs but found type: " << out_ivalue.tagKind()); + } +} + void MarkOutputs(ConversionCtx* ctx, at::ArrayRef outputs) { for (auto out : outputs) { auto it = ctx->value_tensor_map.find(out); if (it == ctx->value_tensor_map.end()) { if (ctx->evaluated_value_map.find(out) != ctx->evaluated_value_map.end()) { auto out_ivalue = ctx->evaluated_value_map[out]; - if (out_ivalue.isCustomClass()) { - std::string name = std::string("output_") + std::to_string(ctx->num_outputs); - auto output_container = out_ivalue.toCustomClass(); - nvinfer1::ITensor* out_tensor = output_container.get()->tensor(); - out_tensor->setName(name.c_str()); - ctx->net->markOutput(*out_tensor); - LOG_INFO( - ctx->logger, "Marking Output " << out->debugName() << " named " << name << " in engine (ctx.MarkOutput)"); - ctx->num_outputs += 1; + if (out_ivalue.isList()) { + c10::List value_list = out_ivalue.toList(); + for(auto it = value_list.begin(); it != value_list.end(); it++) { + MarkIValueOutputs(ctx, *it, out); + } } else { - TRTORCH_THROW_ERROR("Unknown output type. Only a single tensor or a TensorList type is supported."); + MarkIValueOutputs(ctx, out_ivalue, out); } } } else {