Skip to content

Commit 9e85f9b

Browse files
authored
[dataclass_transform] Support default parameters (#14580)
PEP 681 defines several parameters for `typing.dataclass_transform`. This commit adds support for collecting these arguments and forwarding them to the dataclasses plugin. For this first iteration, only the `*_default` parameters are supported; `field_specifiers` will be implemented in a separate commit, since it is more complicated.
1 parent f505614 commit 9e85f9b

10 files changed

+352
-75
lines changed

mypy/nodes.py

Lines changed: 64 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -480,13 +480,7 @@ def accept(self, visitor: StatementVisitor[T]) -> T:
480480
return visitor.visit_import_all(self)
481481

482482

483-
FUNCBASE_FLAGS: Final = [
484-
"is_property",
485-
"is_class",
486-
"is_static",
487-
"is_final",
488-
"is_dataclass_transform",
489-
]
483+
FUNCBASE_FLAGS: Final = ["is_property", "is_class", "is_static", "is_final"]
490484

491485

492486
class FuncBase(Node):
@@ -512,7 +506,6 @@ class FuncBase(Node):
512506
"is_static", # Uses "@staticmethod"
513507
"is_final", # Uses "@final"
514508
"_fullname",
515-
"is_dataclass_transform", # Is decorated with "@typing.dataclass_transform" or similar
516509
)
517510

518511
def __init__(self) -> None:
@@ -531,7 +524,6 @@ def __init__(self) -> None:
531524
self.is_final = False
532525
# Name with module prefix
533526
self._fullname = ""
534-
self.is_dataclass_transform = False
535527

536528
@property
537529
@abstractmethod
@@ -758,6 +750,8 @@ class FuncDef(FuncItem, SymbolNode, Statement):
758750
"deco_line",
759751
"is_trivial_body",
760752
"is_mypy_only",
753+
# Present only when a function is decorated with @typing.datasclass_transform or similar
754+
"dataclass_transform_spec",
761755
)
762756

763757
__match_args__ = ("name", "arguments", "type", "body")
@@ -785,6 +779,7 @@ def __init__(
785779
self.deco_line: int | None = None
786780
# Definitions that appear in if TYPE_CHECKING are marked with this flag.
787781
self.is_mypy_only = False
782+
self.dataclass_transform_spec: DataclassTransformSpec | None = None
788783

789784
@property
790785
def name(self) -> str:
@@ -810,6 +805,11 @@ def serialize(self) -> JsonDict:
810805
"flags": get_flags(self, FUNCDEF_FLAGS),
811806
"abstract_status": self.abstract_status,
812807
# TODO: Do we need expanded, original_def?
808+
"dataclass_transform_spec": (
809+
None
810+
if self.dataclass_transform_spec is None
811+
else self.dataclass_transform_spec.serialize()
812+
),
813813
}
814814

815815
@classmethod
@@ -832,6 +832,11 @@ def deserialize(cls, data: JsonDict) -> FuncDef:
832832
ret.arg_names = data["arg_names"]
833833
ret.arg_kinds = [ArgKind(x) for x in data["arg_kinds"]]
834834
ret.abstract_status = data["abstract_status"]
835+
ret.dataclass_transform_spec = (
836+
DataclassTransformSpec.deserialize(data["dataclass_transform_spec"])
837+
if data["dataclass_transform_spec"] is not None
838+
else None
839+
)
835840
# Leave these uninitialized so that future uses will trigger an error
836841
del ret.arguments
837842
del ret.max_pos
@@ -3857,6 +3862,56 @@ def deserialize(cls, data: JsonDict) -> SymbolTable:
38573862
return st
38583863

38593864

