@@ -11,7 +11,8 @@ void pointwise_test_helper(
1111 bool dynamicInput = false ,
1212 std::vector<int64_t > shape1 = {5 },
1313 std::vector<int64_t > shape2 = {5 },
14- bool negative_input = false ) {
14+ bool negative_input = false ,
15+ bool int_tensors = false ) {
1516 auto g = std::make_shared<torch::jit::Graph>();
1617 torch::jit::parseIR (graph_ir, g.get ());
1718
@@ -26,6 +27,11 @@ void pointwise_test_helper(
2627 if (!singleInput) {
2728 torch_inputs.push_back (at::randint (1 , 5 , shape2, {at::kCUDA }));
2829 }
30+ if (int_tensors){
31+ for (size_t i = 0UL ; i < torch_inputs.size (); ++i){
32+ torch_inputs[i] = torch_inputs[i].to (at::kInt );
33+ }
34+ }
2935 auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
3036 auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, torch_inputs);
3137
@@ -126,6 +132,15 @@ TEST(Converters, ATenMulWithScalarConvertsCorrectly) {
126132 pointwise_test_helper (graph, true );
127133}
128134
135+ TEST (Converters, ATenMulWithIntScalarConvertsCorrectly) {
136+ const auto graph = R"IR(
137+ graph(%0 : Tensor):
138+ %scalar : int = prim::Constant[value=2]()
139+ %1 : Tensor = aten::mul(%0, %scalar)
140+ return (%1))IR" ;
141+ pointwise_test_helper (graph, true , false , {5 }, {5 }, false , true );
142+ }
143+
129144TEST (Converters, ATenDivConvertsCorrectly) {
130145 const auto graph = R"IR(
131146 graph(%0 : Tensor, %1 : Tensor):
0 commit comments