Skip to content

Commit 23e2a51

Browse files
authored
Add a class attribute hook to the plugin system (#9881)
This adds a get_class_attribute_hook to plugins to modify attributes on classes (as opposed to the existing get_attribute_hook, which is for attributes on instances). Fixes #9645
1 parent 91e890f commit 23e2a51

File tree

5 files changed

+161
-9
lines changed

5 files changed

+161
-9
lines changed

docs/source/extending_mypy.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,10 @@ fields which already exist on the class. *Exception:* if :py:meth:`__getattr__ <
198198
:py:meth:`__getattribute__ <object.__getattribute__>` is a method on the class, the hook is called for all
199199
fields which do not refer to methods.
200200

201+
**get_class_attribute_hook()** is similar to above, but for attributes on classes rather than instances.
202+
Unlike above, this does not have special casing for :py:meth:`__getattr__ <object.__getattr__>` or
203+
:py:meth:`__getattribute__ <object.__getattribute__>`.
204+
201205
**get_class_decorator_hook()** can be used to update class definition for
202206
given class decorators. For example, you can add some attributes to the class
203207
to match runtime behaviour:

mypy/checkmember.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -704,10 +704,13 @@ def analyze_class_attribute_access(itype: Instance,
704704
if override_info:
705705
info = override_info
706706

707+
fullname = '{}.{}'.format(info.fullname, name)
708+
hook = mx.chk.plugin.get_class_attribute_hook(fullname)
709+
707710
node = info.get(name)
708711
if not node:
709712
if info.fallback_to_any:
710-
return AnyType(TypeOfAny.special_form)
713+
return apply_class_attr_hook(mx, hook, AnyType(TypeOfAny.special_form))
711714
return None
712715

713716
is_decorated = isinstance(node.node, Decorator)
@@ -732,14 +735,16 @@ def analyze_class_attribute_access(itype: Instance,
732735
if info.is_enum and not (mx.is_lvalue or is_decorated or is_method):
733736
enum_class_attribute_type = analyze_enum_class_attribute_access(itype, name, mx)
734737
if enum_class_attribute_type:
735-
return enum_class_attribute_type
738+
return apply_class_attr_hook(mx, hook, enum_class_attribute_type)
736739

737740
t = node.type
738741
if t:
739742
if isinstance(t, PartialType):
740743
symnode = node.node
741744
assert isinstance(symnode, Var)
742-
return mx.chk.handle_partial_var_type(t, mx.is_lvalue, symnode, mx.context)
745+
return apply_class_attr_hook(mx, hook,
746+
mx.chk.handle_partial_var_type(t, mx.is_lvalue, symnode,
747+
mx.context))
743748

744749
# Find the class where method/variable was defined.
745750
if isinstance(node.node, Decorator):
@@ -790,7 +795,8 @@ def analyze_class_attribute_access(itype: Instance,
790795
mx.self_type, original_vars=original_vars)
791796
if not mx.is_lvalue:
792797
result = analyze_descriptor_access(result, mx)
793-
return result
798+
799+
return apply_class_attr_hook(mx, hook, result)
794800
elif isinstance(node.node, Var):
795801
mx.not_ready_callback(name, mx.context)
796802
return AnyType(TypeOfAny.special_form)
@@ -814,7 +820,7 @@ def analyze_class_attribute_access(itype: Instance,
814820
if is_decorated:
815821
assert isinstance(node.node, Decorator)
816822
if node.node.type:
817-
return node.node.type
823+
return apply_class_attr_hook(mx, hook, node.node.type)
818824
else:
819825
mx.not_ready_callback(name, mx.context)
820826
return AnyType(TypeOfAny.from_error)
@@ -826,7 +832,17 @@ def analyze_class_attribute_access(itype: Instance,
826832
# unannotated implicit class methods we do this here.
827833
if node.node.is_class:
828834
typ = bind_self(typ, is_classmethod=True)
829-
return typ
835+
return apply_class_attr_hook(mx, hook, typ)
836+
837+
838+
def apply_class_attr_hook(mx: MemberContext,
839+
hook: Optional[Callable[[AttributeContext], Type]],
840+
result: Type,
841+
) -> Optional[Type]:
842+
if hook:
843+
result = hook(AttributeContext(get_proper_type(mx.original_type),
844+
result, mx.context, mx.chk))
845+
return result
830846

831847

832848
def analyze_enum_class_attribute_access(itype: Instance,

mypy/plugin.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -637,10 +637,10 @@ def get_method_hook(self, fullname: str
637637

638638
def get_attribute_hook(self, fullname: str
639639
) -> Optional[Callable[[AttributeContext], Type]]:
640-
"""Adjust type of a class attribute.
640+
"""Adjust type of an instance attribute.
641641
642-
This method is called with attribute full name using the class where the attribute was
643-
defined (or Var.info.fullname for generated attributes).
642+
This method is called with attribute full name using the class of the instance where
643+
the attribute was defined (or Var.info.fullname for generated attributes).
644644
645645
For classes without __getattr__ or __getattribute__, this hook is only called for
646646
names of fields/properties (but not methods) that exist in the instance MRO.
@@ -667,6 +667,25 @@ class Derived(Base):
667667
"""
668668
return None
669669

670+
def get_class_attribute_hook(self, fullname: str
671+
) -> Optional[Callable[[AttributeContext], Type]]:
672+
"""
673+
Adjust type of a class attribute.
674+
675+
This method is called with attribute full name using the class where the attribute was
676+
defined (or Var.info.fullname for generated attributes).
677+
678+
For example:
679+
680+
class Cls:
681+
x: Any
682+
683+
Cls.x
684+
685+
get_class_attribute_hook is called with '__main__.Cls.x' as fullname.
686+
"""
687+
return None
688+
670689
def get_class_decorator_hook(self, fullname: str
671690
) -> Optional[Callable[[ClassDefContext], None]]:
672691
"""Update class definition for given class decorators.
@@ -788,6 +807,10 @@ def get_attribute_hook(self, fullname: str
788807
) -> Optional[Callable[[AttributeContext], Type]]:
789808
return self._find_hook(lambda plugin: plugin.get_attribute_hook(fullname))
790809

810+
def get_class_attribute_hook(self, fullname: str
811+
) -> Optional[Callable[[AttributeContext], Type]]:
812+
return self._find_hook(lambda plugin: plugin.get_class_attribute_hook(fullname))
813+
791814
def get_class_decorator_hook(self, fullname: str
792815
) -> Optional[Callable[[ClassDefContext], None]]:
793816
return self._find_hook(lambda plugin: plugin.get_class_decorator_hook(fullname))

test-data/unit/check-custom-plugin.test

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -902,3 +902,92 @@ reveal_type(f()) # N: Revealed type is "builtins.str"
902902
[file mypy.ini]
903903
\[mypy]
904904
plugins=<ROOT>/test-data/unit/plugins/method_in_decorator.py
905+
906+
[case testClassAttrPluginClassVar]
907+
# flags: --config-file tmp/mypy.ini
908+
909+
from typing import Type
910+
911+
class Cls:
912+
attr = 'test'
913+
unchanged = 'test'
914+
915+
reveal_type(Cls().attr) # N: Revealed type is "builtins.str"
916+
reveal_type(Cls.attr) # N: Revealed type is "builtins.int"
917+
reveal_type(Cls.unchanged) # N: Revealed type is "builtins.str"
918+
x: Type[Cls]
919+
reveal_type(x.attr) # N: Revealed type is "builtins.int"
920+
[file mypy.ini]
921+
\[mypy]
922+
plugins=<ROOT>/test-data/unit/plugins/class_attr_hook.py
923+
924+
[case testClassAttrPluginMethod]
925+
# flags: --config-file tmp/mypy.ini
926+
927+
class Cls:
928+
def attr(self) -> None:
929+
pass
930+
931+
reveal_type(Cls.attr) # N: Revealed type is "builtins.int"
932+
[file mypy.ini]
933+
\[mypy]
934+
plugins=<ROOT>/test-data/unit/plugins/class_attr_hook.py
935+
936+
[case testClassAttrPluginEnum]
937+
# flags: --config-file tmp/mypy.ini
938+
939+
import enum
940+
941+
class Cls(enum.Enum):
942+
attr = 'test'
943+
944+
reveal_type(Cls.attr) # N: Revealed type is "builtins.int"
945+
[file mypy.ini]
946+
\[mypy]
947+
plugins=<ROOT>/test-data/unit/plugins/class_attr_hook.py
948+
949+
[case testClassAttrPluginMetaclassAnyBase]
950+
# flags: --config-file tmp/mypy.ini
951+
952+
from typing import Any, Type
953+
class M(type):
954+
attr = 'test'
955+
956+
B: Any
957+
class Cls(B, metaclass=M):
958+
pass
959+
960+
reveal_type(Cls.attr) # N: Revealed type is "builtins.int"
961+
[file mypy.ini]
962+
\[mypy]
963+
plugins=<ROOT>/test-data/unit/plugins/class_attr_hook.py
964+
965+
[case testClassAttrPluginMetaclassRegularBase]
966+
# flags: --config-file tmp/mypy.ini
967+
968+
from typing import Any, Type
969+
class M(type):
970+
attr = 'test'
971+
972+
class B:
973+
attr = None
974+
975+
class Cls(B, metaclass=M):
976+
pass
977+
978+
reveal_type(Cls.attr) # N: Revealed type is "builtins.int"
979+
[file mypy.ini]
980+
\[mypy]
981+
plugins=<ROOT>/test-data/unit/plugins/class_attr_hook.py
982+
983+
[case testClassAttrPluginPartialType]
984+
# flags: --config-file tmp/mypy.ini
985+
986+
class Cls:
987+
attr = None
988+
def f(self) -> int:
989+
return Cls.attr + 1
990+
991+
[file mypy.ini]
992+
\[mypy]
993+
plugins=<ROOT>/test-data/unit/plugins/class_attr_hook.py
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from typing import Callable, Optional
2+
3+
from mypy.plugin import AttributeContext, Plugin
4+
from mypy.types import Type as MypyType
5+
6+
7+
class ClassAttrPlugin(Plugin):
8+
def get_class_attribute_hook(self, fullname: str
9+
) -> Optional[Callable[[AttributeContext], MypyType]]:
10+
if fullname == '__main__.Cls.attr':
11+
return my_hook
12+
return None
13+
14+
15+
def my_hook(ctx: AttributeContext) -> MypyType:
16+
return ctx.api.named_generic_type('builtins.int', [])
17+
18+
19+
def plugin(_version: str):
20+
return ClassAttrPlugin

0 commit comments

Comments
 (0)