Skip to content

Commit 4635a8c

Browse files
authored
[dataclass_transform] support field_specifiers (#14667)
These are analogous to `dataclasses.field`/`dataclasses.Field`. Like most dataclass_transform features so far, this commit mostly just plumbs through the necessary metadata so that we can re-use the existing `dataclasses` plugin logic. It also adds support for the `alias=` and `factory=` kwargs for fields, which are small; we rely on typeshed to enforce that these aren't used with `dataclasses.field`.
1 parent ec511c6 commit 4635a8c

File tree

5 files changed

+196
-5
lines changed

5 files changed

+196
-5
lines changed

mypy/message_registry.py

+4
Original file line numberDiff line numberDiff line change
@@ -270,3 +270,7 @@ def with_additional_msg(self, info: str) -> ErrorMessage:
270270
CLASS_PATTERN_UNKNOWN_KEYWORD: Final = 'Class "{}" has no attribute "{}"'
271271
MULTIPLE_ASSIGNMENTS_IN_PATTERN: Final = 'Multiple assignments to name "{}" in pattern'
272272
CANNOT_MODIFY_MATCH_ARGS: Final = 'Cannot assign to "__match_args__"'
273+
274+
DATACLASS_FIELD_ALIAS_MUST_BE_LITERAL: Final = (
275+
'"alias" argument to dataclass field must be a string literal'
276+
)

mypy/plugin.py

+4
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,10 @@ def parse_bool(self, expr: Expression) -> bool | None:
297297
"""Parse True/False literals."""
298298
raise NotImplementedError
299299

300+
@abstractmethod
301+
def parse_str_literal(self, expr: Expression) -> str | None:
302+
"""Parse string literals."""
303+
300304
@abstractmethod
301305
def fail(
302306
self,

mypy/plugins/dataclasses.py

+41-3
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Optional
66
from typing_extensions import Final
77

8+
from mypy import errorcodes, message_registry
89
from mypy.expandtype import expand_type
910
from mypy.nodes import (
1011
ARG_NAMED,
@@ -77,6 +78,7 @@ class DataclassAttribute:
7778
def __init__(
7879
self,
7980
name: str,
81+
alias: str | None,
8082
is_in_init: bool,
8183
is_init_var: bool,
8284
has_default: bool,
@@ -87,6 +89,7 @@ def __init__(
8789
kw_only: bool,
8890
) -> None:
8991
self.name = name
92+
self.alias = alias
9093
self.is_in_init = is_in_init
9194
self.is_init_var = is_init_var
9295
self.has_default = has_default
@@ -121,12 +124,13 @@ def expand_type(self, current_info: TypeInfo) -> Optional[Type]:
121124
return self.type
122125

123126
def to_var(self, current_info: TypeInfo) -> Var:
124-
return Var(self.name, self.expand_type(current_info))
127+
return Var(self.alias or self.name, self.expand_type(current_info))
125128

126129
def serialize(self) -> JsonDict:
127130
assert self.type
128131
return {
129132
"name": self.name,
133+
"alias": self.alias,
130134
"is_in_init": self.is_in_init,
131135
"is_init_var": self.is_init_var,
132136
"has_default": self.has_default,
@@ -495,7 +499,12 @@ def collect_attributes(self) -> list[DataclassAttribute] | None:
495499
# Ensure that something like x: int = field() is rejected
496500
# after an attribute with a default.
497501
if has_field_call:
498-
has_default = "default" in field_args or "default_factory" in field_args
502+
has_default = (
503+
"default" in field_args
504+
or "default_factory" in field_args
505+
# alias for default_factory defined in PEP 681
506+
or "factory" in field_args
507+
)
499508

500509
# All other assignments are already type checked.
501510
elif not isinstance(stmt.rvalue, TempNode):
@@ -511,7 +520,11 @@ def collect_attributes(self) -> list[DataclassAttribute] | None:
511520
# kw_only value from the decorator parameter.
512521
field_kw_only_param = field_args.get("kw_only")
513522
if field_kw_only_param is not None:
514-
is_kw_only = bool(self._api.parse_bool(field_kw_only_param))
523+
value = self._api.parse_bool(field_kw_only_param)
524+
if value is not None:
525+
is_kw_only = value
526+
else:
527+
self._api.fail('"kw_only" argument must be a boolean literal', stmt.rvalue)
515528

516529
if sym.type is None and node.is_final and node.is_inferred:
517530
# This is a special case, assignment like x: Final = 42 is classified
@@ -529,9 +542,20 @@ def collect_attributes(self) -> list[DataclassAttribute] | None:
529542
)
530543
node.type = AnyType(TypeOfAny.from_error)
531544

545+
alias = None
546+
if "alias" in field_args:
547+
alias = self._api.parse_str_literal(field_args["alias"])
548+
if alias is None:
549+
self._api.fail(
550+
message_registry.DATACLASS_FIELD_ALIAS_MUST_BE_LITERAL,
551+
stmt.rvalue,
552+
code=errorcodes.LITERAL_REQ,
553+
)
554+
532555
current_attr_names.add(lhs.name)
533556
found_attrs[lhs.name] = DataclassAttribute(
534557
name=lhs.name,
558+
alias=alias,
535559
is_in_init=is_in_init,
536560
is_init_var=is_init_var,
537561
has_default=has_default,
@@ -624,6 +648,14 @@ def _is_kw_only_type(self, node: Type | None) -> bool:
624648
return node_type.type.fullname == "dataclasses.KW_ONLY"
625649

626650
def _add_dataclass_fields_magic_attribute(self) -> None:
651+
# Only add if the class is a dataclasses dataclass, and omit it for dataclass_transform
652+
# classes.
653+
# It would be nice if this condition were reified rather than using an `is` check.
654+
# Only add if the class is a dataclasses dataclass, and omit it for dataclass_transform
655+
# classes.
656+
if self._spec is not _TRANSFORM_SPEC_FOR_DATACLASSES:
657+
return
658+
627659
attr_name = "__dataclass_fields__"
628660
any_type = AnyType(TypeOfAny.explicit)
629661
field_type = self._api.named_type_or_none("dataclasses.Field", [any_type]) or any_type
@@ -657,6 +689,12 @@ def _collect_field_args(self, expr: Expression) -> tuple[bool, dict[str, Express
657689
# the best we can do for now is not to fail.
658690
# TODO: we can infer what's inside `**` and try to collect it.
659691
message = 'Unpacking **kwargs in "field()" is not supported'
692+
elif self._spec is not _TRANSFORM_SPEC_FOR_DATACLASSES:
693+
# dataclasses.field can only be used with keyword args, but this
694+
# restriction is only enforced for the *standardized* arguments to
695+
# dataclass_transform field specifiers. If this is not a
696+
# dataclasses.dataclass class, we can just skip positional args safely.
697+
continue
660698
else:
661699
message = '"field()" does not accept positional arguments'
662700
self._api.fail(message, expr)

mypy/semanal.py

+28-2
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@
236236
remove_dups,
237237
type_constructors,
238238
)
239-
from mypy.typeops import function_type, get_type_vars
239+
from mypy.typeops import function_type, get_type_vars, try_getting_str_literals_from_type
240240
from mypy.types import (
241241
ASSERT_TYPE_NAMES,
242242
DATACLASS_TRANSFORM_NAMES,
@@ -6462,6 +6462,17 @@ def parse_bool(self, expr: Expression) -> bool | None:
64626462
return False
64636463
return None
64646464

6465+
def parse_str_literal(self, expr: Expression) -> str | None:
6466+
"""Attempt to find the string literal value of the given expression. Returns `None` if no
6467+
literal value can be found."""
6468+
if isinstance(expr, StrExpr):
6469+
return expr.value
6470+
if isinstance(expr, RefExpr) and isinstance(expr.node, Var) and expr.node.type is not None:
6471+
values = try_getting_str_literals_from_type(expr.node.type)
6472+
if values is not None and len(values) == 1:
6473+
return values[0]
6474+
return None
6475+
64656476
def set_future_import_flags(self, module_name: str) -> None:
64666477
if module_name in FUTURE_IMPORTS:
64676478
self.modules[self.cur_mod_id].future_import_flags.add(FUTURE_IMPORTS[module_name])
@@ -6482,7 +6493,9 @@ def parse_dataclass_transform_spec(self, call: CallExpr) -> DataclassTransformSp
64826493
# field_specifiers is currently the only non-boolean argument; check for it first so
64836494
# so the rest of the block can fail through to handling booleans
64846495
if name == "field_specifiers":
6485-
self.fail('"field_specifiers" support is currently unimplemented', call)
6496+
parameters.field_specifiers = self.parse_dataclass_transform_field_specifiers(
6497+
value
6498+
)
64866499
continue
64876500

64886501
boolean = require_bool_literal_argument(self, value, name)
@@ -6502,6 +6515,19 @@ def parse_dataclass_transform_spec(self, call: CallExpr) -> DataclassTransformSp
65026515

65036516
return parameters
65046517

6518+
def parse_dataclass_transform_field_specifiers(self, arg: Expression) -> tuple[str, ...]:
6519+
if not isinstance(arg, TupleExpr):
6520+
self.fail('"field_specifiers" argument must be a tuple literal', arg)
6521+
return tuple()
6522+
6523+
names = []
6524+
for specifier in arg.items:
6525+
if not isinstance(specifier, RefExpr):
6526+
self.fail('"field_specifiers" must only contain identifiers', specifier)
6527+
return tuple()
6528+
names.append(specifier.fullname)
6529+
return tuple(names)
6530+
65056531

65066532
def replace_implicit_first_type(sig: FunctionLike, new: Type) -> FunctionLike:
65076533
if isinstance(sig, CallableType):

test-data/unit/check-dataclass-transform.test

+119
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,125 @@ Foo(5)
210210
[typing fixtures/typing-full.pyi]
211211
[builtins fixtures/dataclasses.pyi]
212212

213+
[case testDataclassTransformFieldSpecifierRejectMalformed]
214+
# flags: --python-version 3.11
215+
from typing import dataclass_transform, Any, Callable, Final, Type
216+
217+
def some_type() -> Type: ...
218+
def some_function() -> Callable[[], None]: ...
219+
220+
def field(*args, **kwargs): ...
221+
def fields_tuple() -> tuple[type | Callable[..., Any], ...]: return (field,)
222+
CONSTANT: Final = (field,)
223+
224+
@dataclass_transform(field_specifiers=(some_type(),)) # E: "field_specifiers" must only contain identifiers
225+
def bad_dataclass1() -> None: ...
226+
@dataclass_transform(field_specifiers=(some_function(),)) # E: "field_specifiers" must only contain identifiers
227+
def bad_dataclass2() -> None: ...
228+
@dataclass_transform(field_specifiers=CONSTANT) # E: "field_specifiers" argument must be a tuple literal
229+
def bad_dataclass3() -> None: ...
230+
@dataclass_transform(field_specifiers=fields_tuple()) # E: "field_specifiers" argument must be a tuple literal
231+
def bad_dataclass4() -> None: ...
232+
233+
[typing fixtures/typing-full.pyi]
234+
[builtins fixtures/dataclasses.pyi]
235+
236+
[case testDataclassTransformFieldSpecifierParams]
237+
# flags: --python-version 3.11
238+
from typing import dataclass_transform, Any, Callable, Type, Final
239+
240+
def field(
241+
*,
242+
init: bool = True,
243+
kw_only: bool = False,
244+
alias: str | None = None,
245+
default: Any | None = None,
246+
default_factory: Callable[[], Any] | None = None,
247+
factory: Callable[[], Any] | None = None,
248+
): ...
249+
@dataclass_transform(field_specifiers=(field,))
250+
def my_dataclass(cls: Type) -> Type:
251+
return cls
252+
253+
B: Final = 'b_'
254+
@my_dataclass
255+
class Foo:
256+
a: int = field(alias='a_')
257+
b: int = field(alias=B)
258+
# cannot be passed as a positional
259+
kwonly: int = field(kw_only=True, default=0)
260+
# Safe to omit from constructor, error to pass
261+
noinit: int = field(init=False, default=1)
262+
# It should be safe to call the constructor without passing any of these
263+
unused1: int = field(default=0)
264+
unused2: int = field(factory=lambda: 0)
265+
unused3: int = field(default_factory=lambda: 0)
266+
267+
Foo(a=5, b_=1) # E: Unexpected keyword argument "a" for "Foo"
268+
Foo(a_=1, b_=1, noinit=1) # E: Unexpected keyword argument "noinit" for "Foo"
269+
Foo(1, 2, 3) # E: Too many positional arguments for "Foo"
270+
foo = Foo(1, 2, kwonly=3)
271+
reveal_type(foo.noinit) # N: Revealed type is "builtins.int"
272+
reveal_type(foo.unused1) # N: Revealed type is "builtins.int"
273+
Foo(a_=5, b_=1, unused1=2, unused2=3, unused3=4)
274+
275+
def some_str() -> str: ...
276+
def some_bool() -> bool: ...
277+
@my_dataclass
278+
class Bad:
279+
bad1: int = field(alias=some_str()) # E: "alias" argument to dataclass field must be a string literal
280+
bad2: int = field(kw_only=some_bool()) # E: "kw_only" argument must be a boolean literal
281+
282+
# this metadata should only exist for dataclasses.dataclass classes
283+
Foo.__dataclass_fields__ # E: "Type[Foo]" has no attribute "__dataclass_fields__"
284+
285+
[typing fixtures/typing-full.pyi]
286+
[builtins fixtures/dataclasses.pyi]
287+
288+
[case testDataclassTransformFieldSpecifierExtraArgs]
289+
# flags: --python-version 3.11
290+
from typing import dataclass_transform
291+
292+
def field(extra1, *, kw_only=False, extra2=0): ...
293+
@dataclass_transform(field_specifiers=(field,))
294+
def my_dataclass(cls):
295+
return cls
296+
297+
@my_dataclass
298+
class Good:
299+
a: int = field(5)
300+
b: int = field(5, extra2=1)
301+
c: int = field(5, kw_only=True)
302+
303+
@my_dataclass
304+
class Bad:
305+
a: int = field(kw_only=True) # E: Missing positional argument "extra1" in call to "field"
306+
307+
[typing fixtures/typing-full.pyi]
308+
[builtins fixtures/dataclasses.pyi]
309+
310+
[case testDataclassTransformMultipleFieldSpecifiers]
311+
# flags: --python-version 3.11
312+
from typing import dataclass_transform
313+
314+
def field1(*, default: int) -> int: ...
315+
def field2(*, default: str) -> str: ...
316+
317+
@dataclass_transform(field_specifiers=(field1, field2))
318+
def my_dataclass(cls): return cls
319+
320+
@my_dataclass
321+
class Foo:
322+
a: int = field1(default=0)
323+
b: str = field2(default='hello')
324+
325+
reveal_type(Foo) # N: Revealed type is "def (a: builtins.int =, b: builtins.str =) -> __main__.Foo"
326+
Foo()
327+
Foo(a=1, b='bye')
328+
329+
[typing fixtures/typing-full.pyi]
330+
[builtins fixtures/dataclasses.pyi]
331+
213332
[case testDataclassTransformOverloadsDecoratorOnOverload]
214333
# flags: --python-version 3.11
215334
from typing import dataclass_transform, overload, Any, Callable, Type, Literal

0 commit comments

Comments
 (0)