Skip to content

Commit 679ea21

Browse files
authored
Merge pull request #1240 from mfeliz-cruise/michael.feliz/element_wise_casting
[feat] Add automatic type promotion to element-wise ops
2 parents c39cd42 + 5a49077 commit 679ea21

File tree

3 files changed

+94
-61
lines changed

3 files changed

+94
-61
lines changed

core/conversion/converters/converter_util.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,14 @@ nvinfer1::ITensor* addUnpadding(
5959
}
6060
}
6161

62+
nvinfer1::DataType promote_types(nvinfer1::DataType type_a, nvinfer1::DataType type_b) {
63+
auto torch_type_a = util::TRTDataTypeToScalarType(type_a);
64+
auto torch_type_b = util::TRTDataTypeToScalarType(type_b);
65+
auto promo_type = at::promote_types(torch_type_a, torch_type_b);
66+
auto trt_promo_type = util::ScalarTypeToTRTDataType(promo_type);
67+
return trt_promo_type;
68+
}
69+
6270
nvinfer1::ILayer* add_elementwise(
6371
ConversionCtx* ctx,
6472
nvinfer1::ElementWiseOperation op,
@@ -78,6 +86,26 @@ nvinfer1::ILayer* add_elementwise(
7886
std::swap(self, other);
7987
swapSelfOther = true;
8088
}
89+
90+
if (self->getType() != other->getType()) {
91+
LOG_DEBUG(
92+
"Type mismatch for inputs in element-wise operation " << name << ": " << self->getType() << ", "
93+
<< other->getType());
94+
auto promo_type = promote_types(self->getType(), other->getType());
95+
if (self->getType() != promo_type) {
96+
LOG_DEBUG(
97+
"Element-wise op type promotion adding cast from " << self->getType() << " to " << promo_type << " for layer "
98+
<< name);
99+
self = castITensor(ctx, self, promo_type);
100+
}
101+
if (other->getType() != promo_type) {
102+
LOG_DEBUG(
103+
"Element-wise op type promotion adding cast from " << other->getType() << " to " << promo_type
104+
<< " for layer " << name);
105+
other = castITensor(ctx, other, promo_type);
106+
}
107+
}
108+
81109
auto selfDim = util::toVec(self->getDimensions());
82110
auto otherDim = util::toVec(other->getDimensions());
83111
if (selfDim.size() != otherDim.size()) {

core/conversion/converters/impl/element_wise.cpp

100644100755
Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,10 @@ auto element_wise_registrations TORCHTRT_UNUSED =
5555
// Should implement self + alpha * other
5656
auto self = args[0].ITensorOrFreeze(ctx);
5757
auto other = args[1].ITensorOrFreeze(ctx);
58-
auto scalar = args[2].unwrapToScalar().to<float>();
58+
auto scalar = args[2].unwrapToScalar();
5959

60-
if (1 != scalar) {
61-
auto alphaTensor = tensor_to_const(ctx, torch::tensor({scalar}));
60+
if (1 != scalar.to<float>()) {
61+
auto alphaTensor = scalar_to_tensor(ctx, scalar);
6262
auto scaleLayer = add_elementwise(
6363
ctx,
6464
nvinfer1::ElementWiseOperation::kPROD,
@@ -84,10 +84,10 @@ auto element_wise_registrations TORCHTRT_UNUSED =
8484
// Should implement self + alpha * other
8585
auto self = args[0].ITensorOrFreeze(ctx);
8686
auto other = args[1].ITensorOrFreeze(ctx);
87-
auto scalar = args[2].unwrapToScalar().to<float>();
87+
auto scalar = args[2].unwrapToScalar();
8888

89-
if (1 != scalar) {
90-
auto alphaTensor = tensor_to_const(ctx, torch::tensor({scalar}));
89+
if (1 != scalar.to<float>()) {
90+
auto alphaTensor = scalar_to_tensor(ctx, scalar);
9191
auto scaleLayer = add_elementwise(
9292
ctx,
9393
nvinfer1::ElementWiseOperation::kPROD,
@@ -262,12 +262,11 @@ auto element_wise_registrations TORCHTRT_UNUSED =
262262
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
263263
// Should implement other - alpha * self
264264
auto self = args[0].ITensorOrFreeze(ctx);
265-
auto otherScalar = args[1].unwrapToScalar().to<float>();
266-
auto other = tensor_to_const(ctx, torch::tensor({otherScalar}));
267-
auto scalar = args[2].unwrapToScalar().to<float>();
265+
auto other = scalar_to_tensor(ctx, args[1].unwrapToScalar());
266+
auto scalar = args[2].unwrapToScalar();
268267

269-
if (1 != scalar) {
270-
auto alphaTensor = tensor_to_const(ctx, torch::tensor({scalar}));
268+
if (1 != scalar.to<float>()) {
269+
auto alphaTensor = scalar_to_tensor(ctx, scalar);
271270
auto scaleLayer = add_elementwise(
272271
ctx,
273272
nvinfer1::ElementWiseOperation::kPROD,
@@ -292,10 +291,10 @@ auto element_wise_registrations TORCHTRT_UNUSED =
292291
// Should implement other - alpha * self
293292
auto self = args[0].ITensorOrFreeze(ctx);
294293
auto other = args[1].ITensorOrFreeze(ctx);
295-
auto scalar = args[2].unwrapToScalar().to<float>();
294+
auto scalar = args[2].unwrapToScalar();
296295

297-
if (1 != scalar) {
298-
auto alphaTensor = tensor_to_const(ctx, torch::tensor({scalar}));
296+
if (1 != scalar.to<float>()) {
297+
auto alphaTensor = scalar_to_tensor(ctx, scalar);
299298
auto scaleLayer = add_elementwise(
300299
ctx,
301300
nvinfer1::ElementWiseOperation::kPROD,
@@ -418,7 +417,6 @@ auto element_wise_registrations TORCHTRT_UNUSED =
418417
// Should implement self * other
419418
auto self = args[0].ITensorOrFreeze(ctx);
420419
auto other = args[1].ITensorOrFreeze(ctx);
421-
422420
auto mul = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, other, util::node_info(n));
423421
TORCHTRT_CHECK(mul, "Unable to create mul layer from node: " << *n);
424422

@@ -433,7 +431,6 @@ auto element_wise_registrations TORCHTRT_UNUSED =
433431
// TODO: Remove with functionalization
434432
auto self = args[0].ITensorOrFreeze(ctx);
435433
auto other = scalar_to_tensor(ctx, args[1].unwrapToScalar());
436-
437434
auto mul = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, other, util::node_info(n));
438435
TORCHTRT_CHECK(mul, "Unable to create mul layer from node: " << *n);
439436

tests/core/conversion/converters/test_element_wise.cpp

Lines changed: 53 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -12,40 +12,29 @@ void pointwise_test_helper(
1212
std::vector<int64_t> shape1 = {5},
1313
std::vector<int64_t> shape2 = {5},
1414
bool negative_input = false,
15-
bool int_tensors = false,
16-
bool float_int_tensors = false,
17-
bool int_float_tensors = false) {
15+
at::ScalarType type1 = at::kFloat,
16+
at::ScalarType type2 = at::kFloat) {
1817
auto g = std::make_shared<torch::jit::Graph>();
1918
torch::jit::parseIR(graph_ir, g.get());
2019

2120
// singleInput case is enabled when elementwise operation is performed
2221
// with an input and a constant embedded in graph
2322
std::vector<at::Tensor> torch_inputs;
24-
if (negative_input) {
25-
torch_inputs.push_back(at::randint(-5, 5, shape1, {at::kCUDA}));
26-
} else {
27-
torch_inputs.push_back(at::randint(1, 5, shape1, {at::kCUDA}));
23+
int first_min = negative_input ? -5 : 1;
24+
int first_max = 5;
25+
int second_min = 1;
26+
int second_max = 5;
27+
if (type1 == at::kBool) {
28+
first_min = 0;
29+
first_max = 1;
2830
}
29-
if (!singleInput) {
30-
torch_inputs.push_back(at::randint(1, 5, shape2, {at::kCUDA}));
31+
if (type2 == at::kBool) {
32+
second_min = 0;
33+
second_max = 1;
3134
}
32-
33-
TORCHTRT_CHECK(
34-
!((int_tensors && (float_int_tensors || int_float_tensors)) || (float_int_tensors && int_float_tensors)),
35-
"Invalid test configuration, only one of int_tensors, float_int_tensors, int_float_tensors can be true");
36-
37-
if (int_tensors) {
38-
for (size_t i = 0UL; i < torch_inputs.size(); ++i) {
39-
torch_inputs[i] = torch_inputs[i].to(at::kInt);
40-
}
41-
} else if (float_int_tensors) {
42-
TORCHTRT_CHECK(!singleInput, "float_int_tensors tests require two inputs");
43-
torch_inputs[0] = torch_inputs[0].to(at::kFloat);
44-
torch_inputs[1] = torch_inputs[1].to(at::kInt);
45-
} else if (int_float_tensors) {
46-
TORCHTRT_CHECK(!singleInput, "int_float_tensors tests require two inputs");
47-
torch_inputs[0] = torch_inputs[0].to(at::kInt);
48-
torch_inputs[1] = torch_inputs[1].to(at::kFloat);
35+
torch_inputs.push_back(at::randint(first_min, first_max, shape1, at::TensorOptions(at::kCUDA).dtype(type1)));
36+
if (!singleInput) {
37+
torch_inputs.push_back(at::randint(second_min, second_max, shape2, at::TensorOptions(at::kCUDA).dtype(type2)));
4938
}
5039

5140
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
@@ -78,8 +67,6 @@ TEST(Converters, ATenAddConvertsCorrectly) {
7867
pointwise_test_helper(graph, false, false, {4}, {3, 4});
7968
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3});
8069
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3});
81-
pointwise_test_helper(graph, false, true, {5}, {5}, false, false, true);
82-
pointwise_test_helper(graph, false, true, {5}, {5}, false, false, false, true);
8370
}
8471

8572
TEST(Converters, ATenAddWithAlphaConvertsCorrectly) {
@@ -93,8 +80,8 @@ TEST(Converters, ATenAddWithAlphaConvertsCorrectly) {
9380
pointwise_test_helper(graph, false, false, {4}, {3, 4});
9481
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3});
9582
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3});
96-
pointwise_test_helper(graph, false, true, {5}, {5}, false, false, true);
97-
pointwise_test_helper(graph, false, true, {5}, {5}, false, false, false, true);
83+
pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kFloat, at::kInt);
84+
pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kInt, at::kFloat);
9885
}
9986

10087
TEST(Converters, ATenAddInplaceWithAlphaConvertsCorrectly) {
@@ -106,6 +93,17 @@ TEST(Converters, ATenAddInplaceWithAlphaConvertsCorrectly) {
10693
pointwise_test_helper(graph, false);
10794
pointwise_test_helper(graph, false, false, {3, 4}, {4});
10895
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3});
96+
pointwise_test_helper(graph, false, false, {3, 4, 3}, {4, 3}, false, at::kFloat, at::kInt);
97+
}
98+
99+
TEST(Converters, ATenAddImplicitWithIntAlphaConvertsCorrectly) {
100+
const auto graph = R"IR(
101+
graph(%0 : Tensor, %1 : Tensor):
102+
%2 : int = prim::Constant[value=42]()
103+
%3 : Tensor = aten::add_(%0, %1, %2)
104+
return (%3))IR";
105+
pointwise_test_helper(graph, false, false, {2, 2}, {2, 2}, false, at::kInt, at::kInt);
106+
pointwise_test_helper(graph, false, false, {3, 4, 3}, {4, 3}, false, at::kInt, at::kInt);
109107
}
110108

111109
TEST(Converters, ATenAddWithScalarConvertsCorrectly) {
@@ -129,8 +127,8 @@ TEST(Converters, ATenSubConvertsCorrectly) {
129127
pointwise_test_helper(graph, false, false, {4}, {3, 4});
130128
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3});
131129
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3});
132-
pointwise_test_helper(graph, false, true, {5}, {5}, false, false, true);
133-
pointwise_test_helper(graph, false, true, {5}, {5}, false, false, false, true);
130+
pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kFloat, at::kInt);
131+
pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kInt, at::kFloat);
134132
}
135133

