Skip to content

Commit 211d182

Browse files
authored
Format suggestions using the "signature" pyannotate format instead (#7415)
This bypasses pyannotate's parsing and processing of annotations and allows us to just specify them directly. This is useful because pyannotate doesn't support Callable.
1 parent 92a56db commit 211d182

File tree

2 files changed

+53
-24
lines changed

2 files changed

+53
-24
lines changed

mypy/suggestions.py

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,14 @@
2525
from typing import (
2626
List, Optional, Tuple, Dict, Callable, Union, NamedTuple, TypeVar, Iterator,
2727
)
28+
from typing_extensions import TypedDict
2829

2930
from mypy.state import strict_optional_set
3031
from mypy.types import (
3132
Type, AnyType, TypeOfAny, CallableType, UnionType, NoneType, Instance, TupleType,
3233
TypeVarType, FunctionLike,
3334
TypeStrVisitor, TypeTranslator,
34-
is_optional, ProperType, get_proper_type,
35+
is_optional, remove_optional, ProperType, get_proper_type,
3536
)
3637
from mypy.build import State, Graph
3738
from mypy.nodes import (
@@ -56,6 +57,11 @@
5657
import json
5758
import os
5859

60+
61+
PyAnnotateSignature = TypedDict('PyAnnotateSignature',
62+
{'return_type': str, 'arg_types': List[str]})
63+
64+
5965
Callsite = NamedTuple(
6066
'Callsite',
6167
[('path', str),
@@ -175,7 +181,7 @@ def suggest(self, function: str) -> str:
175181
if self.give_json:
176182
return self.json_suggestion(function, suggestion)
177183
else:
178-
return suggestion
184+
return self.format_signature(suggestion)
179185

180186
def suggest_callsites(self, function: str) -> str:
181187
"""Find a list of call sites of function."""
@@ -247,7 +253,11 @@ def get_args(self, is_method: bool,
247253
if default:
248254
all_arg_types.append(default)
249255

250-
if all_arg_types:
256+
if len(all_arg_types) == 1 and isinstance(get_proper_type(all_arg_types[0]), NoneType):
257+
types.append(
258+
[UnionType.make_union([all_arg_types[0],
259+
AnyType(TypeOfAny.explicit)])])
260+
elif all_arg_types:
251261
types.append(generate_type_combinations(all_arg_types))
252262
else:
253263
# If we don't have anything, we'll try Any and object
@@ -311,7 +321,7 @@ def find_best(self, func: FuncDef, guesses: List[CallableType]) -> Tuple[Callabl
311321
key=lambda s: (count_errors(errors[s]), self.score_callable(s)))
312322
return best, count_errors(errors[best])
313323

314-
def get_suggestion(self, function: str) -> str:
324+
def get_suggestion(self, function: str) -> PyAnnotateSignature:
315325
"""Compute a suggestion for a function.
316326
317327
Return the type and whether the first argument should be ignored.
@@ -355,7 +365,7 @@ def get_suggestion(self, function: str) -> str:
355365
if self.no_errors and errors:
356366
raise SuggestionFailure("No annotation without errors")
357367

358-
return self.format_callable(mod, is_method, best)
368+
return self.pyannotate_signature(mod, is_method, best)
359369

360370
def format_args(self,
361371
arg_kinds: List[List[int]],
@@ -483,7 +493,7 @@ def ensure_loaded(self, state: State) -> MypyFile:
483493
def builtin_type(self, s: str) -> Instance:
484494
return self.manager.semantic_analyzer.builtin_type(s)
485495

486-
def json_suggestion(self, function: str, suggestion: str) -> str:
496+
def json_suggestion(self, function: str, suggestion: PyAnnotateSignature) -> str:
487497
"""Produce a json blob for a suggestion suitable for application by pyannotate."""
488498
mod, func_name, node = self.find_node(function)
489499
# pyannotate irritatingly drops class names for class and static methods
@@ -496,22 +506,33 @@ def json_suggestion(self, function: str, suggestion: str) -> str:
496506
path = os.path.abspath(self.graph[mod].xpath)
497507

498508
obj = {
499-
'type_comments': [suggestion],
509+
'signature': suggestion,
500510
'line': node.line,
501511
'path': path,
502512
'func_name': func_name,
503513
'samples': 0
504514
}
505515
return json.dumps([obj], sort_keys=True)
506516

507-
def format_callable(self,
508-
cur_module: Optional[str], is_method: bool, typ: CallableType) -> str:
509-
"""Format a callable type in a way suitable as an annotation... kind of"""
517+
def pyannotate_signature(
518+
self,
519+
cur_module: Optional[str],
520+
is_method: bool,
521+
typ: CallableType
522+
) -> PyAnnotateSignature:
523+
"""Format a callable type as a pyannotate dict"""
510524
start = int(is_method)
511-
s = "({}) -> {}".format(
512-
", ".join([self.format_type(cur_module, t) for t in typ.arg_types[start:]]),
513-
self.format_type(cur_module, typ.ret_type))
514-
return s
525+
return {
526+
'arg_types': [self.format_type(cur_module, t) for t in typ.arg_types[start:]],
527+
'return_type': self.format_type(cur_module, typ.ret_type),
528+
}
529+
530+
def format_signature(self, sig: PyAnnotateSignature) -> str:
531+
"""Format a callable type in a way suitable as an annotation... kind of"""
532+
return "({}) -> {}".format(
533+
", ".join(sig['arg_types']),
534+
sig['return_type']
535+
)
515536

516537
def format_type(self, cur_module: Optional[str], typ: Type) -> str:
517538
return typ.accept(TypeFormatter(cur_module, self.graph))
@@ -618,7 +639,9 @@ def visit_instance(self, t: Instance) -> str:
618639
elif t.args != []:
619640
obj += '[{}]'.format(self.list_str(t.args))
620641

621-
if mod == 'builtins':
642+
if mod_obj == ('builtins', 'unicode'):
643+
return 'Text'
644+
elif mod == 'builtins':
622645
return obj
623646
else:
624647
delim = '.' if '.' not in obj else ':'
@@ -632,6 +655,12 @@ def visit_tuple_type(self, t: TupleType) -> str:
632655
s = self.list_str(t.items)
633656
return 'Tuple[{}]'.format(s)
634657

658+
def visit_union_type(self, t: UnionType) -> str:
659+
if len(t.items) == 2 and is_optional(t):
660+
return "Optional[{}]".format(remove_optional(t).accept(self))
661+
else:
662+
return super().visit_union_type(t)
663+
635664
def visit_callable_type(self, t: CallableType) -> str:
636665
# TODO: use extended callables?
637666
if is_tricky_callable(t):

test-data/unit/fine-grained-suggest.test

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def untyped(x) -> None:
6464
foo(x)
6565
[builtins fixtures/isinstancelist.pyi]
6666
[out]
67-
(Union[str, int, None], Union[int, None]) -> None
67+
(Union[str, int, None], Optional[int]) -> None
6868
==
6969

7070
[case testSuggestInferFunc2]
@@ -100,7 +100,7 @@ from foo import bar
100100
bar(None)
101101
[out]
102102
(Any) -> Any
103-
(None) -> None
103+
(Optional[Any]) -> None
104104
==
105105

106106
[case testSuggestInferFuncAny2]
@@ -185,7 +185,7 @@ def foo(s):
185185
from foo import foo
186186
foo('lol')
187187
[out]
188-
(unicode) -> unicode
188+
(Text) -> Text
189189
==
190190

191191
[case testSuggestInferMethod1]
@@ -211,7 +211,7 @@ def bar() -> None:
211211
x.foo(None)
212212
[builtins fixtures/isinstancelist.pyi]
213213
[out]
214-
(Union[str, int, None], Union[int, None]) -> object
214+
(Union[str, int, None], Optional[int]) -> object
215215
==
216216

217217
[case testSuggestInferMethod2]
@@ -238,7 +238,7 @@ def bar() -> None:
238238
a: Union[str, int] = x.foo(None)
239239
[builtins fixtures/isinstancelist.pyi]
240240
[out]
241-
(Union[str, int, None], Union[int, None]) -> Union[int, str]
241+
(Union[str, int, None], Optional[int]) -> Union[int, str]
242242
==
243243

244244
[case testSuggestInferMethod3]
@@ -259,7 +259,7 @@ def bar() -> None:
259259
[builtins fixtures/isinstancelist.pyi]
260260
[out]
261261
==
262-
(Union[str, None]) -> None
262+
(Optional[str]) -> None
263263

264264
[case testSuggestInferFunctionUnreachable]
265265
# flags: --strict-optional
@@ -304,7 +304,7 @@ def bar() -> None:
304304
[builtins fixtures/isinstancelist.pyi]
305305
[out]
306306
==
307-
(Union[str, int, None], Union[int, None]) -> Union[int, str]
307+
(Union[str, int, None], Optional[int]) -> Union[int, str]
308308

309309
[case testSuggestInferNestedMethod]
310310
# flags: --strict-optional
@@ -405,8 +405,8 @@ def baz() -> None:
405405
foo(None)
406406
[builtins fixtures/isinstancelist.pyi]
407407
[out]
408-
(Union[str, int, None], Union[int, None]) -> object
409-
(Union[str, int, None], Union[int, None]) -> None
408+
(Union[str, int, None], Optional[int]) -> object
409+
(Union[str, int, None], Optional[int]) -> None
410410
==
411411

412412
[case testSuggestInferFuncDecorator1]

0 commit comments

Comments
 (0)