diff --git a/core/compiler.cpp b/core/compiler.cpp index f5213b6a4d..b5ca1cddc4 100644 --- a/core/compiler.cpp +++ b/core/compiler.cpp @@ -344,9 +344,10 @@ void MapInputsAndDetermineDTypes( ss << "- Disable partial compilation by setting require_full_compilation to True"; auto warn_str = ss.str(); LOG_WARNING(warn_str); - // Overwrite type map with user settings - first_use_type_map[in] = {util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype)}; } + // Overwrite type map with user settings + // We use this map for partitiioning since we need c10::ScalarTypes not nvinfer::DataTypes + first_use_type_map[in] = {util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype)}; } } else { // The user defined the type so no changes are necessary