@@ -9,14 +9,19 @@ void pointwise_test_helper(
9
9
bool singleInput,
10
10
bool dynamicInput = false ,
11
11
std::vector<int64_t > shape1 = {5 },
12
- std::vector<int64_t > shape2 = {5 }) {
12
+ std::vector<int64_t > shape2 = {5 },
13
+ bool negative_input = false ) {
13
14
auto g = std::make_shared<torch::jit::Graph>();
14
15
torch::jit::parseIR (graph_ir, g.get ());
15
16
16
17
// singleInput case is enabled when elementwise operation is performed
17
18
// with an input and a constant embedded in graph
18
19
std::vector<at::Tensor> torch_inputs;
19
- torch_inputs.push_back (at::randint (1 , 5 , shape1, {at::kCUDA }));
20
+ if (negative_input) {
21
+ torch_inputs.push_back (at::randint (-5 , 5 , shape1, {at::kCUDA }));
22
+ } else {
23
+ torch_inputs.push_back (at::randint (1 , 5 , shape1, {at::kCUDA }));
24
+ }
20
25
if (!singleInput) {
21
26
torch_inputs.push_back (at::randint (1 , 5 , shape2, {at::kCUDA }));
22
27
}
@@ -141,6 +146,45 @@ TEST(Converters, ATenDivWithScalarConvertsCorrectly) {
141
146
pointwise_test_helper (graph, true );
142
147
}
143
148
149
+ TEST (Converters, ATenDivRoundingFloorConvertsCorrectly) {
150
+ const auto graph = R"IR(
151
+ graph(%0 : Tensor, %1 : Tensor):
152
+ %3 : str = prim::Constant[value="floor"]()
153
+ %2 : Tensor = aten::div(%0, %1, %3)
154
+ return (%2))IR" ;
155
+ pointwise_test_helper (graph, false , false , {5 }, {5 }, true );
156
+ pointwise_test_helper (graph, false , false , {3 , 4 }, {4 }, true );
157
+ pointwise_test_helper (graph, false , false , {4 }, {3 , 4 }, true );
158
+ pointwise_test_helper (graph, false , true , {3 , 4 , 3 }, {4 , 3 }, true );
159
+ pointwise_test_helper (graph, false , true , {4 , 3 }, {3 , 4 , 3 }, true );
160
+ }
161
+
162
+ TEST (Converters, ATenDivRoundingTruncConvertsCorrectly) {
163
+ const auto graph = R"IR(
164
+ graph(%0 : Tensor, %1 : Tensor):
165
+ %3 : str = prim::Constant[value="trunc"]()
166
+ %2 : Tensor = aten::div(%0, %1, %3)
167
+ return (%2))IR" ;
168
+ pointwise_test_helper (graph, false , false , {5 }, {5 }, true );
169
+ pointwise_test_helper (graph, false , false , {3 , 4 }, {4 }, true );
170
+ pointwise_test_helper (graph, false , false , {4 }, {3 , 4 }, true );
171
+ pointwise_test_helper (graph, false , true , {3 , 4 , 3 }, {4 , 3 }, true );
172
+ pointwise_test_helper (graph, false , true , {4 , 3 }, {3 , 4 , 3 }, true );
173
+ }
174
+
175
+ TEST (Converters, ATenDivRoundingNoneConvertsCorrectly) {
176
+ const auto graph = R"IR(
177
+ graph(%0 : Tensor, %1 : Tensor):
178
+ %3 : None = prim::Constant()
179
+ %2 : Tensor = aten::div(%0, %1, %3)
180
+ return (%2))IR" ;
181
+ pointwise_test_helper (graph, false , false , {5 }, {5 }, true );
182
+ pointwise_test_helper (graph, false , false , {3 , 4 }, {4 }, true );
183
+ pointwise_test_helper (graph, false , false , {4 }, {3 , 4 }, true );
184
+ pointwise_test_helper (graph, false , true , {3 , 4 , 3 }, {4 , 3 }, true );
185
+ pointwise_test_helper (graph, false , true , {4 , 3 }, {3 , 4 , 3 }, true );
186
+ }
187
+
144
188
TEST (Converters, ATenPowTensorConvertsCorrectly) {
145
189
const auto graph = R"IR(
146
190
graph(%x.1 : Tensor, %x2.1 : Tensor):
0 commit comments