16
16
Type , Instance , CallableType , TypedDictType , UnionType , NoneTyp , TypeVarType ,
17
17
AnyType , TypeList , UnboundType , TypeOfAny , TypeType ,
18
18
)
19
+ from mypy import messages
19
20
from mypy .messages import MessageBuilder
20
21
from mypy .options import Options
21
22
import mypy .interpreted_plugin
@@ -61,6 +62,10 @@ class CheckerPluginInterface:
61
62
msg = None # type: MessageBuilder
62
63
options = None # type: Options
63
64
65
+ @abstractmethod
66
+ def fail (self , msg : str , ctx : Context ) -> None :
67
+ raise NotImplementedError
68
+
64
69
@abstractmethod
65
70
def named_generic_type (self , name : str , args : List [Type ]) -> Instance :
66
71
raise NotImplementedError
@@ -400,6 +405,14 @@ def get_method_signature_hook(self, fullname: str
400
405
401
406
if fullname == 'typing.Mapping.get' :
402
407
return typed_dict_get_signature_callback
408
+ elif fullname == 'mypy_extensions._TypedDict.setdefault' :
409
+ return typed_dict_setdefault_signature_callback
410
+ elif fullname == 'mypy_extensions._TypedDict.pop' :
411
+ return typed_dict_pop_signature_callback
412
+ elif fullname == 'mypy_extensions._TypedDict.update' :
413
+ return typed_dict_update_signature_callback
414
+ elif fullname == 'mypy_extensions._TypedDict.__delitem__' :
415
+ return typed_dict_delitem_signature_callback
403
416
elif fullname == 'ctypes.Array.__setitem__' :
404
417
return ctypes .array_setitem_callback
405
418
return None
@@ -412,6 +425,12 @@ def get_method_hook(self, fullname: str
412
425
return typed_dict_get_callback
413
426
elif fullname == 'builtins.int.__pow__' :
414
427
return int_pow_callback
428
+ elif fullname == 'mypy_extensions._TypedDict.setdefault' :
429
+ return typed_dict_setdefault_callback
430
+ elif fullname == 'mypy_extensions._TypedDict.pop' :
431
+ return typed_dict_pop_callback
432
+ elif fullname == 'mypy_extensions._TypedDict.__delitem__' :
433
+ return typed_dict_delitem_callback
415
434
elif fullname == 'ctypes.Array.__getitem__' :
416
435
return ctypes .array_getitem_callback
417
436
elif fullname == 'ctypes.Array.__iter__' :
@@ -544,6 +563,136 @@ def typed_dict_get_callback(ctx: MethodContext) -> Type:
544
563
return ctx .default_return_type
545
564
546
565
566
+ def typed_dict_pop_signature_callback (ctx : MethodSigContext ) -> CallableType :
567
+ """Try to infer a better signature type for TypedDict.pop.
568
+
569
+ This is used to get better type context for the second argument that
570
+ depends on a TypedDict value type.
571
+ """
572
+ signature = ctx .default_signature
573
+ str_type = ctx .api .named_generic_type ('builtins.str' , [])
574
+ if (isinstance (ctx .type , TypedDictType )
575
+ and len (ctx .args ) == 2
576
+ and len (ctx .args [0 ]) == 1
577
+ and isinstance (ctx .args [0 ][0 ], StrExpr )
578
+ and len (signature .arg_types ) == 2
579
+ and len (signature .variables ) == 1
580
+ and len (ctx .args [1 ]) == 1 ):
581
+ key = ctx .args [0 ][0 ].value
582
+ value_type = ctx .type .items .get (key )
583
+ if value_type :
584
+ # Tweak the signature to include the value type as context. It's
585
+ # only needed for type inference since there's a union with a type
586
+ # variable that accepts everything.
587
+ tv = TypeVarType (signature .variables [0 ])
588
+ typ = UnionType .make_simplified_union ([value_type , tv ])
589
+ return signature .copy_modified (
590
+ arg_types = [str_type , typ ],
591
+ ret_type = typ )
592
+ return signature .copy_modified (arg_types = [str_type , signature .arg_types [1 ]])
593
+
594
+
595
+ def typed_dict_pop_callback (ctx : MethodContext ) -> Type :
596
+ """Type check and infer a precise return type for TypedDict.pop."""
597
+ if (isinstance (ctx .type , TypedDictType )
598
+ and len (ctx .arg_types ) >= 1
599
+ and len (ctx .arg_types [0 ]) == 1 ):
600
+ if isinstance (ctx .args [0 ][0 ], StrExpr ):
601
+ key = ctx .args [0 ][0 ].value
602
+ if key in ctx .type .required_keys :
603
+ ctx .api .msg .typeddict_key_cannot_be_deleted (ctx .type , key , ctx .context )
604
+ value_type = ctx .type .items .get (key )
605
+ if value_type :
606
+ if len (ctx .args [1 ]) == 0 :
607
+ return value_type
608
+ elif (len (ctx .arg_types ) == 2 and len (ctx .arg_types [1 ]) == 1
609
+ and len (ctx .args [1 ]) == 1 ):
610
+ return UnionType .make_simplified_union ([value_type , ctx .arg_types [1 ][0 ]])
611
+ else :
612
+ ctx .api .msg .typeddict_key_not_found (ctx .type , key , ctx .context )
613
+ return AnyType (TypeOfAny .from_error )
614
+ else :
615
+ ctx .api .fail (messages .TYPEDDICT_KEY_MUST_BE_STRING_LITERAL , ctx .context )
616
+ return AnyType (TypeOfAny .from_error )
617
+ return ctx .default_return_type
618
+
619
+
620
+ def typed_dict_setdefault_signature_callback (ctx : MethodSigContext ) -> CallableType :
621
+ """Try to infer a better signature type for TypedDict.setdefault.
622
+
623
+ This is used to get better type context for the second argument that
624
+ depends on a TypedDict value type.
625
+ """
626
+ signature = ctx .default_signature
627
+ str_type = ctx .api .named_generic_type ('builtins.str' , [])
628
+ if (isinstance (ctx .type , TypedDictType )
629
+ and len (ctx .args ) == 2
630
+ and len (ctx .args [0 ]) == 1
631
+ and isinstance (ctx .args [0 ][0 ], StrExpr )
632
+ and len (signature .arg_types ) == 2
633
+ and len (ctx .args [1 ]) == 1 ):
634
+ key = ctx .args [0 ][0 ].value
635
+ value_type = ctx .type .items .get (key )
636
+ if value_type :
637
+ return signature .copy_modified (arg_types = [str_type , value_type ])
638
+ return signature .copy_modified (arg_types = [str_type , signature .arg_types [1 ]])
639
+
640
+
641
+ def typed_dict_setdefault_callback (ctx : MethodContext ) -> Type :
642
+ """Type check TypedDict.setdefault and infer a precise return type."""
643
+ if (isinstance (ctx .type , TypedDictType )
644
+ and len (ctx .arg_types ) == 2
645
+ and len (ctx .arg_types [0 ]) == 1 ):
646
+ if isinstance (ctx .args [0 ][0 ], StrExpr ):
647
+ key = ctx .args [0 ][0 ].value
648
+ value_type = ctx .type .items .get (key )
649
+ if value_type :
650
+ return value_type
651
+ else :
652
+ ctx .api .msg .typeddict_key_not_found (ctx .type , key , ctx .context )
653
+ return AnyType (TypeOfAny .from_error )
654
+ else :
655
+ ctx .api .fail (messages .TYPEDDICT_KEY_MUST_BE_STRING_LITERAL , ctx .context )
656
+ return AnyType (TypeOfAny .from_error )
657
+ return ctx .default_return_type
658
+
659
+
660
+ def typed_dict_delitem_signature_callback (ctx : MethodSigContext ) -> CallableType :
661
+ # Replace NoReturn as the argument type.
662
+ str_type = ctx .api .named_generic_type ('builtins.str' , [])
663
+ return ctx .default_signature .copy_modified (arg_types = [str_type ])
664
+
665
+
666
+ def typed_dict_delitem_callback (ctx : MethodContext ) -> Type :
667
+ """Type check TypedDict.__delitem__."""
668
+ if (isinstance (ctx .type , TypedDictType )
669
+ and len (ctx .arg_types ) == 1
670
+ and len (ctx .arg_types [0 ]) == 1 ):
671
+ if isinstance (ctx .args [0 ][0 ], StrExpr ):
672
+ key = ctx .args [0 ][0 ].value
673
+ if key in ctx .type .required_keys :
674
+ ctx .api .msg .typeddict_key_cannot_be_deleted (ctx .type , key , ctx .context )
675
+ elif key not in ctx .type .items :
676
+ ctx .api .msg .typeddict_key_not_found (ctx .type , key , ctx .context )
677
+ else :
678
+ ctx .api .fail (messages .TYPEDDICT_KEY_MUST_BE_STRING_LITERAL , ctx .context )
679
+ return AnyType (TypeOfAny .from_error )
680
+ return ctx .default_return_type
681
+
682
+
683
+ def typed_dict_update_signature_callback (ctx : MethodSigContext ) -> CallableType :
684
+ """Try to infer a better signature type for TypedDict.update."""
685
+ signature = ctx .default_signature
686
+ if (isinstance (ctx .type , TypedDictType )
687
+ and len (signature .arg_types ) == 1 ):
688
+ arg_type = signature .arg_types [0 ]
689
+ assert isinstance (arg_type , TypedDictType )
690
+ arg_type = arg_type .as_anonymous ()
691
+ arg_type = arg_type .copy_modified (required_keys = set ())
692
+ return signature .copy_modified (arg_types = [arg_type ])
693
+ return signature
694
+
695
+
547
696
def int_pow_callback (ctx : MethodContext ) -> Type :
548
697
"""Infer a more precise return type for int.__pow__."""
549
698
if (len (ctx .arg_types ) == 1
0 commit comments