51
51
from mypy .checkexpr import has_any_type
52
52
53
53
from mypy .join import join_type_list
54
+ from mypy .meet import meet_types
54
55
from mypy .sametypes import is_same_type
55
56
from mypy .typeops import make_simplified_union
56
57
@@ -240,6 +241,12 @@ def get_trivial_type(self, fdef: FuncDef) -> CallableType:
240
241
AnyType (TypeOfAny .special_form ),
241
242
self .builtin_type ('builtins.function' ))
242
243
244
+ def get_starting_type (self , fdef : FuncDef ) -> CallableType :
245
+ if isinstance (fdef .type , CallableType ):
246
+ return fdef .type
247
+ else :
248
+ return self .get_trivial_type (fdef )
249
+
243
250
def get_args (self , is_method : bool ,
244
251
base : CallableType , defaults : List [Optional [Type ]],
245
252
callsites : List [Callsite ]) -> List [List [Type ]]:
@@ -263,7 +270,8 @@ def get_args(self, is_method: bool,
263
270
if default :
264
271
all_arg_types .append (default )
265
272
266
- if len (all_arg_types ) == 1 and isinstance (get_proper_type (all_arg_types [0 ]), NoneType ):
273
+ if (all_arg_types
274
+ and all (isinstance (get_proper_type (tp ), NoneType ) for tp in all_arg_types )):
267
275
types .append (
268
276
[UnionType .make_union ([all_arg_types [0 ],
269
277
AnyType (TypeOfAny .explicit )])])
@@ -294,11 +302,12 @@ def get_guesses(self, is_method: bool, base: CallableType, defaults: List[Option
294
302
"""
295
303
options = self .get_args (is_method , base , defaults , callsites )
296
304
options = [self .add_adjustments (tps ) for tps in options ]
297
- return [base .copy_modified (arg_types = list (x )) for x in itertools .product (* options )]
305
+ return [merge_callables (base , base .copy_modified (arg_types = list (x )))
306
+ for x in itertools .product (* options )]
298
307
299
308
def get_callsites (self , func : FuncDef ) -> Tuple [List [Callsite ], List [str ]]:
300
309
"""Find all call sites of a function."""
301
- new_type = self .get_trivial_type (func )
310
+ new_type = self .get_starting_type (func )
302
311
303
312
collector_plugin = SuggestionPlugin (func .fullname ())
304
313
@@ -350,7 +359,7 @@ def get_suggestion(self, mod: str, node: FuncDef) -> PyAnnotateSignature:
350
359
with strict_optional_set (graph [mod ].options .strict_optional ):
351
360
guesses = self .get_guesses (
352
361
is_method ,
353
- self .get_trivial_type (node ),
362
+ self .get_starting_type (node ),
354
363
self .get_default_arg_types (graph [mod ], node ),
355
364
callsites )
356
365
guesses = self .filter_options (guesses , is_method )
@@ -367,7 +376,7 @@ def get_suggestion(self, mod: str, node: FuncDef) -> PyAnnotateSignature:
367
376
else :
368
377
ret_types = [NoneType ()]
369
378
370
- guesses = [best .copy_modified (ret_type = t ) for t in ret_types ]
379
+ guesses = [merge_callables ( best , best .copy_modified (ret_type = t ) ) for t in ret_types ]
371
380
guesses = self .filter_options (guesses , is_method )
372
381
best , errors = self .find_best (node , guesses )
373
382
@@ -528,8 +537,9 @@ def try_type(self, func: FuncDef, typ: ProperType) -> List[str]:
528
537
"""
529
538
old = func .unanalyzed_type
530
539
# During reprocessing, unanalyzed_type gets copied to type (by aststrip).
531
- # We don't modify type because it isn't necessary and it
532
- # would mess up the snapshotting.
540
+ # We set type to None to ensure that the type always changes during
541
+ # reprocessing.
542
+ func .type = None
533
543
func .unanalyzed_type = typ
534
544
try :
535
545
res = self .fgmanager .trigger (func .fullname ())
@@ -778,6 +788,25 @@ def count_errors(msgs: List[str]) -> int:
778
788
T = TypeVar ('T' )
779
789
780
790
791
+ def merge_callables (t : CallableType , s : CallableType ) -> CallableType :
792
+ """Merge two callable types in a way that prefers dropping Anys.
793
+
794
+ This is implemented by doing a meet on both the arguments and the return type,
795
+ since meet(t, Any) == t.
796
+
797
+ This won't do perfectly with complex compound types (like
798
+ callables nested inside), but it does pretty well.
799
+ """
800
+
801
+ # We don't want to ever squash away optionals while doing this, so set
802
+ # strict optional to be true always
803
+ with strict_optional_set (True ):
804
+ arg_types = [] # type: List[Type]
805
+ for i in range (len (t .arg_types )):
806
+ arg_types .append (meet_types (t .arg_types [i ], s .arg_types [i ]))
807
+ return t .copy_modified (arg_types = arg_types , ret_type = meet_types (t .ret_type , s .ret_type ))
808
+
809
+
781
810
def dedup (old : List [T ]) -> List [T ]:
782
811
new = [] # type: List[T]
783
812
for x in old :
0 commit comments