Skip to content

[dataclass_transform] support field_specifiers #14667

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
4 changes: 4 additions & 0 deletions mypy/message_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,3 +270,7 @@ def with_additional_msg(self, info: str) -> ErrorMessage:
CLASS_PATTERN_UNKNOWN_KEYWORD: Final = 'Class "{}" has no attribute "{}"'
MULTIPLE_ASSIGNMENTS_IN_PATTERN: Final = 'Multiple assignments to name "{}" in pattern'
CANNOT_MODIFY_MATCH_ARGS: Final = 'Cannot assign to "__match_args__"'

DATACLASS_FIELD_ALIAS_MUST_BE_LITERAL: Final = (
'"alias" argument to dataclass field must be a string literal'
)
4 changes: 4 additions & 0 deletions mypy/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,10 @@ def parse_bool(self, expr: Expression) -> bool | None:
"""Parse True/False literals."""
raise NotImplementedError

@abstractmethod
def parse_str_literal(self, expr: Expression) -> str | None:
"""Parse string literals."""

@abstractmethod
def fail(
self,
Expand Down
44 changes: 41 additions & 3 deletions mypy/plugins/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Optional
from typing_extensions import Final

from mypy import errorcodes, message_registry
from mypy.expandtype import expand_type
from mypy.nodes import (
ARG_NAMED,
Expand Down Expand Up @@ -77,6 +78,7 @@ class DataclassAttribute:
def __init__(
self,
name: str,
alias: str | None,
is_in_init: bool,
is_init_var: bool,
has_default: bool,
Expand All @@ -87,6 +89,7 @@ def __init__(
kw_only: bool,
) -> None:
self.name = name
self.alias = alias
self.is_in_init = is_in_init
self.is_init_var = is_init_var
self.has_default = has_default
Expand Down Expand Up @@ -121,12 +124,13 @@ def expand_type(self, current_info: TypeInfo) -> Optional[Type]:
return self.type

def to_var(self, current_info: TypeInfo) -> Var:
return Var(self.name, self.expand_type(current_info))
return Var(self.alias or self.name, self.expand_type(current_info))

def serialize(self) -> JsonDict:
assert self.type
return {
"name": self.name,
"alias": self.alias,
"is_in_init": self.is_in_init,
"is_init_var": self.is_init_var,
"has_default": self.has_default,
Expand Down Expand Up @@ -495,7 +499,12 @@ def collect_attributes(self) -> list[DataclassAttribute] | None:
# Ensure that something like x: int = field() is rejected
# after an attribute with a default.
if has_field_call:
has_default = "default" in field_args or "default_factory" in field_args
has_default = (
"default" in field_args
or "default_factory" in field_args
# alias for default_factory defined in PEP 681
or "factory" in field_args
)

# All other assignments are already type checked.
elif not isinstance(stmt.rvalue, TempNode):
Expand All @@ -511,7 +520,11 @@ def collect_attributes(self) -> list[DataclassAttribute] | None:
# kw_only value from the decorator parameter.
field_kw_only_param = field_args.get("kw_only")
if field_kw_only_param is not None:
is_kw_only = bool(self._api.parse_bool(field_kw_only_param))
value = self._api.parse_bool(field_kw_only_param)
if value is not None:
is_kw_only = value
else:
self._api.fail('"kw_only" argument must be a boolean literal', stmt.rvalue)

if sym.type is None and node.is_final and node.is_inferred:
# This is a special case, assignment like x: Final = 42 is classified
Expand All @@ -529,9 +542,20 @@ def collect_attributes(self) -> list[DataclassAttribute] | None:
)
node.type = AnyType(TypeOfAny.from_error)

alias = None
if "alias" in field_args:
alias = self._api.parse_str_literal(field_args["alias"])
if alias is None:
self._api.fail(
message_registry.DATACLASS_FIELD_ALIAS_MUST_BE_LITERAL,
stmt.rvalue,
code=errorcodes.LITERAL_REQ,
)

current_attr_names.add(lhs.name)
found_attrs[lhs.name] = DataclassAttribute(
name=lhs.name,
alias=alias,
is_in_init=is_in_init,
is_init_var=is_init_var,
has_default=has_default,
Expand Down Expand Up @@ -624,6 +648,14 @@ def _is_kw_only_type(self, node: Type | None) -> bool:
return node_type.type.fullname == "dataclasses.KW_ONLY"

def _add_dataclass_fields_magic_attribute(self) -> None:
# Only add if the class is a dataclasses dataclass, and omit it for dataclass_transform
# classes.
# It would be nice if this condition were reified rather than using an `is` check.
# Only add if the class is a dataclasses dataclass, and omit it for dataclass_transform
# classes.
if self._spec is not _TRANSFORM_SPEC_FOR_DATACLASSES:
return

attr_name = "__dataclass_fields__"
any_type = AnyType(TypeOfAny.explicit)
field_type = self._api.named_type_or_none("dataclasses.Field", [any_type]) or any_type
Expand Down Expand Up @@ -657,6 +689,12 @@ def _collect_field_args(self, expr: Expression) -> tuple[bool, dict[str, Express
# the best we can do for now is not to fail.
# TODO: we can infer what's inside `**` and try to collect it.
message = 'Unpacking **kwargs in "field()" is not supported'
elif self._spec is not _TRANSFORM_SPEC_FOR_DATACLASSES:
# dataclasses.field can only be used with keyword args, but this
# restriction is only enforced for the *standardized* arguments to
# dataclass_transform field specifiers. If this is not a
# dataclasses.dataclass class, we can just skip positional args safely.
continue
else:
message = '"field()" does not accept positional arguments'
self._api.fail(message, expr)
Expand Down
30 changes: 28 additions & 2 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@
remove_dups,
type_constructors,
)
from mypy.typeops import function_type, get_type_vars
from mypy.typeops import function_type, get_type_vars, try_getting_str_literals_from_type
from mypy.types import (
ASSERT_TYPE_NAMES,
DATACLASS_TRANSFORM_NAMES,
Expand Down Expand Up @@ -6462,6 +6462,17 @@ def parse_bool(self, expr: Expression) -> bool | None:
return False
return None

def parse_str_literal(self, expr: Expression) -> str | None:
"""Attempt to find the string literal value of the given expression. Returns `None` if no
literal value can be found."""
if isinstance(expr, StrExpr):
return expr.value
if isinstance(expr, RefExpr) and isinstance(expr.node, Var) and expr.node.type is not None:
values = try_getting_str_literals_from_type(expr.node.type)
if values is not None and len(values) == 1:
return values[0]
return None

def set_future_import_flags(self, module_name: str) -> None:
if module_name in FUTURE_IMPORTS:
self.modules[self.cur_mod_id].future_import_flags.add(FUTURE_IMPORTS[module_name])
Expand All @@ -6482,7 +6493,9 @@ def parse_dataclass_transform_spec(self, call: CallExpr) -> DataclassTransformSp
# field_specifiers is currently the only non-boolean argument; check for it first so
# so the rest of the block can fail through to handling booleans
if name == "field_specifiers":
self.fail('"field_specifiers" support is currently unimplemented', call)
parameters.field_specifiers = self.parse_dataclass_transform_field_specifiers(
value
)
continue

boolean = require_bool_literal_argument(self, value, name)
Expand All @@ -6502,6 +6515,19 @@ def parse_dataclass_transform_spec(self, call: CallExpr) -> DataclassTransformSp

return parameters

def parse_dataclass_transform_field_specifiers(self, arg: Expression) -> tuple[str, ...]:
if not isinstance(arg, TupleExpr):
self.fail('"field_specifiers" argument must be a tuple literal', arg)
return tuple()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we generate an error here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch; i was going to let typeshed handle it, but i overlooked that we need to enforce that a literal is passed


names = []
for specifier in arg.items:
if not isinstance(specifier, RefExpr):
self.fail('"field_specifiers" must only contain identifiers', specifier)
return tuple()
names.append(specifier.fullname)
return tuple(names)


def replace_implicit_first_type(sig: FunctionLike, new: Type) -> FunctionLike:
if isinstance(sig, CallableType):
Expand Down
119 changes: 119 additions & 0 deletions test-data/unit/check-dataclass-transform.test
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,125 @@ Foo(5)
[typing fixtures/typing-full.pyi]
[builtins fixtures/dataclasses.pyi]

[case testDataclassTransformFieldSpecifierRejectMalformed]
# flags: --python-version 3.11
from typing import dataclass_transform, Any, Callable, Final, Type

def some_type() -> Type: ...
def some_function() -> Callable[[], None]: ...

def field(*args, **kwargs): ...
def fields_tuple() -> tuple[type | Callable[..., Any], ...]: return (field,)
CONSTANT: Final = (field,)

@dataclass_transform(field_specifiers=(some_type(),)) # E: "field_specifiers" must only contain identifiers
def bad_dataclass1() -> None: ...
@dataclass_transform(field_specifiers=(some_function(),)) # E: "field_specifiers" must only contain identifiers
def bad_dataclass2() -> None: ...
@dataclass_transform(field_specifiers=CONSTANT) # E: "field_specifiers" argument must be a tuple literal
def bad_dataclass3() -> None: ...
@dataclass_transform(field_specifiers=fields_tuple()) # E: "field_specifiers" argument must be a tuple literal
def bad_dataclass4() -> None: ...

[typing fixtures/typing-full.pyi]
[builtins fixtures/dataclasses.pyi]

[case testDataclassTransformFieldSpecifierParams]
# flags: --python-version 3.11
from typing import dataclass_transform, Any, Callable, Type, Final

def field(
*,
init: bool = True,
kw_only: bool = False,
alias: str | None = None,
default: Any | None = None,
default_factory: Callable[[], Any] | None = None,
factory: Callable[[], Any] | None = None,
): ...
@dataclass_transform(field_specifiers=(field,))
def my_dataclass(cls: Type) -> Type:
return cls

B: Final = 'b_'
@my_dataclass
class Foo:
a: int = field(alias='a_')
b: int = field(alias=B)
# cannot be passed as a positional
kwonly: int = field(kw_only=True, default=0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try passing kwonly below.

# Safe to omit from constructor, error to pass
noinit: int = field(init=False, default=1)
# It should be safe to call the constructor without passing any of these
unused1: int = field(default=0)
unused2: int = field(factory=lambda: 0)
unused3: int = field(default_factory=lambda: 0)

Foo(a=5, b_=1) # E: Unexpected keyword argument "a" for "Foo"
Foo(a_=1, b_=1, noinit=1) # E: Unexpected keyword argument "noinit" for "Foo"
Foo(1, 2, 3) # E: Too many positional arguments for "Foo"
foo = Foo(1, 2, kwonly=3)
reveal_type(foo.noinit) # N: Revealed type is "builtins.int"
reveal_type(foo.unused1) # N: Revealed type is "builtins.int"
Foo(a_=5, b_=1, unused1=2, unused2=3, unused3=4)

def some_str() -> str: ...
def some_bool() -> bool: ...
@my_dataclass
class Bad:
bad1: int = field(alias=some_str()) # E: "alias" argument to dataclass field must be a string literal
bad2: int = field(kw_only=some_bool()) # E: "kw_only" argument must be a boolean literal

# this metadata should only exist for dataclasses.dataclass classes
Foo.__dataclass_fields__ # E: "Type[Foo]" has no attribute "__dataclass_fields__"

[typing fixtures/typing-full.pyi]
[builtins fixtures/dataclasses.pyi]

[case testDataclassTransformFieldSpecifierExtraArgs]
# flags: --python-version 3.11
from typing import dataclass_transform

def field(extra1, *, kw_only=False, extra2=0): ...
@dataclass_transform(field_specifiers=(field,))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also test more than one field specifiers.

def my_dataclass(cls):
return cls

@my_dataclass
class Good:
a: int = field(5)
b: int = field(5, extra2=1)
c: int = field(5, kw_only=True)

@my_dataclass
class Bad:
a: int = field(kw_only=True) # E: Missing positional argument "extra1" in call to "field"

[typing fixtures/typing-full.pyi]
[builtins fixtures/dataclasses.pyi]

[case testDataclassTransformMultipleFieldSpecifiers]
# flags: --python-version 3.11
from typing import dataclass_transform

def field1(*, default: int) -> int: ...
def field2(*, default: str) -> str: ...

@dataclass_transform(field_specifiers=(field1, field2))
def my_dataclass(cls): return cls

@my_dataclass
class Foo:
a: int = field1(default=0)
b: str = field2(default='hello')

reveal_type(Foo) # N: Revealed type is "def (a: builtins.int =, b: builtins.str =) -> __main__.Foo"
Foo()
Foo(a=1, b='bye')

[typing fixtures/typing-full.pyi]
[builtins fixtures/dataclasses.pyi]

[case testDataclassTransformOverloadsDecoratorOnOverload]
# flags: --python-version 3.11
from typing import dataclass_transform, overload, Any, Callable, Type, Literal
Expand Down