diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index dd4037b64c0..5692143d3c6 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -127,85 +127,44 @@ def setUpClass(cls) -> None: register_additional_test_aten_ops() def test_remove_mixed_type_operators(self) -> None: + def count_nodes_with_target_asserting_arguments_have_dtype( + new_graph_module, target, arg_dtype + ): + count = 0 + for node in new_graph_module.graph.nodes: + if node.op == "call_function" and node.target == target: + count += 1 + for arg in node.args: + self.assertEqual(arg.meta["val"].dtype, arg_dtype) + return count + class Add(torch.nn.Module): def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return (x + y) + x - add = Add() - - int_tensor = torch.tensor([[1, 2, 3]]) - float_tensor = torch.tensor([[1.0, 2.0, 3.0]]) - edge_prog = to_edge(export(add, (int_tensor, float_tensor), strict=True)) - - new_prog = edge_prog.transform([RemoveMixedTypeOperators()]) - new_graph_module = new_prog.exported_program().graph_module - self.assertIsNotNone(new_graph_module) - - add_count = 0 - - for node in new_graph_module.graph.nodes: - if ( - node.op == "call_function" - and node.target == exir_ops.edge.aten.add.Tensor - ): - add_count += 1 - node_args = node.args - for arg in node_args: - self.assertEqual(arg.meta["val"].dtype, torch.float) - - self.assertEqual(add_count, 2) - - double_tensor = torch.tensor([[1.0, 2.0, 3.0]]) - double_tensor = double_tensor.to(torch.double) - - double_prog = to_edge(export(add, (int_tensor, double_tensor), strict=True)) - - double_prog.transform([RemoveMixedTypeOperators()]) - new_graph_module_double = double_prog.exported_program().graph_module - self.assertIsNotNone(new_graph_module_double) - - add_count_double = 0 - - for node in new_graph_module_double.graph.nodes: - if ( - node.op == "call_function" - and node.target == exir_ops.edge.aten.add.Tensor - ): - add_count_double += 1 - node_args = node.args - for arg in node_args: - self.assertEqual(arg.meta["val"].dtype, torch.double) - - self.assertEqual(add_count_double, 2) - class Mult(torch.nn.Module): def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return x * y - mult = Mult() - - float_tensor_vert = float_tensor.T - mult_prog = to_edge(export(mult, (int_tensor, float_tensor_vert), strict=True)) - - # graph_module_mult.graph.print_tabular() - - mult_prog = mult_prog.transform([RemoveMixedTypeOperators()]) - new_graph_module_mult = mult_prog.exported_program().graph_module - self.assertIsNotNone(new_graph_module_mult) + for module, op, expected_count in ( + (Add, exir_ops.edge.aten.add.Tensor, 2), + (Mult, exir_ops.edge.aten.mul.Tensor, 1), + ): + for second_arg_dtype in (torch.int64, torch.float, torch.double): + int_tensor = torch.tensor([[1, 2, 3]], dtype=torch.int64) + float_tensor = torch.tensor([[1.0, 2.0, 3.0]], dtype=second_arg_dtype) + edge_prog = to_edge( + export(module(), (int_tensor, float_tensor), strict=True) + ) - mult_count = 0 + new_prog = edge_prog.transform([RemoveMixedTypeOperators()]) + new_graph_module = new_prog.exported_program().graph_module + self.assertIsNotNone(new_graph_module) - for node in new_graph_module_mult.graph.nodes: - if ( - node.op == "call_function" - and node.target == exir_ops.edge.aten.mul.Tensor - ): - mult_count += 1 - node_args = node.args - for arg in node_args: - self.assertEqual(arg.meta["val"].dtype, torch.float) - - self.assertEqual(mult_count, 1) + count = count_nodes_with_target_asserting_arguments_have_dtype( + new_graph_module, op, second_arg_dtype + ) + self.assertEqual(count, expected_count) def test_remove_noop_pass(self) -> None: class Foo(torch.nn.Module):