3865+
class DataclassTransformSpec:
3866+
"""Specifies how a dataclass-like transform should be applied. The fields here are based on the
3867+
parameters accepted by `typing.dataclass_transform`."""
3868+
3869+
__slots__ = (
3870+
"eq_default",
3871+
"order_default",
3872+
"kw_only_default",
3873+
"frozen_default",
3874+
"field_specifiers",
3875+
)
3876+
3877+
def __init__(
3878+
self,
3879+
*,
3880+
eq_default: bool | None = None,
3881+
order_default: bool | None = None,
3882+
kw_only_default: bool | None = None,
3883+
field_specifiers: tuple[str, ...] | None = None,
3884+
# Specified outside of PEP 681:
3885+
# frozen_default was added to CPythonin https://github.com/python/cpython/pull/99958 citing
3886+
# positive discussion in typing-sig
3887+
frozen_default: bool | None = None,
3888+
):
3889+
self.eq_default = eq_default if eq_default is not None else True
3890+
self.order_default = order_default if order_default is not None else False
3891+
self.kw_only_default = kw_only_default if kw_only_default is not None else False
3892+
self.frozen_default = frozen_default if frozen_default is not None else False
3893+
self.field_specifiers = field_specifiers if field_specifiers is not None else ()
3894+
3895+
def serialize(self) -> JsonDict:
3896+
return {
3897+
"eq_default": self.eq_default,
3898+
"order_default": self.order_default,
3899+
"kw_only_default": self.kw_only_default,
3900+
"frozen_only_default": self.frozen_default,
3901+
"field_specifiers": self.field_specifiers,
3902+
}
3903+
3904+
@classmethod
3905+
def deserialize(cls, data: JsonDict) -> DataclassTransformSpec:
3906+
return DataclassTransformSpec(
3907+
eq_default=data.get("eq_default"),
3908+
order_default=data.get("order_default"),
3909+
kw_only_default=data.get("kw_only_default"),
3910+
frozen_default=data.get("frozen_default"),
3911+
field_specifiers=data.get("field_specifiers"),
3912+
)
3913+
3914+
38603915
def get_flags(node: Node, names: list[str]) -> list[str]:
38613916
return [name for name in names if getattr(node, name)]
38623917

mypy/plugins/dataclasses.py

Lines changed: 71 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@
1818
AssignmentStmt,
1919
CallExpr,
2020
Context,
21+
DataclassTransformSpec,
2122
Expression,
2223
JsonDict,
2324
NameExpr,
25+
Node,
2426
PlaceholderNode,
2527
RefExpr,
2628
SymbolTableNode,
@@ -37,6 +39,7 @@
3739
add_method,
3840
deserialize_and_fixup_type,
3941
)
42+
from mypy.semanal_shared import find_dataclass_transform_spec
4043
from mypy.server.trigger import make_wildcard_trigger
4144
from mypy.state import state
4245
from mypy.typeops import map_type_from_supertype
@@ -56,11 +59,16 @@
5659

5760
# The set of decorators that generate dataclasses.
5861
dataclass_makers: Final = {"dataclass", "dataclasses.dataclass"}
59-
# The set of functions that generate dataclass fields.
60-
field_makers: Final = {"dataclasses.field"}
6162

6263

6364
SELF_TVAR_NAME: Final = "_DT"
65+
_TRANSFORM_SPEC_FOR_DATACLASSES = DataclassTransformSpec(
66+
eq_default=True,
67+
order_default=False,
68+
kw_only_default=False,
69+
frozen_default=False,
70+
field_specifiers=("dataclasses.Field", "dataclasses.field"),
71+
)
6472

6573

6674
class DataclassAttribute:
@@ -155,6 +163,7 @@ class DataclassTransformer:
155163

156164
def __init__(self, ctx: ClassDefContext) -> None:
157165
self._ctx = ctx
166+
self._spec = _get_transform_spec(ctx.reason)
158167

