@@ -203,8 +203,9 @@ def visit_type_var(self, left: TypeVarType) -> bool:
203203 def visit_callable_type (self , left : CallableType ) -> bool :
204204 right = self .right
205205 if isinstance (right , CallableType ):
206- return is_callable_subtype (
206+ return is_callable_compatible (
207207 left , right ,
208+ is_compat = is_subtype ,
208209 ignore_pos_arg_names = self .ignore_pos_arg_names )
209210 elif isinstance (right , Overloaded ):
210211 return all (is_subtype (left , item , self .check_type_parameter ,
@@ -310,10 +311,12 @@ def visit_overloaded(self, left: Overloaded) -> bool:
310311 else :
311312 # If this one overlaps with the supertype in any way, but it wasn't
312313 # an exact match, then it's a potential error.
313- if (is_callable_subtype (left_item , right_item , ignore_return = True ,
314- ignore_pos_arg_names = self .ignore_pos_arg_names ) or
315- is_callable_subtype (right_item , left_item , ignore_return = True ,
316- ignore_pos_arg_names = self .ignore_pos_arg_names )):
314+ if (is_callable_compatible (left_item , right_item ,
315+ is_compat = is_subtype , ignore_return = True ,
316+ ignore_pos_arg_names = self .ignore_pos_arg_names ) or
317+ is_callable_compatible (right_item , left_item ,
318+ is_compat = is_subtype , ignore_return = True ,
319+ ignore_pos_arg_names = self .ignore_pos_arg_names )):
317320 # If this is an overload that's already been matched, there's no
318321 # problem.
319322 if left_item not in matched_overloads :
@@ -568,16 +571,54 @@ def non_method_protocol_members(tp: TypeInfo) -> List[str]:
568571 return result
569572
570573
571- def is_callable_subtype (left : CallableType , right : CallableType ,
572- ignore_return : bool = False ,
573- ignore_pos_arg_names : bool = False ,
574- use_proper_subtype : bool = False ) -> bool :
575- """Is left a subtype of right?"""
574+ def is_callable_compatible (left : CallableType , right : CallableType ,
575+ * ,
576+ is_compat : Callable [[Type , Type ], bool ],
577+ is_compat_return : Optional [Callable [[Type , Type ], bool ]] = None ,
578+ ignore_return : bool = False ,
579+ ignore_pos_arg_names : bool = False ,
580+ check_args_covariantly : bool = False ) -> bool :
581+ """Is the left compatible with the right, using the provided compatibility check?
576582
577- if use_proper_subtype :
578- is_compat = is_proper_subtype
579- else :
580- is_compat = is_subtype
583+ is_compat:
584+ The check we want to run against the parameters.
585+
586+ is_compat_return:
587+ The check we want to run against the return type.
588+ If None, use the 'is_compat' check.
589+
590+ check_args_covariantly:
591+ If true, check if the left's args is compatible with the right's
592+ instead of the other way around (contravariantly).
593+
594+ This function is mostly used to check if the left is a subtype of the right which
595+ is why the default is to check the args contravariantly. However, it's occasionally
596+ useful to check the args using some other check, so we leave the variance
597+ configurable.
598+
599+ For example, when checking the validity of overloads, it's useful to see if
600+ the first overload alternative has more precise arguments then the second.
601+ We would want to check the arguments covariantly in that case.
602+
603+ Note! The following two function calls are NOT equivalent:
604+
605+ is_callable_compatible(f, g, is_compat=is_subtype, check_args_covariantly=False)
606+ is_callable_compatible(g, f, is_compat=is_subtype, check_args_covariantly=True)
607+
608+ The two calls are similar in that they both check the function arguments in
609+ the same direction: they both run `is_subtype(argument_from_g, argument_from_f)`.
610+
611+ However, the two calls differ in which direction they check things likee
612+ keyword arguments. For example, suppose f and g are defined like so:
613+
614+ def f(x: int, *y: int) -> int: ...
615+ def g(x: int) -> int: ...
616+
617+ In this case, the first call will succeed and the second will fail: f is a
618+ valid stand-in for g but not vice-versa.
619+ """
620+ if is_compat_return is None :
621+ is_compat_return = is_compat
581622
582623 # If either function is implicitly typed, ignore positional arg names too
583624 if left .implicit or right .implicit :
@@ -607,9 +648,12 @@ def is_callable_subtype(left: CallableType, right: CallableType,
607648 left = unified
608649
609650 # Check return types.
610- if not ignore_return and not is_compat (left .ret_type , right .ret_type ):
651+ if not ignore_return and not is_compat_return (left .ret_type , right .ret_type ):
611652 return False
612653
654+ if check_args_covariantly :
655+ is_compat = flip_compat_check (is_compat )
656+
613657 if right .is_ellipsis_args :
614658 return True
615659
@@ -652,7 +696,7 @@ def is_callable_subtype(left: CallableType, right: CallableType,
652696 # Right has an infinite series of optional positional arguments
653697 # here. Get all further positional arguments of left, and make sure
654698 # they're more general than their corresponding member in this
655- # series. Also make sure left has its own inifite series of
699+ # series. Also make sure left has its own infinite series of
656700 # optional positional arguments.
657701 if not left .is_var_arg :
658702 return False
@@ -664,7 +708,7 @@ def is_callable_subtype(left: CallableType, right: CallableType,
664708 right_by_position = right .argument_by_position (j )
665709 assert right_by_position is not None
666710 if not are_args_compatible (left_by_position , right_by_position ,
667- ignore_pos_arg_names , use_proper_subtype ):
711+ ignore_pos_arg_names , is_compat ):
668712 return False
669713 j += 1
670714 continue
@@ -687,7 +731,7 @@ def is_callable_subtype(left: CallableType, right: CallableType,
687731 right_by_name = right .argument_by_name (name )
688732 assert right_by_name is not None
689733 if not are_args_compatible (left_by_name , right_by_name ,
690- ignore_pos_arg_names , use_proper_subtype ):
734+ ignore_pos_arg_names , is_compat ):
691735 return False
692736 continue
693737
@@ -696,7 +740,8 @@ def is_callable_subtype(left: CallableType, right: CallableType,
696740 if left_arg is None :
697741 return False
698742
699- if not are_args_compatible (left_arg , right_arg , ignore_pos_arg_names , use_proper_subtype ):
743+ if not are_args_compatible (left_arg , right_arg ,
744+ ignore_pos_arg_names , is_compat ):
700745 return False
701746
702747 done_with_positional = False
@@ -748,7 +793,7 @@ def are_args_compatible(
748793 left : FormalArgument ,
749794 right : FormalArgument ,
750795 ignore_pos_arg_names : bool ,
751- use_proper_subtype : bool ) -> bool :
796+ is_compat : Callable [[ Type , Type ], bool ] ) -> bool :
752797 # If right has a specific name it wants this argument to be, left must
753798 # have the same.
754799 if right .name is not None and left .name != right .name :
@@ -759,18 +804,20 @@ def are_args_compatible(
759804 if right .pos is not None and left .pos != right .pos :
760805 return False
761806 # Left must have a more general type
762- if use_proper_subtype :
763- if not is_proper_subtype (right .typ , left .typ ):
764- return False
765- else :
766- if not is_subtype (right .typ , left .typ ):
767- return False
807+ if not is_compat (right .typ , left .typ ):
808+ return False
768809 # If right's argument is optional, left's must also be.
769810 if not right .required and left .required :
770811 return False
771812 return True
772813
773814
815+ def flip_compat_check (is_compat : Callable [[Type , Type ], bool ]) -> Callable [[Type , Type ], bool ]:
816+ def new_is_compat (left : Type , right : Type ) -> bool :
817+ return is_compat (right , left )
818+ return new_is_compat
819+
820+
774821def unify_generic_callable (type : CallableType , target : CallableType ,
775822 ignore_return : bool ) -> Optional [CallableType ]:
776823 """Try to unify a generic callable type with another callable type.
@@ -913,10 +960,7 @@ def visit_type_var(self, left: TypeVarType) -> bool:
913960 def visit_callable_type (self , left : CallableType ) -> bool :
914961 right = self .right
915962 if isinstance (right , CallableType ):
916- return is_callable_subtype (
917- left , right ,
918- ignore_pos_arg_names = False ,
919- use_proper_subtype = True )
963+ return is_callable_compatible (left , right , is_compat = is_proper_subtype )
920964 elif isinstance (right , Overloaded ):
921965 return all (is_proper_subtype (left , item )
922966 for item in right .items ())
0 commit comments