Skip to content

Commit 0afc115

Browse files
committed
Add support for attrs.fields
1 parent fe7007f commit 0afc115

File tree

5 files changed

+59
-2
lines changed

5 files changed

+59
-2
lines changed

mypy/plugins/attrs.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
Var,
4747
is_class_var,
4848
)
49-
from mypy.plugin import SemanticAnalyzerPluginInterface
49+
from mypy.plugin import FunctionContext, SemanticAnalyzerPluginInterface
5050
from mypy.plugins.common import (
5151
_get_argument,
5252
_get_bool_argument,
@@ -1060,3 +1060,27 @@ def evolve_function_sig_callback(ctx: mypy.plugin.FunctionSigContext) -> Callabl
10601060
fallback=ctx.default_signature.fallback,
10611061
name=f"{ctx.default_signature.name} of {inst_type_str}",
10621062
)
1063+
1064+
1065+
def _get_cls_from_init(t: Type) -> TypeInfo | None:
1066+
if isinstance(t, CallableType):
1067+
return t.type_object()
1068+
return None
1069+
1070+
1071+
def fields_function_callback(ctx: FunctionContext) -> Type:
1072+
"""Provide the proper return value for `attrs.fields`."""
1073+
if ctx.arg_types and ctx.arg_types[0] and ctx.arg_types[0][0]:
1074+
first_arg_type = ctx.arg_types[0][0]
1075+
cls = _get_cls_from_init(first_arg_type)
1076+
if cls is not None:
1077+
if MAGIC_ATTR_NAME in cls.names:
1078+
# This is a proper attrs class.
1079+
ret_type = cls.names[MAGIC_ATTR_NAME].type
1080+
return ret_type
1081+
else:
1082+
ctx.api.fail(
1083+
f'Argument 1 to "fields" has incompatible type "{format_type_bare(first_arg_type)}"; expected an attrs class',
1084+
ctx.context,
1085+
)
1086+
return ctx.default_return_type

mypy/plugins/default.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,14 @@ class DefaultPlugin(Plugin):
3939
"""Type checker plugin that is enabled by default."""
4040

4141
def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type] | None:
42-
from mypy.plugins import ctypes, singledispatch
42+
from mypy.plugins import attrs, ctypes, singledispatch
4343

4444
if fullname == "_ctypes.Array":
4545
return ctypes.array_constructor_callback
4646
elif fullname == "functools.singledispatch":
4747
return singledispatch.create_singledispatch_function_callback
48+
elif fullname in ("attr.fields", "attrs.fields"):
49+
return attrs.fields_function_callback
4850
return None
4951

5052
def get_function_signature_hook(

test-data/unit/check-plugin-attrs.test

+27
Original file line numberDiff line numberDiff line change
@@ -1548,6 +1548,33 @@ takes_attrs_cls(A(1, "")) # E: Argument 1 to "takes_attrs_cls" has incompatible
15481548
takes_attrs_instance(A) # E: Argument 1 to "takes_attrs_instance" has incompatible type "Type[A]"; expected "AttrsInstance" # N: ClassVar protocol member AttrsInstance.__attrs_attrs__ can never be matched by a class object
15491549
[builtins fixtures/plugin_attrs.pyi]
15501550

1551+
[case testAttrsFields]
1552+
import attr
1553+
from attrs import fields
1554+
1555+
@attr.define
1556+
class A:
1557+
b: int
1558+
c: str
1559+
1560+
reveal_type(fields(A)) # N: Revealed type is "Tuple[attr.Attribute[builtins.int], attr.Attribute[builtins.str], fallback=__main__.A.____main___A_AttrsAttributes__]"
1561+
reveal_type(fields(A)[0]) # N: Revealed type is "attr.Attribute[builtins.int]"
1562+
reveal_type(fields(A).b) # N: Revealed type is "attr.Attribute[builtins.int]"
1563+
fields(A).x # E: "____main___A_AttrsAttributes__" has no attribute "x"
1564+
1565+
[builtins fixtures/attr.pyi]
1566+
1567+
[case testNonattrsFields]
1568+
from attrs import fields
1569+
1570+
class A:
1571+
b: int
1572+
c: str
1573+
1574+
fields(A) # E: Argument 1 to "fields" has incompatible type "Type[A]"; expected an attrs class
1575+
1576+
[builtins fixtures/attr.pyi]
1577+
15511578
[case testAttrsInitMethodAlwaysGenerates]
15521579
from typing import Tuple
15531580
import attr

test-data/unit/lib-stub/attr/__init__.pyi

+2
Original file line numberDiff line numberDiff line change
@@ -247,3 +247,5 @@ def field(
247247

248248
def evolve(inst: _T, **changes: Any) -> _T: ...
249249
def assoc(inst: _T, **changes: Any) -> _T: ...
250+
251+
def fields(cls: _C) -> Any: ...

test-data/unit/lib-stub/attrs/__init__.pyi

+2
Original file line numberDiff line numberDiff line change
@@ -129,3 +129,5 @@ def field(
129129

130130
def evolve(inst: _T, **changes: Any) -> _T: ...
131131
def assoc(inst: _T, **changes: Any) -> _T: ...
132+
133+
def fields(cls: _C) -> Any: ...

0 commit comments

Comments
 (0)