25
25
from typing import (
26
26
List , Optional , Tuple , Dict , Callable , Union , NamedTuple , TypeVar , Iterator ,
27
27
)
28
+ from typing_extensions import TypedDict
28
29
29
30
from mypy .state import strict_optional_set
30
31
from mypy .types import (
31
32
Type , AnyType , TypeOfAny , CallableType , UnionType , NoneType , Instance , TupleType ,
32
33
TypeVarType , FunctionLike ,
33
34
TypeStrVisitor , TypeTranslator ,
34
- is_optional , ProperType , get_proper_type ,
35
+ is_optional , remove_optional , ProperType , get_proper_type ,
35
36
)
36
37
from mypy .build import State , Graph
37
38
from mypy .nodes import (
56
57
import json
57
58
import os
58
59
60
+
61
+ PyAnnotateSignature = TypedDict ('PyAnnotateSignature' ,
62
+ {'return_type' : str , 'arg_types' : List [str ]})
63
+
64
+
59
65
Callsite = NamedTuple (
60
66
'Callsite' ,
61
67
[('path' , str ),
@@ -175,7 +181,7 @@ def suggest(self, function: str) -> str:
175
181
if self .give_json :
176
182
return self .json_suggestion (function , suggestion )
177
183
else :
178
- return suggestion
184
+ return self . format_signature ( suggestion )
179
185
180
186
def suggest_callsites (self , function : str ) -> str :
181
187
"""Find a list of call sites of function."""
@@ -247,7 +253,11 @@ def get_args(self, is_method: bool,
247
253
if default :
248
254
all_arg_types .append (default )
249
255
250
- if all_arg_types :
256
+ if len (all_arg_types ) == 1 and isinstance (get_proper_type (all_arg_types [0 ]), NoneType ):
257
+ types .append (
258
+ [UnionType .make_union ([all_arg_types [0 ],
259
+ AnyType (TypeOfAny .explicit )])])
260
+ elif all_arg_types :
251
261
types .append (generate_type_combinations (all_arg_types ))
252
262
else :
253
263
# If we don't have anything, we'll try Any and object
@@ -311,7 +321,7 @@ def find_best(self, func: FuncDef, guesses: List[CallableType]) -> Tuple[Callabl
311
321
key = lambda s : (count_errors (errors [s ]), self .score_callable (s )))
312
322
return best , count_errors (errors [best ])
313
323
314
- def get_suggestion (self , function : str ) -> str :
324
+ def get_suggestion (self , function : str ) -> PyAnnotateSignature :
315
325
"""Compute a suggestion for a function.
316
326
317
327
Return the type and whether the first argument should be ignored.
@@ -355,7 +365,7 @@ def get_suggestion(self, function: str) -> str:
355
365
if self .no_errors and errors :
356
366
raise SuggestionFailure ("No annotation without errors" )
357
367
358
- return self .format_callable (mod , is_method , best )
368
+ return self .pyannotate_signature (mod , is_method , best )
359
369
360
370
def format_args (self ,
361
371
arg_kinds : List [List [int ]],
@@ -483,7 +493,7 @@ def ensure_loaded(self, state: State) -> MypyFile:
483
493
def builtin_type (self , s : str ) -> Instance :
484
494
return self .manager .semantic_analyzer .builtin_type (s )
485
495
486
- def json_suggestion (self , function : str , suggestion : str ) -> str :
496
+ def json_suggestion (self , function : str , suggestion : PyAnnotateSignature ) -> str :
487
497
"""Produce a json blob for a suggestion suitable for application by pyannotate."""
488
498
mod , func_name , node = self .find_node (function )
489
499
# pyannotate irritatingly drops class names for class and static methods
@@ -496,22 +506,33 @@ def json_suggestion(self, function: str, suggestion: str) -> str:
496
506
path = os .path .abspath (self .graph [mod ].xpath )
497
507
498
508
obj = {
499
- 'type_comments ' : [ suggestion ] ,
509
+ 'signature ' : suggestion ,
500
510
'line' : node .line ,
501
511
'path' : path ,
502
512
'func_name' : func_name ,
503
513
'samples' : 0
504
514
}
505
515
return json .dumps ([obj ], sort_keys = True )
506
516
507
- def format_callable (self ,
508
- cur_module : Optional [str ], is_method : bool , typ : CallableType ) -> str :
509
- """Format a callable type in a way suitable as an annotation... kind of"""
517
+ def pyannotate_signature (
518
+ self ,
519
+ cur_module : Optional [str ],
520
+ is_method : bool ,
521
+ typ : CallableType
522
+ ) -> PyAnnotateSignature :
523
+ """Format a callable type as a pyannotate dict"""
510
524
start = int (is_method )
511
- s = "({}) -> {}" .format (
512
- ", " .join ([self .format_type (cur_module , t ) for t in typ .arg_types [start :]]),
513
- self .format_type (cur_module , typ .ret_type ))
514
- return s
525
+ return {
526
+ 'arg_types' : [self .format_type (cur_module , t ) for t in typ .arg_types [start :]],
527
+ 'return_type' : self .format_type (cur_module , typ .ret_type ),
528
+ }
529
+
530
+ def format_signature (self , sig : PyAnnotateSignature ) -> str :
531
+ """Format a callable type in a way suitable as an annotation... kind of"""
532
+ return "({}) -> {}" .format (
533
+ ", " .join (sig ['arg_types' ]),
534
+ sig ['return_type' ]
535
+ )
515
536
516
537
def format_type (self , cur_module : Optional [str ], typ : Type ) -> str :
517
538
return typ .accept (TypeFormatter (cur_module , self .graph ))
@@ -618,7 +639,9 @@ def visit_instance(self, t: Instance) -> str:
618
639
elif t .args != []:
619
640
obj += '[{}]' .format (self .list_str (t .args ))
620
641
621
- if mod == 'builtins' :
642
+ if mod_obj == ('builtins' , 'unicode' ):
643
+ return 'Text'
644
+ elif mod == 'builtins' :
622
645
return obj
623
646
else :
624
647
delim = '.' if '.' not in obj else ':'
@@ -632,6 +655,12 @@ def visit_tuple_type(self, t: TupleType) -> str:
632
655
s = self .list_str (t .items )
633
656
return 'Tuple[{}]' .format (s )
634
657
658
+ def visit_union_type (self , t : UnionType ) -> str :
659
+ if len (t .items ) == 2 and is_optional (t ):
660
+ return "Optional[{}]" .format (remove_optional (t ).accept (self ))
661
+ else :
662
+ return super ().visit_union_type (t )
663
+
635
664
def visit_callable_type (self , t : CallableType ) -> str :
636
665
# TODO: use extended callables?
637
666
if is_tricky_callable (t ):
0 commit comments