Skip to content

Commit d2d1a9d

Browse files
committed
[suggest] Support refining existing type annotations
1 parent 22a5a4f commit d2d1a9d

File tree

2 files changed

+73
-7
lines changed

2 files changed

+73
-7
lines changed

mypy/suggestions.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
from mypy.checkexpr import has_any_type
5252

5353
from mypy.join import join_type_list
54+
from mypy.meet import meet_types
5455
from mypy.sametypes import is_same_type
5556
from mypy.typeops import make_simplified_union
5657

@@ -240,6 +241,12 @@ def get_trivial_type(self, fdef: FuncDef) -> CallableType:
240241
AnyType(TypeOfAny.special_form),
241242
self.builtin_type('builtins.function'))
242243

244+
def get_starting_type(self, fdef: FuncDef) -> CallableType:
245+
if isinstance(fdef.type, CallableType):
246+
return fdef.type
247+
else:
248+
return self.get_trivial_type(fdef)
249+
243250
def get_args(self, is_method: bool,
244251
base: CallableType, defaults: List[Optional[Type]],
245252
callsites: List[Callsite]) -> List[List[Type]]:
@@ -263,7 +270,8 @@ def get_args(self, is_method: bool,
263270
if default:
264271
all_arg_types.append(default)
265272

266-
if len(all_arg_types) == 1 and isinstance(get_proper_type(all_arg_types[0]), NoneType):
273+
if (all_arg_types
274+
and all(isinstance(get_proper_type(tp), NoneType) for tp in all_arg_types)):
267275
types.append(
268276
[UnionType.make_union([all_arg_types[0],
269277
AnyType(TypeOfAny.explicit)])])
@@ -294,11 +302,12 @@ def get_guesses(self, is_method: bool, base: CallableType, defaults: List[Option
294302
"""
295303
options = self.get_args(is_method, base, defaults, callsites)
296304
options = [self.add_adjustments(tps) for tps in options]
297-
return [base.copy_modified(arg_types=list(x)) for x in itertools.product(*options)]
305+
return [merge_callables(base, base.copy_modified(arg_types=list(x)))
306+
for x in itertools.product(*options)]
298307

299308
def get_callsites(self, func: FuncDef) -> Tuple[List[Callsite], List[str]]:
300309
"""Find all call sites of a function."""
301-
new_type = self.get_trivial_type(func)
310+
new_type = self.get_starting_type(func)
302311

303312
collector_plugin = SuggestionPlugin(func.fullname())
304313

@@ -350,7 +359,7 @@ def get_suggestion(self, mod: str, node: FuncDef) -> PyAnnotateSignature:
350359
with strict_optional_set(graph[mod].options.strict_optional):
351360
guesses = self.get_guesses(
352361
is_method,
353-
self.get_trivial_type(node),
362+
self.get_starting_type(node),
354363
self.get_default_arg_types(graph[mod], node),
355364
callsites)
356365
guesses = self.filter_options(guesses, is_method)
@@ -367,7 +376,7 @@ def get_suggestion(self, mod: str, node: FuncDef) -> PyAnnotateSignature:
367376
else:
368377
ret_types = [NoneType()]
369378

370-
guesses = [best.copy_modified(ret_type=t) for t in ret_types]
379+
guesses = [merge_callables(best, best.copy_modified(ret_type=t)) for t in ret_types]
371380
guesses = self.filter_options(guesses, is_method)
372381
best, errors = self.find_best(node, guesses)
373382

@@ -528,8 +537,9 @@ def try_type(self, func: FuncDef, typ: ProperType) -> List[str]:
528537
"""
529538
old = func.unanalyzed_type
530539
# During reprocessing, unanalyzed_type gets copied to type (by aststrip).
531-
# We don't modify type because it isn't necessary and it
532-
# would mess up the snapshotting.
540+
# We set type to None to ensure that the type always changes during
541+
# reprocessing.
542+
func.type = None
533543
func.unanalyzed_type = typ
534544
try:
535545
res = self.fgmanager.trigger(func.fullname())
@@ -778,6 +788,25 @@ def count_errors(msgs: List[str]) -> int:
778788
T = TypeVar('T')
779789

780790

791+
def merge_callables(t: CallableType, s: CallableType) -> CallableType:
792+
"""Merge two callable types in a way that prefers dropping Anys.
793+
794+
This is implemented by doing a meet on both the arguments and the return type,
795+
since meet(t, Any) == t.
796+
797+
This won't do perfectly with complex compound types (like
798+
callables nested inside), but it does pretty well.
799+
"""
800+
801+
# We don't want to ever squash away optionals while doing this, so set
802+
# strict optional to be true always
803+
with strict_optional_set(True):
804+
arg_types = [] # type: List[Type]
805+
for i in range(len(t.arg_types)):
806+
arg_types.append(meet_types(t.arg_types[i], s.arg_types[i]))
807+
return t.copy_modified(arg_types=arg_types, ret_type=meet_types(t.ret_type, s.ret_type))
808+
809+
781810
def dedup(old: List[T]) -> List[T]:
782811
new = [] # type: List[T]
783812
for x in old:

test-data/unit/fine-grained-suggest.test

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -829,3 +829,40 @@ Command 'suggest' is only valid after a 'check' command (that produces no parse
829829
==
830830
foo.py:4: error: unexpected EOF while parsing
831831
-- )
832+
833+
[case testSuggestRefine]
834+
# suggest: foo.foo
835+
# suggest: foo.spam
836+
# suggest: foo.eggs
837+
# suggest: foo.take_l
838+
[file foo.py]
839+
from typing import Any, List
840+
841+
def bar():
842+
return 10
843+
844+
def foo(x: int, y):
845+
return x + y
846+
847+
def spam(x: int, y: Any) -> Any:
848+
return x + y
849+
850+
def eggs(x: int) -> List[Any]:
851+
a = [x]
852+
return a
853+
854+
def take_l(x: List[Any]) -> Any:
855+
return x[0]
856+
857+
858+
foo(bar(), 10)
859+
spam(bar(), 20)
860+
test = [10, 20]
861+
take_l(test)
862+
[builtins fixtures/isinstancelist.pyi]
863+
[out]
864+
(int, int) -> int
865+
(int, int) -> int
866+
(int) -> foo.List[int]
867+
(foo.List[int]) -> int
868+
==

0 commit comments

Comments
 (0)