159168
def transform(self) -> bool:
160169
"""Apply all the necessary transformations to the underlying
@@ -172,9 +181,9 @@ def transform(self) -> bool:
172181
return False
173182
decorator_arguments = {
174183
"init": _get_decorator_bool_argument(self._ctx, "init", True),
175-
"eq": _get_decorator_bool_argument(self._ctx, "eq", True),
176-
"order": _get_decorator_bool_argument(self._ctx, "order", False),
177-
"frozen": _get_decorator_bool_argument(self._ctx, "frozen", False),
184+
"eq": _get_decorator_bool_argument(self._ctx, "eq", self._spec.eq_default),
185+
"order": _get_decorator_bool_argument(self._ctx, "order", self._spec.order_default),
186+
"frozen": _get_decorator_bool_argument(self._ctx, "frozen", self._spec.frozen_default),
178187
"slots": _get_decorator_bool_argument(self._ctx, "slots", False),
179188
"match_args": _get_decorator_bool_argument(self._ctx, "match_args", True),
180189
}
@@ -411,7 +420,7 @@ def collect_attributes(self) -> list[DataclassAttribute] | None:
411420

412421
# Second, collect attributes belonging to the current class.
413422
current_attr_names: set[str] = set()
414-
kw_only = _get_decorator_bool_argument(ctx, "kw_only", False)
423+
kw_only = _get_decorator_bool_argument(ctx, "kw_only", self._spec.kw_only_default)
415424
for stmt in cls.defs.body:
416425
# Any assignment that doesn't use the new type declaration
417426
# syntax can be ignored out of hand.
@@ -461,7 +470,7 @@ def collect_attributes(self) -> list[DataclassAttribute] | None:
461470
if self._is_kw_only_type(node_type):
462471
kw_only = True
463472

464-
has_field_call, field_args = _collect_field_args(stmt.rvalue, ctx)
473+
has_field_call, field_args = self._collect_field_args(stmt.rvalue, ctx)
465474

466475
is_in_init_param = field_args.get("init")
467476
if is_in_init_param is None:
@@ -614,6 +623,36 @@ def _add_dataclass_fields_magic_attribute(self) -> None:
614623
kind=MDEF, node=var, plugin_generated=True
615624
)
616625

626+
def _collect_field_args(
627+
self, expr: Expression, ctx: ClassDefContext
628+
) -> tuple[bool, dict[str, Expression]]:
629+
"""Returns a tuple where the first value represents whether or not
630+
the expression is a call to dataclass.field and the second is a
631+
dictionary of the keyword arguments that field() was called with.
632+
"""
633+
if (
634+
isinstance(expr, CallExpr)
635+
and isinstance(expr.callee, RefExpr)
636+
and expr.callee.fullname in self._spec.field_specifiers
637+
):
638+
# field() only takes keyword arguments.
639+
args = {}
640+
for name, arg, kind in zip(expr.arg_names, expr.args, expr.arg_kinds):
641+
if not kind.is_named():
642+
if kind.is_named(star=True):
643+
# This means that `field` is used with `**` unpacking,
644+
# the best we can do for now is not to fail.
645+
# TODO: we can infer what's inside `**` and try to collect it.
646+
message = 'Unpacking **kwargs in "field()" is not supported'
647+
else:
648+
message = '"field()" does not accept positional arguments'
649+
ctx.api.fail(message, expr)
650+
return True, {}
651+
assert name is not None
652+
args[name] = arg
653+
return True, args
654+
return False, {}
655+
617656

618657
def dataclass_tag_callback(ctx: ClassDefContext) -> None:
619658
"""Record that we have a dataclass in the main semantic analysis pass.
@@ -631,32 +670,29 @@ def dataclass_class_maker_callback(ctx: ClassDefContext) -> bool:
631670
return transformer.transform()
632671

633672

