18
18
AssignmentStmt ,
19
19
CallExpr ,
20
20
Context ,
21
+ DataclassTransformSpec ,
21
22
Expression ,
22
23
JsonDict ,
23
24
NameExpr ,
25
+ Node ,
24
26
PlaceholderNode ,
25
27
RefExpr ,
26
28
SymbolTableNode ,
37
39
add_method ,
38
40
deserialize_and_fixup_type ,
39
41
)
42
+ from mypy .semanal_shared import find_dataclass_transform_spec
40
43
from mypy .server .trigger import make_wildcard_trigger
41
44
from mypy .state import state
42
45
from mypy .typeops import map_type_from_supertype
56
59
57
60
# The set of decorators that generate dataclasses.
58
61
dataclass_makers : Final = {"dataclass" , "dataclasses.dataclass" }
59
- # The set of functions that generate dataclass fields.
60
- field_makers : Final = {"dataclasses.field" }
61
62
62
63
63
64
SELF_TVAR_NAME : Final = "_DT"
65
+ _TRANSFORM_SPEC_FOR_DATACLASSES = DataclassTransformSpec (
66
+ eq_default = True ,
67
+ order_default = False ,
68
+ kw_only_default = False ,
69
+ frozen_default = False ,
70
+ field_specifiers = ("dataclasses.Field" , "dataclasses.field" ),
71
+ )
64
72
65
73
66
74
class DataclassAttribute :
@@ -155,6 +163,7 @@ class DataclassTransformer:
155
163
156
164
def __init__ (self , ctx : ClassDefContext ) -> None :
157
165
self ._ctx = ctx
166
+ self ._spec = _get_transform_spec (ctx .reason )
158
167
159
168
def transform (self ) -> bool :
160
169
"""Apply all the necessary transformations to the underlying
@@ -172,9 +181,9 @@ def transform(self) -> bool:
172
181
return False
173
182
decorator_arguments = {
174
183
"init" : _get_decorator_bool_argument (self ._ctx , "init" , True ),
175
- "eq" : _get_decorator_bool_argument (self ._ctx , "eq" , True ),
176
- "order" : _get_decorator_bool_argument (self ._ctx , "order" , False ),
177
- "frozen" : _get_decorator_bool_argument (self ._ctx , "frozen" , False ),
184
+ "eq" : _get_decorator_bool_argument (self ._ctx , "eq" , self . _spec . eq_default ),
185
+ "order" : _get_decorator_bool_argument (self ._ctx , "order" , self . _spec . order_default ),
186
+ "frozen" : _get_decorator_bool_argument (self ._ctx , "frozen" , self . _spec . frozen_default ),
178
187
"slots" : _get_decorator_bool_argument (self ._ctx , "slots" , False ),
179
188
"match_args" : _get_decorator_bool_argument (self ._ctx , "match_args" , True ),
180
189
}
@@ -411,7 +420,7 @@ def collect_attributes(self) -> list[DataclassAttribute] | None:
411
420
412
421
# Second, collect attributes belonging to the current class.
413
422
current_attr_names : set [str ] = set ()
414
- kw_only = _get_decorator_bool_argument (ctx , "kw_only" , False )
423
+ kw_only = _get_decorator_bool_argument (ctx , "kw_only" , self . _spec . kw_only_default )
415
424
for stmt in cls .defs .body :
416
425
# Any assignment that doesn't use the new type declaration
417
426
# syntax can be ignored out of hand.
@@ -461,7 +470,7 @@ def collect_attributes(self) -> list[DataclassAttribute] | None:
461
470
if self ._is_kw_only_type (node_type ):
462
471
kw_only = True
463
472
464
- has_field_call , field_args = _collect_field_args (stmt .rvalue , ctx )
473
+ has_field_call , field_args = self . _collect_field_args (stmt .rvalue , ctx )
465
474
466
475
is_in_init_param = field_args .get ("init" )
467
476
if is_in_init_param is None :
@@ -614,6 +623,36 @@ def _add_dataclass_fields_magic_attribute(self) -> None:
614
623
kind = MDEF , node = var , plugin_generated = True
615
624
)
616
625
626
+ def _collect_field_args (
627
+ self , expr : Expression , ctx : ClassDefContext
628
+ ) -> tuple [bool , dict [str , Expression ]]:
629
+ """Returns a tuple where the first value represents whether or not
630
+ the expression is a call to dataclass.field and the second is a
631
+ dictionary of the keyword arguments that field() was called with.
632
+ """
633
+ if (
634
+ isinstance (expr , CallExpr )
635
+ and isinstance (expr .callee , RefExpr )
636
+ and expr .callee .fullname in self ._spec .field_specifiers
637
+ ):
638
+ # field() only takes keyword arguments.
639
+ args = {}
640
+ for name , arg , kind in zip (expr .arg_names , expr .args , expr .arg_kinds ):
641
+ if not kind .is_named ():
642
+ if kind .is_named (star = True ):
643
+ # This means that `field` is used with `**` unpacking,
644
+ # the best we can do for now is not to fail.
645
+ # TODO: we can infer what's inside `**` and try to collect it.
646
+ message = 'Unpacking **kwargs in "field()" is not supported'
647
+ else :
648
+ message = '"field()" does not accept positional arguments'
649
+ ctx .api .fail (message , expr )
650
+ return True , {}
651
+ assert name is not None
652
+ args [name ] = arg
653
+ return True , args
654
+ return False , {}
655
+
617
656
618
657
def dataclass_tag_callback (ctx : ClassDefContext ) -> None :
619
658
"""Record that we have a dataclass in the main semantic analysis pass.
@@ -631,32 +670,29 @@ def dataclass_class_maker_callback(ctx: ClassDefContext) -> bool:
631
670
return transformer .transform ()
632
671
633
672
634
- def _collect_field_args (
635
- expr : Expression , ctx : ClassDefContext
636
- ) -> tuple [bool , dict [str , Expression ]]:
637
- """Returns a tuple where the first value represents whether or not
638
- the expression is a call to dataclass.field and the second is a
639
- dictionary of the keyword arguments that field() was called with.
673
+ def _get_transform_spec (reason : Expression ) -> DataclassTransformSpec :
674
+ """Find the relevant transform parameters from the decorator/parent class/metaclass that
675
+ triggered the dataclasses plugin.
676
+
677
+ Although the resulting DataclassTransformSpec is based on the typing.dataclass_transform
678
+ function, we also use it for traditional dataclasses.dataclass classes as well for simplicity.
679
+ In those cases, we return a default spec rather than one based on a call to
680
+ `typing.dataclass_transform`.
640
681
"""
641
- if (
642
- isinstance (expr , CallExpr )
643
- and isinstance (expr .callee , RefExpr )
644
- and expr .callee .fullname in field_makers
645
- ):
646
- # field() only takes keyword arguments.
647
- args = {}
648
- for name , arg , kind in zip (expr .arg_names , expr .args , expr .arg_kinds ):
649
- if not kind .is_named ():
650
- if kind .is_named (star = True ):
651
- # This means that `field` is used with `**` unpacking,
652
- # the best we can do for now is not to fail.
653
- # TODO: we can infer what's inside `**` and try to collect it.
654
- message = 'Unpacking **kwargs in "field()" is not supported'
655
- else :
656
- message = '"field()" does not accept positional arguments'
657
- ctx .api .fail (message , expr )
658
- return True , {}
659
- assert name is not None
660
- args [name ] = arg
661
- return True , args
662
- return False , {}
682
+ if _is_dataclasses_decorator (reason ):
683
+ return _TRANSFORM_SPEC_FOR_DATACLASSES
684
+
685
+ spec = find_dataclass_transform_spec (reason )
686
+ assert spec is not None , (
687
+ "trying to find dataclass transform spec, but reason is neither dataclasses.dataclass nor "
688
+ "decorated with typing.dataclass_transform"
689
+ )
690
+ return spec
691
+
692
+
693
+ def _is_dataclasses_decorator (node : Node ) -> bool :
694
+ if isinstance (node , CallExpr ):
695
+ node = node .callee
696
+ if isinstance (node , RefExpr ):
697
+ return node .fullname in dataclass_makers
698
+ return False
0 commit comments