Skip to content

Commit bb5ec6e

Browse files
JelleZijlstrasobolevncarljm
authored
gh-82129: Improve annotations for make_dataclass() (#133406)
Co-authored-by: sobolevn <[email protected]> Co-authored-by: Carl Meyer <[email protected]>
1 parent 4e498d1 commit bb5ec6e

File tree

3 files changed

+97
-12
lines changed

3 files changed

+97
-12
lines changed

Lib/dataclasses.py

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,10 @@ def __repr__(self):
244244
property,
245245
})
246246

247+
# Any marker is used in `make_dataclass` to mark unannotated fields as `Any`
248+
# without importing `typing` module.
249+
_ANY_MARKER = object()
250+
247251

248252
class InitVar:
249253
__slots__ = ('type', )
@@ -1591,7 +1595,7 @@ class C(Base):
15911595
for item in fields:
15921596
if isinstance(item, str):
15931597
name = item
1594-
tp = 'typing.Any'
1598+
tp = _ANY_MARKER
15951599
elif len(item) == 2:
15961600
name, tp, = item
15971601
elif len(item) == 3:
@@ -1610,15 +1614,49 @@ class C(Base):
16101614
seen.add(name)
16111615
annotations[name] = tp
16121616

1617+
# We initially block the VALUE format, because inside dataclass() we'll
1618+
# call get_annotations(), which will try the VALUE format first. If we don't
1619+
# block, that means we'd always end up eagerly importing typing here, which
1620+
# is what we're trying to avoid.
1621+
value_blocked = True
1622+
1623+
def annotate_method(format):
1624+
def get_any():
1625+
match format:
1626+
case annotationlib.Format.STRING:
1627+
return 'typing.Any'
1628+
case annotationlib.Format.FORWARDREF:
1629+
typing = sys.modules.get("typing")
1630+
if typing is None:
1631+
return annotationlib.ForwardRef("Any", module="typing")
1632+
else:
1633+
return typing.Any
1634+
case annotationlib.Format.VALUE:
1635+
if value_blocked:
1636+
raise NotImplementedError
1637+
from typing import Any
1638+
return Any
1639+
case _:
1640+
raise NotImplementedError
1641+
annos = {
1642+
ann: get_any() if t is _ANY_MARKER else t
1643+
for ann, t in annotations.items()
1644+
}
1645+
if format == annotationlib.Format.STRING:
1646+
return annotationlib.annotations_to_string(annos)
1647+
else:
1648+
return annos
1649+
16131650
# Update 'ns' with the user-supplied namespace plus our calculated values.
16141651
def exec_body_callback(ns):
16151652
ns.update(namespace)
16161653
ns.update(defaults)
1617-
ns['__annotations__'] = annotations
16181654

16191655
# We use `types.new_class()` instead of simply `type()` to allow dynamic creation
16201656
# of generic dataclasses.
16211657
cls = types.new_class(cls_name, bases, {}, exec_body_callback)
1658+
# For now, set annotations including the _ANY_MARKER.
1659+
cls.__annotate__ = annotate_method
16221660

16231661
# For pickling to work, the __module__ variable needs to be set to the frame
16241662
# where the dataclass is created.
@@ -1634,10 +1672,13 @@ def exec_body_callback(ns):
16341672
cls.__module__ = module
16351673

16361674
# Apply the normal provided decorator.
1637-
return decorator(cls, init=init, repr=repr, eq=eq, order=order,
1638-
unsafe_hash=unsafe_hash, frozen=frozen,
1639-
match_args=match_args, kw_only=kw_only, slots=slots,
1640-
weakref_slot=weakref_slot)
1675+
cls = decorator(cls, init=init, repr=repr, eq=eq, order=order,
1676+
unsafe_hash=unsafe_hash, frozen=frozen,
1677+
match_args=match_args, kw_only=kw_only, slots=slots,
1678+
weakref_slot=weakref_slot)
1679+
# Now that the class is ready, allow the VALUE format.
1680+
value_blocked = False
1681+
return cls
16411682

16421683

16431684
def replace(obj, /, **changes):

