@@ -496,18 +496,9 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]:
496
496
if_overload_name : Optional [str ] = None
497
497
if_block_with_overload : Optional [Block ] = None
498
498
if_unknown_truth_value : Optional [IfStmt ] = None
499
- if (
500
- isinstance (stmt , IfStmt )
501
- and len (stmt .body [0 ].body ) == 1
502
- and seen_unconditional_func_def is False
503
- and (
504
- isinstance (stmt .body [0 ].body [0 ], (Decorator , OverloadedFuncDef ))
505
- or current_overload_name is not None
506
- and isinstance (stmt .body [0 ].body [0 ], FuncDef )
507
- )
508
- ):
499
+ if isinstance (stmt , IfStmt ) and seen_unconditional_func_def is False :
509
500
# Check IfStmt block to determine if function overloads can be merged
510
- if_overload_name = self ._check_ifstmt_for_overloads (stmt )
501
+ if_overload_name = self ._check_ifstmt_for_overloads (stmt , current_overload_name )
511
502
if if_overload_name is not None :
512
503
if_block_with_overload , if_unknown_truth_value = \
513
504
self ._get_executable_if_block_with_overloads (stmt )
@@ -553,8 +544,11 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]:
553
544
else :
554
545
current_overload .append (last_if_overload )
555
546
last_if_stmt , last_if_overload = None , None
556
- if isinstance (if_block_with_overload .body [0 ], OverloadedFuncDef ):
557
- current_overload .extend (if_block_with_overload .body [0 ].items )
547
+ if isinstance (if_block_with_overload .body [- 1 ], OverloadedFuncDef ):
548
+ skipped_if_stmts .extend (
549
+ cast (List [IfStmt ], if_block_with_overload .body [:- 1 ])
550
+ )
551
+ current_overload .extend (if_block_with_overload .body [- 1 ].items )
558
552
else :
559
553
current_overload .append (
560
554
cast (Union [Decorator , FuncDef ], if_block_with_overload .body [0 ])
@@ -600,9 +594,12 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]:
600
594
last_if_stmt = stmt
601
595
last_if_stmt_overload_name = None
602
596
if if_block_with_overload is not None :
597
+ skipped_if_stmts .extend (
598
+ cast (List [IfStmt ], if_block_with_overload .body [:- 1 ])
599
+ )
603
600
last_if_overload = cast (
604
601
Union [Decorator , FuncDef , OverloadedFuncDef ],
605
- if_block_with_overload .body [0 ]
602
+ if_block_with_overload .body [- 1 ]
606
603
)
607
604
last_if_unknown_truth_value = if_unknown_truth_value
608
605
else :
@@ -620,23 +617,38 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]:
620
617
ret .append (current_overload [0 ])
621
618
elif len (current_overload ) > 1 :
622
619
ret .append (OverloadedFuncDef (current_overload ))
620
+ elif last_if_overload is not None :
621
+ ret .append (last_if_overload )
623
622
elif last_if_stmt is not None :
624
623
ret .append (last_if_stmt )
625
624
return ret
626
625
627
- def _check_ifstmt_for_overloads (self , stmt : IfStmt ) -> Optional [str ]:
626
+ def _check_ifstmt_for_overloads (
627
+ self , stmt : IfStmt , current_overload_name : Optional [str ] = None
628
+ ) -> Optional [str ]:
628
629
"""Check if IfStmt contains only overloads with the same name.
629
630
Return overload_name if found, None otherwise.
630
631
"""
631
632
# Check that block only contains a single Decorator, FuncDef, or OverloadedFuncDef.
632
633
# Multiple overloads have already been merged as OverloadedFuncDef.
633
634
if not (
634
635
len (stmt .body [0 ].body ) == 1
635
- and isinstance (stmt .body [0 ].body [0 ], (Decorator , FuncDef , OverloadedFuncDef ))
636
+ and (
637
+ isinstance (stmt .body [0 ].body [0 ], (Decorator , OverloadedFuncDef ))
638
+ or current_overload_name is not None
639
+ and isinstance (stmt .body [0 ].body [0 ], FuncDef )
640
+ )
641
+ or len (stmt .body [0 ].body ) > 1
642
+ and isinstance (stmt .body [0 ].body [- 1 ], OverloadedFuncDef )
643
+ and all (
644
+ self ._is_stripped_if_stmt (if_stmt )
645
+ for if_stmt in stmt .body [0 ].body [:- 1 ]
646
+ )
636
647
):
637
648
return None
638
649
639
- overload_name = stmt .body [0 ].body [0 ].name
650
+ overload_name = cast (
651
+ Union [Decorator , FuncDef , OverloadedFuncDef ], stmt .body [0 ].body [- 1 ]).name
640
652
if stmt .else_body is None :
641
653
return overload_name
642
654
@@ -649,7 +661,9 @@ def _check_ifstmt_for_overloads(self, stmt: IfStmt) -> Optional[str]:
649
661
return overload_name
650
662
if (
651
663
isinstance (stmt .else_body .body [0 ], IfStmt )
652
- and self ._check_ifstmt_for_overloads (stmt .else_body .body [0 ]) == overload_name
664
+ and self ._check_ifstmt_for_overloads (
665
+ stmt .else_body .body [0 ], current_overload_name
666
+ ) == overload_name
653
667
):
654
668
return overload_name
655
669
@@ -704,6 +718,25 @@ def _strip_contents_from_if_stmt(self, stmt: IfStmt) -> None:
704
718
else :
705
719
stmt .else_body .body = []
706
720
721
+ def _is_stripped_if_stmt (self , stmt : Statement ) -> bool :
722
+ """Check stmt to make sure it is a stripped IfStmt.
723
+
724
+ See also: _strip_contents_from_if_stmt
725
+ """
726
+ if not isinstance (stmt , IfStmt ):
727
+ return False
728
+
729
+ if not (len (stmt .body ) == 1 and len (stmt .body [0 ].body ) == 0 ):
730
+ # Body not empty
731
+ return False
732
+
733
+ if not stmt .else_body or len (stmt .else_body .body ) == 0 :
734
+ # No or empty else_body
735
+ return True
736
+
737
+ # For elif, IfStmt are stored recursively in else_body
738
+ return self ._is_stripped_if_stmt (stmt .else_body .body [0 ])
739
+
707
740
def in_method_scope (self ) -> bool :
708
741
return self .class_and_function_stack [- 2 :] == ['C' , 'F' ]
709
742
0 commit comments