Skip to content

Commit 77c7ca2

Browse files
committed
ensure dataclass init fn has the correct signature
1 parent bb67682 commit 77c7ca2

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

mypy/plugins/dataclasses.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,8 @@ def transform(self) -> None:
130130
add_method(
131131
ctx,
132132
'__init__',
133-
args=[attr.to_argument() for attr in attributes if attr.is_in_init],
133+
args=[attr.to_argument() for attr in attributes if attr.is_in_init
134+
and not self._is_kw_only_type(attr.type)],
134135
return_type=NoneType(),
135136
)
136137

@@ -257,8 +258,7 @@ def collect_attributes(self) -> Optional[List[DataclassAttribute]]:
257258
is_init_var = True
258259
node.type = node_type.args[0]
259260

260-
if (isinstance(node_type, Instance) and
261-
node_type.type.fullname == 'dataclasses._KW_ONLY_TYPE'):
261+
if self._is_kw_only_type(node_type):
262262
kw_only = True
263263

264264
has_field_call, field_args = _collect_field_args(stmt.rvalue)
@@ -395,6 +395,15 @@ def _propertize_callables(self, attributes: List[DataclassAttribute]) -> None:
395395
var._fullname = info.fullname + '.' + var.name
396396
info.names[var.name] = SymbolTableNode(MDEF, var)
397397

398+
def _is_kw_only_type(self, node: Optional[Type]) -> bool:
399+
"""Checks if the type of the node is the KW_ONLY sentinel value."""
400+
if node is None:
401+
return False
402+
node_type = get_proper_type(node)
403+
if not isinstance(node_type, Instance):
404+
return False
405+
return node_type.type.fullname == 'dataclasses._KW_ONLY_TYPE'
406+
398407

399408
def dataclass_class_maker_callback(ctx: ClassDefContext) -> None:
400409
"""Hooks into the class typechecking process to add support for dataclasses.

test-data/unit/check-dataclasses.test

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,8 @@ class D(Base):
384384
z: str
385385
a: str = "a"
386386

387+
D("Hello", "World")
388+
387389
[builtins fixtures/list.pyi]
388390

389391
[case testDataclassesClassmethods]

0 commit comments

Comments
 (0)