136134
TEST(Converters, ATenMulConvertsCorrectly) {
@@ -143,8 +141,8 @@ TEST(Converters, ATenMulConvertsCorrectly) {
143141
pointwise_test_helper(graph, false, false, {4}, {3, 4});
144142
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3});
145143
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3});
146-
pointwise_test_helper(graph, false, true, {5}, {5}, false, false, true);
147-
pointwise_test_helper(graph, false, true, {5}, {5}, false, false, false, true);
144+
pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kFloat, at::kInt);
145+
pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kInt, at::kFloat);
148146
}
149147

150148
TEST(Converters, ATenMulWithScalarConvertsCorrectly) {
@@ -162,7 +160,7 @@ TEST(Converters, ATenMulWithIntScalarConvertsCorrectly) {
162160
%scalar : int = prim::Constant[value=2]()
163161
%1 : Tensor = aten::mul(%0, %scalar)
164162
return (%1))IR";
165-
pointwise_test_helper(graph, true, false, {5}, {5}, false, true);
163+
pointwise_test_helper(graph, true, false, {5}, {5}, false, at::kInt);
166164
}
167165

168166
TEST(Converters, ATenDivConvertsCorrectly) {
@@ -175,8 +173,6 @@ TEST(Converters, ATenDivConvertsCorrectly) {
175173
pointwise_test_helper(graph, false, false, {4}, {3, 4});
176174
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3});
177175
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3});
178-
pointwise_test_helper(graph, false, true, {5}, {5}, false, false, true);
179-
pointwise_test_helper(graph, false, true, {5}, {5}, false, false, false, true);
180176
}
181177

