diff --git a/Lib/annotationlib.py b/Lib/annotationlib.py index 971f636f9714d7..b4cecb41d13c46 100644 --- a/Lib/annotationlib.py +++ b/Lib/annotationlib.py @@ -392,12 +392,19 @@ def binop(self, other): __mod__ = _make_binop(ast.Mod()) __lshift__ = _make_binop(ast.LShift()) __rshift__ = _make_binop(ast.RShift()) - __or__ = _make_binop(ast.BitOr()) __xor__ = _make_binop(ast.BitXor()) __and__ = _make_binop(ast.BitAnd()) __floordiv__ = _make_binop(ast.FloorDiv()) __pow__ = _make_binop(ast.Pow()) + def __or__(self, other): + if self.__stringifier_dict__.create_unions: + return types.UnionType[self, other] + + return self.__make_new( + ast.BinOp(self.__get_ast(), ast.BitOr(), self.__convert_to_ast(other)) + ) + del _make_binop def _make_rbinop(op: ast.AST): @@ -416,12 +423,19 @@ def rbinop(self, other): __rmod__ = _make_rbinop(ast.Mod()) __rlshift__ = _make_rbinop(ast.LShift()) __rrshift__ = _make_rbinop(ast.RShift()) - __ror__ = _make_rbinop(ast.BitOr()) __rxor__ = _make_rbinop(ast.BitXor()) __rand__ = _make_rbinop(ast.BitAnd()) __rfloordiv__ = _make_rbinop(ast.FloorDiv()) __rpow__ = _make_rbinop(ast.Pow()) + def __ror__(self, other): + if self.__stringifier_dict__.create_unions: + return types.UnionType[other, self] + + return self.__make_new( + ast.BinOp(self.__convert_to_ast(other), ast.BitOr(), self.__get_ast()) + ) + del _make_rbinop def _make_compare(op): @@ -459,12 +473,13 @@ def unary_op(self): class _StringifierDict(dict): - def __init__(self, namespace, globals=None, owner=None, is_class=False): + def __init__(self, namespace, globals=None, owner=None, is_class=False, create_unions=False): super().__init__(namespace) self.namespace = namespace self.globals = globals self.owner = owner self.is_class = is_class + self.create_unions = create_unions self.stringifiers = [] def __missing__(self, key): @@ -569,7 +584,13 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False): # that returns a bool and an defined set of attributes. namespace = {**annotate.__builtins__, **annotate.__globals__} is_class = isinstance(owner, type) - globals = _StringifierDict(namespace, annotate.__globals__, owner, is_class) + globals = _StringifierDict( + namespace, + annotate.__globals__, + owner, + is_class, + create_unions=True + ) if annotate.__closure__: freevars = annotate.__code__.co_freevars new_closure = [] diff --git a/Lib/test/test_annotationlib.py b/Lib/test/test_annotationlib.py index 0890be529a7e52..6f17c85659c34e 100644 --- a/Lib/test/test_annotationlib.py +++ b/Lib/test/test_annotationlib.py @@ -115,8 +115,11 @@ def f( self.assertEqual(z_anno, support.EqualToForwardRef("some(module)", owner=f)) alpha_anno = anno["alpha"] - self.assertIsInstance(alpha_anno, ForwardRef) - self.assertEqual(alpha_anno, support.EqualToForwardRef("some | obj", owner=f)) + self.assertIsInstance(alpha_anno, Union) + self.assertEqual( + typing.get_args(alpha_anno), + (support.EqualToForwardRef("some", owner=f), support.EqualToForwardRef("obj", owner=f)) + ) beta_anno = anno["beta"] self.assertIsInstance(beta_anno, ForwardRef) @@ -126,6 +129,27 @@ def f( self.assertIsInstance(gamma_anno, ForwardRef) self.assertEqual(gamma_anno, support.EqualToForwardRef("some < obj", owner=f)) + def test_partially_nonexistent_union(self): + # Test unions with '|' syntax equal unions with typing.Union[] with some forwardrefs + class UnionForwardrefs: + pipe: str | undefined + union: Union[str, undefined] + + annos = get_annotations(UnionForwardrefs, format=Format.FORWARDREF) + + match = ( + str, + support.EqualToForwardRef("undefined", is_class=True, owner=UnionForwardrefs) + ) + + self.assertEqual( + typing.get_args(annos["pipe"]), + typing.get_args(annos["union"]) + ) + + self.assertEqual(typing.get_args(annos["pipe"]), match) + self.assertEqual(typing.get_args(annos["union"]), match) + class TestSourceFormat(unittest.TestCase): def test_closure(self):