diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index 5692143d3c6..e83a4cb7e50 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -146,9 +146,14 @@ class Mult(torch.nn.Module): def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return x * y + class Minimum(torch.nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return torch.minimum(x, y) + for module, op, expected_count in ( (Add, exir_ops.edge.aten.add.Tensor, 2), (Mult, exir_ops.edge.aten.mul.Tensor, 1), + (Minimum, exir_ops.edge.aten.minimum.default, 1), ): for second_arg_dtype in (torch.int64, torch.float, torch.double): int_tensor = torch.tensor([[1, 2, 3]], dtype=torch.int64)