@@ -450,8 +450,9 @@ def test_model(x: FLOAT[1024]) -> FLOAT[1024]:
450450 self .assertEqual (model .graph .node (1 ).op_type , "Original" )
451451
452452 def test_match_optional_input (self ):
453- def none_pattern (op , optional_input , x ):
453+ def none_pattern (op , x ):
454454 # match against a call to Original where the first input may or may not be None
455+ optional_input = pattern .Var ("optional_input" , can_match_none = True )
455456 return op .Original (optional_input , x )
456457
457458 def replacement (op , optional_input , x ):
@@ -478,6 +479,44 @@ def test_model(x: FLOAT[1024]) -> FLOAT[1024]:
478479 self .assertEqual (model .graph .node (0 ).op_type , "ReplacedNone" )
479480 self .assertEqual (model .graph .node (1 ).op_type , "ReplacedNotNone" )
480481
482+ def test_mismatched_number_of_inputs (self ):
483+ def var_length_pattern (op ):
484+ # match against a call to Original where the first input may or may not be None
485+ input1 = pattern .Var ("input1" , can_match_none = False )
486+ input2 = pattern .Var ("input2" , can_match_none = True )
487+ return op .Original (input1 , input2 )
488+
489+ def replacement (op , input1 , input2 ):
490+ return op .Replaced (input1 , input2 )
491+
492+ rule = pattern .RewriteRule (var_length_pattern , replacement )
493+
494+ @script ()
495+ def test_model (x : FLOAT [1024 ], y : FLOAT [1024 ], z : FLOAT [1024 ]) -> FLOAT [1024 ]:
496+ # Pattern should NOT match following 2 calls, since pattern requires first input to be non-None
497+ t0 = op .Original ()
498+ t1 = op .Original (None , x )
499+
500+ # Pattern should match following 3 calls, since second input can be None
501+ t2 = op .Original (x )
502+ t3 = op .Original (x , None )
503+ t4 = op .Original (x , y )
504+
505+ # Pattern should NOT match following call, since it has more than 2 inputs
506+ t5 = op .Original (x , y , z )
507+ return op .All (t0 , t1 , t2 , t3 , t4 , t5 )
508+
509+ model_proto = test_model .to_model_proto ()
510+ model = ir .serde .deserialize_model (model_proto )
511+
512+ count = rule .apply_to_model (model )
513+ self .assertEqual (count , 3 )
514+ self .assertEqual (len (model .graph ), 7 )
515+ self .assertEqual (
516+ [n .op_type for n in model .graph ],
517+ ["Original" , "Original" , "Replaced" , "Replaced" , "Replaced" , "Original" , "All" ],
518+ )
519+
481520 def test_graph_visitor (self ):
482521 class ReplaceFoo (pattern .RewriteRuleClassBase ):
483522 def __init__ (self ):
0 commit comments