Skip to content

Commit 202a7c7

Browse files
committed
Add automatic type promotion to element-wise ops
Adds automatic type promotion to match the default torch-script behavior for element-wise ops. Debug messages added for the type mismatch and cast. Messages written to log. ``` DEBUG: [Torch-TensorRT] - Type mismatch for inputs in element-wise operation %3 : Tensor = aten::add(%0, %1, %2): Int32, Float32 DEBUG: [Torch-TensorRT] - Element-wise op type promotion adding cast from Int32 to Float32 for layer %3 : Tensor = aten::add(%0, %1, %2) ``` Fixes # (issue) Please delete options that are not relevant and/or add your own. - Bug fix (non-breaking change which fixes an issue) - New feature (non-breaking change which adds functionality) - Breaking change (fix or feature that would cause existing functionality to not work as expected) - This change requires a documentation update - [ ] My code follows the style guidelines of this project (You can use the linters) - [ ] I have performed a self-review of my own code - [ ] I have commented my code, particularly in hard-to-understand areas and hacks - [ ] I have made corresponding changes to the documentation - [ ] I have added tests to verify my fix or my feature - [ ] New and existing unit tests pass locally with my changes - [ ] I have added the relevant labels to my PR in so that relevant reviewers are notified Signed-off-by: Michael Feliz <[email protected]>
1 parent d0e471f commit 202a7c7

File tree

3 files changed

+145
-27
lines changed

3 files changed

+145
-27
lines changed

core/conversion/converters/converter_util.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,14 @@ nvinfer1::ITensor* addUnpadding(
5959
}
6060
}
6161