634-
def _collect_field_args(
635-
expr: Expression, ctx: ClassDefContext
636-
) -> tuple[bool, dict[str, Expression]]:
637-
"""Returns a tuple where the first value represents whether or not
638-
the expression is a call to dataclass.field and the second is a
639-
dictionary of the keyword arguments that field() was called with.
673+
def _get_transform_spec(reason: Expression) -> DataclassTransformSpec:
674+
"""Find the relevant transform parameters from the decorator/parent class/metaclass that
675+
triggered the dataclasses plugin.
676+
677+
Although the resulting DataclassTransformSpec is based on the typing.dataclass_transform
678+
function, we also use it for traditional dataclasses.dataclass classes as well for simplicity.
679+
In those cases, we return a default spec rather than one based on a call to
680+
`typing.dataclass_transform`.
640681
"""
641-
if (
642-
isinstance(expr, CallExpr)
643-
and isinstance(expr.callee, RefExpr)
644-
and expr.callee.fullname in field_makers
645-
):
646-
# field() only takes keyword arguments.
647-
args = {}
648-
for name, arg, kind in zip(expr.arg_names, expr.args, expr.arg_kinds):
649-
if not kind.is_named():
650-
if kind.is_named(star=True):
651-
# This means that `field` is used with `**` unpacking,
652-
# the best we can do for now is not to fail.
653-
# TODO: we can infer what's inside `**` and try to collect it.
654-
message = 'Unpacking **kwargs in "field()" is not supported'
655-
else:
656-
message = '"field()" does not accept positional arguments'
657-
ctx.api.fail(message, expr)
658-
return True, {}
659-
assert name is not None
660-
args[name] = arg
661-
return True, args
662-
return False, {}
682+
if _is_dataclasses_decorator(reason):
683+
return _TRANSFORM_SPEC_FOR_DATACLASSES
684+
685+
spec = find_dataclass_transform_spec(reason)
686+
assert spec is not None, (
687+
"trying to find dataclass transform spec, but reason is neither dataclasses.dataclass nor "
688+
"decorated with typing.dataclass_transform"
689+
)
690+
return spec
691+
692+
693+
def _is_dataclasses_decorator(node: Node) -> bool:
694+
if isinstance(node, CallExpr):
695+
node = node.callee
696+
if isinstance(node, RefExpr):
697+
return node.fullname in dataclass_makers
698+
return False

mypy/semanal.py

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@
9999
ConditionalExpr,
100100
Context,
101101
ContinueStmt,
102+
DataclassTransformSpec,
102103
Decorator,
103104
DelStmt,
104105
DictExpr,
@@ -213,6 +214,7 @@
213214
PRIORITY_FALLBACKS,
214215
SemanticAnalyzerInterface,
215216
calculate_tuple_fallback,
217+
find_dataclass_transform_spec,
216218
has_placeholder,
217219
set_callable_name as set_callable_name,
218220
)
@@ -1523,7 +1525,7 @@ def visit_decorator(self, dec: Decorator) -> None:
15231525
elif isinstance(d, CallExpr) and refers_to_fullname(
15241526
d.callee, DATACLASS_TRANSFORM_NAMES
15251527
):
1526-
dec.func.is_dataclass_transform = True
1528+
dec.func.dataclass_transform_spec = self.parse_dataclass_transform_spec(d)
15271529
elif not dec.var.is_property:
15281530
# We have seen a "non-trivial" decorator before seeing @property, if
15291531
# we will see a @property later, give an error, as we don't support this.
@@ -1728,7 +1730,7 @@ def apply_class_plugin_hooks(self, defn: ClassDef) -> None:
17281730
# Special case: if the decorator is itself decorated with
17291731
# typing.dataclass_transform, apply the hook for the dataclasses plugin
17301732
# TODO: remove special casing here
1731-
if hook is None and is_dataclass_transform_decorator(decorator):
1733+
if hook is None and find_dataclass_transform_spec(decorator):
17321734
hook = dataclasses_plugin.dataclass_tag_callback
17331735
if hook:
17341736
hook(ClassDefContext(defn, decorator, self))
@@ -6456,6 +6458,35 @@ def set_future_import_flags(self, module_name: str) -> None:
64566458
def is_future_flag_set(self, flag: str) -> bool:
64576459
return self.modules[self.cur_mod_id].is_future_flag_set(flag)
64586460

6461+
def parse_dataclass_transform_spec(self, call: CallExpr) -> DataclassTransformSpec:
6462+
"""Build a DataclassTransformSpec from the arguments passed to the given call to
6463+
typing.dataclass_transform."""
6464+
parameters = DataclassTransformSpec()
6465+
for name, value in zip(call.arg_names, call.args):
6466+
# field_specifiers is currently the only non-boolean argument; check for it first so
6467+
# so the rest of the block can fail through to handling booleans
6468+
if name == "field_specifiers":
6469+
self.fail('"field_specifiers" support is currently unimplemented', call)
6470+
continue
6471+
6472+
boolean = self.parse_bool(value)
6473+
if boolean is None:
6474+
self.fail(f'"{name}" argument must be a True or False literal', call)
6475+
continue
6476+
6477+
if name == "eq_default":
6478+
parameters.eq_default = boolean
6479+
elif name == "order_default":
6480+
parameters.order_default = boolean
6481+
elif name == "kw_only_default":
6482+
parameters.kw_only_default = boolean
6483+
elif name == "frozen_default":
6484+
parameters.frozen_default = boolean
6485+
else:
6486+
self.fail(f'Unrecognized dataclass_transform parameter "{name}"', call)
6487+
6488+
return parameters
6489+
64596490

64606491
def replace_implicit_first_type(sig: FunctionLike, new: Type) -> FunctionLike:
64616492
if isinstance(sig, CallableType):
@@ -6645,21 +6676,3 @@ def halt(self, reason: str = ...) -> NoReturn:
66456676
return isinstance(stmt, PassStmt) or (
66466677
isinstance(stmt, ExpressionStmt) and isinstance(stmt.expr, EllipsisExpr)
66476678
)
6648-
6649-
6650-
def is_dataclass_transform_decorator(node: Node | None) -> bool:
6651-
if isinstance(node, RefExpr):
6652-
return is_dataclass_transform_decorator(node.node)
6653-
if isinstance(node, CallExpr):
6654-
# Like dataclasses.dataclass, transform-based decorators can be applied either with or
6655-
# without parameters; ie, both of these forms are accepted:
6656-
#
6657-
# @typing.dataclass_transform
6658-
# class Foo: ...
6659-
# @typing.dataclass_transform(eq=True, order=True, ...)
6660-
# class Bar: ...
6661-
#
6662-
# We need to unwrap the call for the second variant.
6663-
return is_dataclass_transform_decorator(node.callee)
6664-
6665-
return isinstance(node, Decorator) and node.func.is_dataclass_transform

0 commit comments

Comments
 (0)