Skip to content

Commit 91c9ffa

Browse files
Fix bug in add op to ignore alpha != 1 (fixes: #11683)
Differential Revision: D76932880 Pull Request resolved: #11777
1 parent d5fe5fa commit 91c9ffa

File tree

2 files changed

+35
-0
lines changed

2 files changed

+35
-0
lines changed

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,17 @@ def __init__(self, **kwargs):
107107
def supported_precision_types(self) -> List[ConfigPrecisionType]:
108108
return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT]
109109

110+
def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
111+
if not self.check_common_constraints(node, ep):
112+
return False
113+
# No support for add nodes with alpha != 1
114+
if "alpha" in node.kwargs and not np.isclose(
115+
node.kwargs["alpha"], 1.0, atol=1e-9, rtol=1e-9
116+
):
117+
why(node, reason="Add node doesn't support alpha != 1")
118+
return False
119+
return True
120+
110121

111122
class ReLUConfig(GenericNodePartitionerConfig):
112123
target_name = "relu.default"

backends/xnnpack/test/ops/test_add.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,3 +240,27 @@ def forward(self, x, z):
240240
.serialize()
241241
.run_method_and_compare_outputs()
242242
)
243+
244+
class AddWithAlpha(torch.nn.Module):
245+
def forward(self, x, y):
246+
# node with alpha = 1.0 will be partitioned
247+
out1 = torch.add(x, y, alpha=1)
248+
# node with alpha != 1.0 will not be partitioned
249+
out2 = torch.add(x, y, alpha=2)
250+
return out1, out2
251+
252+
def test_add_with_alpha(self):
253+
inputs = (torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 4))
254+
(
255+
Tester(self.AddWithAlpha(), inputs)
256+
.export()
257+
.check_count({"torch.ops.aten.add.Tensor": 2})
258+
.to_edge_transform_and_lower()
259+
# unpartitioned node
260+
.check_count({"executorch_exir_dialects_edge__ops_aten_add_Tensor": 1})
261+
# partitioned node
262+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
263+
.to_executorch()
264+
.serialize()
265+
.run_method_and_compare_outputs()
266+
)

0 commit comments

Comments
 (0)