182178
TEST(Converters, ATenDivWithScalarConvertsCorrectly) {
@@ -199,8 +195,8 @@ TEST(Converters, ATenDivRoundingFloorConvertsCorrectly) {
199195
pointwise_test_helper(graph, false, false, {4}, {3, 4}, true);
200196
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3}, true);
201197
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3}, true);
202-
pointwise_test_helper(graph, false, true, {5}, {5}, false, false, true);
203-
pointwise_test_helper(graph, false, true, {5}, {5}, false, false, false, true);
198+
pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kFloat, at::kInt);
199+
pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kInt, at::kFloat);
204200
}
205201

206202
TEST(Converters, ATenDivRoundingTruncConvertsCorrectly) {
@@ -214,8 +210,8 @@ TEST(Converters, ATenDivRoundingTruncConvertsCorrectly) {
214210
pointwise_test_helper(graph, false, false, {4}, {3, 4}, true);
215211
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3}, true);
216212
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3}, true);
217-
pointwise_test_helper(graph, false, true, {5}, {5}, false, false, true);
218-
pointwise_test_helper(graph, false, true, {5}, {5}, false, false, false, true);
213+
pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kFloat, at::kInt);
214+
pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kInt, at::kFloat);
219215
}
220216

221217
TEST(Converters, ATenDivRoundingNoneConvertsCorrectly) {
@@ -241,8 +237,8 @@ TEST(Converters, ATenPowTensorConvertsCorrectly) {
241237
pointwise_test_helper(graph, false, false, {4}, {3, 4});
242238
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3});
243239
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3});
244-
pointwise_test_helper(graph, false, true, {5}, {5}, false, false, true);
245-
pointwise_test_helper(graph, false, true, {5}, {5}, false, false, false, true);
240+
pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kFloat, at::kInt);
241+
pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kInt, at::kFloat);
246242
}
247243

