Skip to content

Commit 77c35f5

Browse files
authored
Arm Backend: temp fix for flaky eq op test (#9794)
This patch provides a temporary workaround for the flaky op_eq test. - Add eq_scalar_rank4_randn test to xfail temporarily, and set strict to false - Change misleading aten_op naming Signed-off-by: Fang-Ching <[email protected]>
1 parent a5e326a commit 77c35f5

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

backends/arm/test/ops/test_eq.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020

2121

2222
class Equal(torch.nn.Module):
23-
aten_op_BI = "torch.ops.aten.eq.Tensor"
24-
aten_op_MI = "torch.ops.aten.eq.Scalar"
23+
aten_op_Tensor = "torch.ops.aten.eq.Tensor"
24+
aten_op_Scalar = "torch.ops.aten.eq.Scalar"
2525
exir_op = "executorch_exir_dialects_edge__ops_aten_eq_Tensor"
2626

2727
def __init__(self, input, other):
@@ -80,7 +80,7 @@ def get_inputs(self):
8080
@common.parametrize("test_module", test_data_tensor)
8181
def test_eq_tensor_tosa_MI(test_module):
8282
pipeline = TosaPipelineMI[input_t](
83-
test_module, test_module.get_inputs(), Equal.aten_op_BI, Equal.exir_op
83+
test_module, test_module.get_inputs(), Equal.aten_op_Tensor, Equal.exir_op
8484
)
8585
pipeline.run()
8686

@@ -90,7 +90,7 @@ def test_eq_scalar_tosa_MI(test_module):
9090
pipeline = TosaPipelineMI[input_t](
9191
test_module,
9292
test_module.get_inputs(),
93-
Equal.aten_op_MI,
93+
Equal.aten_op_Scalar,
9494
Equal.exir_op,
9595
)
9696
pipeline.run()
@@ -99,7 +99,7 @@ def test_eq_scalar_tosa_MI(test_module):
9999
@common.parametrize("test_module", test_data_tensor | test_data_scalar)
100100
def test_eq_tosa_BI(test_module):
101101
pipeline = TosaPipelineBI[input_t](
102-
test_module, test_module.get_inputs(), Equal.aten_op_BI, Equal.exir_op
102+
test_module, test_module.get_inputs(), Equal.aten_op_Tensor, Equal.exir_op
103103
)
104104
pipeline.run()
105105

@@ -135,15 +135,17 @@ def test_eq_scalar_u55_BI(test_module):
135135
"test_module",
136136
test_data_tensor | test_data_scalar,
137137
xfails={
138-
"eq_tensor_rank4_randn": "4D fails because boolean Tensors can't be subtracted",
138+
"eq_tensor_rank4_randn": "MLETORCH-847: Boolean eq result unstable on U85",
139+
"eq_scalar_rank4_randn": "MLETORCH-847: Boolean eq result unstable on U85",
139140
},
141+
strict=False,
140142
)
141143
@common.XfailIfNoCorstone320
142144
def test_eq_u85_BI(test_module):
143145
pipeline = EthosU85PipelineBI[input_t](
144146
test_module,
145147
test_module.get_inputs(),
146-
Equal.aten_op_BI,
148+
Equal.aten_op_Tensor,
147149
Equal.exir_op,
148150
run_on_fvp=True,
149151
)

0 commit comments

Comments
 (0)