diff --git a/core/compiler.cpp b/core/compiler.cpp index fc1cc66aee..7b58dbb2c1 100644 --- a/core/compiler.cpp +++ b/core/compiler.cpp @@ -256,6 +256,7 @@ GraphAndMapping ConstructFallbackGraph( // update the input ranges for each segments convert_cfg.inputs = ir::associate_specs_with_inputs(seg_block.g(), inputs, static_params); + // TODO mapping Inputs Ivalue to flatten one here auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, static_params); auto temp_g = std::make_shared(); auto device_spec = convert_cfg.engine_settings.device; @@ -306,40 +307,62 @@ void MapInputsAndDetermineDTypes( CompileSpec& cfg, std::shared_ptr& g, ir::StaticParams& static_params, - ir::TypeMap& first_use_type_map) { - // Associate input specs with inputs - cfg.convert_info.inputs = std::move(ir::associate_specs_with_inputs(g, cfg.inputs, static_params)); - - for (auto& in : g->inputs()) { - if (static_params.find(in) == static_params.end()) { - ir::Input& spec = cfg.convert_info.inputs.find(in)->second; - auto est_type_opt = first_use_type_map.find(in)->second; - if (est_type_opt && !spec.dtype_is_user_defined) { + ir::CollectionTypeMap& first_use_type_map) { + cfg.convert_info.collection_input_spec_map = + std::move(ir::associate_specs_with_collection_inputs(g, cfg.graph_inputs, static_params)); + + auto collection_inputs = ir::get_collection_inputs(g, static_params); + LOG_DEBUG( + "In MapInputsAndDetermineDTypes, the g->inputs() size is " + << g->inputs().size() << ", CollectionInputSpecMap size is" << collection_inputs.size()); + + for (auto in : collection_inputs) { + std::vector& spec = cfg.convert_info.collection_input_spec_map.find(in)->second; + std::vector> est_type_opt; + + auto est_it = first_use_type_map.find(in); + if (est_it != first_use_type_map.end()) { + est_type_opt = first_use_type_map.find(in)->second; + } + // traverse elements in est_type_out and spec + for (size_t i = 0; i < est_type_opt.size(); i++) { + if (est_type_opt[i] && !spec[i].dtype_is_user_defined) { // If we can calculate the type from the graph and the type was not defined by the user then use the calculated // type LOG_INFO( - "Since input type is not explicitly defined, infering using first tensor calculation\n Found input " - << in->debugName() << " has type " << est_type_opt.value() - << ". If this is incorrect explicitly set dtype for input and file a bug"); - spec.dtype = util::ScalarTypeToTRTDataType(est_type_opt.value()); - } else if (!est_type_opt && !spec.dtype_is_user_defined) { + "Since input type is not explicitly defined, infering using first tensor calculation\n Inferred input " + << in->debugName() << " has type " << est_type_opt[i].value()); + spec[i].dtype = util::ScalarTypeToTRTDataType(est_type_opt[i].value()); + } else if (!est_type_opt[i] && !spec[i].dtype_is_user_defined) { // If we cannot calculate the type and the user did not define the type, then default to FP32 LOG_WARNING( "Cannot infer input type from calcuations in graph for input " << in->debugName() << ". Assuming it is Float32. If not, specify input type explicity"); - spec.dtype = nvinfer1::DataType::kFLOAT; - } else if (spec.dtype_is_user_defined && cfg.partition_info.enabled) { - if (!est_type_opt) { - LOG_INFO("Cannot infer input tensor dtype in graph. Using user provided input dtype settings"); - first_use_type_map[in] = {util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype)}; + spec[i].dtype = nvinfer1::DataType::kFLOAT; + } else if (spec[i].dtype_is_user_defined && cfg.partition_info.enabled) { + if (!est_type_opt[i]) { + LOG_INFO("Cannot infer input tensor dtype in graph, compiler is going to use the user setting"); + std::stringstream ss; + ss << "For input " << in->debugName() << ", found user specified input dtype as "; + ss << cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype; + ss << ". The compiler is going to use the user setting " + << cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype; + auto warn_str = ss.str(); + LOG_WARNING(warn_str); + // Overwrite type map with user settings + first_use_type_map[in][i] = { + util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype)}; + } else { - if (util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype) != est_type_opt.value()) { + if (util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype) != + est_type_opt[i].value()) { std::stringstream ss; ss << "For input " << in->debugName() << ", found user specified input dtype as "; - ss << cfg.convert_info.inputs.find(in)->second.dtype; + ss << cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype; ss << ", however when inspecting the graph, the input type expected was inferred to be "; - ss << est_type_opt.value() << std::endl; - ss << "The compiler is going to use the user setting " << cfg.convert_info.inputs.find(in)->second.dtype; + ss << est_type_opt[i].value() << std::endl; + ss << "The compiler is going to use the user setting " + << cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype; ss << "\nThis conflict may cause an error at runtime due to partial compilation being enabled and therefore\n"; ss << "compatibility with PyTorch's data type convention is required.\n"; ss << "If you do indeed see errors at runtime either:\n"; @@ -347,16 +370,17 @@ void MapInputsAndDetermineDTypes( ss << "- Disable partial compilation by setting require_full_compilation to True"; auto warn_str = ss.str(); LOG_WARNING(warn_str); + // Overwrite type map with user settings + first_use_type_map[in][i] = { + util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype)}; } - // Overwrite type map with user settings - // We use this map for partitiioning since we need c10::ScalarTypes not nvinfer::DataTypes - first_use_type_map[in] = {util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype)}; } } else { // The user defined the type so no changes are necessary } } } + // } } std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg) { @@ -370,7 +394,7 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std:: auto params = graph_and_parameters.second; auto static_params = ir::get_static_params(g->inputs(), params); // Infer the type of an input from the weights of the calculation - auto first_use_types = ir::get_block_first_calc_dtypes_opt(g->block()); + auto first_use_types = ir::get_block_first_calc_dtypes_opt_collection(g->block()); MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types); @@ -395,10 +419,11 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg) auto params = graph_and_parameters.second; auto static_params = ir::get_static_params(g->inputs(), params); // Infer the type of an input from the weights of the calculation - auto first_use_types = ir::get_block_first_calc_dtypes_opt(g->block()); + auto first_use_types = ir::get_block_first_calc_dtypes_opt_collection(g->block()); MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types); auto isBlockConvertible = conversion::VerifyConverterSupportForBlock(g->block(), true); + auto outputIsCollection = conversion::OutputIsCollection(g->block()); if (cfg.partition_info.enabled && (cfg.lower_info.forced_fallback_modules.size() == 0 && cfg.partition_info.forced_fallback_operators.size() == 0 && isBlockConvertible)) { @@ -406,12 +431,14 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg) } if (cfg.partition_info.enabled && - !(cfg.lower_info.forced_fallback_modules.size() == 0 && - cfg.partition_info.forced_fallback_operators.size() == 0 && isBlockConvertible)) { - auto input_ivalues_map = partitioning::generateRandomInputs(cfg.convert_info.inputs, first_use_types); + (!(cfg.lower_info.forced_fallback_modules.size() == 0 && + cfg.partition_info.forced_fallback_operators.size() == 0 && isBlockConvertible) || + outputIsCollection)) { std::unordered_map fallback_nodes; - auto graph_and_mapping = - ConstructFallbackGraph(new_mod, g->block(), input_ivalues_map, cfg, static_params, fallback_nodes); + auto collection_input_ivalues_map = + partitioning::generateRandomInputs(cfg.convert_info.collection_input_spec_map, first_use_types); + auto graph_and_mapping = ConstructFallbackGraph( + new_mod, g->block(), collection_input_ivalues_map, cfg, static_params, fallback_nodes); new_g = graph_and_mapping.first; // renaming the input name of graph after fallback to ensure pytorch deserialize it correctly for (size_t i = 0; i < new_g->inputs().size(); ++i) { @@ -429,6 +456,7 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg) TORCHTRT_CHECK( conversion::VerifyConverterSupportForBlock(g->block()), "Not all operations in graph are supported by the compiler"); + // TODO find the right auto engine = conversion::ConvertBlockToEngine(g->block(), cfg.convert_info, static_params); AddEngineToGraph(new_mod, new_g, engine, cuda_device); } diff --git a/core/compiler.h b/core/compiler.h index c1bb85aa3b..c8dc85020b 100644 --- a/core/compiler.h +++ b/core/compiler.h @@ -8,13 +8,15 @@ #include "core/partitioning/partitioning.h" #include "core/runtime/runtime.h" #include "torch/csrc/jit/api/module.h" +#include "torch/csrc/jit/ir/ir.h" namespace torch_tensorrt { namespace core { struct CompileSpec { - CompileSpec(std::vector inputs) : inputs(inputs) {} - std::vector inputs; + CompileSpec(std::vector inputs) : graph_inputs(inputs) {} + CompileSpec(torch::jit::IValue& input_signature) : graph_inputs(input_signature) {} + ir::GraphInputs graph_inputs; conversion::ConversionInfo convert_info; lowering::LowerInfo lower_info; partitioning::PartitionInfo partition_info; diff --git a/core/conversion/conversion.cpp b/core/conversion/conversion.cpp index 3211e7dd98..5f4b20e1b3 100644 --- a/core/conversion/conversion.cpp +++ b/core/conversion/conversion.cpp @@ -135,10 +135,11 @@ void AddLayer(ConversionCtx* ctx, const torch::jit::Node* n) { << "please report this error to https://www.github.com/NVIDIA/Torch-TensorRT/issues"); } -void AddInputs( - ConversionCtx* ctx, - c10::ArrayRef inputs, - std::unordered_map& input_specs) { +void AddInputs(ConversionCtx* ctx, c10::ArrayRef inputs, ConversionInfo& conversion_info) { + std::unordered_map& input_specs = conversion_info.inputs; + std::unordered_map> collection_input_spec = + conversion_info.collection_input_spec_map; + std::vector input_tensors; for (auto in : inputs) { // Disregarding inputs that are not tensors @@ -166,9 +167,15 @@ void AddInputs( for (auto input : input_tensors) { const torch::jit::Value* in = input; TORCHTRT_CHECK( - input_specs.find(in) != input_specs.end(), + input_specs.find(in) != input_specs.end() || collection_input_spec.find(in) != collection_input_spec.end(), "Cannot find an input spec associated with input: " << in->debugName()); - ir::Input& spec = input_specs.find(in)->second; + ir::Input spec; + if (input_specs.find(in) != input_specs.end()) { + spec = input_specs.find(in)->second; + } else { + spec = collection_input_spec.find(in)->second[0]; // assume input is tensor + } + // ir::Input& spec = input_specs.find(in)->second; std::string name = std::string("input_") + std::to_string(ctx->num_inputs); LOG_INFO( @@ -408,7 +415,7 @@ void ConvertBlockToNetDef( auto inputs = b->inputs(); AddParamsToCtxValueMap(ctx, static_params); - AddInputs(ctx, inputs, build_info.inputs); + AddInputs(ctx, inputs, build_info); auto nodes = b->nodes(); @@ -549,6 +556,16 @@ std::set ConvertableOpsInBlock(const torch::jit::Block* b) { return convertable_ops; } +bool OutputIsCollection(const torch::jit::Block* b) { + for (auto out : b->outputs()) { + if (out->type()->kind() == torch::jit::TypeKind::TupleType || + out->type()->kind() == torch::jit::TypeKind::ListType) { + return true; + } + } + return false; +} + bool VerifyConverterSupportForBlock(const torch::jit::Block* b, bool suppress_errors) { auto unsupported_ops = GetUnsupportedOpsInBlock(b); if (unsupported_ops.size() != 0) { diff --git a/core/conversion/conversion.h b/core/conversion/conversion.h index 58c06b42a3..a578c4288e 100644 --- a/core/conversion/conversion.h +++ b/core/conversion/conversion.h @@ -13,6 +13,7 @@ namespace conversion { struct ConversionInfo { ir::InputSpecMap inputs; + ir::CollectionInputSpecMap collection_input_spec_map; BuilderSettings engine_settings; }; @@ -25,6 +26,8 @@ std::string ConvertBlockToEngine( bool OpSupported(const torch::jit::Node* n); +bool OutputIsCollection(const torch::jit::Block* b); + bool VerifyConverterSupportForBlock(const torch::jit::Block* b, bool suppress_errors = false); c10::optional EvaluateNode( diff --git a/core/conversion/converters/converter_util.cpp b/core/conversion/converters/converter_util.cpp index 745261589e..94ac827ef4 100644 --- a/core/conversion/converters/converter_util.cpp +++ b/core/conversion/converters/converter_util.cpp @@ -65,6 +65,13 @@ nvinfer1::ILayer* add_elementwise( nvinfer1::ITensor* self, nvinfer1::ITensor* other, const std::string& name) { + if (self->getType() == nvinfer1::DataType::kFLOAT && other->getType() == nvinfer1::DataType::kINT32) { + LOG_DEBUG("Type mismatch, casting other to " << self->getType()); + other = castITensor(ctx, other, self->getType()); + } else if (self->getType() == nvinfer1::DataType::kINT32 && other->getType() == nvinfer1::DataType::kFLOAT) { + LOG_DEBUG("Type mismatch, casting self to " << other->getType()); + self = castITensor(ctx, self, other->getType()); + } // ensure self to have larger number of dimension bool swapSelfOther = false; if (self->getDimensions().nbDims < other->getDimensions().nbDims) { diff --git a/core/conversion/converters/impl/element_wise.cpp b/core/conversion/converters/impl/element_wise.cpp index 2f0c3a9d13..da9d58ef43 100644 --- a/core/conversion/converters/impl/element_wise.cpp +++ b/core/conversion/converters/impl/element_wise.cpp @@ -412,6 +412,7 @@ auto element_wise_registrations TORCHTRT_UNUSED = // Should implement self * other auto self = args[0].ITensorOrFreeze(ctx); auto other = args[1].ITensorOrFreeze(ctx); + auto mul = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, other, util::node_info(n)); TORCHTRT_CHECK(mul, "Unable to create mul layer from node: " << *n); @@ -426,6 +427,7 @@ auto element_wise_registrations TORCHTRT_UNUSED = // TODO: Remove with functionalization auto self = args[0].ITensorOrFreeze(ctx); auto other = scalar_to_tensor(ctx, args[1].unwrapToScalar()); + auto mul = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, other, util::node_info(n)); TORCHTRT_CHECK(mul, "Unable to create mul layer from node: " << *n); diff --git a/core/conversion/evaluators/aten.cpp b/core/conversion/evaluators/aten.cpp index 4d8795f378..ca9ff4d488 100644 --- a/core/conversion/evaluators/aten.cpp +++ b/core/conversion/evaluators/aten.cpp @@ -19,14 +19,6 @@ namespace conversion { namespace evaluators { namespace { -int64_t normalizeIndex(int64_t idx, int64_t list_size) { - if (idx < 0) { - // Handle negative indexing - idx = list_size + idx; - } - return idx; -} - DEFINE_GENERIC_TWO_INPUT_EVALUATOR( eq, "aten::eq", diff --git a/core/conversion/evaluators/eval_util.cpp b/core/conversion/evaluators/eval_util.cpp index 79b377cd37..742a4f4938 100644 --- a/core/conversion/evaluators/eval_util.cpp +++ b/core/conversion/evaluators/eval_util.cpp @@ -12,6 +12,15 @@ namespace core { namespace conversion { namespace evaluators { +int64_t normalizeIndex(int64_t idx, int64_t list_size) { + if (idx < 0) { + // Handle negative indexing + idx = list_size + idx; + } + return idx; +} + + // TODO: Switch back to PyTorch canonical implimentation c10::optional toIValue(const torch::jit::Value* v) { if (v->node()->kind() != torch::jit::prim::Constant || v->type()->cast()) { diff --git a/core/conversion/evaluators/eval_util.h b/core/conversion/evaluators/eval_util.h index 5e233b4e2d..a9c21339bb 100644 --- a/core/conversion/evaluators/eval_util.h +++ b/core/conversion/evaluators/eval_util.h @@ -13,6 +13,8 @@ at::Tensor createTensorFromList( const torch::jit::IValue& dtype, const torch::jit::IValue& device); +int64_t normalizeIndex(int64_t idx, int64_t list_size); + at::Tensor scalar_to_tensor(const at::Scalar& s, const at::Device device = at::kCPU); } // namespace evaluators diff --git a/core/conversion/evaluators/prim.cpp b/core/conversion/evaluators/prim.cpp index 7d5373a5f9..338c427ccd 100644 --- a/core/conversion/evaluators/prim.cpp +++ b/core/conversion/evaluators/prim.cpp @@ -259,6 +259,56 @@ auto prim_registrations = } }, EvalOptions().validSchemas({"prim::shape(Tensor a) -> (int[])"})}) + .evaluator({torch::jit::prim::TupleConstruct, + [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + auto num_inputs = n->inputs().size(); + c10::IValue tuple = c10::ivalue::Tuple::create(); + switch (num_inputs) { + case 0: + tuple = c10::ivalue::Tuple::create(); + break; + case 1: + tuple = c10::ivalue::Tuple::create(std::move((*args.at(n->input(0)).IValue()))); + break; + case 2: { + tuple = c10::ivalue::Tuple::create( + std::move(*(args.at(n->input(0)).IValue())), + std::move(*(args.at(n->input(1)).IValue()))); + break; + } + case 3: { + tuple = c10::ivalue::Tuple::create( + std::move(*(args.at(n->input(0)).IValue())), + std::move(*(args.at(n->input(1)).IValue())), + std::move(*(args.at(n->input(2)).IValue()))); + break; + } + default: { + std::vector elems; + for (size_t i = 0; i < num_inputs; i++) { + elems.push_back(*(args.at(n->input(i)).IValue())); + } + tuple = c10::ivalue::Tuple::create(std::move(elems)); + break; + } + } + return c10::optional(std::move(tuple)); + }}) + .evaluator({torch::jit::prim::TupleIndex, + [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + // Outputs is an IValue which has list of tensors which can be found in ctx->evaluated_value_map + auto tuple = args.at(n->input(0)).IValue()->toTuple(); + int64_t idx = args.at(n->input(1)).IValue()->toInt(); + int64_t norm_idx = normalizeIndex(idx, tuple->elements().size()); + return c10::optional(std::move(tuple->elements()[norm_idx])); + }, + EvalOptions().validSchemas({"prim::TupleIndex(Any tup, int i) -> (Any)"})}) + .evaluator({torch::jit::prim::TupleUnpack, + [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + // Outputs is an IValue which has list of tensors which can be found in ctx->evaluated_value_map + auto output = args.at(n->input()).IValue()->toTuple(); + return c10::optional(std::move(output)); + }}) .evaluator({c10::Symbol::fromQualString("prim::unchecked_cast"), [](const torch::jit::Node* n, kwargs& args) -> c10::optional { return *(args.at(n->input(0)).IValue()); @@ -277,4 +327,4 @@ auto prim_registrations = } // namespace evaluators } // namespace conversion } // namespace core -} // namespace torch_tensorrt +} // namespace torch_tensorrt \ No newline at end of file diff --git a/core/ir/BUILD b/core/ir/BUILD index a613aaf489..2e9ef7e6a8 100644 --- a/core/ir/BUILD +++ b/core/ir/BUILD @@ -15,7 +15,8 @@ cc_library( srcs = [ "ir.cpp", "Input.cpp", - "StaticParams.cpp" + "StaticParams.cpp", + "GraphInputs.cpp" ], deps = [ "@tensorrt//:nvinfer", diff --git a/core/ir/GraphInputs.cpp b/core/ir/GraphInputs.cpp new file mode 100644 index 0000000000..f3fa889385 --- /dev/null +++ b/core/ir/GraphInputs.cpp @@ -0,0 +1,77 @@ +#include "core/ir/ir.h" +#include "core/util/prelude.h" + +namespace torch_tensorrt { +namespace core { +namespace ir { + +void flatten_dfs( + std::vector& flattened_inputs, + std::vector>& collection_inputs, + torch::jit::IValue input_ivalue, + int level, + int index) { + if (input_ivalue.isTuple()) { + auto input_tuple = input_ivalue.toTuple(); + int idx = 0; + if (level == 0) { + collection_inputs.resize(input_tuple->elements().size()); + } + for (auto item : input_tuple->elements()) { + torch::jit::IValue converted_item; + int cur_idx = level < 1 ? idx : index; + flatten_dfs(flattened_inputs, collection_inputs, item, level + 1, cur_idx); + idx++; + } + } else if (input_ivalue.isList()) { + auto input_list = input_ivalue.toList().vec(); + if (level == 0) { + collection_inputs.resize(input_list.size()); + } + c10::TypePtr type = input_list[0].type(); + auto converted_elements = c10::impl::GenericList(type); + int idx = 0; + for (auto item : input_list) { + int cur_idx = level < 1 ? idx : index; + flatten_dfs(flattened_inputs, collection_inputs, item, level + 1, cur_idx); + idx++; + } + } else if (input_ivalue.isCustomClass()) { + torch_tensorrt::core::ir::Input cur_input = *(input_ivalue.toCustomClass()); + flattened_inputs.push_back(cur_input); + if (level == 0) { // a single value like A + collection_inputs.resize(1); + collection_inputs[0].push_back(cur_input); + } else if (level == 1) { // like A in [A, A] or [(B, B), A] + collection_inputs[index].push_back(cur_input); + } else if (level == 2) { // like A in [(A, A), C] + collection_inputs[index].push_back(cur_input); + } else { // only support 2 level + LOG_ERROR( + "Input nesting depth exceeds currently supported depth (3), use 1 level: [A, B], or 2 level: [A, (B, C)]"); + } + } +} + +GraphInputs::GraphInputs(std::vector inputs_) { + inputs = inputs_; + collection_inputs.resize(inputs_.size()); + for (size_t i = 0; i < inputs_.size(); i++) { + collection_inputs[i].push_back(inputs_[i]); + } +} + +GraphInputs::GraphInputs(torch::jit::IValue& input_signature_) { + std::vector flattened_inputs; + std::vector> collection_inputs_; + + flatten_dfs(flattened_inputs, collection_inputs_, input_signature_, 0, 0); + inputs = flattened_inputs; + input_signature = input_signature_; + collection_inputs = collection_inputs_; + LOG_DEBUG("Collection Input Size: " << collection_inputs_.size()); +} + +} // namespace ir +} // namespace core +} // namespace torch_tensorrt \ No newline at end of file diff --git a/core/ir/StaticParams.cpp b/core/ir/StaticParams.cpp index ac16c72d9f..8502c80acf 100644 --- a/core/ir/StaticParams.cpp +++ b/core/ir/StaticParams.cpp @@ -11,7 +11,9 @@ StaticParams get_static_params(c10::ArrayRef inputs, std::ve StaticParams static_params; auto param_it = params.begin(); for (auto in : inputs) { - if (in->type() != c10::TensorType::get() && param_it != params.end()) { + // handle TensorType, TupleType and ListType + if (in->type() != c10::TensorType::get() && in->type()->kind() != torch::jit::TypeKind::TupleType && + in->type()->kind() != torch::jit::TypeKind::ListType && param_it != params.end()) { static_params[in] = *param_it; ++param_it; } diff --git a/core/ir/ir.cpp b/core/ir/ir.cpp index fcca3df33c..99bf4f97b1 100644 --- a/core/ir/ir.cpp +++ b/core/ir/ir.cpp @@ -13,6 +13,14 @@ InputSpecMap associate_specs_with_inputs( return pair_input_vals_with_specs(tensor_inputs, specs); } +CollectionInputSpecMap associate_specs_with_collection_inputs( + std::shared_ptr& g, + ir::GraphInputs graph_inputs, + StaticParams& static_params) { + auto tensor_inputs = get_collection_inputs(g, static_params); + return pair_input_vals_with_specs_collection(tensor_inputs, graph_inputs.collection_inputs); +} + InputSpecMap pair_input_vals_with_specs(std::vector vals, std::vector specs) { TORCHTRT_CHECK( vals.size() == specs.size(), @@ -21,7 +29,23 @@ InputSpecMap pair_input_vals_with_specs(std::vector va std::unordered_map a; for (size_t i = 0; i < vals.size(); i++) { - LOG_DEBUG("Pairing " << i << ": " << vals[i]->debugName() << " : " << specs[i]); + LOG_DEBUG("Pairing " << i << ": " << vals[i]->debugName() << ": " << specs[i]); + a.insert({vals[i], specs[i]}); + } + return a; +} + +CollectionInputSpecMap pair_input_vals_with_specs_collection( + std::vector vals, + std::vector>& specs) { + TORCHTRT_CHECK( + vals.size() == specs.size(), + "Expected dimension specifications for all input tensors" + << ", but found " << vals.size() << " input tensors and " << specs.size() << " dimension specs"); + + CollectionInputSpecMap a; + for (size_t i = 0; i < vals.size(); i++) { + LOG_DEBUG("Paring " << i << ": " << vals[i]->debugName() << " : " << specs[i]); a.insert({vals[i], specs[i]}); } return a; @@ -32,7 +56,9 @@ std::vector get_tensor_inputs( StaticParams& static_params) { std::vector input_tensors; auto inputs = g->inputs(); + LOG_DEBUG("Found " << inputs.size() << " inputs to graph"); for (auto in : inputs) { + LOG_DEBUG("Handle input of debug name: " << in->debugName()); // Disregarding inputs that are not tensors or are static // // Ex. @@ -45,6 +71,30 @@ std::vector get_tensor_inputs( return input_tensors; } +std::vector get_collection_inputs( + std::shared_ptr& g, + StaticParams& static_params) { + std::vector input_tensors; + auto inputs = g->inputs(); + LOG_DEBUG("Found " << inputs.size() << " inputs to graph"); + for (auto in : inputs) { + LOG_DEBUG("Handle input of debug name: " << in->debugName()); + if (in->type()->isSubtypeOf(c10::TensorType::get()) && static_params.find(in) == static_params.end()) { + input_tensors.push_back(in); + } else if (in->type()->kind() == torch::jit::TypeKind::TupleType && static_params.find(in) == static_params.end()) { + // } else if (in->type()->isSubtypeOf(c10::TupleType::create()) && static_params.find(in) == static_params.end()) + // { + input_tensors.push_back(in); // push original tuple + at::ArrayRef unpack_tuple = torch::jit::createTupleUnpack(in); + LOG_DEBUG("Input tuple size " << unpack_tuple.size()); + } else if (in->type()->kind() == torch::jit::TypeKind::ListType && static_params.find(in) == static_params.end()) { + LOG_DEBUG("Input list use size " << in->uses().size()); + input_tensors.push_back(in); // push original list + } + } + return input_tensors; +} + c10::optional get_value_first_calc_dtype_opt(torch::jit::Block* b, torch::jit::Value* in) { TORCHTRT_ASSERT(in->owningGraph() == b->owningGraph(), "Provided input is not part of the provided graph"); c10::optional dtype = {}; @@ -52,9 +102,6 @@ c10::optional get_value_first_calc_dtype_opt(torch::jit::Block* auto b_ins = b->inputs(); std::unordered_set b_in_set(b_ins.begin(), b_ins.end()); - TORCHTRT_ASSERT( - in->type() == c10::TensorType::get(), "Input is not a tensor, cannot check for dtype based on calculation"); - auto consumers = in->uses(); auto search_list = std::vector(consumers.begin(), consumers.end()); @@ -142,16 +189,56 @@ c10::optional get_value_first_calc_dtype_opt(torch::jit::Block* TypeMap get_block_first_calc_dtypes_opt(torch::jit::Block* b) { TypeMap types; - for (auto i : b->inputs()) { if (i->type() == c10::TensorType::get()) { torch::jit::Value* in = i; types.insert({in, get_value_first_calc_dtype_opt(b, i)}); + } else if (i->type()->cast()) { + // make sure very time get the same ptr + at::ArrayRef unpack_tuple = torch::jit::createTupleUnpack(i); + LOG_DEBUG("Tuple size " << unpack_tuple.size()); + for (auto item : unpack_tuple) { + torch::jit::Value* in = item; + types.insert({in, get_value_first_calc_dtype_opt(b, i)}); + } + } else if (i->type()->isSubtypeOf(c10::ListType::ofTensors())) { + LOG_INFO("Unsupported type of c10::ListType::ofTensors()"); } } return types; } +CollectionTypeMap get_block_first_calc_dtypes_opt_collection(torch::jit::Block* b) { + CollectionTypeMap types; + for (auto i : b->inputs()) { + if (i->type() == c10::TensorType::get()) { + torch::jit::Value* in = i; + types.insert({in, {get_value_first_calc_dtype_opt(b, i)}}); + + } else if (i->type()->kind() == torch::jit::TypeKind::TupleType) { + // TODO: to evaluate the data type of tuple element + // make sure very time get the same ptr + // c10::optional tp = get_value_first_calc_dtype_opt(b, i); + at::ArrayRef unpack_tuple = torch::jit::createTupleUnpack(i); + // TODO: calculate the tuple element type, currently we use {} as default datatype + // std::vector> dytpes(unpack_tuple.size(), tp); + std::vector> dytpes(unpack_tuple.size()); + types.insert({i, dytpes}); // insert an empty + + } else if (i->type()->kind() == torch::jit::TypeKind::ListType) { + // TODO: to decide the size of list and type of list element + LOG_DEBUG("Number of list uses " << i->uses().size()); + c10::optional tp = get_value_first_calc_dtype_opt(b, i); + // std::vector> dytpes(i->uses().size()); + std::vector> dytpes(i->uses().size(), tp); + types.insert({i, dytpes}); // insert an empty + } + } + return types; +} + +static auto core_input_container = torch::class_("_torch_tensorrt_core_ir", "Input").def(torch::init<>()); + } // namespace ir } // namespace core } // namespace torch_tensorrt diff --git a/core/ir/ir.h b/core/ir/ir.h index 2d9acccc69..a5225daa25 100644 --- a/core/ir/ir.h +++ b/core/ir/ir.h @@ -11,9 +11,8 @@ namespace torch_tensorrt { namespace core { namespace ir { -struct Input { - // Input(std::vector shape); - // Input(std::vector min_shape, std::vector opt_shape, std::vector max_shape); +struct Input : torch::CustomClassHolder { + Input(){}; Input( std::vector shape, nvinfer1::DataType dtype = nvinfer1::DataType::kFLOAT, @@ -36,27 +35,54 @@ struct Input { nvinfer1::Dims opt; nvinfer1::DataType dtype; nvinfer1::TensorFormat format; + int id; }; +// Add to spec +struct GraphInputs { + GraphInputs(std::vector inputs); + GraphInputs(torch::jit::IValue& input_signature); + torch::jit::IValue input_signature; // nested Input, full input spec + std::vector inputs; // flattend Input + std::vector> collection_inputs; // only support two layer nesting, e.g. ((a, b), [c, d], e) +}; + +typedef std::pair GraphIO; // Graph input output mapping + using StaticParams = std::map; StaticParams get_static_params(c10::ArrayRef inputs, std::vector params); using InputSpecMap = std::unordered_map; +using CollectionInputSpecMap = std::unordered_map>; +std::vector get_tensor_inputs( + std::shared_ptr& g, + StaticParams& static_params); InputSpecMap associate_specs_with_inputs( std::shared_ptr& g, std::vector specs, StaticParams& static_params); +CollectionInputSpecMap associate_specs_with_collection_inputs( + std::shared_ptr& g, + ir::GraphInputs graph_inputs, + StaticParams& static_params); InputSpecMap pair_input_vals_with_specs(std::vector vals, std::vector specs); +CollectionInputSpecMap pair_input_vals_with_specs_collection( + std::vector vals, + std::vector>& specs); std::vector get_tensor_inputs( std::shared_ptr& g, StaticParams& static_params); +std::vector get_collection_inputs( + std::shared_ptr& g, + StaticParams& static_params); using TypeMap = std::unordered_map>; +using CollectionTypeMap = std::unordered_map>>; c10::optional get_value_first_calc_dtype_opt(torch::jit::Block* b, torch::jit::Value* in); ir::TypeMap get_block_first_calc_dtypes_opt(torch::jit::Block* b); - +ir::CollectionTypeMap get_block_first_calc_dtypes_opt_collection(torch::jit::Block* b); } // namespace ir } // namespace core } // namespace torch_tensorrt diff --git a/core/lowering/lowering.cpp b/core/lowering/lowering.cpp index d3296c347c..8bbae296c3 100644 --- a/core/lowering/lowering.cpp +++ b/core/lowering/lowering.cpp @@ -33,7 +33,6 @@ void LowerGraph(std::shared_ptr& g, LowerInfo lower_info) { torch::jit::InlineFunctionalGraphs(g); torch::jit::PeepholeOptimize(g, false); torch::jit::FuseLinear(g); - torch::jit::LowerAllTuples(g); if (!lower_info.disable_cse) { torch::jit::EliminateCommonSubexpression(g); } diff --git a/core/partitioning/partitioning.cpp b/core/partitioning/partitioning.cpp index 8d54b51089..28bfd0712c 100644 --- a/core/partitioning/partitioning.cpp +++ b/core/partitioning/partitioning.cpp @@ -17,22 +17,13 @@ struct usage_info { std::vector tensorrt_use_id; // ids of segmented blocks which are of type TensorRT }; -inline bool isTensorOrTensorList(torch::jit::Value* val) { - return val->type()->isSubtypeOf(torch::jit::TensorType::get()) || - val->type()->isSubtypeOf(torch::jit::ListType::ofTensors()); -} - -inline bool isTensorList(torch::jit::Value* val) { - return val->type()->isSubtypeOf(torch::jit::ListType::ofTensors()); -} - inline bool isTensor(torch::jit::Value* val) { return val->type()->isSubtypeOf(torch::jit::TensorType::get()); } bool containNonTensorOutputs(torch::jit::Node* n) { for (auto output : n->outputs()) { - if (!isTensorOrTensorList(output)) { + if (!isTensor(output)) { return true; } } @@ -68,6 +59,7 @@ std::vector findModifyingNodes( return modifying_nodes; } +// this function is only used when a TRT segment produces nonTensor values which are used by later TRT segment std::vector getDependencyNodes( const std::vector& vals, const SegmentedBlock& seg_block) { @@ -88,7 +80,7 @@ std::vector getDependencyNodes( stk.insert(stk.end(), modifying_nodes.rbegin(), modifying_nodes.rend()); stk.push_back(node); for (auto input : node->inputs()) { - if (!isTensorOrTensorList(input)) { + if (!isTensor(input)) { q.push(input); } } @@ -98,6 +90,28 @@ std::vector getDependencyNodes( return stk; } +// check if the input and output of the graph is Tensor after collection is enabled. If it is, then fallback related +// nodes +void fallback_graph_nontensor_in_out( + torch::jit::Block* block, + std::unordered_map& global_fallback_nodes) { + // fallback nodes that produce entire graph's nonTensor output + for (auto i : block->outputs()) { + if (!isTensor(i)) { + global_fallback_nodes.insert({i->node(), FallbackNodeType::kNON_TENSOR}); + } + } + + // fallback nodes that consume entire graph's nonTensor input + for (auto i : block->inputs()) { + if (!isTensor(i)) { + for (auto use : i->uses()) { + global_fallback_nodes.insert({use.user, FallbackNodeType::kNON_TENSOR}); + } + } + } +} + void find_all_fallback_nodes( std::unordered_map& initial_fallback_nodes, std::unordered_map& global_fallback_nodes) { @@ -177,7 +191,7 @@ void registerSegmentsOutputs(PartitionedGraph& segmented_blocks, torch::jit::Blo if (std::find(seg_block.raw_inputs().begin(), seg_block.raw_inputs().end(), mini_graph_input) == seg_block.raw_inputs().end() && seg_block.contain_raw_value(mini_graph_input)) { - if (!isTensorOrTensorList(mini_graph_input) && seg_block.target() == SegmentedBlock::kTensorRT) + if (!isTensor(mini_graph_input) && seg_block.target() == SegmentedBlock::kTensorRT) continue; seg_block.registerOutput(mini_graph_input); } @@ -200,6 +214,7 @@ void registerSegmentsOutputs(PartitionedGraph& segmented_blocks, torch::jit::Blo } } } + std::for_each(segmented_blocks.begin(), segmented_blocks.end(), [](SegmentedBlock& seg_block) { torch::jit::EliminateDeadCode(seg_block.g()); }); @@ -361,7 +376,6 @@ PartitionedGraph segment_graph( find_min_block_size_fallback_nodes(block, global_fallback_nodes, min_block_size); auto nodes = block->nodes(); - PartitionedGraph segmented_blocks; // segment the nodes @@ -371,7 +385,7 @@ PartitionedGraph segment_graph( if (n->kind() == torch::jit::prim::Constant) { continue; } - + // the outputs of trt subgraph shouldn't be collections if (check_node_fallback(n, global_fallback_nodes)) { in_prog_trt_blk_nodes.push_back(n); @@ -430,7 +444,6 @@ PartitionedGraph segment_graph( in_prog_pyt_blk_nodes.end(), in_prog_trt_blk_nodes.begin(), in_prog_trt_blk_nodes.end()); finalize_block(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes); } - return segmented_blocks; } @@ -440,6 +453,10 @@ PartitionedGraph Partition( const PartitionInfo& partition_info, std::unordered_map& global_fallback_nodes) { LOG_DEBUG(partition_info); + // if there is nonTensor input/output for the entire graph, fallback the node that consumes/produces this nonTensor + // output + fallback_graph_nontensor_in_out(block, global_fallback_nodes); + // segment lowering global graph into blocks LOG_DEBUG("Parititioning source module into PyTorch and TensorRT sub blocks"); PartitionedGraph segmented_blocks = segment_graph(block, partition_info, global_fallback_nodes); diff --git a/core/partitioning/shape_analysis.cpp b/core/partitioning/shape_analysis.cpp index 49bbab0b36..192fc10555 100644 --- a/core/partitioning/shape_analysis.cpp +++ b/core/partitioning/shape_analysis.cpp @@ -1,169 +1,197 @@ -#include "core/partitioning/shape_analysis.h" -#include -#include -#include "core/util/prelude.h" -#include "torch/csrc/jit/api/module.h" -#include "torch/csrc/jit/passes/constant_pooling.h" - -namespace torch_tensorrt { -namespace core { -namespace partitioning { - -std::unordered_map generateRandomInputs( - std::unordered_map& inputs, - std::unordered_map>& types) { - // generate random inputs for running pytorch segments - std::unordered_map ivalue_map; - - uint64_t in_i = 0; - for (auto& input : inputs) { - auto cur_shape = input.second.input_shape; - std::vector shape; - shape.insert(shape.begin(), std::begin(cur_shape.d), std::begin(cur_shape.d) + cur_shape.nbDims); - auto type_opt = types[input.first]; - auto type = at::kFloat; - if (type_opt) { - type = type_opt.value(); - } else { - LOG_WARNING("Input type for doing shape analysis could not be determined, defaulting to F32"); - } - auto in = at::randint(5, shape, {at::kCUDA}).to(type); - ivalue_map[input.first] = in.clone(); - in_i++; - } - return ivalue_map; -} - -void getSegmentsOutputByRunning( - SegmentedBlock& seg_block, - std::unordered_map& ivalues_maps, - const PartitionInfo& partition_info) { - // create a module to run the graph - auto g = seg_block.g(); - auto copy_g = g->copy(); - - // create tuple for multiple outputs - if (seg_block.raw_outputs().size() > 1) { - auto new_output_node = copy_g->appendNode(copy_g->createTuple(copy_g->outputs())); - for (int idx = copy_g->outputs().size() - 1; idx >= 0; --idx) { - copy_g->eraseOutput(idx); - } - - copy_g->registerOutput(new_output_node->outputs()[0]); - } - - torch::jit::script::Module cur_mod(c10::QualifiedName("module")); - - auto self = copy_g->insertInput(0, "self_1"); - self->setType(cur_mod.type()); - - auto cur_method = cur_mod._ivalue()->compilation_unit()->create_function(c10::QualifiedName("forward"), copy_g); - auto schema = util::GenerateGraphSchema(cur_method->name(), copy_g); - cur_mod.type()->addMethod(cur_method); - cur_method->setSchema(schema); - - std::vector jit_inputs_ivalues; - - // set inputs ivalues, now supports Tensor/Int to pass argumentes between different segments - for (auto& input : seg_block.raw_inputs()) { - TORCHTRT_CHECK( - ivalues_maps.count(input), - "Could not find torch::jit::Value* " << input->debugName() << " produced from " - << util::node_info(input->node()) - << " in lowering graph for mini graph input.\n"); - if (input->node()->kind() == torch::jit::prim::Param) { - jit_inputs_ivalues.push_back(ivalues_maps[input]); - } else if (input->type()->isSubtypeOf(torch::jit::TensorType::get())) { - jit_inputs_ivalues.push_back(ivalues_maps[input].toTensor()); - } else if (input->type()->isSubtypeOf(torch::jit::IntType::get())) { - jit_inputs_ivalues.push_back(ivalues_maps[input].toInt()); - } else if (input->type()->isSubtypeOf(torch::jit::BoolType::get())) { - jit_inputs_ivalues.push_back(ivalues_maps[input].toBool()); - } else if (input->type()->kind() == torch::jit::TypeKind::ListType) { - jit_inputs_ivalues.push_back(ivalues_maps[input].toList()); - } else if (input->type()->kind() == torch::jit::TypeKind::TupleType) { - jit_inputs_ivalues.push_back(ivalues_maps[input].toTuple()); - } else if (input->type()->kind() == torch::jit::TypeKind::NumberType) { - jit_inputs_ivalues.push_back(ivalues_maps[input].toScalar()); - } else if (input->type()->kind() == torch::jit::TypeKind::DictType) { - jit_inputs_ivalues.push_back(ivalues_maps[input].toGenericDict()); - } else if (input->type()->kind() == torch::jit::TypeKind::DeviceObjType) { - jit_inputs_ivalues.push_back(ivalues_maps[input].toDevice()); - } else { - TORCHTRT_THROW_ERROR( - "Expected to find type " << input->type()->str() << " for value " << input->debugName() - << " but get nothing. "); - } - } - - // run segments to get outputs for later segments input shape, and other arguments such as Int - std::vector jit_results; - torch::jit::IValue jit_results_ivalues = cur_mod.forward(jit_inputs_ivalues); - - if (jit_results_ivalues.isTuple()) { - auto results = jit_results_ivalues.toTuple()->elements(); - for (auto r : results) { - jit_results.push_back(r); - } - } else { - jit_results.push_back(jit_results_ivalues); - } - - size_t idx = 0; - for (auto& output : seg_block.raw_outputs()) { - ivalues_maps[output] = jit_results[idx++]; - } - - // set input shape for each segmented block so we wil use it in conversion process - std::vector input_shapes; - std::vector input_types; - for (auto& i : seg_block.raw_inputs()) { - if (ivalues_maps[i].isTensor()) { - // set the input_shape and data_type - // we can use a temp value here instead of replacing the values in ivalues_map since we only use ivalues_map for - // shape inference - auto cur_ivalue = ivalues_maps[i]; - at::ScalarType t = cur_ivalue.toTensor().scalar_type(); - if (!partition_info.truncate_long_and_double && (t == at::kLong || t == at::kDouble)) { - TORCHTRT_THROW_ERROR( - "Unable to process subgraph input type of at::kLong/at::kDouble, try to compile model with truncate_long_and_double enabled"); - } else if (partition_info.truncate_long_and_double && t == at::kLong) { - cur_ivalue = cur_ivalue.toTensor().to(at::kInt); - LOG_WARNING("Truncating graph input type from at::kLong to at::kInt"); - } else if (partition_info.truncate_long_and_double && t == at::kDouble) { - cur_ivalue = cur_ivalue.toTensor().to(at::kFloat); - LOG_WARNING("Truncating graph input type from at::kDouble to at::kFloat"); - } - c10::optional dtype = util::optTypeMetaToTRTDataType(cur_ivalue.toTensor().dtype()); - if (dtype == c10::nullopt) { - TORCHTRT_THROW_ERROR("Unsupported input data type " << cur_ivalue.toTensor().dtype()); - } - if (cur_ivalue.toTensor().sizes().size() == 0) { - // handle Scalar types, which has sizes of [] - input_shapes.push_back(util::toVec(util::toDims(c10::List({1})))); - } else { - input_shapes.push_back(util::toVec(util::toDims(cur_ivalue.toTensor().sizes()))); - } - input_types.push_back(cur_ivalue.toTensor().scalar_type()); - } - } - - seg_block.register_inshapes(input_shapes); - seg_block.register_intypes(input_types); -} - -void runShapeAnalysis( - std::vector& segmented_blocks, - std::unordered_map& example_tensor_map, - const PartitionInfo& partition_info) { - // register every segment's input shape, and it's running output IValues - for (auto& seg_block : segmented_blocks) { - torch::jit::ConstantPooling(seg_block.g()); - getSegmentsOutputByRunning(seg_block, example_tensor_map, partition_info); - } - return; -} - -} // namespace partitioning -} // namespace core -} // namespace torch_tensorrt +#include "core/partitioning/shape_analysis.h" +#include +#include "core/util/prelude.h" +#include "torch/csrc/jit/api/module.h" +#include "torch/csrc/jit/passes/constant_pooling.h" + +namespace torch_tensorrt { +namespace core { +namespace partitioning { + +at::Tensor generateSingleInput(ir::Input& input, c10::optional& type_opt) { + auto cur_shape = input.input_shape; + std::vector shape; + shape.insert(shape.begin(), std::begin(cur_shape.d), std::begin(cur_shape.d) + cur_shape.nbDims); + // auto type_opt = types[input.first][i]; + auto type = at::kFloat; + if (type_opt) { + type = type_opt.value(); + } else { + LOG_WARNING("Input type for doing shape analysis could not be determined, defaulting to F32"); + } + auto in = at::randint(5, shape, {at::kCUDA}).to(type); + // ivalue_map[input.first] = in.clone(); + return in; +} + +std::unordered_map generateRandomInputs( + std::unordered_map>& inputs, + std::unordered_map>>& types) { + // generate random inputs for running pytorch segments + std::unordered_map ivalue_map; + + for (auto& input : inputs) { + if (input.first->type()->kind() == torch::jit::TypeKind::ListType) { + // create list + std::vector list; + c10::TypePtr elementType = c10::TensorType::get(); + auto generic_list = c10::impl::GenericList(elementType); + for (size_t i = 0; i < input.second.size(); i++) { + auto in = generateSingleInput(input.second[i], types[input.first][i]); + generic_list.push_back(in.clone()); + } + ivalue_map[input.first] = c10::IValue(generic_list); + } else if (input.first->type()->kind() == torch::jit::TypeKind::TupleType) { + // create tuple + std::vector list; + for (size_t i = 0; i < input.second.size(); i++) { + auto in = generateSingleInput(input.second[i], types[input.first][i]); + list.push_back(in.clone()); + } + auto tuple = c10::ivalue::Tuple::create(list); // create tuple ptr + ivalue_map[input.first] = c10::IValue(tuple); + } else { + auto in = generateSingleInput(input.second[0], types[input.first][0]); + ivalue_map[input.first] = in.clone(); + } + } + return ivalue_map; +} + +void getSegmentsOutputByRunning( + SegmentedBlock& seg_block, + std::unordered_map& ivalues_maps, + const PartitionInfo& partition_info) { + // create a module to run the graph + auto g = seg_block.g(); + auto copy_g = g->copy(); + + // create tuple for multiple outputs + if (seg_block.raw_outputs().size() > 1) { + auto new_output_node = copy_g->appendNode(copy_g->createTuple(copy_g->outputs())); + for (int idx = copy_g->outputs().size() - 1; idx >= 0; --idx) { + copy_g->eraseOutput(idx); + } + + copy_g->registerOutput(new_output_node->outputs()[0]); + } + + torch::jit::script::Module cur_mod(c10::QualifiedName("module")); + + auto self = copy_g->insertInput(0, "self_1"); + self->setType(cur_mod.type()); + + auto cur_method = cur_mod._ivalue()->compilation_unit()->create_function(c10::QualifiedName("forward"), copy_g); + auto schema = util::GenerateGraphSchema(cur_method->name(), copy_g); + cur_mod.type()->addMethod(cur_method); + cur_method->setSchema(schema); + + std::vector jit_inputs_ivalues; + + // set inputs ivalues, now supports Tensor/Int to pass argumentes between different segments + for (auto& input : seg_block.raw_inputs()) { + TORCHTRT_CHECK( + ivalues_maps.count(input), + "Could not find torch::jit::Value* " << input->debugName() << " produced from " + << util::node_info(input->node()) + << " in lowering graph for mini graph input.\n"); + if (input->node()->kind() == torch::jit::prim::Param) { + jit_inputs_ivalues.push_back(ivalues_maps[input]); + } else if (input->type()->isSubtypeOf(torch::jit::TensorType::get())) { + jit_inputs_ivalues.push_back(ivalues_maps[input].toTensor()); + } else if (input->type()->isSubtypeOf(torch::jit::IntType::get())) { + jit_inputs_ivalues.push_back(ivalues_maps[input].toInt()); + } else if (input->type()->isSubtypeOf(torch::jit::BoolType::get())) { + jit_inputs_ivalues.push_back(ivalues_maps[input].toBool()); + } else if (input->type()->kind() == torch::jit::TypeKind::ListType) { + // create list + jit_inputs_ivalues.push_back(ivalues_maps[input].toList()); + ; + } else if (input->type()->kind() == torch::jit::TypeKind::TupleType) { + // create tuple + jit_inputs_ivalues.push_back(ivalues_maps[input].toTuple()); + } else if (input->type()->kind() == torch::jit::TypeKind::NumberType) { + jit_inputs_ivalues.push_back(ivalues_maps[input].toScalar()); + } else if (input->type()->kind() == torch::jit::TypeKind::DictType) { + jit_inputs_ivalues.push_back(ivalues_maps[input].toGenericDict()); + } else if (input->type()->kind() == torch::jit::TypeKind::DeviceObjType) { + jit_inputs_ivalues.push_back(ivalues_maps[input].toDevice()); + } else { + TORCHTRT_THROW_ERROR( + "Expected to find type " << input->type()->str() << " for value " << input->debugName() + << " but get nothing. "); + } + } + + // run segments to get outputs for later segments input shape, and other arguments such as Int + std::vector jit_results; + torch::jit::IValue jit_results_ivalues = cur_mod.forward(jit_inputs_ivalues); + + if (jit_results_ivalues.isTuple()) { + auto results = jit_results_ivalues.toTuple()->elements(); + for (auto r : results) { + jit_results.push_back(r); + } + } else { + jit_results.push_back(jit_results_ivalues); + } + + size_t idx = 0; + for (auto& output : seg_block.raw_outputs()) { + ivalues_maps[output] = jit_results[idx++]; + } + + // set input shape for each segmented block so we wil use it in conversion process + std::vector input_shapes; + std::vector input_types; + for (auto& i : seg_block.raw_inputs()) { + if (ivalues_maps[i].isTensor()) { + // set the input_shape and data_type + // we can use a temp value here instead of replacing the values in ivalues_map since we only use ivalues_map for + // shape inference + auto cur_ivalue = ivalues_maps[i]; + at::ScalarType t = cur_ivalue.toTensor().scalar_type(); + if (!partition_info.truncate_long_and_double && (t == at::kLong || t == at::kDouble)) { + TORCHTRT_THROW_ERROR( + "Unable to process subgraph input type of at::kLong/at::kDouble, try to compile model with truncate_long_and_double enabled"); + } else if (partition_info.truncate_long_and_double && t == at::kLong) { + cur_ivalue = cur_ivalue.toTensor().to(at::kInt); + LOG_WARNING("Truncating graph input type from at::kLong to at::kInt"); + } else if (partition_info.truncate_long_and_double && t == at::kDouble) { + cur_ivalue = cur_ivalue.toTensor().to(at::kFloat); + LOG_WARNING("Truncating graph input type from at::kDouble to at::kFloat"); + } + c10::optional dtype = util::optTypeMetaToTRTDataType(cur_ivalue.toTensor().dtype()); + if (dtype == c10::nullopt) { + TORCHTRT_THROW_ERROR("Unsupported input data type " << cur_ivalue.toTensor().dtype()); + } + if (cur_ivalue.toTensor().sizes().size() == 0) { + // handle Scalar types, which has sizes of [] + input_shapes.push_back(util::toVec(util::toDims(c10::List({1})))); + } else { + input_shapes.push_back(util::toVec(util::toDims(cur_ivalue.toTensor().sizes()))); + } + input_types.push_back(cur_ivalue.toTensor().scalar_type()); + } + // TODO: tuple and list inputs in subgraph + } + + seg_block.register_inshapes(input_shapes); + seg_block.register_intypes(input_types); +} + +void runShapeAnalysis( + std::vector& segmented_blocks, + std::unordered_map& example_tensor_map, + const PartitionInfo& partition_info) { + // register every segment's input shape, and it's running output IValues + for (auto& seg_block : segmented_blocks) { + torch::jit::ConstantPooling(seg_block.g()); + getSegmentsOutputByRunning(seg_block, example_tensor_map, partition_info); + } + return; +} + +} // namespace partitioning +} // namespace core +} // namespace torch_tensorrt \ No newline at end of file diff --git a/core/partitioning/shape_analysis.h b/core/partitioning/shape_analysis.h index 0626490222..e9c51fc62d 100644 --- a/core/partitioning/shape_analysis.h +++ b/core/partitioning/shape_analysis.h @@ -7,8 +7,8 @@ namespace core { namespace partitioning { std::unordered_map generateRandomInputs( - std::unordered_map& input_ranges, - std::unordered_map>& input_types); + std::unordered_map>& input_ranges, + std::unordered_map>>& input_types); void runShapeAnalysis( std::vector& segmented_blocks, diff --git a/cpp/include/torch_tensorrt/torch_tensorrt.h b/cpp/include/torch_tensorrt/torch_tensorrt.h index 1dca94d9dd..45497a13a3 100644 --- a/cpp/include/torch_tensorrt/torch_tensorrt.h +++ b/cpp/include/torch_tensorrt/torch_tensorrt.h @@ -14,6 +14,7 @@ #include #include #include +#include "torch/custom_class.h" #include "torch_tensorrt/macros.h" @@ -364,7 +365,7 @@ class TensorFormat { * signifying a static input shape or a set of three input shapes representing * the min, optiminal and max input shapes allowed for the engine. */ -struct Input { +struct TORCHTRT_API Input : torch::CustomClassHolder { /// Minimum acceptable input size into the engine std::vector min_shape; /// Optimal input size into the engine (size optimized for given kernels accept any size in min max range) @@ -379,6 +380,7 @@ struct Input { /// Expected tensor format for the input TensorFormat format; + Input() {} /** * @brief Construct a new Input spec object for static input size from * vector, optional arguments allow the user to configure expected input shape @@ -513,6 +515,16 @@ struct Input { bool input_is_dynamic; }; +/** + * @brief A struct to hold complex inputs + * + * This struct can either hold a complex inputs of shape or a flattened one, + */ +struct TORCHTRT_API GraphInputs { + torch::jit::IValue input_signature; // nested Input, full input spec + std::vector inputs; // flatten input spec +}; + /** * @brief Get the build information for the library including the dependency * versions @@ -558,7 +570,7 @@ struct CompileSpec { TORCHTRT_API CompileSpec(std::vector> fixed_sizes); /** - * @brief Construct a new Extra Info object + * @brief Construct a new Compile Spec object * Convienence constructor to set fixed input size from c10::ArrayRef's (the * output of tensor.sizes()) describing size of input tensors. Each entry in * the vector represents a input and should be provided in call order. @@ -572,7 +584,7 @@ struct CompileSpec { TORCHTRT_API CompileSpec(std::vector> fixed_sizes); /** - * @brief Construct a new Extra Info object from input ranges. + * @brief Construct a new Compile Spec object from input ranges. * Each entry in the vector represents a input and should be provided in call * order. * @@ -580,18 +592,21 @@ struct CompileSpec { * * @param inputs */ - CompileSpec(std::vector inputs) : inputs(std::move(inputs)) {} - - // Defaults should reflect TensorRT defaults for BuilderConfig + CompileSpec(std::vector inputs); /** - * @brief Specifications for inputs to the engine, can either be a single size or a range defined by min, opt and max - * sizes Users can also specify expected input type as well as tensor memory format + * @brief Construct a new Compile Spec object from IValue which represents the nesting of input tensors for a module. * - * Order in vector should match call order for the function + * @param input_signature */ - std::vector inputs; + CompileSpec(torch::jit::IValue input_signature); + // Defaults should reflect TensorRT defaults for BuilderConfig + /** + * @brief Specifications for inputs to the engine, can store a IValue which has stored complex Input + * or a flatened Input + */ + GraphInputs graph_inputs; /** * @brief The set of precisions TensorRT is allowed to use for kernels during compilation * diff --git a/cpp/src/compile_spec.cpp b/cpp/src/compile_spec.cpp index 2881887aea..8c7bb8b403 100644 --- a/cpp/src/compile_spec.cpp +++ b/cpp/src/compile_spec.cpp @@ -18,18 +18,85 @@ torchtrt::core::runtime::CudaDevice to_internal_cuda_device(Device device); namespace torchscript { CompileSpec::CompileSpec(std::vector> fixed_sizes) { for (auto in : fixed_sizes) { - inputs.push_back(Input(in)); + graph_inputs.inputs.push_back(Input(in)); } } CompileSpec::CompileSpec(std::vector> fixed_sizes) { for (auto in : fixed_sizes) { - inputs.push_back(Input(in)); + graph_inputs.inputs.push_back(Input(in)); + } +} + +CompileSpec::CompileSpec(std::vector inputs) { + graph_inputs.inputs = std::move(inputs); +} + +CompileSpec::CompileSpec(torch::jit::IValue input_signature) { + graph_inputs.input_signature = input_signature; +} + +void to_internal_input_signature(torch::jit::IValue input_ivalue, torch::jit::IValue& converted_ivalue) { + if (input_ivalue.isTuple()) { + auto input_tuple = input_ivalue.toTuple(); + std::vector converted_elements; + for (auto item : input_tuple->elements()) { + torch::jit::IValue converted_item; + to_internal_input_signature(item, converted_item); + converted_elements.push_back(converted_item); + auto tuple_ptr = c10::ivalue::Tuple::create(converted_elements); + converted_ivalue = torch::jit::IValue(tuple_ptr); + } + } else if (input_ivalue.isList()) { + auto input_list = input_ivalue.toList().vec(); + c10::TypePtr type = input_list[0].type(); + auto converted_elements = c10::impl::GenericList(type); + for (auto item : input_list) { + torch::jit::IValue converted_item; + to_internal_input_signature(item, converted_item); + converted_elements.push_back(converted_item); + } + converted_ivalue = torch::jit::IValue(converted_elements); + } else if (input_ivalue.isCustomClass()) { + torchtrt::core::ir::Input cur_input = to_internal_input(*(input_ivalue.toCustomClass())); + converted_ivalue = torch::jit::IValue(std::move(c10::make_intrusive(cur_input))); + } +} + +torchtrt::core::CompileSpec init_compile_spec(CompileSpec& external) { + if (external.graph_inputs.inputs.size() > 0) { + torchtrt::core::CompileSpec internal(to_vec_internal_inputs(external.graph_inputs.inputs)); + return internal; + } else { + torch::jit::IValue converted_input_signature; + LOG_WARNING( "Input signature parsing is an experimental feature, behavior and APIs may change"); + to_internal_input_signature(external.graph_inputs.input_signature, converted_input_signature); + torchtrt::core::CompileSpec internal(converted_input_signature); + + TORCHTRT_CHECK(!external.require_full_compilation, \ + "Grouped inputs currently requires partial compilation to be enabled, \ + this restriction will be relaxed in a future release"); + + LOG_DEBUG("Grouped inputs currently requires additional settings to enable the feature"); + LOG_DEBUG("Adding the following ops to torch_executed_ops:" \ + << std::endl << " - aten::__getitem__" \ + << std::endl << " - prim::ListConstruct" \ + << std::endl << " - prim::ListUnpack" \ + << std::endl << " - prim::TupleIndex" \ + << std::endl << " - prim::TupleConstruct" \ + << std::endl << " - prim::TupleUnpack"); + external.torch_executed_ops.push_back("aten::__getitem__"); + external.torch_executed_ops.push_back("prim::ListConstruct"); + external.torch_executed_ops.push_back("prim::ListUnpack"); + external.torch_executed_ops.push_back("prim::TupleIndex"); + external.torch_executed_ops.push_back("prim::TupleConstruct"); + external.torch_executed_ops.push_back("prim::TupleUnpack"); + return internal; } } torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external) { - torchtrt::core::CompileSpec internal(to_vec_internal_inputs(external.inputs)); + torchtrt::core::CompileSpec internal = init_compile_spec(external); for (auto p : external.enabled_precisions) { internal.convert_info.engine_settings.enabled_precisions.insert(toTRTDataType(p)); diff --git a/cpp/src/torch_tensorrt.cpp b/cpp/src/torch_tensorrt.cpp index 42b44833de..22855aeb03 100644 --- a/cpp/src/torch_tensorrt.cpp +++ b/cpp/src/torch_tensorrt.cpp @@ -52,4 +52,6 @@ void set_device(const int gpu_id) { // Want to export a much simpler (non CUDA header dependent) API torch_tensorrt::core::set_device(gpu_id); } + +static auto tensorrt_input_container = torch::class_("_torch_tensorrt", "Input").def(torch::init<>()); } // namespace torch_tensorrt diff --git a/py/torch_tensorrt/csrc/register_tensorrt_classes.cpp b/py/torch_tensorrt/csrc/register_tensorrt_classes.cpp index ba2e168ed3..9db567ca86 100644 --- a/py/torch_tensorrt/csrc/register_tensorrt_classes.cpp +++ b/py/torch_tensorrt/csrc/register_tensorrt_classes.cpp @@ -23,6 +23,19 @@ void RegisterTRTCompileSpec() { ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, torch_tensorrt::pyapi::Input, input_is_dynamic); ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, torch_tensorrt::pyapi::Input, explicit_set_dtype); + static auto TORCHTRT_UNUSED TRTInputSignatureTSRegistration = + torch::class_("tensorrt", "_InputSignature") + .def(torch::init<>()) + .def("_set_signature_ivalue_torchbind", + [](const c10::intrusive_ptr& self, + torch::jit::IValue ival) { + self->signature_ivalue = ival; + }) + .def("__str__", &torch_tensorrt::pyapi::InputSignature::to_str); + + ADD_FIELD_GET_SET_REGISTRATION( + TRTInputSignatureTSRegistration, torch_tensorrt::pyapi::InputSignature, signature_ivalue); + static auto TORCHTRT_UNUSED TRTDeviceTSRegistration = torch::class_("tensorrt", "_Device") .def(torch::init<>()) @@ -49,6 +62,7 @@ void RegisterTRTCompileSpec() { torch::class_("tensorrt", "CompileSpec") .def(torch::init<>()) .def("_append_input", &torch_tensorrt::pyapi::CompileSpec::appendInput) + .def("_set_input_signature", &torch_tensorrt::pyapi::CompileSpec::setInputSignature) .def("_set_precisions", &torch_tensorrt::pyapi::CompileSpec::setPrecisions) .def("_set_device", &torch_tensorrt::pyapi::CompileSpec::setDeviceIntrusive) .def("_set_torch_fallback", &torch_tensorrt::pyapi::CompileSpec::setTorchFallbackIntrusive) diff --git a/py/torch_tensorrt/csrc/tensorrt_classes.cpp b/py/torch_tensorrt/csrc/tensorrt_classes.cpp index 775c71dea5..96fef793fd 100644 --- a/py/torch_tensorrt/csrc/tensorrt_classes.cpp +++ b/py/torch_tensorrt/csrc/tensorrt_classes.cpp @@ -104,6 +104,56 @@ std::string Input::to_str() { return ss.str(); } +std::string sig_to_str(torch::jit::IValue input_sig) { + if (input_sig.isTuple()) { + auto input_tuple = input_sig.toTuple(); + std::vector children; + for (auto item : input_tuple->elements()) { + auto child = sig_to_str(item); + children.push_back(child); + } + std::stringstream ss; + ss << "("; + for (auto i : children) { + ss << i << ", "; + } + ss << ")"; + return ss.str(); + } else if (input_sig.isList()) { + auto input_list = input_sig.toList().vec(); + std::vector children; + for (auto item : input_list) { + auto child = sig_to_str(item); + children.push_back(child); + } + std::stringstream ss; + ss << "["; + for (auto i : children) { + ss << i << ", "; + } + ss << "]"; + return ss.str(); + } else if (input_sig.isCustomClass()) { + auto cur_input = input_sig.toCustomClass(); + return cur_input->to_str(); + } else if (input_sig.isPyObject()) { + auto py_object_holder = input_sig.toPyObjectHolder(); + auto infer_type = py_object_holder->tryToInferType(); + auto type = infer_type.type(); + torch::jit::IValue ival = py_object_holder->toIValue(type); + torch::jit::IValue converted_item; + return sig_to_str(ival); + } else { + LOG_ERROR("Unknown input spec type"); + return ""; + } +} + +std::string InputSignature::to_str() { + std::stringstream ss; + return sig_to_str(signature_ivalue); +} + std::string to_str(DeviceType value) { switch (value) { case DeviceType::kDLA: @@ -184,13 +234,63 @@ std::string TorchFallback::to_str() { return ss.str(); } -core::CompileSpec CompileSpec::toInternalCompileSpec() { - std::vector internal_inputs; - for (auto i : inputs) { - internal_inputs.push_back(i.toInternalInput()); +void to_internal_input_signature(torch::jit::IValue input_ivalue, torch::jit::IValue& converted_ivalue) { + if (input_ivalue.isTuple()) { + auto input_tuple = input_ivalue.toTuple(); + std::vector converted_elements; + for (auto item : input_tuple->elements()) { + torch::jit::IValue converted_item; + to_internal_input_signature(item, converted_item); + converted_elements.push_back(converted_item); + auto tuple_ptr = c10::ivalue::Tuple::create(converted_elements); + converted_ivalue = torch::jit::IValue(tuple_ptr); + } + } else if (input_ivalue.isList()) { + auto input_list = input_ivalue.toList().vec(); + c10::TypePtr type = input_list[0].type(); + auto converted_elements = c10::impl::GenericList(type); + for (auto item : input_list) { + torch::jit::IValue converted_item; + to_internal_input_signature(item, converted_item); + converted_elements.push_back(converted_item); + } + converted_ivalue = torch::jit::IValue(converted_elements); + } else if (input_ivalue.isCustomClass()) { + core::ir::Input cur_input = (*(input_ivalue.toCustomClass())).toInternalInput(); + converted_ivalue = torch::jit::IValue(std::move(c10::make_intrusive(cur_input))); + } else if (input_ivalue.isPyObject()) { + auto py_object_holder = input_ivalue.toPyObjectHolder(); + auto infer_type = py_object_holder->tryToInferType(); + auto type = infer_type.type(); + torch::jit::IValue ival = py_object_holder->toIValue(type); + torch::jit::IValue converted_item; + to_internal_input_signature(ival, converted_item); + converted_ivalue = torch::jit::IValue(converted_item); + } else { + LOG_ERROR("Unknown input spec type"); + } +} + +core::CompileSpec init_compile_spec(CompileSpec external) { + if (external.inputs.size() > 0) { + LOG_DEBUG("init_compile_spec with input vector"); + std::vector internal_inputs; + for (auto i : external.inputs) { + internal_inputs.push_back(i.toInternalInput()); + } + core::CompileSpec internal(internal_inputs); + return internal; + } else { + LOG_DEBUG("init_compile_spec with input signature"); + torch::jit::IValue converted_input_signature; + to_internal_input_signature(external.input_signature.signature_ivalue, converted_input_signature); + core::CompileSpec internal(converted_input_signature); + return internal; } +} - auto info = core::CompileSpec(internal_inputs); +core::CompileSpec CompileSpec::toInternalCompileSpec() { + core::CompileSpec info = init_compile_spec(*this); for (auto p : enabled_precisions) { info.convert_info.engine_settings.enabled_precisions.insert(toTRTDataType(p)); @@ -243,16 +343,20 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() { std::string CompileSpec::stringify() { std::stringstream ss; ss << "TensorRT Compile Spec: {" << std::endl; - ss << " \"Inputs\": [" << std::endl; - for (auto i : inputs) { - ss << i.to_str(); + if (inputs.size() > 0) { + ss << " \"Inputs\": [" << std::endl; + for (auto i : inputs) { + ss << i.to_str(); + } + ss << " ]" << std::endl; + } else { + ss << " \"Input Signature\": " << input_signature.to_str() << std::endl; } - ss << " ]" << std::endl; - ss << " \"Enabled Precision\": [" << std::endl; + ss << " \"Enabled Precision\": ["; for (auto p : enabled_precisions) { - ss << to_str(p); + ss << to_str(p) << ", "; } - ss << " ]" << std::endl; + ss << "]" << std::endl; ss << " \"TF32 Disabled\": " << disable_tf32 << std::endl; ss << " \"Sparsity\": " << sparse_weights << std::endl; ss << " \"Refit\": " << refit << std::endl; diff --git a/py/torch_tensorrt/csrc/tensorrt_classes.h b/py/torch_tensorrt/csrc/tensorrt_classes.h index b615022bd0..be2fab3b8e 100644 --- a/py/torch_tensorrt/csrc/tensorrt_classes.h +++ b/py/torch_tensorrt/csrc/tensorrt_classes.h @@ -57,6 +57,12 @@ struct Input : torch::CustomClassHolder { std::string to_str(); }; +struct InputSignature : torch::CustomClassHolder { + torch::jit::IValue signature_ivalue; // nested Input, full input spec + ADD_FIELD_GET_SET(signature_ivalue, torch::jit::IValue); + std::string to_str(); +}; + enum DeviceType : int8_t { kGPU, kDLA, @@ -119,6 +125,10 @@ struct CompileSpec : torch::CustomClassHolder { inputs.push_back(*ir); } + void setInputSignature(const c10::intrusive_ptr& is) { + input_signature = *is; + } + void setPrecisions(const std::vector& precisions_raw) { for (auto p : precisions_raw) { TORCHTRT_CHECK(p >= 0 && p <= static_cast(DataType::kBool), "Invalid enum value for field"); @@ -158,6 +168,7 @@ struct CompileSpec : torch::CustomClassHolder { ADD_FIELD_GET_SET(ptq_calibrator, nvinfer1::IInt8Calibrator*); std::vector inputs; + InputSignature input_signature; nvinfer1::IInt8Calibrator* ptq_calibrator = nullptr; std::set enabled_precisions = {}; bool sparse_weights = false; diff --git a/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp b/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp index 74a8b72711..6b1ffd4ccf 100644 --- a/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp +++ b/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp @@ -1,6 +1,7 @@ #include "pybind11/pybind11.h" #include "pybind11/stl.h" +#include "ATen/core/jit_type.h" #include "Python.h" #include "core/compiler.h" #include "core/conversion/conversion.h" @@ -178,6 +179,16 @@ PYBIND11_MODULE(_C, m) { .def_readwrite("dtype", &Input::dtype) .def_readwrite("format", &Input::format); + py::class_(m, "InputSignature") + .def(pybind11::init([](py::object py_obj) { + InputSignature input_signature; + input_signature.signature_ivalue = + torch::jit::toIValue(std::move(py_obj), c10::PyObjectType::get(), c10::nullopt); + return input_signature; + })) + .def("__str__", &InputSignature::to_str) + .def_readwrite("_signature_ivalue", &InputSignature::signature_ivalue); + py::enum_(m, "dtype", "Enum to specifiy operating precision for engine execution") .value("float", DataType::kFloat, "32 bit floating point number") .value("float32", DataType::kFloat, "32 bit floating point number") @@ -292,6 +303,7 @@ PYBIND11_MODULE(_C, m) { .def("__str__", &torch_tensorrt::pyapi::CompileSpec::stringify) .def("_get_calibrator_handle", &CompileSpec::getPTQCalibratorHandle, "[Internal] gets a handle from a calibrator") .def_readwrite("inputs", &CompileSpec::inputs) + .def_readwrite("input_signature", &CompileSpec::input_signature) .def_readwrite("enabled_precisions", &CompileSpec::enabled_precisions) .def_readwrite("ptq_calibrator", &CompileSpec::ptq_calibrator) .def_readwrite("refit", &CompileSpec::refit) diff --git a/py/torch_tensorrt/ts/_compile_spec.py b/py/torch_tensorrt/ts/_compile_spec.py index 4c7b8b5b5d..ac3465b1c4 100644 --- a/py/torch_tensorrt/ts/_compile_spec.py +++ b/py/torch_tensorrt/ts/_compile_spec.py @@ -5,9 +5,22 @@ from torch_tensorrt import _enums from torch_tensorrt._Input import Input from torch_tensorrt._Device import Device - +from torch_tensorrt.logging import Level, log +from typing import Tuple, List, Dict import warnings +from copy import deepcopy + +def _internal_input_to_torch_class_input(i: _C.Input) -> torch.classes.tensorrt._Input: + clone = torch.classes.tensorrt._Input() + clone._set_min(i.min) + clone._set_opt(i.opt) + clone._set_max(i.max) + clone._set_dtype(i.dtype) + clone._set_format(i.format) + clone._set_input_is_dynamic(i.input_is_dynamic) + clone._set_explicit_set_dtype(i._explicit_set_dtype) + return clone def _supported_input_size_type(input_size: Any) -> bool: if isinstance(input_size, torch.Size): @@ -156,15 +169,32 @@ def _parse_torch_fallback(fallback_info: Dict[str, Any]) -> _ts_C.TorchFallback: return info +def _parse_input_signature(input_signature: Any): + if isinstance(input_signature, tuple): + input_list = [] + for item in input_signature: + input = _parse_input_signature(item) + input_list.append(input) + return tuple(input_list) + elif isinstance(input_signature, list): + input_list = [] + for item in input_signature: + input = _parse_input_signature(item) + input_list.append(input) + return input_list + elif isinstance(input_signature, Input) or isinstance(input_signature, torch.Tensor): + i = Input._from_tensor(input_signature) if isinstance(input_signature, torch.Tensor) else input_signature + clone = _internal_input_to_torch_class_input(i._to_internal()) + return clone + else: + raise KeyError("Input signature contains an unsupported type {}".format(type(input_signature))) -def _parse_compile_spec(compile_spec: Dict[str, Any]) -> _ts_C.CompileSpec: +def _parse_compile_spec(compile_spec_: Dict[str, Any]) -> _ts_C.CompileSpec: + # TODO: Remove deep copy once collections does not need partial compilation + compile_spec = deepcopy(compile_spec_) info = _ts_C.CompileSpec() - if "inputs" not in compile_spec: - raise KeyError( - "Module input definitions are requried to compile module. Provide a list of torch_tensorrt.Input keyed to \"inputs\" in the compile spec" - ) - if "inputs" in compile_spec: + if len(compile_spec["inputs"]) > 0: if not all([isinstance(i, torch.Tensor) or isinstance(i, Input) for i in compile_spec["inputs"]]): raise KeyError("Input specs should be either torch_tensorrt.Input or torch.Tensor, found types: {}".format( [type(i) for i in compile_spec["inputs"]])) @@ -172,7 +202,34 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> _ts_C.CompileSpec: inputs = [Input._from_tensor(i) if isinstance(i, torch.Tensor) else i for i in compile_spec["inputs"]] info.inputs = [i._to_internal() for i in inputs] - assert (len(info.inputs) > 0), "Require at least one input definition to compile model" + elif compile_spec["input_signature"] is not None: + log(Level.Warning, "Input signature parsing is an experimental feature, behavior and APIs may change") + signature = _parse_input_signature(compile_spec["input_signature"]) + info.input_signature = _C.InputSignature(signature) # py_object + + if not compile_spec["torch_fallback"]["enabled"]: + raise ValueError("Grouped inputs currently requires partial compilation to be enabled, this restriction will be relaxed in a future release") + + log(Level.Debug, "Grouped inputs currently requires additional settings to enable the feature") + log(Level.Debug, """Adding the following ops to torch_executed_ops: + - aten::__getitem__ + - prim::ListConstruct + - prim::ListUnpack + - prim::TupleIndex + - prim::TupleConstruct + - prim::TupleUnpack +""") + compile_spec["torch_fallback"]["forced_fallback_ops"].append("aten::__getitem__") + compile_spec["torch_fallback"]["forced_fallback_ops"].append("prim::ListConstruct") + compile_spec["torch_fallback"]["forced_fallback_ops"].append("prim::ListUnpack") + compile_spec["torch_fallback"]["forced_fallback_ops"].append("prim::TupleIndex") + compile_spec["torch_fallback"]["forced_fallback_ops"].append("prim::TupleConstruct") + compile_spec["torch_fallback"]["forced_fallback_ops"].append("prim::TupleUnpack") + + else: + raise KeyError( + "Module input definitions are requried to compile module. Provide a list of torch_tensorrt.Input keyed to \"inputs\" in the compile spec" + ) if "enabled_precisions" in compile_spec: info.enabled_precisions = _parse_enabled_precisions(compile_spec["enabled_precisions"]) @@ -230,10 +287,13 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> _ts_C.CompileSpec: if "torch_fallback" in compile_spec: info.torch_fallback = _parse_torch_fallback(compile_spec["torch_fallback"]) + log(Level.Debug, str(info)) + return info def TensorRTCompileSpec(inputs=[], + input_signature=None, device=Device._current_device(), disable_tf32=False, sparse_weights=False, @@ -288,6 +348,7 @@ def TensorRTCompileSpec(inputs=[], compile_spec = { "inputs": inputs, + #"input_signature": input_signature, "device": device, "disable_tf32": disable_tf32, # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas @@ -309,15 +370,11 @@ def TensorRTCompileSpec(inputs=[], backend_spec = torch.classes.tensorrt.CompileSpec() + if input_signature is not None: + raise ValueError("Input signature parsing is not currently supported in the TorchScript backend integration") + for i in parsed_spec.inputs: - clone = torch.classes.tensorrt._Input() - clone._set_min(i.min) - clone._set_opt(i.opt) - clone._set_max(i.max) - clone._set_dtype(i.dtype) - clone._set_format(i.format) - clone._set_input_is_dynamic(i.input_is_dynamic) - clone._set_explicit_set_dtype(i._explicit_set_dtype) + clone = _internal_input_to_torch_class_input(i) backend_spec._append_input(clone) d = torch.classes.tensorrt._Device() diff --git a/py/torch_tensorrt/ts/_compiler.py b/py/torch_tensorrt/ts/_compiler.py index 83704a4b6c..cc5f4b24d1 100644 --- a/py/torch_tensorrt/ts/_compiler.py +++ b/py/torch_tensorrt/ts/_compiler.py @@ -11,6 +11,7 @@ def compile(module: torch.jit.ScriptModule, inputs=[], + input_signature=None, device=Device._current_device(), disable_tf32=False, sparse_weights=False, @@ -57,6 +58,19 @@ def compile(module: torch.jit.ScriptModule, torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings ] + input_signature Union(List, Tuple, torch_tensorrt.Input, torch.Tensor): A formatted collection of input specifications for the module. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using + torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum to select device type. **This API should be considered beta-level stable and may change in the future** :: + + input_signature=([ + torch_tensorrt.Input((1, 3, 224, 224)), # Static NCHW input shape for input #1 + torch_tensorrt.Input( + min_shape=(1, 224, 224, 3), + opt_shape=(1, 512, 512, 3), + max_shape=(1, 1024, 1024, 3), + dtype=torch.int32 + format=torch.channel_last + ), # Dynamic input shape for input #2 + ], torch.randn((1, 3, 224, 244))) # Use an example tensor and let torch_tensorrt infer settings for input #3 device (Union(torch_tensorrt.Device, torch.device, dict)): Target device for TensorRT engines to run on :: device=torch_tensorrt.Device("dla:1", allow_gpu_fallback=True) @@ -89,11 +103,11 @@ def compile(module: torch.jit.ScriptModule, if require_full_compilation and (len(torch_executed_modules) > 0 or len(torch_executed_ops) > 0): raise ValueError( - "require_full_compilation is enabled however the list of modules and ops to run in torch is not empty. Found: torch_executed_ops: " - + torch_executed_ops + ", torch_executed_modules: " + torch_executed_modules) + f"require_full_compilation is enabled however the list of modules and ops to run in torch is not empty. Found: torch_executed_ops: {torch_executed_ops}, torch_executed_modules: {torch_executed_modules}") spec = { "inputs": inputs, + "input_signature": input_signature, "device": device, "disable_tf32": disable_tf32, # Force FP32 layers to use traditional as FP32 format "sparse_weights": sparse_weights, #Enable sparsity for convolution and fully connected layers. @@ -161,6 +175,20 @@ def convert_method_to_trt_engine(module: torch.jit.ScriptModule, torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings ] + input_signature Union(List, Tuple, torch_tensorrt.Input, torch.Tensor): A formatted collection of input specifications for the module. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using + torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum to select device type. **This API should be considered beta-level stable and may change in the future** :: + + input_signature=([ + torch_tensorrt.Input((1, 3, 224, 224)), # Static NCHW input shape for input #1 + torch_tensorrt.Input( + min_shape=(1, 224, 224, 3), + opt_shape=(1, 512, 512, 3), + max_shape=(1, 1024, 1024, 3), + dtype=torch.int32 + format=torch.channel_last + ), # Dynamic input shape for input #2 + ], torch.randn((1, 3, 224, 244))) # Use an example tensor and let torch_tensorrt infer settings for input #3 + device (Union(torch_tensorrt.Device, torch.device, dict)): Target device for TensorRT engines to run on :: device=torch_tensorrt.Device("dla:1", allow_gpu_fallback=True) diff --git a/tests/core/conversion/converters/test_element_wise.cpp b/tests/core/conversion/converters/test_element_wise.cpp index 994fb25811..939c9b7394 100644 --- a/tests/core/conversion/converters/test_element_wise.cpp +++ b/tests/core/conversion/converters/test_element_wise.cpp @@ -12,7 +12,9 @@ void pointwise_test_helper( std::vector shape1 = {5}, std::vector shape2 = {5}, bool negative_input = false, - bool int_tensors = false) { + bool int_tensors = false, + bool float_int_tensors = false, + bool int_float_tensors = false) { auto g = std::make_shared(); torch::jit::parseIR(graph_ir, g.get()); @@ -27,11 +29,24 @@ void pointwise_test_helper( if (!singleInput) { torch_inputs.push_back(at::randint(1, 5, shape2, {at::kCUDA})); } + + TORCHTRT_CHECK(!((int_tensors && (float_int_tensors || int_float_tensors)) || (float_int_tensors && int_float_tensors)), + "Invalid test configuration, only one of int_tensors, float_int_tensors, int_float_tensors can be true"); + if(int_tensors){ for(size_t i = 0UL; i < torch_inputs.size(); ++i){ torch_inputs[i] = torch_inputs[i].to(at::kInt); } + } else if(float_int_tensors) { + TORCHTRT_CHECK(!singleInput, "float_int_tensors tests require two inputs"); + torch_inputs[0] = torch_inputs[0].to(at::kFloat); + torch_inputs[1] = torch_inputs[1].to(at::kInt); + } else if (int_float_tensors) { + TORCHTRT_CHECK(!singleInput, "int_float_tensors tests require two inputs"); + torch_inputs[0] = torch_inputs[0].to(at::kInt); + torch_inputs[1] = torch_inputs[1].to(at::kFloat); } + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, torch_inputs); @@ -62,6 +77,8 @@ TEST(Converters, ATenAddConvertsCorrectly) { pointwise_test_helper(graph, false, false, {4}, {3, 4}); pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3}); pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3}); + pointwise_test_helper(graph, false, true, {5}, {5}, false, false, true); + pointwise_test_helper(graph, false, true, {5}, {5}, false, false, false, true); } TEST(Converters, ATenAddWithAlphaConvertsCorrectly) { @@ -75,9 +92,11 @@ TEST(Converters, ATenAddWithAlphaConvertsCorrectly) { pointwise_test_helper(graph, false, false, {4}, {3, 4}); pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3}); pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3}); + pointwise_test_helper(graph, false, true, {5}, {5}, false, false, true); + pointwise_test_helper(graph, false, true, {5}, {5}, false, false, false, true); } -TEST(Converters, ATenAddImplicitWithAlphaConvertsCorrectly) { +TEST(Converters, ATenAddInplaceWithAlphaConvertsCorrectly) { const auto graph = R"IR( graph(%0 : Tensor, %1 : Tensor): %2 : float = prim::Constant[value=7.6]() @@ -109,6 +128,8 @@ TEST(Converters, ATenSubConvertsCorrectly) { pointwise_test_helper(graph, false, false, {4}, {3, 4}); pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3}); pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3}); + pointwise_test_helper(graph, false, true, {5}, {5}, false, false, true); + pointwise_test_helper(graph, false, true, {5}, {5}, false, false, false, true); } TEST(Converters, ATenMulConvertsCorrectly) { @@ -121,6 +142,8 @@ TEST(Converters, ATenMulConvertsCorrectly) { pointwise_test_helper(graph, false, false, {4}, {3, 4}); pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3}); pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3}); + pointwise_test_helper(graph, false, true, {5}, {5}, false, false, true); + pointwise_test_helper(graph, false, true, {5}, {5}, false, false, false, true); } TEST(Converters, ATenMulWithScalarConvertsCorrectly) { @@ -151,6 +174,8 @@ TEST(Converters, ATenDivConvertsCorrectly) { pointwise_test_helper(graph, false, false, {4}, {3, 4}); pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3}); pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3}); + pointwise_test_helper(graph, false, true, {5}, {5}, false, false, true); + pointwise_test_helper(graph, false, true, {5}, {5}, false, false, false, true); } TEST(Converters, ATenDivWithScalarConvertsCorrectly) { @@ -173,6 +198,8 @@ TEST(Converters, ATenDivRoundingFloorConvertsCorrectly) { pointwise_test_helper(graph, false, false, {4}, {3, 4}, true); pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3}, true); pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3}, true); + pointwise_test_helper(graph, false, true, {5}, {5}, false, false, true); + pointwise_test_helper(graph, false, true, {5}, {5}, false, false, false, true); } TEST(Converters, ATenDivRoundingTruncConvertsCorrectly) { @@ -186,6 +213,8 @@ TEST(Converters, ATenDivRoundingTruncConvertsCorrectly) { pointwise_test_helper(graph, false, false, {4}, {3, 4}, true); pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3}, true); pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3}, true); + pointwise_test_helper(graph, false, true, {5}, {5}, false, false, true); + pointwise_test_helper(graph, false, true, {5}, {5}, false, false, false, true); } TEST(Converters, ATenDivRoundingNoneConvertsCorrectly) { @@ -211,6 +240,8 @@ TEST(Converters, ATenPowTensorConvertsCorrectly) { pointwise_test_helper(graph, false, false, {4}, {3, 4}); pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3}); pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3}); + pointwise_test_helper(graph, false, true, {5}, {5}, false, false, true); + pointwise_test_helper(graph, false, true, {5}, {5}, false, false, false, true); } TEST(Converters, ATenPowScalarConvertsCorrectly) { @@ -251,6 +282,8 @@ TEST(Converters, ATenFloorDivideConvertsCorrectly) { pointwise_test_helper(graph, false, false, {4}, {3, 4}); pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3}); pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3}); + pointwise_test_helper(graph, false, true, {5}, {5}, false, false, true); + pointwise_test_helper(graph, false, true, {5}, {5}, false, false, false, true); } TEST(Converters, ATenFloorDivideWithScalarConvertsCorrectly) { diff --git a/tests/core/conversion/evaluators/test_prim_evaluators.cpp b/tests/core/conversion/evaluators/test_prim_evaluators.cpp index 0ff250f0e9..508d4eb1b0 100644 --- a/tests/core/conversion/evaluators/test_prim_evaluators.cpp +++ b/tests/core/conversion/evaluators/test_prim_evaluators.cpp @@ -51,5 +51,112 @@ TEST(Evaluators, NumToTensorEvaluatesCorrectly) { auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {}); auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {}); + ASSERT_TRUE(jit_results[0] == trt_results[0]); +} + +TEST(Evaluators, PrimTupleConstruct1EvaluatesCorrectly) { + const auto graph = R"IR( + graph(): + %1 : int = prim::Constant[value=3]() + %tc : (int) = prim::TupleConstruct(%1) + return (%tc))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {}); + auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {}); + + ASSERT_TRUE(jit_results[0] == trt_results[0]); +} + +TEST(Evaluators, PrimTupleConstruct2EvaluatesCorrectly) { + const auto graph = R"IR( + graph(): + %1 : int = prim::Constant[value=3]() + %2 : int = prim::Constant[value=4]() + %tc : (int, int) = prim::TupleConstruct(%1, %2) + return (%tc))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {}); + auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {}); + + ASSERT_TRUE(jit_results[0] == trt_results[0]); +} + +TEST(Evaluators, PrimTupleConstruct3EvaluatesCorrectly) { + const auto graph = R"IR( + graph(): + %1 : int = prim::Constant[value=3]() + %2 : int = prim::Constant[value=4]() + %3 : int = prim::Constant[value=4]() + %tc : (int, int, int) = prim::TupleConstruct(%1, %2, %3) + return (%tc))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {}); + auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {}); + + ASSERT_TRUE(jit_results[0] == trt_results[0]); +} + +TEST(Evaluators, PrimTupleConstruct4EvaluatesCorrectly) { + const auto graph = R"IR( + graph(): + %1 : int = prim::Constant[value=3]() + %2 : int = prim::Constant[value=4]() + %3 : int = prim::Constant[value=3]() + %4 : int = prim::Constant[value=4]() + %tc : (int, int, int, int) = prim::TupleConstruct(%1, %2, %3, %4) + return (%tc))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {}); + auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {}); + + ASSERT_TRUE(jit_results[0] == trt_results[0]); +} + +TEST(Evaluators, PrimTupleUnpackEvaluatesCorrectly) { + const auto graph = R"IR( + graph(): + %1 : int = prim::Constant[value=3]() + %2 : int = prim::Constant[value=4]() + %tc : (int, int) = prim::TupleConstruct(%1, %2) + %tu.1 : int, %tu.2 : int = prim::TupleUnpack(%tc) + return (%tu.1, %tu.2))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {}); + auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {}); + + ASSERT_TRUE(jit_results[0] == trt_results[0]); +} + +TEST(Evaluators, PrimTupleIndexEvaluatesCorrectly) { + const auto graph = R"IR( + graph(): + %0 : int = prim::Constant[value=1]() + %1 : int = prim::Constant[value=3]() + %2 : int = prim::Constant[value=4]() + %tc : (int, int) = prim::TupleConstruct(%1, %2) + %ti : int = prim::TupleIndex(%tc, %0) + return (%ti))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {}); + auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {}); + ASSERT_TRUE(jit_results[0] == trt_results[0]); } \ No newline at end of file diff --git a/tests/core/partitioning/test_resolve_nontensor_inputs.cpp b/tests/core/partitioning/test_resolve_nontensor_inputs.cpp index fea202fc65..1f3ee3b051 100644 --- a/tests/core/partitioning/test_resolve_nontensor_inputs.cpp +++ b/tests/core/partitioning/test_resolve_nontensor_inputs.cpp @@ -116,11 +116,11 @@ TEST(Partitioning, ResolveNonTensorInputsCorrectly) { inputs.push_back(torch_tensorrt::core::ir::Input({16, 3, 3, 3})); inputs.push_back(torch_tensorrt::core::ir::Input({16})); - std::unordered_map inputs_map; - std::unordered_map> input_types; + torch_tensorrt::core::ir::CollectionInputSpecMap inputs_map; + std::unordered_map>> input_types; for (size_t i = 0; i < g->inputs().size(); ++i) { - inputs_map.insert({g->inputs()[i], inputs[i]}); - input_types.insert({g->inputs()[i], {at::kFloat}}); + inputs_map.insert({g->inputs()[i], {inputs[i]}}); + input_types.insert({g->inputs()[i], {{at::kFloat}}}); } auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types); std::unordered_map fallback_nodes; @@ -175,11 +175,11 @@ TEST(Partitioning, ResolveTensorListInputsInTrtCorrectly) { inputs.push_back(torch_tensorrt::core::ir::Input({16, 6, 3, 3})); inputs.push_back(torch_tensorrt::core::ir::Input({16})); - std::unordered_map inputs_map; - std::unordered_map> input_types; + std::unordered_map> inputs_map; + std::unordered_map>> input_types; for (size_t i = 0; i < g->inputs().size(); ++i) { - inputs_map.insert({g->inputs()[i], inputs[i]}); - input_types.insert({g->inputs()[i], {at::kFloat}}); + inputs_map.insert({g->inputs()[i], {inputs[i]}}); + input_types.insert({g->inputs()[i], {{at::kFloat}}}); } auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types); std::unordered_map fallback_nodes; @@ -367,11 +367,11 @@ TEST(Partitioning, ResolveOnlyNeccessaryNonTensorInputs) { inputs.push_back(torch_tensorrt::core::ir::Input({4, 4})); inputs.push_back(torch_tensorrt::core::ir::Input({4, 4})); - std::unordered_map inputs_map; - std::unordered_map> input_types; + torch_tensorrt::core::ir::CollectionInputSpecMap inputs_map; + std::unordered_map>> input_types; for (size_t i = 0; i < g->inputs().size(); ++i) { - inputs_map.insert({g->inputs()[i], inputs[i]}); - input_types.insert({g->inputs()[i], {at::kFloat}}); + inputs_map.insert({g->inputs()[i], {inputs[i]}}); + input_types.insert({g->inputs()[i], {{at::kFloat}}}); } auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types); std::unordered_map fallback_nodes; diff --git a/tests/core/partitioning/test_shape_analysis.cpp b/tests/core/partitioning/test_shape_analysis.cpp index 7bcabc0d51..151a6e75ad 100644 --- a/tests/core/partitioning/test_shape_analysis.cpp +++ b/tests/core/partitioning/test_shape_analysis.cpp @@ -59,11 +59,11 @@ TEST(Partitioning, InferSequentialModelSegmentedBlockShapeCorrectly) { inputs.push_back(torch_tensorrt::core::ir::Input({8, 16, 3, 3})); inputs.push_back(torch_tensorrt::core::ir::Input({8})); - std::unordered_map inputs_map; - std::unordered_map> input_types; + std::unordered_map> inputs_map; + std::unordered_map>> input_types; for (size_t i = 0; i < g->inputs().size(); ++i) { - inputs_map.insert({g->inputs()[i], inputs[i]}); - input_types.insert({g->inputs()[i], {at::kFloat}}); + inputs_map.insert({g->inputs()[i], {inputs[i]}}); + input_types.insert({g->inputs()[i], {{at::kFloat}}}); } auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types); std::unordered_map fallback_nodes; @@ -110,11 +110,11 @@ TEST(Partitioning, InferBranchModelSegmentedBlockShapeCorrectly) { inputs.push_back(torch_tensorrt::core::ir::Input({16, 32, 3, 3})); inputs.push_back(torch_tensorrt::core::ir::Input({16})); - std::unordered_map inputs_map; - std::unordered_map> input_types; + std::unordered_map> inputs_map; + std::unordered_map>> input_types; for (size_t i = 0; i < g->inputs().size(); ++i) { - inputs_map.insert({g->inputs()[i], inputs[i]}); - input_types.insert({g->inputs()[i], {at::kFloat}}); + inputs_map.insert({g->inputs()[i], {inputs[i]}}); + input_types.insert({g->inputs()[i], {{at::kFloat}}}); } auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types); std::unordered_map fallback_nodes; diff --git a/tests/cpp/BUILD b/tests/cpp/BUILD index 3d69afba95..8e479e2e0a 100644 --- a/tests/cpp/BUILD +++ b/tests/cpp/BUILD @@ -18,7 +18,8 @@ test_suite( ":test_multiple_registered_engines", ":test_serialization", ":test_module_fallback", - ":test_example_tensors" + ":test_example_tensors", + ":test_collections" ], ) @@ -32,7 +33,8 @@ test_suite( ":test_multiple_registered_engines", ":test_serialization", ":test_module_fallback", - ":test_example_tensors" + ":test_example_tensors", + ":test_collections" ], ) @@ -122,6 +124,20 @@ cc_test( }) ) +cc_test( + name = "test_collections", + srcs = ["test_collections.cpp"], + data = [ + "//tests/modules:jit_models", + ], + deps = [ + "//tests/util", + "@googletest//:gtest_main", + ] + select({ + ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], + "//conditions:default": ["@libtorch//:libtorch"], + }) +) cc_test( name = "test_compiled_modules", srcs = ["test_compiled_modules.cpp"], diff --git a/tests/cpp/test_collections.cpp b/tests/cpp/test_collections.cpp new file mode 100644 index 0000000000..31495a47a7 --- /dev/null +++ b/tests/cpp/test_collections.cpp @@ -0,0 +1,320 @@ +#include +#include +#include "gtest/gtest.h" +#include "tests/util/util.h" +#include "torch/script.h" +#include "torch_tensorrt/torch_tensorrt.h" + +TEST(CppAPITests, TestCollectionStandardTensorInput) { + std::string path = "tests/modules/standard_tensor_input_scripted.jit.pt"; + torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf); + std::vector inputs; + inputs.push_back(in0); + inputs.push_back(in0); + + torch::jit::Module mod; + try { + // Deserialize the ScriptModule from a file using torch::jit::load(). + mod = torch::jit::load(path); + } catch (const c10::Error& e) { + std::cerr << "error loading the model\n"; + } + mod.eval(); + mod.to(torch::kCUDA); + + std::vector inputs_; + + for (auto in : inputs) { + inputs_.push_back(torch::jit::IValue(in.clone())); + } + + auto out = mod.forward(inputs_); + + std::vector input_range; + input_range.push_back({in0.sizes(), torch::kF16}); + input_range.push_back({in0.sizes(), torch::kF16}); + torch_tensorrt::ts::CompileSpec compile_settings(input_range); + compile_settings.min_block_size = 1; + + // // FP16 execution + compile_settings.enabled_precisions = {torch::kHalf}; + // // Compile module + auto trt_mod = torch_tensorrt::torchscript::compile(mod, compile_settings); + auto trt_out = trt_mod.forward(inputs_); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(out.toTensor(), trt_out.toTensor(), 1e-5)); +} + +TEST(CppAPITests, TestCollectionTupleInput) { + std::string path = "tests/modules/tuple_input_scripted.jit.pt"; + torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf); + + torch::jit::Module mod; + try { + // Deserialize the ScriptModule from a file using torch::jit::load(). + mod = torch::jit::load(path); + } catch (const c10::Error& e) { + std::cerr << "error loading the model\n"; + } + mod.eval(); + mod.to(torch::kCUDA); + + std::vector complex_inputs, complex_inputs_list; + std::tuple input_tuple(in0, in0); + + complex_inputs.push_back(input_tuple); + + auto out = mod.forward(complex_inputs); + + auto input_shape = torch_tensorrt::Input(in0.sizes(), torch_tensorrt::DataType::kHalf); + + auto input_shape_ivalue = torch::jit::IValue(std::move(c10::make_intrusive(input_shape))); + + std::tuple input_shape_tuple(input_shape_ivalue, input_shape_ivalue); + + torch::jit::IValue complex_input_shape(input_shape_tuple); + std::tuple input_tuple2(complex_input_shape); + torch::jit::IValue complex_input_shape2(input_tuple2); + + auto compile_settings = torch_tensorrt::ts::CompileSpec(complex_input_shape2); + compile_settings.min_block_size = 1; + + // // FP16 execution + compile_settings.enabled_precisions = {torch::kHalf}; + // // Compile module + auto trt_mod = torch_tensorrt::torchscript::compile(mod, compile_settings); + auto trt_out = trt_mod.forward(complex_inputs); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(out.toTensor(), trt_out.toTensor(), 1e-5)); +} + +TEST(CppAPITests, TestCollectionListInput) { + std::string path = "tests/modules/list_input_scripted.jit.pt"; + torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf); + std::vector inputs; + inputs.push_back(in0); + + torch::jit::Module mod; + try { + // Deserialize the ScriptModule from a file using torch::jit::load(). + mod = torch::jit::load(path); + } catch (const c10::Error& e) { + std::cerr << "error loading the model\n"; + } + mod.eval(); + mod.to(torch::kCUDA); + + std::vector inputs_; + + for (auto in : inputs) { + inputs_.push_back(torch::jit::IValue(in.clone())); + } + + std::vector complex_inputs; + auto input_list = c10::impl::GenericList(c10::TensorType::get()); + input_list.push_back(inputs_[0]); + input_list.push_back(inputs_[0]); + + torch::jit::IValue input_list_ivalue = torch::jit::IValue(input_list); + + complex_inputs.push_back(input_list_ivalue); + + auto out = mod.forward(complex_inputs); + + auto input_shape = torch_tensorrt::Input(in0.sizes(), torch_tensorrt::DataType::kHalf); + auto input_shape_ivalue = torch::jit::IValue(std::move(c10::make_intrusive(input_shape))); + + c10::TypePtr elementType = input_shape_ivalue.type(); + auto list = c10::impl::GenericList(elementType); + list.push_back(input_shape_ivalue); + list.push_back(input_shape_ivalue); + + torch::jit::IValue complex_input_shape(list); + std::tuple input_tuple2(complex_input_shape); + torch::jit::IValue complex_input_shape2(input_tuple2); + + auto compile_settings = torch_tensorrt::ts::CompileSpec(complex_input_shape2); + compile_settings.min_block_size = 1; + //compile_settings.torch_executed_ops.push_back("aten::__getitem__"); + + // // FP16 execution + compile_settings.enabled_precisions = {torch::kHalf}; + // // Compile module + auto trt_mod = torch_tensorrt::torchscript::compile(mod, compile_settings); + LOG_DEBUG("Finish compile"); + auto trt_out = trt_mod.forward(complex_inputs); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(out.toTensor(), trt_out.toTensor(), 1e-5)); +} + +TEST(CppAPITests, TestCollectionTupleInputOutput) { + std::string path = "tests/modules/tuple_input_output_scripted.jit.pt"; + + torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf); + + torch::jit::Module mod; + try { + // Deserialize the ScriptModule from a file using torch::jit::load(). + mod = torch::jit::load(path); + } catch (const c10::Error& e) { + std::cerr << "error loading the model\n"; + } + mod.eval(); + mod.to(torch::kCUDA); + + std::vector complex_inputs, complex_inputs_list; + std::tuple input_tuple(in0, in0); + + complex_inputs.push_back(input_tuple); + + auto out = mod.forward(complex_inputs); + + auto input_shape = torch_tensorrt::Input(in0.sizes(), torch_tensorrt::DataType::kHalf); + + auto input_shape_ivalue = torch::jit::IValue(std::move(c10::make_intrusive(input_shape))); + + std::tuple input_shape_tuple(input_shape_ivalue, input_shape_ivalue); + + torch::jit::IValue complex_input_shape(input_shape_tuple); + std::tuple input_tuple2(complex_input_shape); + torch::jit::IValue complex_input_shape2(input_tuple2); + // torch::jit::IValue complex_input_shape(list); + + auto compile_settings = torch_tensorrt::ts::CompileSpec(complex_input_shape2); + compile_settings.min_block_size = 1; + + // compile_settings.torch_executed_ops.push_back("prim::TupleConstruct"); + + // // FP16 execution + compile_settings.enabled_precisions = {torch::kHalf}; + // // Compile module + auto trt_mod = torch_tensorrt::torchscript::compile(mod, compile_settings); + auto trt_out = trt_mod.forward(complex_inputs); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual( + out.toTuple()->elements()[0].toTensor(), trt_out.toTuple()->elements()[0].toTensor(), 1e-5)); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual( + out.toTuple()->elements()[1].toTensor(), trt_out.toTuple()->elements()[1].toTensor(), 1e-5)); +} + +TEST(CppAPITests, TestCollectionListInputOutput) { + std::string path = "tests/modules/list_input_output_scripted.jit.pt"; + torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf); + std::vector inputs; + inputs.push_back(in0); + + torch::jit::Module mod; + try { + // Deserialize the ScriptModule from a file using torch::jit::load(). + mod = torch::jit::load(path); + } catch (const c10::Error& e) { + std::cerr << "error loading the model\n"; + } + mod.eval(); + mod.to(torch::kCUDA); + + std::vector inputs_; + + for (auto in : inputs) { + inputs_.push_back(torch::jit::IValue(in.clone())); + } + + std::vector complex_inputs; + auto input_list = c10::impl::GenericList(c10::TensorType::get()); + input_list.push_back(inputs_[0]); + input_list.push_back(inputs_[0]); + + torch::jit::IValue input_list_ivalue = torch::jit::IValue(input_list); + + complex_inputs.push_back(input_list_ivalue); + + auto out = mod.forward(complex_inputs); + + auto input_shape = torch_tensorrt::Input(in0.sizes(), torch_tensorrt::DataType::kHalf); + + auto input_shape_ivalue = torch::jit::IValue(std::move(c10::make_intrusive(input_shape))); + + c10::TypePtr elementType = input_shape_ivalue.type(); + auto list = c10::impl::GenericList(elementType); + list.push_back(input_shape_ivalue); + list.push_back(input_shape_ivalue); + + torch::jit::IValue complex_input_shape(list); + std::tuple input_tuple2(complex_input_shape); + torch::jit::IValue complex_input_shape2(input_tuple2); + + auto compile_settings = torch_tensorrt::ts::CompileSpec(complex_input_shape2); + compile_settings.min_block_size = 1; + + // // FP16 execution + compile_settings.enabled_precisions = {torch::kHalf}; + // // Compile module + auto trt_mod = torch_tensorrt::torchscript::compile(mod, compile_settings); + auto trt_out = trt_mod.forward(complex_inputs); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual( + out.toList().vec()[0].toTensor(), trt_out.toList().vec()[0].toTensor(), 1e-5)); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual( + out.toList().vec()[1].toTensor(), trt_out.toList().vec()[1].toTensor(), 1e-5)); +} + +TEST(CppAPITests, TestCollectionComplexModel) { + std::string path = "tests/modules/list_input_tuple_output_scripted.jit.pt"; + torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf); + std::vector inputs; + inputs.push_back(in0); + + torch::jit::Module mod; + try { + // Deserialize the ScriptModule from a file using torch::jit::load(). + mod = torch::jit::load(path); + } catch (const c10::Error& e) { + std::cerr << "error loading the model\n"; + } + mod.eval(); + mod.to(torch::kCUDA); + + std::vector inputs_; + + for (auto in : inputs) { + inputs_.push_back(torch::jit::IValue(in.clone())); + } + + std::vector complex_inputs; + auto input_list = c10::impl::GenericList(c10::TensorType::get()); + input_list.push_back(inputs_[0]); + input_list.push_back(inputs_[0]); + + torch::jit::IValue input_list_ivalue = torch::jit::IValue(input_list); + + complex_inputs.push_back(input_list_ivalue); + + auto out = mod.forward(complex_inputs); + + auto input_shape = torch_tensorrt::Input(in0.sizes(), torch_tensorrt::DataType::kHalf); + + auto input_shape_ivalue = torch::jit::IValue(std::move(c10::make_intrusive(input_shape))); + + c10::TypePtr elementType = input_shape_ivalue.type(); + auto list = c10::impl::GenericList(elementType); + list.push_back(input_shape_ivalue); + list.push_back(input_shape_ivalue); + + torch::jit::IValue complex_input_shape(list); + std::tuple input_tuple2(complex_input_shape); + torch::jit::IValue complex_input_shape2(input_tuple2); + + auto compile_settings = torch_tensorrt::ts::CompileSpec(complex_input_shape2); + compile_settings.min_block_size = 1; + + // // FP16 execution + compile_settings.enabled_precisions = {torch::kHalf}; + // // Compile module + auto trt_mod = torch_tensorrt::torchscript::compile(mod, compile_settings); + auto trt_out = trt_mod.forward(complex_inputs); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual( + out.toTuple()->elements()[0].toTensor(), trt_out.toTuple()->elements()[0].toTensor(), 1e-5)); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual( + out.toTuple()->elements()[1].toTensor(), trt_out.toTuple()->elements()[1].toTensor(), 1e-5)); +} \ No newline at end of file diff --git a/tests/cpp/test_example_tensors.cpp b/tests/cpp/test_example_tensors.cpp index 6561cd16a0..256e6f1b59 100644 --- a/tests/cpp/test_example_tensors.cpp +++ b/tests/cpp/test_example_tensors.cpp @@ -9,7 +9,8 @@ TEST_P(CppAPITests, InputsFromTensors) { trt_inputs_ivalues.push_back(in.clone()); } - auto spec = torch_tensorrt::ts::CompileSpec({trt_inputs_ivalues[0].toTensor()}); + auto inputs = std::vector{trt_inputs_ivalues[0].toTensor()}; + auto spec = torch_tensorrt::ts::CompileSpec(inputs); auto trt_mod = torch_tensorrt::ts::compile(mod, spec); torch::jit::IValue trt_results_ivalues = torch_tensorrt::tests::util::RunModuleForward(trt_mod, trt_inputs_ivalues); diff --git a/tests/modules/custom_models.py b/tests/modules/custom_models.py index 20d501045f..a92e01e7a4 100644 --- a/tests/modules/custom_models.py +++ b/tests/modules/custom_models.py @@ -2,6 +2,7 @@ import torch.nn as nn from transformers import BertModel, BertTokenizer, BertConfig import torch.nn.functional as F +from typing import Tuple, List, Dict # Sample Pool Model (for testing plugin serialization) @@ -100,6 +101,67 @@ def forward(self, x, y): z = torch.cat(mod_list) return z +# Collection input/output models +class StandardTensorInput(nn.Module): + def __init__(self): + super(StandardTensorInput, self).__init__() + + def forward(self, x, y): + r = x + y + return r + +class TupleInput(nn.Module): + def __init__(self): + super(TupleInput, self).__init__() + + def forward(self, z: Tuple[torch.Tensor, torch.Tensor]): + r = z[0] + z[1] + return r + +class ListInput(nn.Module): + def __init__(self): + super(ListInput, self).__init__() + + def forward(self, z: List[torch.Tensor]): + r = z[0] + z[1] + return r + +class TupleInputOutput(nn.Module): + def __init__(self): + super(TupleInputOutput, self).__init__() + + def forward(self, z: Tuple[torch.Tensor, torch.Tensor]): + r1 = z[0] + z[1] + r2 = z[0] - z[1] + r1 = r1 * 10 + r = (r1, r2) + return r + +class ListInputOutput(nn.Module): + def __init__(self): + super(ListInputOutput, self).__init__() + + def forward(self, z: List[torch.Tensor]): + r1 = z[0] + z[1] + r2 = z[0] - z[1] + r = [r1, r2] + return r + +class ListInputTupleOutput(nn.Module): + def __init__(self): + super(ListInputTupleOutput, self).__init__() + self.list_model = ListInputOutput() + self.tuple_model = TupleInputOutput() + + def forward(self, z: List[torch.Tensor]): + r1 = z[0] + z[1] + r2 = z[0] - z[1] + r3 = (r1, r2) + r4 = [r2, r1] + tuple_out = self.tuple_model(r3) + list_out = self.list_model(r4) + r = (tuple_out[1], list_out[0]) + return r def BertModule(): model_name = "bert-base-uncased" diff --git a/tests/modules/hub.py b/tests/modules/hub.py index 48e6b519cb..7d3e03e395 100644 --- a/tests/modules/hub.py +++ b/tests/modules/hub.py @@ -104,6 +104,30 @@ "model": cm.FallbackInplaceOPIf(), "path": "script" }, + "standard_tensor_input": { + "model": cm.StandardTensorInput(), + "path": "script" + }, + "tuple_input": { + "model": cm.TupleInput(), + "path": "script" + }, + "list_input": { + "model": cm.ListInput(), + "path": "script" + }, + "tuple_input_output": { + "model": cm.TupleInputOutput(), + "path": "script" + }, + "list_input_output": { + "model": cm.ListInputOutput(), + "path": "script" + }, + "list_input_tuple_output": { + "model": cm.ListInputTupleOutput(), + "path": "script" + }, "bert_base_uncased": { "model": cm.BertModule(), "path": "trace" @@ -193,5 +217,5 @@ def main(): f.write(record) f.truncate() - -main() +if __name__ == "__main__": + main() diff --git a/tests/modules/requirements.txt b/tests/modules/requirements.txt index d4b5105850..00acec5861 100644 --- a/tests/modules/requirements.txt +++ b/tests/modules/requirements.txt @@ -1,2 +1,3 @@ +torchvision timm==v0.4.12 transformers==4.17.0 diff --git a/tests/py/api/test_collections.py b/tests/py/api/test_collections.py new file mode 100644 index 0000000000..154145e681 --- /dev/null +++ b/tests/py/api/test_collections.py @@ -0,0 +1,143 @@ +import unittest +import torch_tensorrt as torchtrt +import torch +import torchvision.models as models +import os + +def find_repo_root(max_depth=10): + dir_path = os.path.dirname(os.path.realpath(__file__)) + for i in range(max_depth): + files = os.listdir(dir_path) + if "WORKSPACE" in files: + return dir_path + else: + dir_path = os.path.dirname(dir_path) + + raise RuntimeError("Could not find repo root") + +MODULE_DIR = find_repo_root() + "/tests/modules" + +class TestStandardTensorInput(unittest.TestCase): + + + def test_compile(self): + + self.input = torch.randn((1, 3, 224, 224)).to("cuda") + self.model = torch.jit.load(MODULE_DIR + "/standard_tensor_input_scripted.jit.pt").eval().to("cuda") + + compile_spec = { + "inputs": [torchtrt.Input(self.input.shape), + torchtrt.Input(self.input.shape)], + "device": torchtrt.Device("gpu:0"), + "enabled_precisions": {torch.float} + } + + trt_mod = torchtrt.ts.compile(self.model, **compile_spec) + same = (trt_mod(self.input, self.input) - self.model(self.input, self.input)).abs().max() + self.assertTrue(same < 2e-2) + +class TestTupleInput(unittest.TestCase): + + + def test_compile(self): + + self.input = torch.randn((1, 3, 224, 224)).to("cuda") + self.model = torch.jit.load(MODULE_DIR + "/tuple_input_scripted.jit.pt").eval().to("cuda") + + compile_spec = { + "input_signature": ((torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)),), + "device": torchtrt.Device("gpu:0"), + "enabled_precisions": {torch.float}, + "min_block_size": 1 + } + + trt_mod = torchtrt.ts.compile(self.model, **compile_spec) + same = (trt_mod((self.input, self.input)) - self.model((self.input, self.input))).abs().max() + self.assertTrue(same < 2e-2) + +class TestListInput(unittest.TestCase): + + + def test_compile(self): + + self.input = torch.randn((1, 3, 224, 224)).to("cuda") + self.model = torch.jit.load(MODULE_DIR + "/list_input_scripted.jit.pt").eval().to("cuda") + + + compile_spec = { + "input_signature": ([torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)],), + "device": torchtrt.Device("gpu:0"), + "enabled_precisions": {torch.float}, + "min_block_size": 1 + } + + trt_mod = torchtrt.ts.compile(self.model, **compile_spec) + same = (trt_mod([self.input, self.input]) - self.model([self.input, self.input])).abs().max() + self.assertTrue(same < 2e-2) + +class TestTupleInputOutput(unittest.TestCase): + + def test_compile(self): + + self.input = torch.randn((1, 3, 224, 224)).to("cuda") + self.model = torch.jit.load(MODULE_DIR + "/tuple_input_output_scripted.jit.pt").eval().to("cuda") + + + compile_spec = { + "input_signature": ((torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)),), + "device": torchtrt.Device("gpu:0"), + "enabled_precisions": {torch.float}, + "min_block_size": 1 + } + + trt_mod = torchtrt.ts.compile(self.model, **compile_spec) + trt_out = trt_mod((self.input, self.input)) + pyt_out = self.model((self.input, self.input)) + results = [(t - p).abs().max() < 2e-2 for (t, p) in zip(trt_out, pyt_out)] + self.assertTrue(all(results)) + +class TestListInputOutput(unittest.TestCase): + + def test_compile(self): + + self.input = torch.randn((1, 3, 224, 224)).to("cuda") + self.model = torch.jit.load(MODULE_DIR + "/list_input_output_scripted.jit.pt").eval().to("cuda") + + + compile_spec = { + "input_signature": ([torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)],), + "device": torchtrt.Device("gpu:0"), + "enabled_precisions": {torch.float}, + "min_block_size": 1 + } + + trt_mod = torchtrt.ts.compile(self.model, **compile_spec) + trt_out = trt_mod((self.input, self.input)) + pyt_out = self.model((self.input, self.input)) + results = [(t - p).abs().max() < 2e-2 for (t, p) in zip(trt_out, pyt_out)] + self.assertTrue(all(results)) + + +class TestListInputTupleOutput(unittest.TestCase): + + def test_compile(self): + + self.input = torch.randn((1, 3, 224, 224)).to("cuda") + self.model = torch.jit.load(MODULE_DIR + "/list_input_tuple_output_scripted.jit.pt").eval().to("cuda") + + + compile_spec = { + "input_signature": ([torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)],), + "device": torchtrt.Device("gpu:0"), + "enabled_precisions": {torch.float}, + "min_block_size": 1 + } + + trt_mod = torchtrt.ts.compile(self.model, **compile_spec) + trt_out = trt_mod((self.input, self.input)) + pyt_out = self.model((self.input, self.input)) + results = [(t - p).abs().max() < 2e-2 for (t, p) in zip(trt_out, pyt_out)] + self.assertTrue(all(results)) + +if __name__ == "__main__": + unittest.main() diff --git a/tests/py/model_test_case.py b/tests/py/model_test_case.py index e529f05013..1c772c1faf 100644 --- a/tests/py/model_test_case.py +++ b/tests/py/model_test_case.py @@ -1,7 +1,9 @@ import unittest import torch import torchvision.models as models +import os +REPO_ROOT = os.path.abspath(os.getcwd()) + "/../../" class ModelTestCase(unittest.TestCase): diff --git a/tests/py/requirements.txt b/tests/py/requirements.txt index 0ea1c76a29..e35531e566 100644 --- a/tests/py/requirements.txt +++ b/tests/py/requirements.txt @@ -1,2 +1 @@ -torchvision==0.13.0+cu113 --f https://download.pytorch.org/whl/torch_stable.html +torchvision diff --git a/tests/util/evaluate_graph.cpp b/tests/util/evaluate_graph.cpp index 7e69b454ef..5a9f10f7b0 100644 --- a/tests/util/evaluate_graph.cpp +++ b/tests/util/evaluate_graph.cpp @@ -28,7 +28,7 @@ std::vector EvaluateGraph(const torch::jit::Block* b, std::v "Test graph contains non evaluatable nodes: " << *n); auto eval = core::conversion::EvaluateNode(ctx, n); if (eval) { - if (eval.value().isTuple()) { + if (eval.value().isTuple() && n->outputs().size() > 1) { auto eval_list = eval.value().toTuple(); for (size_t i = 0; i < eval_list->elements().size(); i++) { auto eval_output = eval_list.get()->elements()[i]; diff --git a/tests/util/run_graph_engine.cpp b/tests/util/run_graph_engine.cpp index b0bb920768..1d77550d1d 100644 --- a/tests/util/run_graph_engine.cpp +++ b/tests/util/run_graph_engine.cpp @@ -30,6 +30,7 @@ std::vector toInputsDynamic(std::vector ten, bool d for (auto i : ten) { auto opt = core::util::toVec(i.sizes()); + auto dtype = core::util::ScalarTypeToTRTDataType(i.scalar_type()); if (dynamic_batch) { std::vector min_range(opt); @@ -38,7 +39,7 @@ std::vector toInputsDynamic(std::vector ten, bool d min_range[0] = ceil(opt[0] / 2.0); max_range[0] = 2 * opt[0]; - a.push_back(core::ir::Input(min_range, opt, max_range)); + a.push_back(core::ir::Input(min_range, opt, max_range, dtype)); } else { std::vector min_range(opt); std::vector max_range(opt); @@ -46,7 +47,7 @@ std::vector toInputsDynamic(std::vector ten, bool d min_range[1] = ceil(opt[1] / 2.0); max_range[1] = 2 * opt[1]; - a.push_back(core::ir::Input(min_range, opt, max_range)); + a.push_back(core::ir::Input(min_range, opt, max_range, dtype)); } } diff --git a/tools/linter/utils.py b/tools/linter/utils.py index 1754702f6b..8d4d75cd70 100644 --- a/tools/linter/utils.py +++ b/tools/linter/utils.py @@ -6,7 +6,7 @@ BLACKLISTED_BAZEL_TARGETS = [ "//experiments", "//tools", "//docker", "//third_party", "//bazel-bin", "//bazel-genfiles", "//bazel-out", "//bazel-TRTorch", "//bazel-Torch-TensorRT", "//bazel-torch-tensorrt", "//bazel-workspace", - "//bazel-testlogs", "//py/build", + "//bazel-tensorrt", "bazel-TensorRT", "//bazel-testlogs", "//py/build", "//py/dist", "//py/trtorch.egg-info", "//py/wheelhouse", "//examples", "//docsrc", "//docs" ] @@ -35,4 +35,4 @@ def glob_files(project, file_types): files = [] for t in file_types: files += glob.glob(project + "/**/*" + t, recursive=True) - return files \ No newline at end of file + return files