Skip to content

Commit 10e036b

Browse files
committed
Add automatic type promotion to element-wise ops
Adds automatic type promotion to match the default torch-script behavior for element-wise ops. Debug messages added for the type mismatch and cast. Messages written to log. ``` DEBUG: [Torch-TensorRT] - Type mismatch for inputs in element-wise operation %3 : Tensor = aten::add(%0, %1, %2): Int32, Float32 DEBUG: [Torch-TensorRT] - Element-wise op type promotion adding cast from Int32 to Float32 for layer %3 : Tensor = aten::add(%0, %1, %2) ``` Fixes # (issue) Please delete options that are not relevant and/or add your own. - Bug fix (non-breaking change which fixes an issue) - New feature (non-breaking change which adds functionality) - Breaking change (fix or feature that would cause existing functionality to not work as expected) - This change requires a documentation update - [ ] My code follows the style guidelines of this project (You can use the linters) - [ ] I have performed a self-review of my own code - [ ] I have commented my code, particularly in hard-to-understand areas and hacks - [ ] I have made corresponding changes to the documentation - [ ] I have added tests to verify my fix or my feature - [ ] New and existing unit tests pass locally with my changes - [ ] I have added the relevant labels to my PR in so that relevant reviewers are notified Signed-off-by: Michael Feliz <[email protected]>
1 parent 95a6732 commit 10e036b

File tree

3 files changed

+854
-814
lines changed

3 files changed

+854
-814
lines changed

core/conversion/converters/converter_util.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,14 @@ nvinfer1::ITensor* addUnpadding(
5959
}
6060
}
6161

62+
nvinfer1::DataType promote_types(nvinfer1::DataType type_a, nvinfer1::DataType type_b){
63+
auto torch_type_a = util::TRTDataTypeToScalarType(type_a);
64+
auto torch_type_b = util::TRTDataTypeToScalarType(type_b);
65+
auto promo_type = at::promote_types(torch_type_a, torch_type_b);
66+
auto trt_promo_type = util::ScalarTypeToTRTDataType(promo_type);
67+
return trt_promo_type;
68+
}
69+
6270
nvinfer1::ILayer* add_elementwise(
6371
ConversionCtx* ctx,
6472
nvinfer1::ElementWiseOperation op,
@@ -78,6 +86,20 @@ nvinfer1::ILayer* add_elementwise(
7886
std::swap(self, other);
7987
swapSelfOther = true;
8088
}
89+
90+
if(self->getType() != other->getType()){
91+
LOG_DEBUG("Type mismatch for inputs in element-wise operation " << name << ": " << self->getType() << ", " << other->getType());
92+
auto promo_type = promote_types(self->getType(), other->getType());
93+
if(self->getType() != promo_type){
94+
LOG_DEBUG("Element-wise op type promotion adding cast from " << self->getType() << " to " << promo_type << " for layer " << name);
95+
self = castITensor(ctx, self, promo_type);
96+
}
97+
if(other->getType() != promo_type){
98+
LOG_DEBUG("Element-wise op type promotion adding cast from " << other->getType() << " to " << promo_type << " for layer " << name);
99+
other = castITensor(ctx, other, promo_type);
100+
}
101+
}
102+
81103
auto selfDim = util::toVec(self->getDimensions());
82104
auto otherDim = util::toVec(other->getDimensions());
83105
if (selfDim.size() != otherDim.size()) {

0 commit comments

Comments
 (0)