248244
TEST(Converters, ATenPowScalarConvertsCorrectly) {
@@ -283,8 +279,8 @@ TEST(Converters, ATenFloorDivideConvertsCorrectly) {
283279
pointwise_test_helper(graph, false, false, {4}, {3, 4});
284280
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3});
285281
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3});
286-
pointwise_test_helper(graph, false, true, {5}, {5}, false, false, true);
287-
pointwise_test_helper(graph, false, true, {5}, {5}, false, false, false, true);
282+
pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kFloat, at::kInt);
283+
pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kInt, at::kFloat);
288284
}
289285

290286
TEST(Converters, ATenFloorDivideWithScalarConvertsCorrectly) {
@@ -329,6 +325,8 @@ TEST(Converters, ATenRsubWithTensorConvertsCorrectly) {
329325
pointwise_test_helper(graph, false, false, {3, 4}, {4});
330326
pointwise_test_helper(graph, false, false, {4}, {3, 4});
331327
pointwise_test_helper(graph, false, true, {4, 3, 3, 3}, {4, 3, 3, 3});
328+
pointwise_test_helper(graph, false, false, {4, 3, 3, 3}, {4, 3, 3, 3}, false, at::kInt, at::kFloat);
329+
pointwise_test_helper(graph, false, false, {4, 3, 3, 3}, {4, 3, 3, 3}, false, at::kInt, at::kInt);
332330
}
333331

334332
TEST(Converters, ATenRsubWithScalarConvertsCorrectly) {
@@ -341,6 +339,16 @@ TEST(Converters, ATenRsubWithScalarConvertsCorrectly) {
341339
pointwise_test_helper(graph, true, false, {4, 3, 3, 3});
342340
}
343341

342+
TEST(Converters, ATenRsubWithIntScalarConvertsCorrectly) {
343+
const auto graph = R"IR(
344+
graph(%0 : Tensor):
345+
%2 : int = prim::Constant[value=2]()
346+
%scalar : int = prim::Constant[value=8]()
347+
%3 : Tensor = aten::rsub(%0, %scalar, %2)
348+
return (%3))IR";
349+
pointwise_test_helper(graph, true, false, {4, 3, 3, 3}, {}, false, at::kInt);
350+
}
351+
344352
TEST(Converters, ATenClampMinConvertsCorrectly) {
345353
const auto graph = R"IR(
346354
graph(%x.1 : Tensor):

0 commit comments

Comments
 (0)