44
44
from mypy import message_registry , errorcodes as codes
45
45
from mypy .errors import Errors
46
46
from mypy .options import Options
47
- from mypy .reachability import mark_block_unreachable
47
+ from mypy .reachability import infer_reachability_of_if_statement , mark_block_unreachable
48
48
from mypy .util import bytes_to_human_readable_repr
49
49
50
50
try :
@@ -344,9 +344,19 @@ def fail(self,
344
344
msg : str ,
345
345
line : int ,
346
346
column : int ,
347
- blocker : bool = True ) -> None :
347
+ blocker : bool = True ,
348
+ code : codes .ErrorCode = codes .SYNTAX ) -> None :
348
349
if blocker or not self .options .ignore_errors :
349
- self .errors .report (line , column , msg , blocker = blocker , code = codes .SYNTAX )
350
+ self .errors .report (line , column , msg , blocker = blocker , code = code )
351
+
352
+ def fail_merge_overload (self , node : IfStmt ) -> None :
353
+ self .fail (
354
+ "Condition can't be inferred, unable to merge overloads" ,
355
+ line = node .line ,
356
+ column = node .column ,
357
+ blocker = False ,
358
+ code = codes .MISC ,
359
+ )
350
360
351
361
def visit (self , node : Optional [AST ]) -> Any :
352
362
if node is None :
@@ -476,12 +486,93 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]:
476
486
ret : List [Statement ] = []
477
487
current_overload : List [OverloadPart ] = []
478
488
current_overload_name : Optional [str ] = None
489
+ last_if_stmt : Optional [IfStmt ] = None
490
+ last_if_overload : Optional [Union [Decorator , FuncDef , OverloadedFuncDef ]] = None
491
+ last_if_stmt_overload_name : Optional [str ] = None
492
+ last_if_unknown_truth_value : Optional [IfStmt ] = None
493
+ skipped_if_stmts : List [IfStmt ] = []
479
494
for stmt in stmts :
495
+ if_overload_name : Optional [str ] = None
496
+ if_block_with_overload : Optional [Block ] = None
497
+ if_unknown_truth_value : Optional [IfStmt ] = None
498
+ if (
499
+ isinstance (stmt , IfStmt )
500
+ and len (stmt .body [0 ].body ) == 1
501
+ and (
502
+ isinstance (stmt .body [0 ].body [0 ], (Decorator , OverloadedFuncDef ))
503
+ or current_overload_name is not None
504
+ and isinstance (stmt .body [0 ].body [0 ], FuncDef )
505
+ )
506
+ ):
507
+ # Check IfStmt block to determine if function overloads can be merged
508
+ if_overload_name = self ._check_ifstmt_for_overloads (stmt )
509
+ if if_overload_name is not None :
510
+ if_block_with_overload , if_unknown_truth_value = \
511
+ self ._get_executable_if_block_with_overloads (stmt )
512
+
480
513
if (current_overload_name is not None
481
514
and isinstance (stmt , (Decorator , FuncDef ))
482
515
and stmt .name == current_overload_name ):
516
+ if last_if_stmt is not None :
517
+ skipped_if_stmts .append (last_if_stmt )
518
+ if last_if_overload is not None :
519
+ # Last stmt was an IfStmt with same overload name
520
+ # Add overloads to current_overload
521
+ if isinstance (last_if_overload , OverloadedFuncDef ):
522
+ current_overload .extend (last_if_overload .items )
523
+ else :
524
+ current_overload .append (last_if_overload )
525
+ last_if_stmt , last_if_overload = None , None
526
+ if last_if_unknown_truth_value :
527
+ self .fail_merge_overload (last_if_unknown_truth_value )
528
+ last_if_unknown_truth_value = None
483
529
current_overload .append (stmt )
530
+ elif (
531
+ current_overload_name is not None
532
+ and isinstance (stmt , IfStmt )
533
+ and if_overload_name == current_overload_name
534
+ ):
535
+ # IfStmt only contains stmts relevant to current_overload.
536
+ # Check if stmts are reachable and add them to current_overload,
537
+ # otherwise skip IfStmt to allow subsequent overload
538
+ # or function definitions.
539
+ skipped_if_stmts .append (stmt )
540
+ if if_block_with_overload is None :
541
+ if if_unknown_truth_value is not None :
542
+ self .fail_merge_overload (if_unknown_truth_value )
543
+ continue
544
+ if last_if_overload is not None :
545
+ # Last stmt was an IfStmt with same overload name
546
+ # Add overloads to current_overload
547
+ if isinstance (last_if_overload , OverloadedFuncDef ):
548
+ current_overload .extend (last_if_overload .items )
549
+ else :
550
+ current_overload .append (last_if_overload )
551
+ last_if_stmt , last_if_overload = None , None
552
+ if isinstance (if_block_with_overload .body [0 ], OverloadedFuncDef ):
553
+ current_overload .extend (if_block_with_overload .body [0 ].items )
554
+ else :
555
+ current_overload .append (
556
+ cast (Union [Decorator , FuncDef ], if_block_with_overload .body [0 ])
557
+ )
484
558
else :
559
+ if last_if_stmt is not None :
560
+ ret .append (last_if_stmt )
561
+ last_if_stmt_overload_name = current_overload_name
562
+ last_if_stmt , last_if_overload = None , None
563
+ last_if_unknown_truth_value = None
564
+
565
+ if current_overload and current_overload_name == last_if_stmt_overload_name :
566
+ # Remove last stmt (IfStmt) from ret if the overload names matched
567
+ # Only happens if no executable block had been found in IfStmt
568
+ skipped_if_stmts .append (cast (IfStmt , ret .pop ()))
569
+ if current_overload and skipped_if_stmts :
570
+ # Add bare IfStmt (without overloads) to ret
571
+ # Required for mypy to be able to still check conditions
572
+ for if_stmt in skipped_if_stmts :
573
+ self ._strip_contents_from_if_stmt (if_stmt )
574
+ ret .append (if_stmt )
575
+ skipped_if_stmts = []
485
576
if len (current_overload ) == 1 :
486
577
ret .append (current_overload [0 ])
487
578
elif len (current_overload ) > 1 :
@@ -495,17 +586,119 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]:
495
586
if isinstance (stmt , Decorator ) and not unnamed_function (stmt .name ):
496
587
current_overload = [stmt ]
497
588
current_overload_name = stmt .name
589
+ elif (
590
+ isinstance (stmt , IfStmt )
591
+ and if_overload_name is not None
592
+ ):
593
+ current_overload = []
594
+ current_overload_name = if_overload_name
595
+ last_if_stmt = stmt
596
+ last_if_stmt_overload_name = None
597
+ if if_block_with_overload is not None :
598
+ last_if_overload = cast (
599
+ Union [Decorator , FuncDef , OverloadedFuncDef ],
600
+ if_block_with_overload .body [0 ]
601
+ )
602
+ last_if_unknown_truth_value = if_unknown_truth_value
498
603
else :
499
604
current_overload = []
500
605
current_overload_name = None
501
606
ret .append (stmt )
502
607
608
+ if current_overload and skipped_if_stmts :
609
+ # Add bare IfStmt (without overloads) to ret
610
+ # Required for mypy to be able to still check conditions
611
+ for if_stmt in skipped_if_stmts :
612
+ self ._strip_contents_from_if_stmt (if_stmt )
613
+ ret .append (if_stmt )
503
614
if len (current_overload ) == 1 :
504
615
ret .append (current_overload [0 ])
505
616
elif len (current_overload ) > 1 :
506
617
ret .append (OverloadedFuncDef (current_overload ))
618
+ elif last_if_stmt is not None :
619
+ ret .append (last_if_stmt )
507
620
return ret
508
621
622
+ def _check_ifstmt_for_overloads (self , stmt : IfStmt ) -> Optional [str ]:
623
+ """Check if IfStmt contains only overloads with the same name.
624
+ Return overload_name if found, None otherwise.
625
+ """
626
+ # Check that block only contains a single Decorator, FuncDef, or OverloadedFuncDef.
627
+ # Multiple overloads have already been merged as OverloadedFuncDef.
628
+ if not (
629
+ len (stmt .body [0 ].body ) == 1
630
+ and isinstance (stmt .body [0 ].body [0 ], (Decorator , FuncDef , OverloadedFuncDef ))
631
+ ):
632
+ return None
633
+
634
+ overload_name = stmt .body [0 ].body [0 ].name
635
+ if stmt .else_body is None :
636
+ return overload_name
637
+
638
+ if isinstance (stmt .else_body , Block ) and len (stmt .else_body .body ) == 1 :
639
+ # For elif: else_body contains an IfStmt itself -> do a recursive check.
640
+ if (
641
+ isinstance (stmt .else_body .body [0 ], (Decorator , FuncDef , OverloadedFuncDef ))
642
+ and stmt .else_body .body [0 ].name == overload_name
643
+ ):
644
+ return overload_name
645
+ if (
646
+ isinstance (stmt .else_body .body [0 ], IfStmt )
647
+ and self ._check_ifstmt_for_overloads (stmt .else_body .body [0 ]) == overload_name
648
+ ):
649
+ return overload_name
650
+
651
+ return None
652
+
653
+ def _get_executable_if_block_with_overloads (
654
+ self , stmt : IfStmt
655
+ ) -> Tuple [Optional [Block ], Optional [IfStmt ]]:
656
+ """Return block from IfStmt that will get executed.
657
+
658
+ Return
659
+ 0 -> A block if sure that alternative blocks are unreachable.
660
+ 1 -> An IfStmt if the reachability of it can't be inferred,
661
+ i.e. the truth value is unknown.
662
+ """
663
+ infer_reachability_of_if_statement (stmt , self .options )
664
+ if (
665
+ stmt .else_body is None
666
+ and stmt .body [0 ].is_unreachable is True
667
+ ):
668
+ # always False condition with no else
669
+ return None , None
670
+ if (
671
+ stmt .else_body is None
672
+ or stmt .body [0 ].is_unreachable is False
673
+ and stmt .else_body .is_unreachable is False
674
+ ):
675
+ # The truth value is unknown, thus not conclusive
676
+ return None , stmt
677
+ if stmt .else_body .is_unreachable is True :
678
+ # else_body will be set unreachable if condition is always True
679
+ return stmt .body [0 ], None
680
+ if stmt .body [0 ].is_unreachable is True :
681
+ # body will be set unreachable if condition is always False
682
+ # else_body can contain an IfStmt itself (for elif) -> do a recursive check
683
+ if isinstance (stmt .else_body .body [0 ], IfStmt ):
684
+ return self ._get_executable_if_block_with_overloads (stmt .else_body .body [0 ])
685
+ return stmt .else_body , None
686
+ return None , stmt
687
+
688
+ def _strip_contents_from_if_stmt (self , stmt : IfStmt ) -> None :
689
+ """Remove contents from IfStmt.
690
+
691
+ Needed to still be able to check the conditions after the contents
692
+ have been merged with the surrounding function overloads.
693
+ """
694
+ if len (stmt .body ) == 1 :
695
+ stmt .body [0 ].body = []
696
+ if stmt .else_body and len (stmt .else_body .body ) == 1 :
697
+ if isinstance (stmt .else_body .body [0 ], IfStmt ):
698
+ self ._strip_contents_from_if_stmt (stmt .else_body .body [0 ])
699
+ else :
700
+ stmt .else_body .body = []
701
+
509
702
def in_method_scope (self ) -> bool :
510
703
return self .class_and_function_stack [- 2 :] == ['C' , 'F' ]
511
704
0 commit comments