Skip to content

Commit 38c4468

Browse files
gramalingamCopilot
andauthored
Handle matching against None explicitly (#2460)
Provide a way to indicate that a pattern-variable can match successfully against a None-valued input. Cleanup current handling which was inconsistent in one place. Add test cases. --------- Signed-off-by: Ganesan Ramalingam <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent 413f3df commit 38c4468

File tree

6 files changed

+89
-19
lines changed

6 files changed

+89
-19
lines changed

onnxscript/rewriter/_matcher.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -149,18 +149,21 @@ def _match_node(self, pattern_node: _pattern_ir.NodePattern, node: ir.Node) -> b
149149
match.bind_node(pattern_node, node)
150150

151151
# TODO: Revisit this to handle optional trailing inputs better.
152-
if pattern_node.allow_other_inputs:
153-
if len(node.inputs) < len(pattern_node.inputs):
152+
153+
if len(node.inputs) > len(pattern_node.inputs):
154+
if not pattern_node.allow_other_inputs:
154155
return self.fail(
155-
f"Number of inputs ({len(node.inputs)}) is less than expected ({len(pattern_node.inputs)})"
156+
f"Number of inputs ({len(node.inputs)}) is greater than expected ({len(pattern_node.inputs)})"
156157
)
158+
checked_inputs = zip(node.inputs, pattern_node.inputs)
157159
else:
158-
if len(node.inputs) != len(pattern_node.inputs):
159-
return self.fail(
160-
f"Input nums mismatch. {len(node.inputs)} vs {len(pattern_node.inputs)}"
161-
)
160+
# In ONNX, trailing Nones can be omitted in the inputs of a node. So, we extend actual
161+
# node inputs with None values to match the pattern node inputs length when zipping.
162+
checked_inputs = itertools.zip_longest(
163+
node.inputs, pattern_node.inputs, fillvalue=None
164+
)
162165

163-
for arg_value, arg_pattern in zip(node.inputs, pattern_node.inputs):
166+
for arg_value, arg_pattern in checked_inputs:
164167
# arg_pattern could be a Var, if it's the original arg.
165168
if arg_pattern is None:
166169
if arg_value is None:
@@ -216,6 +219,11 @@ def _match_value(
216219
if pattern_value.tag_var is not None:
217220
self._match.bind(pattern_value.tag_var, i)
218221
return result
222+
# Default case: a plain pattern variable (ValuePattern)
223+
if value is None and not pattern_value.can_match_none:
224+
return self.fail(
225+
f"Mismatch: pattern variable {pattern_value} does not match None."
226+
)
219227
return True
220228

221229
def _match_node_output(

onnxscript/rewriter/_pattern_ir.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,12 +123,16 @@ def _to_attr_pattern(value: AttrPattern | ValuePattern | SupportedAttrTypes) ->
123123
"""Represents promotion of values allowed as keyword-arguments in a pattern-builder call to an AttrPattern."""
124124
if isinstance(value, AttrPattern):
125125
return value
126-
if type(value) is ValuePattern:
127-
# This is a hack. Currently, when we create pattern-variables, we create them as ValuePattern,
126+
if isinstance(value, Var):
127+
# This is a hack. Currently, when we create pattern-variables, we create them as Var,
128128
# and change them to AttrPattern if/when used in an attribute context. We could use type
129129
# annotations to distinguish between ValuePattern and AttrPattern, but forces users to
130130
# use these type annotations.
131131
# TODO: check for misuse at rule-creation time. (Currently will be caught by matcher at match-time.)
132+
if value.can_match_none or value.check_method is not None:
133+
raise ValueError(
134+
"Pattern variables used in attributes must not have can_match_none or check_method set."
135+
)
132136
return AttrPattern(value.name)
133137
if isinstance(value, (int, float, str)):
134138
return AttrConstantPattern(value)
@@ -320,9 +324,12 @@ class ValuePattern:
320324
operations, so that we can write patterns like `x + 1` and `1 + x`.
321325
"""
322326

323-
def __init__(self, name: str | None, *, check: Callable | None = None) -> None:
327+
def __init__(
328+
self, name: str | None, *, check: Callable | None = None, can_match_none: bool = False
329+
) -> None:
324330
self._name = name
325331
self._check = check
332+
self._can_match_none = can_match_none
326333
# Note: uses will be computed only when the full graph-pattern is constructed.
327334
self._uses: list[tuple[NodePattern, int]] = []
328335

@@ -338,6 +345,11 @@ def name(self) -> str | None:
338345
def check_method(self) -> Callable | None:
339346
return self._check
340347

348+
@property
349+
def can_match_none(self) -> bool:
350+
"""Indicates whether this variable can match a None input."""
351+
return self._can_match_none
352+
341353
def producer(self) -> NodePattern | None:
342354
return None
343355

@@ -547,7 +559,17 @@ def producer(self) -> NodePattern:
547559
return self._producer
548560

549561

550-
Var = ValuePattern
562+
class Var(ValuePattern):
563+
"""Represents a pattern-variable."""
564+
565+
def __init__(
566+
self, name: str | None, *, check: Callable | None = None, can_match_none: bool = False
567+
) -> None:
568+
super().__init__(name, check=check, can_match_none=can_match_none)
569+
570+
def clone(self, node_map: dict[NodePattern, NodePattern]) -> Var:
571+
"""Clones the pattern-variable, preserving its name and check method."""
572+
return Var(self.name, check=self.check_method, can_match_none=self.can_match_none)
551573

552574

553575
class AnyValue(ValuePattern):

onnxscript/rewriter/ort_fusions/attention.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ def pattern(
3434
qkv_bias,
3535
# mask_index,
3636
past,
37-
attention_bias,
3837
num_heads,
3938
# scale,
4039
start1,
@@ -106,7 +105,7 @@ def pattern(
106105
value_BSD,
107106
qkv_bias,
108107
None, # key_padding_mask
109-
attention_bias,
108+
pattern.Var("attention_bias", can_match_none=True),
110109
past_key,
111110
past_value,
112111
num_heads=num_heads,
@@ -127,7 +126,7 @@ def pattern(
127126
value_BSD,
128127
qkv_bias,
129128
None, # key_padding_mask
130-
attention_bias,
129+
pattern.Var("attention_bias", can_match_none=True),
131130
None, # past_key
132131
None, # past_value
133132
num_heads=num_heads,

onnxscript/rewriter/ort_fusions/fuse_mha_bias.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@ def pattern(
5252
value_BSD,
5353
None, # bias
5454
None, # key padding mask
55-
mask, # attention mask/bias
56-
past_key,
57-
past_value,
55+
pattern.Var("mask", can_match_none=True), # attention mask/bias
56+
pattern.Var("past_key", can_match_none=True),
57+
pattern.Var("past_value", can_match_none=True),
5858
num_heads=num_heads,
5959
# scale=scale,
6060
_domain="com.microsoft",

onnxscript/rewriter/pattern.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
Constant,
1111
OpsetPatternBuilder,
1212
OrValue,
13+
Var,
1314
pattern_builder,
1415
torch_module_op,
1516
)
@@ -41,4 +42,5 @@
4142
"PatternMatcher",
4243
"SimplePatternMatcher",
4344
"torch_module_op",
45+
"Var",
4446
]

onnxscript/rewriter/pattern_test.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -450,8 +450,9 @@ def test_model(x: FLOAT[1024]) -> FLOAT[1024]:
450450
self.assertEqual(model.graph.node(1).op_type, "Original")
451451

452452
def test_match_optional_input(self):
453-
def none_pattern(op, optional_input, x):
453+
def none_pattern(op, x):
454454
# match against a call to Original where the first input may or may not be None
455+
optional_input = pattern.Var("optional_input", can_match_none=True)
455456
return op.Original(optional_input, x)
456457

457458
def replacement(op, optional_input, x):
@@ -478,6 +479,44 @@ def test_model(x: FLOAT[1024]) -> FLOAT[1024]:
478479
self.assertEqual(model.graph.node(0).op_type, "ReplacedNone")
479480
self.assertEqual(model.graph.node(1).op_type, "ReplacedNotNone")
480481

482+
def test_mismatched_number_of_inputs(self):
483+
def var_length_pattern(op):
484+
# match against a call to Original where the first input may or may not be None
485+
input1 = pattern.Var("input1", can_match_none=False)
486+
input2 = pattern.Var("input2", can_match_none=True)
487+
return op.Original(input1, input2)
488+
489+
def replacement(op, input1, input2):
490+
return op.Replaced(input1, input2)
491+
492+
rule = pattern.RewriteRule(var_length_pattern, replacement)
493+
494+
@script()
495+
def test_model(x: FLOAT[1024], y: FLOAT[1024], z: FLOAT[1024]) -> FLOAT[1024]:
496+
# Pattern should NOT match following 2 calls, since pattern requires first input to be non-None
497+
t0 = op.Original()
498+
t1 = op.Original(None, x)
499+
500+
# Pattern should match following 3 calls, since second input can be None
501+
t2 = op.Original(x)
502+
t3 = op.Original(x, None)
503+
t4 = op.Original(x, y)
504+
505+
# Pattern should NOT match following call, since it has more than 2 inputs
506+
t5 = op.Original(x, y, z)
507+
return op.All(t0, t1, t2, t3, t4, t5)
508+
509+
model_proto = test_model.to_model_proto()
510+
model = ir.serde.deserialize_model(model_proto)
511+
512+
count = rule.apply_to_model(model)
513+
self.assertEqual(count, 3)
514+
self.assertEqual(len(model.graph), 7)
515+
self.assertEqual(
516+
[n.op_type for n in model.graph],
517+
["Original", "Original", "Replaced", "Replaced", "Replaced", "Original", "All"],
518+
)
519+
481520
def test_graph_visitor(self):
482521
class ReplaceFoo(pattern.RewriteRuleClassBase):
483522
def __init__(self):

0 commit comments

Comments
 (0)