1111
1212from onnxscript .rewriter import MatchingTracer , MatchStatus , RewriteRule , testing
1313from onnxscript .rewriter .rules .common ._min_max_to_clip import (
14- fuse_successive_max_min_rule ,
15- fuse_successive_max_rule ,
16- fuse_successive_min_max_rule ,
17- fuse_successive_min_rule ,
14+ max_max_rule ,
15+ max_min_rule ,
16+ min_max_rule ,
17+ min_min_rule ,
1818 rules ,
1919)
2020
@@ -154,8 +154,8 @@ def test_successful_fuse_successive_min_or_max_constants(self, _, op_type):
154154
155155 @parameterized .expand (
156156 [
157- ("min_nonconst" , "Min" , fuse_successive_min_rule ),
158- ("max_nonconst" , "Max" , fuse_successive_max_rule ),
157+ ("min_nonconst" , "Min" , min_min_rule ),
158+ ("max_nonconst" , "Max" , max_max_rule ),
159159 ]
160160 )
161161 def test_failure_fuse_successive_min_or_max_non_constant (self , _ , op_type , rewrite_rule ):
@@ -239,9 +239,7 @@ def test_failure_min_max_to_clip_invalid_bounds(self):
239239 Y = Max(x1, max)
240240 }
241241 """ )
242- self .run_failed_condition_test (
243- base_model , fuse_successive_min_max_rule , "Invalid bounds:"
244- )
242+ self .run_failed_condition_test (base_model , min_max_rule , "Invalid bounds:" )
245243
246244 def test_failure_fuse_min_max_to_clip_non_constant (self ):
247245 model = ir .from_onnx_text ("""
@@ -254,9 +252,7 @@ def test_failure_fuse_min_max_to_clip_non_constant(self):
254252 Y = Max(x1, max)
255253 }
256254 """ )
257- self .run_failed_condition_test (
258- model , fuse_successive_min_max_rule , "is not a constant."
259- )
255+ self .run_failed_condition_test (model , min_max_rule , "is not a constant." )
260256
261257 def test_failure_min_max_to_clip_need_scalars (self ):
262258 base_model = ir .from_onnx_text ("""
@@ -268,9 +264,7 @@ def test_failure_min_max_to_clip_need_scalars(self):
268264 Y = Max(x1, max)
269265 }
270266 """ )
271- self .run_failed_condition_test (
272- base_model , fuse_successive_min_max_rule , "is not a scalar"
273- )
267+ self .run_failed_condition_test (base_model , min_max_rule , "is not a scalar" )
274268
275269
276270class TestMaxMinToClip (_TestMinMaxToClipBase ):
@@ -334,9 +328,7 @@ def test_failure_fuse_max_min_to_clip_non_constant(self):
334328 Y = Min(x1, min)
335329 }
336330 """ )
337- self .run_failed_condition_test (
338- model , fuse_successive_max_min_rule , "is not a constant."
339- )
331+ self .run_failed_condition_test (model , max_min_rule , "is not a constant." )
340332
341333 def test_failure_max_min_to_clip_need_scalars (self ):
342334 base_model = ir .from_onnx_text ("""
@@ -348,9 +340,7 @@ def test_failure_max_min_to_clip_need_scalars(self):
348340 Y = Min(x1, max)
349341 }
350342 """ )
351- self .run_failed_condition_test (
352- base_model , fuse_successive_max_min_rule , "is not a scalar"
353- )
343+ self .run_failed_condition_test (base_model , max_min_rule , "is not a scalar" )
354344
355345
356346class TestIntegrationMinMaxToClip (_TestMinMaxToClipBase ):
0 commit comments