44
44
from mypy .util import split_module_names
45
45
from mypy .typevars import fill_typevars
46
46
from mypy .visitor import ExpressionVisitor
47
- from mypy .funcplugins import get_function_plugin_callbacks , PluginCallback
47
+ from mypy .plugin import Plugin , PluginContext , MethodSignatureHook
48
48
from mypy .typeanal import make_optional_type
49
49
50
50
from mypy import experiments
@@ -105,17 +105,18 @@ class ExpressionChecker(ExpressionVisitor[Type]):
105
105
type_context = None # type: List[Optional[Type]]
106
106
107
107
strfrm_checker = None # type: StringFormatterChecker
108
- function_plugins = None # type: Dict[str, PluginCallback]
108
+ plugin = None # type: Plugin
109
109
110
110
def __init__ (self ,
111
111
chk : 'mypy.checker.TypeChecker' ,
112
- msg : MessageBuilder ) -> None :
112
+ msg : MessageBuilder ,
113
+ plugin : Plugin ) -> None :
113
114
"""Construct an expression type checker."""
114
115
self .chk = chk
115
116
self .msg = msg
117
+ self .plugin = plugin
116
118
self .type_context = [None ]
117
119
self .strfrm_checker = StringFormatterChecker (self , self .chk , self .msg )
118
- self .function_plugins = get_function_plugin_callbacks (self .chk .options .python_version )
119
120
120
121
def visit_name_expr (self , e : NameExpr ) -> Type :
121
122
"""Type check a name expression.
@@ -208,11 +209,33 @@ def visit_call_expr(self, e: CallExpr, allow_none_return: bool = False) -> Type:
208
209
isinstance (callee_type , CallableType )
209
210
and callee_type .implicit ):
210
211
return self .msg .untyped_function_call (callee_type , e )
212
+ # Figure out the full name of the callee for plugin loopup.
213
+ object_type = None
211
214
if not isinstance (e .callee , RefExpr ):
212
215
fullname = None
213
216
else :
214
217
fullname = e .callee .fullname
215
- ret_type = self .check_call_expr_with_callee_type (callee_type , e , fullname )
218
+ if (fullname is None
219
+ and isinstance (e .callee , MemberExpr )
220
+ and isinstance (callee_type , FunctionLike )):
221
+ # For method calls we include the defining class for the method
222
+ # in the full name (example: 'typing.Mapping.get').
223
+ callee_expr_type = self .chk .type_map .get (e .callee .expr )
224
+ info = None
225
+ # TODO: Support fallbacks of other kinds of types as well?
226
+ if isinstance (callee_expr_type , Instance ):
227
+ info = callee_expr_type .type
228
+ elif isinstance (callee_expr_type , TypedDictType ):
229
+ info = callee_expr_type .fallback .type .get_containing_type_info (e .callee .name )
230
+ if info :
231
+ fullname = '{}.{}' .format (info .fullname (), e .callee .name )
232
+ object_type = callee_expr_type
233
+ # Apply plugin signature hook that may generate a better signature.
234
+ signature_hook = self .plugin .get_method_signature_hook (fullname )
235
+ if signature_hook :
236
+ callee_type = self .apply_method_signature_hook (
237
+ e , callee_type , object_type , signature_hook )
238
+ ret_type = self .check_call_expr_with_callee_type (callee_type , e , fullname , object_type )
216
239
if isinstance (ret_type , UninhabitedType ):
217
240
self .chk .binder .unreachable ()
218
241
if not allow_none_return and isinstance (ret_type , NoneTyp ):
@@ -351,8 +374,10 @@ def apply_function_plugin(self,
351
374
formal_to_actual : List [List [int ]],
352
375
args : List [Expression ],
353
376
num_formals : int ,
354
- fullname : Optional [str ]) -> Type :
355
- """Use special case logic to infer the return type for of a particular named function.
377
+ fullname : Optional [str ],
378
+ object_type : Optional [Type ],
379
+ context : Context ) -> Type :
380
+ """Use special case logic to infer the return type of a specific named function/method.
356
381
357
382
Return the inferred return type.
358
383
"""
@@ -362,41 +387,90 @@ def apply_function_plugin(self,
362
387
for actual in actuals :
363
388
formal_arg_types [formal ].append (arg_types [actual ])
364
389
formal_arg_exprs [formal ].append (args [actual ])
365
- return self .function_plugins [fullname ](
366
- formal_arg_types , formal_arg_exprs , inferred_ret_type , self .chk .named_generic_type )
367
-
368
- def check_call_expr_with_callee_type (self , callee_type : Type ,
369
- e : CallExpr , callable_name : Optional [str ]) -> Type :
390
+ if object_type is None :
391
+ # Apply function plugin
392
+ callback = self .plugin .get_function_hook (fullname )
393
+ assert callback is not None # Assume that caller ensures this
394
+ return callback (formal_arg_types , formal_arg_exprs , inferred_ret_type ,
395
+ self .chk .named_generic_type )
396
+ else :
397
+ # Apply method plugin
398
+ method_callback = self .plugin .get_method_hook (fullname )
399
+ assert method_callback is not None # Assume that caller ensures this
400
+ return method_callback (object_type , formal_arg_types , formal_arg_exprs ,
401
+ inferred_ret_type , self .create_plugin_context (context ))
402
+
403
+ def apply_method_signature_hook (self , e : CallExpr , callee : FunctionLike , object_type : Type ,
404
+ signature_hook : MethodSignatureHook ) -> FunctionLike :
405
+ """Apply a plugin hook that may infer a more precise signature for a method."""
406
+ if isinstance (callee , CallableType ):
407
+ arg_kinds = e .arg_kinds
408
+ arg_names = e .arg_names
409
+ args = e .args
410
+ num_formals = len (callee .arg_kinds )
411
+ formal_to_actual = map_actuals_to_formals (
412
+ arg_kinds , arg_names ,
413
+ callee .arg_kinds , callee .arg_names ,
414
+ lambda i : self .accept (args [i ]))
415
+ formal_arg_exprs = [[] for _ in range (num_formals )] # type: List[List[Expression]]
416
+ for formal , actuals in enumerate (formal_to_actual ):
417
+ for actual in actuals :
418
+ formal_arg_exprs [formal ].append (args [actual ])
419
+ return signature_hook (object_type , formal_arg_exprs , callee ,
420
+ self .chk .named_generic_type )
421
+ else :
422
+ assert isinstance (callee , Overloaded )
423
+ items = []
424
+ for item in callee .items ():
425
+ adjusted = self .apply_method_signature_hook (e , item , object_type , signature_hook )
426
+ assert isinstance (adjusted , CallableType )
427
+ items .append (adjusted )
428
+ return Overloaded (items )
429
+
430
+ def create_plugin_context (self , context : Context ) -> PluginContext :
431
+ return PluginContext (self .chk .named_generic_type , self .msg , context )
432
+
433
+ def check_call_expr_with_callee_type (self ,
434
+ callee_type : Type ,
435
+ e : CallExpr ,
436
+ callable_name : Optional [str ],
437
+ object_type : Optional [Type ]) -> Type :
370
438
"""Type check call expression.
371
439
372
440
The given callee type overrides the type of the callee
373
441
expression.
374
442
"""
375
443
return self .check_call (callee_type , e .args , e .arg_kinds , e ,
376
444
e .arg_names , callable_node = e .callee ,
377
- callable_name = callable_name )[0 ]
445
+ callable_name = callable_name ,
446
+ object_type = object_type )[0 ]
378
447
379
448
def check_call (self , callee : Type , args : List [Expression ],
380
449
arg_kinds : List [int ], context : Context ,
381
450
arg_names : List [str ] = None ,
382
451
callable_node : Expression = None ,
383
452
arg_messages : MessageBuilder = None ,
384
- callable_name : Optional [str ] = None ) -> Tuple [Type , Type ]:
453
+ callable_name : Optional [str ] = None ,
454
+ object_type : Optional [Type ] = None ) -> Tuple [Type , Type ]:
385
455
"""Type check a call.
386
456
387
457
Also infer type arguments if the callee is a generic function.
388
458
389
459
Return (result type, inferred callee type).
390
460
391
461
Arguments:
392
- callee: type of the called value
393
- args: actual argument expressions
394
- arg_kinds: contains nodes.ARG_* constant for each argument in args
395
- describing whether the argument is positional, *arg, etc.
396
- arg_names: names of arguments (optional)
397
- callable_node: associate the inferred callable type to this node,
398
- if specified
399
- arg_messages: TODO
462
+ callee: type of the called value
463
+ args: actual argument expressions
464
+ arg_kinds: contains nodes.ARG_* constant for each argument in args
465
+ describing whether the argument is positional, *arg, etc.
466
+ arg_names: names of arguments (optional)
467
+ callable_node: associate the inferred callable type to this node,
468
+ if specified
469
+ arg_messages: TODO
470
+ callable_name: Fully-qualified name of the function/method to call,
471
+ or None if unavaiable (examples: 'builtins.open', 'typing.Mapping.get')
472
+ object_type: If callable_name refers to a method, the type of the object
473
+ on which the method is being called
400
474
"""
401
475
arg_messages = arg_messages or self .msg
402
476
if isinstance (callee , CallableType ):
@@ -443,10 +517,12 @@ def check_call(self, callee: Type, args: List[Expression],
443
517
if callable_node :
444
518
# Store the inferred callable type.
445
519
self .chk .store_type (callable_node , callee )
446
- if callable_name in self .function_plugins :
520
+
521
+ if ((object_type is None and self .plugin .get_function_hook (callable_name ))
522
+ or (object_type is not None and self .plugin .get_method_hook (callable_name ))):
447
523
ret_type = self .apply_function_plugin (
448
524
arg_types , callee .ret_type , arg_kinds , formal_to_actual ,
449
- args , len (callee .arg_types ), callable_name )
525
+ args , len (callee .arg_types ), callable_name , object_type , context )
450
526
callee = callee .copy_modified (ret_type = ret_type )
451
527
return callee .ret_type , callee
452
528
elif isinstance (callee , Overloaded ):
@@ -461,7 +537,9 @@ def check_call(self, callee: Type, args: List[Expression],
461
537
callee , context ,
462
538
messages = arg_messages )
463
539
return self .check_call (target , args , arg_kinds , context , arg_names ,
464
- arg_messages = arg_messages )
540
+ arg_messages = arg_messages ,
541
+ callable_name = callable_name ,
542
+ object_type = object_type )
465
543
elif isinstance (callee , AnyType ) or not self .chk .in_checked_function ():
466
544
self .infer_arg_types_in_context (None , args )
467
545
return AnyType (), AnyType ()
@@ -1295,8 +1373,16 @@ def check_op_local(self, method: str, base_type: Type, arg: Expression,
1295
1373
method_type = analyze_member_access (method , base_type , context , False , False , True ,
1296
1374
self .named_type , self .not_ready_callback , local_errors ,
1297
1375
original_type = base_type , chk = self .chk )
1376
+ callable_name = None
1377
+ object_type = None
1378
+ if isinstance (base_type , Instance ):
1379
+ # TODO: Find out in which class the method was defined originally?
1380
+ # TODO: Support non-Instance types.
1381
+ callable_name = '{}.{}' .format (base_type .type .fullname (), method )
1382
+ object_type = base_type
1298
1383
return self .check_call (method_type , [arg ], [nodes .ARG_POS ],
1299
- context , arg_messages = local_errors )
1384
+ context , arg_messages = local_errors ,
1385
+ callable_name = callable_name , object_type = object_type )
1300
1386
1301
1387
def check_op (self , method : str , base_type : Type , arg : Expression ,
1302
1388
context : Context ,
@@ -1769,13 +1855,14 @@ def visit_dict_expr(self, e: DictExpr) -> Type:
1769
1855
# an error, but returns the TypedDict type that matches the literal it found
1770
1856
# that would cause a second error when that TypedDict type is returned upstream
1771
1857
# to avoid the second error, we always return TypedDict type that was requested
1772
- if isinstance (self .type_context [- 1 ], TypedDictType ):
1858
+ typeddict_context = self .find_typeddict_context (self .type_context [- 1 ])
1859
+ if typeddict_context :
1773
1860
self .check_typeddict_call_with_dict (
1774
- callee = self . type_context [ - 1 ] ,
1861
+ callee = typeddict_context ,
1775
1862
kwargs = e ,
1776
1863
context = e
1777
1864
)
1778
- return self . type_context [ - 1 ] .copy_modified ()
1865
+ return typeddict_context .copy_modified ()
1779
1866
1780
1867
# Collect function arguments, watching out for **expr.
1781
1868
args = [] # type: List[Expression] # Regular "key: value"
@@ -1826,6 +1913,21 @@ def visit_dict_expr(self, e: DictExpr) -> Type:
1826
1913
self .check_call (method , [arg ], [nodes .ARG_POS ], arg )
1827
1914
return rv
1828
1915
1916
+ def find_typeddict_context (self , context : Type ) -> Optional [TypedDictType ]:
1917
+ if isinstance (context , TypedDictType ):
1918
+ return context
1919
+ elif isinstance (context , UnionType ):
1920
+ items = []
1921
+ for item in context .items :
1922
+ item_context = self .find_typeddict_context (item )
1923
+ if item_context :
1924
+ items .append (item_context )
1925
+ if len (items ) == 1 :
1926
+ # Only one union item is TypedDict, so use the context as it's unambiguous.
1927
+ return items [0 ]
1928
+ # No TypedDict type in context.
1929
+ return None
1930
+
1829
1931
def visit_lambda_expr (self , e : LambdaExpr ) -> Type :
1830
1932
"""Type check lambda expression."""
1831
1933
inferred_type , type_override = self .infer_lambda_type_using_context (e )
0 commit comments