62+
nvinfer1::DataType promote_types(nvinfer1::DataType type_a, nvinfer1::DataType type_b){
63+
auto torch_type_a = util::TRTDataTypeToScalarType(type_a);
64+
auto torch_type_b = util::TRTDataTypeToScalarType(type_b);
65+
auto promo_type = at::promote_types(torch_type_a, torch_type_b);
66+
auto trt_promo_type = util::ScalarTypeToTRTDataType(promo_type);
67+
return trt_promo_type;
68+
}
69+
6270
nvinfer1::ILayer* add_elementwise(
6371
ConversionCtx* ctx,
6472
nvinfer1::ElementWiseOperation op,
@@ -71,6 +79,20 @@ nvinfer1::ILayer* add_elementwise(
7179
std::swap(self, other);
7280
swapSelfOther = true;
7381
}
82+
83+
if(self->getType() != other->getType()){
84+
LOG_DEBUG("Type mismatch for inputs in element-wise operation " << name << ": " << self->getType() << ", " << other->getType());
85+
auto promo_type = promote_types(self->getType(), other->getType());
86+
if(self->getType() != promo_type){
87+
LOG_DEBUG("Element-wise op type promotion adding cast from " << self->getType() << " to " << promo_type << " for layer " << name);
88+
self = castITensor(ctx, self, promo_type);
89+
}
90+
if(other->getType() != promo_type){
91+
LOG_DEBUG("Element-wise op type promotion adding cast from " << other->getType() << " to " << promo_type << " for layer " << name);
92+
other = castITensor(ctx, other, promo_type);
93+
}
94+
}
95+
7496
auto selfDim = util::toVec(self->getDimensions());
7597
auto otherDim = util::toVec(other->getDimensions());
7698
if (selfDim.size() != otherDim.size()) {

core/conversion/converters/impl/element_wise.cpp

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,10 @@ auto element_wise_registrations TORCHTRT_UNUSED =
5454
// Should implement self + alpha * other
5555
auto self = args[0].ITensorOrFreeze(ctx);
5656
auto other = args[1].ITensorOrFreeze(ctx);
57-
auto scalar = args[2].unwrapToScalar().to<float>();
57+
auto scalar = args[2].unwrapToScalar();
5858

59-
if (1 != scalar) {
60-
auto alphaTensor = tensor_to_const(ctx, torch::tensor({scalar}));
59+
if (1 != scalar.to<float>()) {
60+
auto alphaTensor = impl::scalar_to_tensor(ctx, scalar);
6161
auto scaleLayer = add_elementwise(
6262
ctx,
6363
nvinfer1::ElementWiseOperation::kPROD,
@@ -83,10 +83,10 @@ auto element_wise_registrations TORCHTRT_UNUSED =
8383
// Should implement self + alpha * other
8484
auto self = args[0].ITensorOrFreeze(ctx);
8585
auto other = args[1].ITensorOrFreeze(ctx);
86-
auto scalar = args[2].unwrapToScalar().to<float>();
86+
auto scalar = args[2].unwrapToScalar();
8787

88-
if (1 != scalar) {
89-
auto alphaTensor = tensor_to_const(ctx, torch::tensor({scalar}));
88+
if (1 != scalar.to<float>()) {
89+
auto alphaTensor = impl::scalar_to_tensor(ctx, scalar);
9090
auto scaleLayer = add_elementwise(
9191
ctx,
9292
nvinfer1::ElementWiseOperation::kPROD,
@@ -257,12 +257,11 @@ auto element_wise_registrations TORCHTRT_UNUSED =
257257
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
258258
// Should implement other - alpha * self
259259
auto self = args[0].ITensorOrFreeze(ctx);
260-
auto otherScalar = args[1].unwrapToScalar().to<float>();
261-
auto other = tensor_to_const(ctx, torch::tensor({otherScalar}));
262-
auto scalar = args[2].unwrapToScalar().to<float>();
260+
auto other = impl::scalar_to_tensor(ctx, args[1].unwrapToScalar());
261+
auto scalar = args[2].unwrapToScalar();
263262

264-
if (1 != scalar) {
265-
auto alphaTensor = tensor_to_const(ctx, torch::tensor({scalar}));
263+
if (1 != scalar.to<float>()) {
264+
auto alphaTensor = impl::scalar_to_tensor(ctx, scalar);
266265
auto scaleLayer = add_elementwise(
267266
ctx,
268267
nvinfer1::ElementWiseOperation::kPROD,
@@ -287,10 +286,10 @@ auto element_wise_registrations TORCHTRT_UNUSED =
287286
// Should implement other - alpha * self
288287
auto self = args[0].ITensorOrFreeze(ctx);
289288
auto other = args[1].ITensorOrFreeze(ctx);
290-
auto scalar = args[2].unwrapToScalar().to<float>();
289+
auto scalar = args[2].unwrapToScalar();
291290

292-
if (1 != scalar) {
293-
auto alphaTensor = tensor_to_const(ctx, torch::tensor({scalar}));
291+
if (1 != scalar.to<float>()) {
292+
auto alphaTensor = impl::scalar_to_tensor(ctx, scalar);
294293
auto scaleLayer = add_elementwise(
295294
ctx,
296295
nvinfer1::ElementWiseOperation::kPROD,

tests/core/conversion/converters/test_element_wise.cpp

Lines changed: 110 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,26 +12,31 @@ void pointwise_test_helper(
1212
std::vector<int64_t> shape1 = {5},
1313
std::vector<int64_t> shape2 = {5},
1414
bool negative_input = false,
15-
bool int_tensors = false) {
15+
at::ScalarType type1 = at::kFloat,
16+
at::ScalarType type2 = at::kFloat) {
1617
auto g = std::make_shared<torch::jit::Graph>();
1718
torch::jit::parseIR(graph_ir, g.get());
1819

1920
// singleInput case is enabled when elementwise operation is performed
2021
// with an input and a constant embedded in graph
2122
std::vector<at::Tensor> torch_inputs;
22-
if (negative_input) {
23-
torch_inputs.push_back(at::randint(-5, 5, shape1, {at::kCUDA}));
24-
} else {
25-
torch_inputs.push_back(at::randint(1, 5, shape1, {at::kCUDA}));
23+
int first_min = negative_input ? -5 : 1;
24+
int first_max = 5;
25+
int second_min = 1;
26+
int second_max = 5;
27+
if(type1 == at::kBool){
28+
first_min = 0;
29+
first_max = 1;
2630
}
27-
if (!singleInput) {
28-
torch_inputs.push_back(at::randint(1, 5, shape2, {at::kCUDA}));
31+
if(type2 == at::kBool){
32+
second_min = 0;
33+
second_max = 1;
2934
}
30-
if(int_tensors){
31-
for(size_t i = 0UL; i < torch_inputs.size(); ++i){
32-
torch_inputs[i] = torch_inputs[i].to(at::kInt);
33-
}
35+
torch_inputs.push_back(at::randint(first_min, first_max, shape1, at::TensorOptions(at::kCUDA).dtype(type1)));
36+
if (!singleInput) {
37+
torch_inputs.push_back(at::randint(second_min, second_max, shape2, at::TensorOptions(at::kCUDA).dtype(type2)));
3438
}
39+
3540
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
3641
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, torch_inputs);
3742

@@ -62,6 +67,13 @@ TEST(Converters, ATenAddConvertsCorrectly) {
6267
pointwise_test_helper(graph, false, false, {4}, {3, 4});
6368
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3});
6469
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3});
70+
pointwise_test_helper(graph, false, false, {4, 3}, {3, 4, 3}, false, at::kInt, at::kInt);
71+
pointwise_test_helper(graph, false, false, {4, 3}, {3, 4, 3}, false, at::kFloat, at::kInt);
72+
pointwise_test_helper(graph, false, false, {4, 3}, {3, 4, 3}, false, at::kInt, at::kFloat);
73+
pointwise_test_helper(graph, false, false, {4, 3}, {3, 4, 3}, false, at::kBool, at::kInt);
74+
pointwise_test_helper(graph, false, false, {4, 3}, {3, 4, 3}, false, at::kBool, at::kFloat);
75+
pointwise_test_helper(graph, false, false, {4, 3}, {3, 4, 3}, false, at::kInt, at::kBool);
76+
pointwise_test_helper(graph, false, false, {4, 3}, {3, 4, 3}, false, at::kFloat, at::kBool);
6577
}
6678

6779
TEST(Converters, ATenAddWithAlphaConvertsCorrectly) {
@@ -86,6 +98,17 @@ TEST(Converters, ATenAddImplicitWithAlphaConvertsCorrectly) {
8698
pointwise_test_helper(graph, false);
8799
pointwise_test_helper(graph, false, false, {3, 4}, {4});
88100
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3});
101+
pointwise_test_helper(graph, false, false, {3, 4, 3}, {4, 3}, false, at::kFloat, at::kInt);
102+
}
103+
104+
TEST(Converters, ATenAddImplicitWithIntAlphaConvertsCorrectly) {
105+
const auto graph = R"IR(
106+
graph(%0 : Tensor, %1 : Tensor):
107+
%2 : int = prim::Constant[value=42]()
108+
%3 : Tensor = aten::add_(%0, %1, %2)
109+
return (%3))IR";
110+
pointwise_test_helper(graph, false, false, {2, 2}, {2, 2}, false, at::kInt, at::kInt);
111+
pointwise_test_helper(graph, false, false, {3, 4, 3}, {4, 3}, false, at::kInt, at::kInt);
89112
}
90113

91114
TEST(Converters, ATenAddWithScalarConvertsCorrectly) {
@@ -138,7 +161,7 @@ TEST(Converters, ATenMulWithIntScalarConvertsCorrectly) {
138161
%scalar : int = prim::Constant[value=2]()
139162
%1 : Tensor = aten::mul(%0, %scalar)
140163
return (%1))IR";
141-
pointwise_test_helper(graph, true, false, {5}, {5}, false, true);
164+
pointwise_test_helper(graph, true, false, {5}, {5}, false, at::kInt);
142165
}
143166

144167
TEST(Converters, ATenDivConvertsCorrectly) {
@@ -151,6 +174,8 @@ TEST(Converters, ATenDivConvertsCorrectly) {
151174
pointwise_test_helper(graph, false, false, {4}, {3, 4});
152175
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3});
153176
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3});
177+
pointwise_test_helper(graph, false, false, {4, 3}, {3, 4, 3}, false, at::kFloat, at::kInt);
178+
pointwise_test_helper(graph, false, false, {4, 3}, {3, 4, 3}, false, at::kInt, at::kFloat);
154179
}
155180

156181
TEST(Converters, ATenDivWithScalarConvertsCorrectly) {
@@ -295,6 +320,8 @@ TEST(Converters, ATenRsubWithTensorConvertsCorrectly) {
295320
pointwise_test_helper(graph, false, false, {3, 4}, {4});
296321
pointwise_test_helper(graph, false, false, {4}, {3, 4});
297322
pointwise_test_helper(graph, false, true, {4, 3, 3, 3}, {4, 3, 3, 3});
323+
pointwise_test_helper(graph, false, false, {4, 3, 3, 3}, {4, 3, 3, 3}, false, at::kInt, at::kFloat);
324+
pointwise_test_helper(graph, false, false, {4, 3, 3, 3}, {4, 3, 3, 3}, false, at::kInt, at::kInt);
298325
}
299326

300327
TEST(Converters, ATenRsubWithScalarConvertsCorrectly) {
@@ -307,6 +334,46 @@ TEST(Converters, ATenRsubWithScalarConvertsCorrectly) {
307334
pointwise_test_helper(graph, true, false, {4, 3, 3, 3});
308335
}
309336

337+
TEST(Converters, ATenRsubWithIntScalarConvertsCorrectly) {
338+
const auto graph = R"IR(
339+
graph(%0 : Tensor):
340+
%2 : int = prim::Constant[value=2]()
341+
%scalar : int = prim::Constant[value=8]()
342+
%3 : Tensor = aten::rsub(%0, %scalar, %2)
343+
return (%3))IR";
344+
pointwise_test_helper(graph, true, false, {4, 3, 3, 3}, {}, false, at::kInt);
345+
}
346+
347+
TEST(Converters, ATenClipMinConvertsCorrectly) {
348+
const auto graph = R"IR(
349+
graph(%x.1 : Tensor):
350+
%2 : float = prim::Constant[value=1.5]()
351+
%3 : None = prim::Constant()
352+
%4 : Tensor = aten::clip(%x.1, %2, %3)
353+
return (%4))IR";
354+
pointwise_test_helper(graph, true);
355+
}
356+
357+
TEST(Converters, ATenClipMaxConvertsCorrectly) {
358+
const auto graph = R"IR(
359+
graph(%x.1 : Tensor):
360+
%2 : float = prim::Constant[value=3.5]()
361+
%3 : None = prim::Constant()
362+
%4 : Tensor = aten::clip(%x.1, %3, %2)
363+
return (%4))IR";
364+
pointwise_test_helper(graph, true);
365+
}
366+
367+
TEST(Converters, ATenClipMinMaxConvertsCorrectly) {
368+
const auto graph = R"IR(
369+
graph(%x.1 : Tensor):
370+
%2 : float = prim::Constant[value=3.5]()
371+
%3 : float = prim::Constant[value=1.5]()
372+
%4 : Tensor = aten::clip(%x.1, %3, %2)
373+
return (%4))IR";
374+
pointwise_test_helper(graph, true);
375+
}
376+
310377
TEST(Converters, ATenClampMinConvertsCorrectly) {
311378
const auto graph = R"IR(
312379
graph(%x.1 : Tensor):
@@ -337,6 +404,36 @@ TEST(Converters, ATenClampMinMaxConvertsCorrectly) {
337404
pointwise_test_helper(graph, true);
338405
}
339406

407+
TEST(Converters, ATenClampIntMinConvertsCorrectly) {
408+
const auto graph = R"IR(
409+
graph(%x.1 : Tensor):
410+
%2 : int = prim::Constant[value=1]()
411+
%3 : None = prim::Constant()
412+
%4 : Tensor = aten::clamp(%x.1, %2, %3)
413+
return (%4))IR";
414+
pointwise_test_helper(graph, true, false, {5}, {5}, false, at::kInt);
415+
}
416+
417+
TEST(Converters, ATenClampIntMaxConvertsCorrectly) {
418+
const auto graph = R"IR(
419+
graph(%x.1 : Tensor):
420+
%2 : int = prim::Constant[value=3]()
421+
%3 : None = prim::Constant()
422+
%4 : Tensor = aten::clamp(%x.1, %3, %2)
423+
return (%4))IR";
424+
pointwise_test_helper(graph, true, false, {5}, {5}, false, at::kInt);
425+
}
426+
427+
TEST(Converters, ATenClampIntMinMaxConvertsCorrectly) {
428+
const auto graph = R"IR(
429+
graph(%x.1 : Tensor):
430+
%2 : int = prim::Constant[value=3]()
431+
%3 : int = prim::Constant[value=1]()
432+
%4 : Tensor = aten::clamp(%x.1, %3, %2)
433+
return (%4))IR";
434+
pointwise_test_helper(graph, true, false, {5}, {5}, false, at::kInt);
435+
}
436+
340437
TEST(Converters, ATenClampMinimumConvertsCorrectly) {
341438
const auto graph = R"IR(
342439
graph(%x.1 : Tensor):
@@ -487,4 +584,4 @@ TEST(Converters, ATenRemainderWithScalarConvertsCorrectly) {
487584
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
488585

489586
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
490-
}
587+
}

0 commit comments

Comments
 (0)