From 01c89d155fe9ab0bb9cca90700d3494d31936b2d Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Mon, 28 Feb 2022 17:58:16 -0800 Subject: [PATCH 1/2] fix(//core): Take user setting in the case we can't determine the inferred type. fixes: #814 Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- core/compiler.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/core/compiler.cpp b/core/compiler.cpp index f5213b6a4d..1e11cc262a 100644 --- a/core/compiler.cpp +++ b/core/compiler.cpp @@ -329,6 +329,7 @@ void MapInputsAndDetermineDTypes( } else if (spec.dtype_is_user_defined && cfg.partition_info.enabled) { if (!est_type_opt) { LOG_INFO("Cannot infer input tensor dtype in graph, unable to verify user input dtype settings"); + first_use_type_map[in] = {util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype)}; } else { if (util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype) != est_type_opt.value()) { std::stringstream ss; From 5d8fb5a5775f2bb4219239500e21c7195d03a5ac Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Mon, 28 Feb 2022 18:29:48 -0800 Subject: [PATCH 2/2] refactor(//core): Unify the dtype work for the partition section Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- core/compiler.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/core/compiler.cpp b/core/compiler.cpp index 1e11cc262a..b5ca1cddc4 100644 --- a/core/compiler.cpp +++ b/core/compiler.cpp @@ -329,7 +329,6 @@ void MapInputsAndDetermineDTypes( } else if (spec.dtype_is_user_defined && cfg.partition_info.enabled) { if (!est_type_opt) { LOG_INFO("Cannot infer input tensor dtype in graph, unable to verify user input dtype settings"); - first_use_type_map[in] = {util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype)}; } else { if (util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype) != est_type_opt.value()) { std::stringstream ss; @@ -345,9 +344,10 @@ void MapInputsAndDetermineDTypes( ss << "- Disable partial compilation by setting require_full_compilation to True"; auto warn_str = ss.str(); LOG_WARNING(warn_str); - // Overwrite type map with user settings - first_use_type_map[in] = {util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype)}; } + // Overwrite type map with user settings + // We use this map for partitiioning since we need c10::ScalarTypes not nvinfer::DataTypes + first_use_type_map[in] = {util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype)}; } } else { // The user defined the type so no changes are necessary