@@ -12,40 +12,29 @@ void pointwise_test_helper(
12
12
std::vector<int64_t > shape1 = {5 },
13
13
std::vector<int64_t > shape2 = {5 },
14
14
bool negative_input = false ,
15
- bool int_tensors = false ,
16
- bool float_int_tensors = false ,
17
- bool int_float_tensors = false ) {
15
+ at::ScalarType type1 = at::kFloat ,
16
+ at::ScalarType type2 = at::kFloat ) {
18
17
auto g = std::make_shared<torch::jit::Graph>();
19
18
torch::jit::parseIR (graph_ir, g.get ());
20
19
21
20
// singleInput case is enabled when elementwise operation is performed
22
21
// with an input and a constant embedded in graph
23
22
std::vector<at::Tensor> torch_inputs;
24
- if (negative_input) {
25
- torch_inputs.push_back (at::randint (-5 , 5 , shape1, {at::kCUDA }));
26
- } else {
27
- torch_inputs.push_back (at::randint (1 , 5 , shape1, {at::kCUDA }));
23
+ int first_min = negative_input ? -5 : 1 ;
24
+ int first_max = 5 ;
25
+ int second_min = 1 ;
26
+ int second_max = 5 ;
27
+ if (type1 == at::kBool ) {
28
+ first_min = 0 ;
29
+ first_max = 1 ;
28
30
}
29
- if (!singleInput) {
30
- torch_inputs.push_back (at::randint (1 , 5 , shape2, {at::kCUDA }));
31
+ if (type2 == at::kBool ) {
32
+ second_min = 0 ;
33
+ second_max = 1 ;
31
34
}
32
-
33
- TORCHTRT_CHECK (
34
- !((int_tensors && (float_int_tensors || int_float_tensors)) || (float_int_tensors && int_float_tensors)),
35
- " Invalid test configuration, only one of int_tensors, float_int_tensors, int_float_tensors can be true" );
36
-
37
- if (int_tensors) {
38
- for (size_t i = 0UL ; i < torch_inputs.size (); ++i) {
39
- torch_inputs[i] = torch_inputs[i].to (at::kInt );
40
- }
41
- } else if (float_int_tensors) {
42
- TORCHTRT_CHECK (!singleInput, " float_int_tensors tests require two inputs" );
43
- torch_inputs[0 ] = torch_inputs[0 ].to (at::kFloat );
44
- torch_inputs[1 ] = torch_inputs[1 ].to (at::kInt );
45
- } else if (int_float_tensors) {
46
- TORCHTRT_CHECK (!singleInput, " int_float_tensors tests require two inputs" );
47
- torch_inputs[0 ] = torch_inputs[0 ].to (at::kInt );
48
- torch_inputs[1 ] = torch_inputs[1 ].to (at::kFloat );
35
+ torch_inputs.push_back (at::randint (first_min, first_max, shape1, at::TensorOptions (at::kCUDA ).dtype (type1)));
36
+ if (!singleInput) {
37
+ torch_inputs.push_back (at::randint (second_min, second_max, shape2, at::TensorOptions (at::kCUDA ).dtype (type2)));
49
38
}
50
39
51
40
auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
@@ -78,8 +67,6 @@ TEST(Converters, ATenAddConvertsCorrectly) {
78
67
pointwise_test_helper (graph, false , false , {4 }, {3 , 4 });
79
68
pointwise_test_helper (graph, false , true , {3 , 4 , 3 }, {4 , 3 });
80
69
pointwise_test_helper (graph, false , true , {4 , 3 }, {3 , 4 , 3 });
81
- pointwise_test_helper (graph, false , true , {5 }, {5 }, false , false , true );
82
- pointwise_test_helper (graph, false , true , {5 }, {5 }, false , false , false , true );
83
70
}
84
71
85
72
TEST (Converters, ATenAddWithAlphaConvertsCorrectly) {
@@ -93,8 +80,8 @@ TEST(Converters, ATenAddWithAlphaConvertsCorrectly) {
93
80
pointwise_test_helper (graph, false , false , {4 }, {3 , 4 });
94
81
pointwise_test_helper (graph, false , true , {3 , 4 , 3 }, {4 , 3 });
95
82
pointwise_test_helper (graph, false , true , {4 , 3 }, {3 , 4 , 3 });
96
- pointwise_test_helper (graph, false , true , {5 }, {5 }, false , false , true );
97
- pointwise_test_helper (graph, false , true , {5 }, {5 }, false , false , false , true );
83
+ pointwise_test_helper (graph, false , true , {5 }, {5 }, false , at:: kFloat , at:: kInt );
84
+ pointwise_test_helper (graph, false , true , {5 }, {5 }, false , at:: kInt , at:: kFloat );
98
85
}
99
86
100
87
TEST (Converters, ATenAddInplaceWithAlphaConvertsCorrectly) {
@@ -106,6 +93,17 @@ TEST(Converters, ATenAddInplaceWithAlphaConvertsCorrectly) {
106
93
pointwise_test_helper (graph, false );
107
94
pointwise_test_helper (graph, false , false , {3 , 4 }, {4 });
108
95
pointwise_test_helper (graph, false , true , {3 , 4 , 3 }, {4 , 3 });
96
+ pointwise_test_helper (graph, false , false , {3 , 4 , 3 }, {4 , 3 }, false , at::kFloat , at::kInt );
97
+ }
98
+
99
+ TEST (Converters, ATenAddImplicitWithIntAlphaConvertsCorrectly) {
100
+ const auto graph = R"IR(
101
+ graph(%0 : Tensor, %1 : Tensor):
102
+ %2 : int = prim::Constant[value=42]()
103
+ %3 : Tensor = aten::add_(%0, %1, %2)
104
+ return (%3))IR" ;
105
+ pointwise_test_helper (graph, false , false , {2 , 2 }, {2 , 2 }, false , at::kInt , at::kInt );
106
+ pointwise_test_helper (graph, false , false , {3 , 4 , 3 }, {4 , 3 }, false , at::kInt , at::kInt );
109
107
}
110
108
111
109
TEST (Converters, ATenAddWithScalarConvertsCorrectly) {
@@ -129,8 +127,8 @@ TEST(Converters, ATenSubConvertsCorrectly) {
129
127
pointwise_test_helper (graph, false , false , {4 }, {3 , 4 });
130
128
pointwise_test_helper (graph, false , true , {3 , 4 , 3 }, {4 , 3 });
131
129
pointwise_test_helper (graph, false , true , {4 , 3 }, {3 , 4 , 3 });
132
- pointwise_test_helper (graph, false , true , {5 }, {5 }, false , false , true );
133
- pointwise_test_helper (graph, false , true , {5 }, {5 }, false , false , false , true );
130
+ pointwise_test_helper (graph, false , true , {5 }, {5 }, false , at:: kFloat , at:: kInt );
131
+ pointwise_test_helper (graph, false , true , {5 }, {5 }, false , at:: kInt , at:: kFloat );
134
132
}
135
133
136
134
TEST (Converters, ATenMulConvertsCorrectly) {
@@ -143,8 +141,8 @@ TEST(Converters, ATenMulConvertsCorrectly) {
143
141
pointwise_test_helper (graph, false , false , {4 }, {3 , 4 });
144
142
pointwise_test_helper (graph, false , true , {3 , 4 , 3 }, {4 , 3 });
145
143
pointwise_test_helper (graph, false , true , {4 , 3 }, {3 , 4 , 3 });
146
- pointwise_test_helper (graph, false , true , {5 }, {5 }, false , false , true );
147
- pointwise_test_helper (graph, false , true , {5 }, {5 }, false , false , false , true );
144
+ pointwise_test_helper (graph, false , true , {5 }, {5 }, false , at:: kFloat , at:: kInt );
145
+ pointwise_test_helper (graph, false , true , {5 }, {5 }, false , at:: kInt , at:: kFloat );
148
146
}
149
147
150
148
TEST (Converters, ATenMulWithScalarConvertsCorrectly) {
@@ -162,7 +160,7 @@ TEST(Converters, ATenMulWithIntScalarConvertsCorrectly) {
162
160
%scalar : int = prim::Constant[value=2]()
163
161
%1 : Tensor = aten::mul(%0, %scalar)
164
162
return (%1))IR" ;
165
- pointwise_test_helper (graph, true , false , {5 }, {5 }, false , true );
163
+ pointwise_test_helper (graph, true , false , {5 }, {5 }, false , at:: kInt );
166
164
}
167
165
168
166
TEST (Converters, ATenDivConvertsCorrectly) {
@@ -175,8 +173,6 @@ TEST(Converters, ATenDivConvertsCorrectly) {
175
173
pointwise_test_helper (graph, false , false , {4 }, {3 , 4 });
176
174
pointwise_test_helper (graph, false , true , {3 , 4 , 3 }, {4 , 3 });
177
175
pointwise_test_helper (graph, false , true , {4 , 3 }, {3 , 4 , 3 });
178
- pointwise_test_helper (graph, false , true , {5 }, {5 }, false , false , true );
179
- pointwise_test_helper (graph, false , true , {5 }, {5 }, false , false , false , true );
180
176
}
181
177
182
178
TEST (Converters, ATenDivWithScalarConvertsCorrectly) {
@@ -199,8 +195,8 @@ TEST(Converters, ATenDivRoundingFloorConvertsCorrectly) {
199
195
pointwise_test_helper (graph, false , false , {4 }, {3 , 4 }, true );
200
196
pointwise_test_helper (graph, false , true , {3 , 4 , 3 }, {4 , 3 }, true );
201
197
pointwise_test_helper (graph, false , true , {4 , 3 }, {3 , 4 , 3 }, true );
202
- pointwise_test_helper (graph, false , true , {5 }, {5 }, false , false , true );
203
- pointwise_test_helper (graph, false , true , {5 }, {5 }, false , false , false , true );
198
+ pointwise_test_helper (graph, false , true , {5 }, {5 }, false , at:: kFloat , at:: kInt );
199
+ pointwise_test_helper (graph, false , true , {5 }, {5 }, false , at:: kInt , at:: kFloat );
204
200
}
205
201
206
202
TEST (Converters, ATenDivRoundingTruncConvertsCorrectly) {
@@ -214,8 +210,8 @@ TEST(Converters, ATenDivRoundingTruncConvertsCorrectly) {
214
210
pointwise_test_helper (graph, false , false , {4 }, {3 , 4 }, true );
215
211
pointwise_test_helper (graph, false , true , {3 , 4 , 3 }, {4 , 3 }, true );
216
212
pointwise_test_helper (graph, false , true , {4 , 3 }, {3 , 4 , 3 }, true );
217
- pointwise_test_helper (graph, false , true , {5 }, {5 }, false , false , true );
218
- pointwise_test_helper (graph, false , true , {5 }, {5 }, false , false , false , true );
213
+ pointwise_test_helper (graph, false , true , {5 }, {5 }, false , at:: kFloat , at:: kInt );
214
+ pointwise_test_helper (graph, false , true , {5 }, {5 }, false , at:: kInt , at:: kFloat );
219
215
}
220
216
221
217
TEST (Converters, ATenDivRoundingNoneConvertsCorrectly) {
@@ -241,8 +237,8 @@ TEST(Converters, ATenPowTensorConvertsCorrectly) {
241
237
pointwise_test_helper (graph, false , false , {4 }, {3 , 4 });
242
238
pointwise_test_helper (graph, false , true , {3 , 4 , 3 }, {4 , 3 });
243
239
pointwise_test_helper (graph, false , true , {4 , 3 }, {3 , 4 , 3 });
244
- pointwise_test_helper (graph, false , true , {5 }, {5 }, false , false , true );
245
- pointwise_test_helper (graph, false , true , {5 }, {5 }, false , false , false , true );
240
+ pointwise_test_helper (graph, false , true , {5 }, {5 }, false , at:: kFloat , at:: kInt );
241
+ pointwise_test_helper (graph, false , true , {5 }, {5 }, false , at:: kInt , at:: kFloat );
246
242
}
247
243
248
244
TEST (Converters, ATenPowScalarConvertsCorrectly) {
@@ -283,8 +279,8 @@ TEST(Converters, ATenFloorDivideConvertsCorrectly) {
283
279
pointwise_test_helper (graph, false , false , {4 }, {3 , 4 });
284
280
pointwise_test_helper (graph, false , true , {3 , 4 , 3 }, {4 , 3 });
285
281
pointwise_test_helper (graph, false , true , {4 , 3 }, {3 , 4 , 3 });
286
- pointwise_test_helper (graph, false , true , {5 }, {5 }, false , false , true );
287
- pointwise_test_helper (graph, false , true , {5 }, {5 }, false , false , false , true );
282
+ pointwise_test_helper (graph, false , true , {5 }, {5 }, false , at:: kFloat , at:: kInt );
283
+ pointwise_test_helper (graph, false , true , {5 }, {5 }, false , at:: kInt , at:: kFloat );
288
284
}
289
285
290
286
TEST (Converters, ATenFloorDivideWithScalarConvertsCorrectly) {
@@ -329,6 +325,8 @@ TEST(Converters, ATenRsubWithTensorConvertsCorrectly) {
329
325
pointwise_test_helper (graph, false , false , {3 , 4 }, {4 });
330
326
pointwise_test_helper (graph, false , false , {4 }, {3 , 4 });
331
327
pointwise_test_helper (graph, false , true , {4 , 3 , 3 , 3 }, {4 , 3 , 3 , 3 });
328
+ pointwise_test_helper (graph, false , false , {4 , 3 , 3 , 3 }, {4 , 3 , 3 , 3 }, false , at::kInt , at::kFloat );
329
+ pointwise_test_helper (graph, false , false , {4 , 3 , 3 , 3 }, {4 , 3 , 3 , 3 }, false , at::kInt , at::kInt );
332
330
}
333
331
334
332
TEST (Converters, ATenRsubWithScalarConvertsCorrectly) {
@@ -341,6 +339,16 @@ TEST(Converters, ATenRsubWithScalarConvertsCorrectly) {
341
339
pointwise_test_helper (graph, true , false , {4 , 3 , 3 , 3 });
342
340
}
343
341
342
+ TEST (Converters, ATenRsubWithIntScalarConvertsCorrectly) {
343
+ const auto graph = R"IR(
344
+ graph(%0 : Tensor):
345
+ %2 : int = prim::Constant[value=2]()
346
+ %scalar : int = prim::Constant[value=8]()
347
+ %3 : Tensor = aten::rsub(%0, %scalar, %2)
348
+ return (%3))IR" ;
349
+ pointwise_test_helper (graph, true , false , {4 , 3 , 3 , 3 }, {}, false , at::kInt );
350
+ }
351
+
344
352
TEST (Converters, ATenClampMinConvertsCorrectly) {
345
353
const auto graph = R"IR(
346
354
graph(%x.1 : Tensor):
0 commit comments