Lib/test/test_dataclasses/__init__.py

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55
from dataclasses import *
66

77
import abc
8+
import annotationlib
89
import io
910
import pickle
1011
import inspect
1112
import builtins
1213
import types
1314
import weakref
1415
import traceback
16+
import sys
1517
import textwrap
1618
import unittest
1719
from unittest.mock import Mock
@@ -25,6 +27,7 @@
2527
import dataclasses # Needed for the string "dataclasses.InitVar[int]" to work as an annotation.
2628

2729
from test import support
30+
from test.support import import_helper
2831

2932
# Just any custom exception we can catch.
3033
class CustomError(Exception): pass
@@ -3754,7 +3757,6 @@ class A(WithDictSlot): ...
37543757
@support.cpython_only
37553758
def test_dataclass_slot_dict_ctype(self):
37563759
# https://github.com/python/cpython/issues/123935
3757-
from test.support import import_helper
37583760
# Skips test if `_testcapi` is not present:
37593761
_testcapi = import_helper.import_module('_testcapi')
37603762

@@ -4246,16 +4248,56 @@ def test_no_types(self):
42464248
C = make_dataclass('Point', ['x', 'y', 'z'])
42474249
c = C(1, 2, 3)
42484250
self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
4249-
self.assertEqual(C.__annotations__, {'x': 'typing.Any',
4250-
'y': 'typing.Any',
4251-
'z': 'typing.Any'})
4251+
self.assertEqual(C.__annotations__, {'x': typing.Any,
4252+
'y': typing.Any,
4253+
'z': typing.Any})
42524254

42534255
C = make_dataclass('Point', ['x', ('y', int), 'z'])
42544256
c = C(1, 2, 3)
42554257
self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
4256-
self.assertEqual(C.__annotations__, {'x': 'typing.Any',
4258+
self.assertEqual(C.__annotations__, {'x': typing.Any,
42574259
'y': int,
4258-
'z': 'typing.Any'})
4260+
'z': typing.Any})
4261+
4262+
def test_no_types_get_annotations(self):
4263+
C = make_dataclass('C', ['x', ('y', int), 'z'])
4264+
4265+
self.assertEqual(
4266+
annotationlib.get_annotations(C, format=annotationlib.Format.VALUE),
4267+
{'x': typing.Any, 'y': int, 'z': typing.Any},
4268+
)
4269+
self.assertEqual(
4270+
annotationlib.get_annotations(
4271+
C, format=annotationlib.Format.FORWARDREF),
4272+
{'x': typing.Any, 'y': int, 'z': typing.Any},
4273+
)
4274+
self.assertEqual(
4275+
annotationlib.get_annotations(
4276+
C, format=annotationlib.Format.STRING),
4277+
{'x': 'typing.Any', 'y': 'int', 'z': 'typing.Any'},
4278+
)
4279+
4280+
def test_no_types_no_typing_import(self):
4281+
with import_helper.CleanImport('typing'):
4282+
self.assertNotIn('typing', sys.modules)
4283+
C = make_dataclass('C', ['x', ('y', int)])
4284+
4285+
self.assertNotIn('typing', sys.modules)
4286+
self.assertEqual(
4287+
C.__annotate__(annotationlib.Format.FORWARDREF),
4288+
{
4289+
'x': annotationlib.ForwardRef('Any', module='typing'),
4290+
'y': int,
4291+
},
4292+
)
4293+
self.assertNotIn('typing', sys.modules)
4294+
4295+
for field in fields(C):
4296+
if field.name == "x":
4297+
self.assertEqual(field.type, annotationlib.ForwardRef('Any', module='typing'))
4298+
else:
4299+
self.assertEqual(field.name, "y")
4300+
self.assertIs(field.type, int)
42594301

42604302
def test_module_attr(self):
42614303
self.assertEqual(ByMakeDataClass.__module__, __name__)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Fix :exc:`NameError` when calling :func:`typing.get_type_hints` on a :func:`dataclasses.dataclass` created by
2+
:func:`dataclasses.make_dataclass` with un-annotated fields.

0 commit comments

Comments
 (0)