@@ -12,26 +12,31 @@ void pointwise_test_helper(
12
12
std::vector<int64_t > shape1 = {5 },
13
13
std::vector<int64_t > shape2 = {5 },
14
14
bool negative_input = false ,
15
- bool int_tensors = false ) {
15
+ at::ScalarType type1 = at::kFloat ,
16
+ at::ScalarType type2 = at::kFloat ) {
16
17
auto g = std::make_shared<torch::jit::Graph>();
17
18
torch::jit::parseIR (graph_ir, g.get ());
18
19
19
20
// singleInput case is enabled when elementwise operation is performed
20
21
// with an input and a constant embedded in graph
21
22
std::vector<at::Tensor> torch_inputs;
22
- if (negative_input) {
23
- torch_inputs.push_back (at::randint (-5 , 5 , shape1, {at::kCUDA }));
24
- } else {
25
- torch_inputs.push_back (at::randint (1 , 5 , shape1, {at::kCUDA }));
23
+ int first_min = negative_input ? -5 : 1 ;
24
+ int first_max = 5 ;
25
+ int second_min = 1 ;
26
+ int second_max = 5 ;
27
+ if (type1 == at::kBool ){
28
+ first_min = 0 ;
29
+ first_max = 1 ;
26
30
}
27
- if (!singleInput) {
28
- torch_inputs.push_back (at::randint (1 , 5 , shape2, {at::kCUDA }));
31
+ if (type2 == at::kBool ){
32
+ second_min = 0 ;
33
+ second_max = 1 ;
29
34
}
30
- if (int_tensors){
31
- for (size_t i = 0UL ; i < torch_inputs.size (); ++i){
32
- torch_inputs[i] = torch_inputs[i].to (at::kInt );
33
- }
35
+ torch_inputs.push_back (at::randint (first_min, first_max, shape1, at::TensorOptions (at::kCUDA ).dtype (type1)));
36
+ if (!singleInput) {
37
+ torch_inputs.push_back (at::randint (second_min, second_max, shape2, at::TensorOptions (at::kCUDA ).dtype (type2)));
34
38
}
39
+
35
40
auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
36
41
auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, torch_inputs);
37
42
@@ -62,6 +67,13 @@ TEST(Converters, ATenAddConvertsCorrectly) {
62
67
pointwise_test_helper (graph, false , false , {4 }, {3 , 4 });
63
68
pointwise_test_helper (graph, false , true , {3 , 4 , 3 }, {4 , 3 });
64
69
pointwise_test_helper (graph, false , true , {4 , 3 }, {3 , 4 , 3 });
70
+ pointwise_test_helper (graph, false , false , {4 , 3 }, {3 , 4 , 3 }, false , at::kInt , at::kInt );
71
+ pointwise_test_helper (graph, false , false , {4 , 3 }, {3 , 4 , 3 }, false , at::kFloat , at::kInt );
72
+ pointwise_test_helper (graph, false , false , {4 , 3 }, {3 , 4 , 3 }, false , at::kInt , at::kFloat );
73
+ pointwise_test_helper (graph, false , false , {4 , 3 }, {3 , 4 , 3 }, false , at::kBool , at::kInt );
74
+ pointwise_test_helper (graph, false , false , {4 , 3 }, {3 , 4 , 3 }, false , at::kBool , at::kFloat );
75
+ pointwise_test_helper (graph, false , false , {4 , 3 }, {3 , 4 , 3 }, false , at::kInt , at::kBool );
76
+ pointwise_test_helper (graph, false , false , {4 , 3 }, {3 , 4 , 3 }, false , at::kFloat , at::kBool );
65
77
}
66
78
67
79
TEST (Converters, ATenAddWithAlphaConvertsCorrectly) {
@@ -86,6 +98,17 @@ TEST(Converters, ATenAddImplicitWithAlphaConvertsCorrectly) {
86
98
pointwise_test_helper (graph, false );
87
99
pointwise_test_helper (graph, false , false , {3 , 4 }, {4 });
88
100
pointwise_test_helper (graph, false , true , {3 , 4 , 3 }, {4 , 3 });
101
+ pointwise_test_helper (graph, false , false , {3 , 4 , 3 }, {4 , 3 }, false , at::kFloat , at::kInt );
102
+ }
103
+
104
+ TEST (Converters, ATenAddImplicitWithIntAlphaConvertsCorrectly) {
105
+ const auto graph = R"IR(
106
+ graph(%0 : Tensor, %1 : Tensor):
107
+ %2 : int = prim::Constant[value=42]()
108
+ %3 : Tensor = aten::add_(%0, %1, %2)
109
+ return (%3))IR" ;
110
+ pointwise_test_helper (graph, false , false , {2 , 2 }, {2 , 2 }, false , at::kInt , at::kInt );
111
+ pointwise_test_helper (graph, false , false , {3 , 4 , 3 }, {4 , 3 }, false , at::kInt , at::kInt );
89
112
}
90
113
91
114
TEST (Converters, ATenAddWithScalarConvertsCorrectly) {
@@ -138,7 +161,7 @@ TEST(Converters, ATenMulWithIntScalarConvertsCorrectly) {
138
161
%scalar : int = prim::Constant[value=2]()
139
162
%1 : Tensor = aten::mul(%0, %scalar)
140
163
return (%1))IR" ;
141
- pointwise_test_helper (graph, true , false , {5 }, {5 }, false , true );
164
+ pointwise_test_helper (graph, true , false , {5 }, {5 }, false , at:: kInt );
142
165
}
143
166
144
167
TEST (Converters, ATenDivConvertsCorrectly) {
@@ -151,6 +174,8 @@ TEST(Converters, ATenDivConvertsCorrectly) {
151
174
pointwise_test_helper (graph, false , false , {4 }, {3 , 4 });
152
175
pointwise_test_helper (graph, false , true , {3 , 4 , 3 }, {4 , 3 });
153
176
pointwise_test_helper (graph, false , true , {4 , 3 }, {3 , 4 , 3 });
177
+ pointwise_test_helper (graph, false , false , {4 , 3 }, {3 , 4 , 3 }, false , at::kFloat , at::kInt );
178
+ pointwise_test_helper (graph, false , false , {4 , 3 }, {3 , 4 , 3 }, false , at::kInt , at::kFloat );
154
179
}
155
180
156
181
TEST (Converters, ATenDivWithScalarConvertsCorrectly) {
@@ -295,6 +320,8 @@ TEST(Converters, ATenRsubWithTensorConvertsCorrectly) {
295
320
pointwise_test_helper (graph, false , false , {3 , 4 }, {4 });
296
321
pointwise_test_helper (graph, false , false , {4 }, {3 , 4 });
297
322
pointwise_test_helper (graph, false , true , {4 , 3 , 3 , 3 }, {4 , 3 , 3 , 3 });
323
+ pointwise_test_helper (graph, false , false , {4 , 3 , 3 , 3 }, {4 , 3 , 3 , 3 }, false , at::kInt , at::kFloat );
324
+ pointwise_test_helper (graph, false , false , {4 , 3 , 3 , 3 }, {4 , 3 , 3 , 3 }, false , at::kInt , at::kInt );
298
325
}
299
326
300
327
TEST (Converters, ATenRsubWithScalarConvertsCorrectly) {
@@ -307,6 +334,46 @@ TEST(Converters, ATenRsubWithScalarConvertsCorrectly) {
307
334
pointwise_test_helper (graph, true , false , {4 , 3 , 3 , 3 });
308
335
}
309
336
337
+ TEST (Converters, ATenRsubWithIntScalarConvertsCorrectly) {
338
+ const auto graph = R"IR(
339
+ graph(%0 : Tensor):
340
+ %2 : int = prim::Constant[value=2]()
341
+ %scalar : int = prim::Constant[value=8]()
342
+ %3 : Tensor = aten::rsub(%0, %scalar, %2)
343
+ return (%3))IR" ;
344
+ pointwise_test_helper (graph, true , false , {4 , 3 , 3 , 3 }, {}, false , at::kInt );
345
+ }
346
+
347
+ TEST (Converters, ATenClipMinConvertsCorrectly) {
348
+ const auto graph = R"IR(
349
+ graph(%x.1 : Tensor):
350
+ %2 : float = prim::Constant[value=1.5]()
351
+ %3 : None = prim::Constant()
352
+ %4 : Tensor = aten::clip(%x.1, %2, %3)
353
+ return (%4))IR" ;
354
+ pointwise_test_helper (graph, true );
355
+ }
356
+
357
+ TEST (Converters, ATenClipMaxConvertsCorrectly) {
358
+ const auto graph = R"IR(
359
+ graph(%x.1 : Tensor):
360
+ %2 : float = prim::Constant[value=3.5]()
361
+ %3 : None = prim::Constant()
362
+ %4 : Tensor = aten::clip(%x.1, %3, %2)
363
+ return (%4))IR" ;
364
+ pointwise_test_helper (graph, true );
365
+ }
366
+
367
+ TEST (Converters, ATenClipMinMaxConvertsCorrectly) {
368
+ const auto graph = R"IR(
369
+ graph(%x.1 : Tensor):
370
+ %2 : float = prim::Constant[value=3.5]()
371
+ %3 : float = prim::Constant[value=1.5]()
372
+ %4 : Tensor = aten::clip(%x.1, %3, %2)
373
+ return (%4))IR" ;
374
+ pointwise_test_helper (graph, true );
375
+ }
376
+
310
377
TEST (Converters, ATenClampMinConvertsCorrectly) {
311
378
const auto graph = R"IR(
312
379
graph(%x.1 : Tensor):
@@ -337,6 +404,36 @@ TEST(Converters, ATenClampMinMaxConvertsCorrectly) {
337
404
pointwise_test_helper (graph, true );
338
405
}
339
406
407
+ TEST (Converters, ATenClampIntMinConvertsCorrectly) {
408
+ const auto graph = R"IR(
409
+ graph(%x.1 : Tensor):
410
+ %2 : int = prim::Constant[value=1]()
411
+ %3 : None = prim::Constant()
412
+ %4 : Tensor = aten::clamp(%x.1, %2, %3)
413
+ return (%4))IR" ;
414
+ pointwise_test_helper (graph, true , false , {5 }, {5 }, false , at::kInt );
415
+ }
416
+
417
+ TEST (Converters, ATenClampIntMaxConvertsCorrectly) {
418
+ const auto graph = R"IR(
419
+ graph(%x.1 : Tensor):
420
+ %2 : int = prim::Constant[value=3]()
421
+ %3 : None = prim::Constant()
422
+ %4 : Tensor = aten::clamp(%x.1, %3, %2)
423
+ return (%4))IR" ;
424
+ pointwise_test_helper (graph, true , false , {5 }, {5 }, false , at::kInt );
425
+ }
426
+
427
+ TEST (Converters, ATenClampIntMinMaxConvertsCorrectly) {
428
+ const auto graph = R"IR(
429
+ graph(%x.1 : Tensor):
430
+ %2 : int = prim::Constant[value=3]()
431
+ %3 : int = prim::Constant[value=1]()
432
+ %4 : Tensor = aten::clamp(%x.1, %3, %2)
433
+ return (%4))IR" ;
434
+ pointwise_test_helper (graph, true , false , {5 }, {5 }, false , at::kInt );
435
+ }
436
+
340
437
TEST (Converters, ATenClampMinimumConvertsCorrectly) {
341
438
const auto graph = R"IR(
342
439
graph(%x.1 : Tensor):
@@ -487,4 +584,4 @@ TEST(Converters, ATenRemainderWithScalarConvertsCorrectly) {
487
584
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {in});
488
585
489
586
ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
490
- }
587
+ }
0 commit comments