Skip to content

Commit 4eb41d0

Browse files
bpo-42233: Add union type expression support for GenericAlias and fix de-duplicating of GenericAlias (GH-23077)
1 parent 23831a7 commit 4eb41d0

File tree

6 files changed

+51
-17
lines changed

6 files changed

+51
-17
lines changed

Include/internal/pycore_unionobject.h

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ extern "C" {
1010

1111
PyAPI_FUNC(PyObject *) _Py_Union(PyObject *args);
1212
PyAPI_DATA(PyTypeObject) _Py_UnionType;
13+
PyAPI_FUNC(PyObject *) _Py_union_type_or(PyObject* self, PyObject* param);
1314

1415
#ifdef __cplusplus
1516
}

Lib/test/test_types.py

+22
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,28 @@ def test_or_type_repr(self):
713713
assert repr(int | None) == "int | None"
714714
assert repr(int | typing.GenericAlias(list, int)) == "int | list[int]"
715715

716+
def test_or_type_operator_with_genericalias(self):
717+
a = list[int]
718+
b = list[str]
719+
c = dict[float, str]
720+
# equivalence with typing.Union
721+
self.assertEqual(a | b | c, typing.Union[a, b, c])
722+
# de-duplicate
723+
self.assertEqual(a | c | b | b | a | c, a | b | c)
724+
# order shouldn't matter
725+
self.assertEqual(a | b, b | a)
726+
self.assertEqual(repr(a | b | c),
727+
"list[int] | list[str] | dict[float, str]")
728+
729+
class BadType(type):
730+
def __eq__(self, other):
731+
return 1 / 0
732+
733+
bt = BadType('bt', (), {})
734+
# Comparison should fail and errors should propagate out for bad types.
735+
with self.assertRaises(ZeroDivisionError):
736+
list[int] | list[bt]
737+
716738
def test_ellipsis_type(self):
717739
self.assertIsInstance(Ellipsis, types.EllipsisType)
718740

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
Allow ``GenericAlias`` objects to use :ref:`union type expressions <types-union>`.
2+
This allows expressions like ``list[int] | dict[float, str]`` where previously a
3+
``TypeError`` would have been thrown. This also fixes union type expressions
4+
not de-duplicating ``GenericAlias`` objects. (Contributed by Ken Jin in
5+
:issue:`42233`.)

Objects/genericaliasobject.c

+6
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "Python.h"
44
#include "pycore_object.h"
5+
#include "pycore_unionobject.h" // _Py_union_as_number
56
#include "structmember.h" // PyMemberDef
67

78
typedef struct {
@@ -573,6 +574,10 @@ ga_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
573574
return Py_GenericAlias(origin, arguments);
574575
}
575576

577+
static PyNumberMethods ga_as_number = {
578+
.nb_or = (binaryfunc)_Py_union_type_or, // Add __or__ function
579+
};
580+
576581
// TODO:
577582
// - argument clinic?
578583
// - __doc__?
@@ -586,6 +591,7 @@ PyTypeObject Py_GenericAliasType = {
586591
.tp_basicsize = sizeof(gaobject),
587592
.tp_dealloc = ga_dealloc,
588593
.tp_repr = ga_repr,
594+
.tp_as_number = &ga_as_number, // allow X | Y of GenericAlias objs
589595
.tp_as_mapping = &ga_as_mapping,
590596
.tp_hash = ga_hash,
591597
.tp_call = ga_call,

Objects/typeobject.c

+2-12
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
#include "pycore_object.h"
77
#include "pycore_pyerrors.h"
88
#include "pycore_pystate.h" // _PyThreadState_GET()
9-
#include "pycore_unionobject.h" // _Py_Union()
9+
#include "pycore_unionobject.h" // _Py_Union(), _Py_union_type_or
1010
#include "frameobject.h"
1111
#include "structmember.h" // PyMemberDef
1212

@@ -3789,19 +3789,9 @@ type_is_gc(PyTypeObject *type)
37893789
return type->tp_flags & Py_TPFLAGS_HEAPTYPE;
37903790
}
37913791

3792-
static PyObject *
3793-
type_or(PyTypeObject* self, PyObject* param) {
3794-
PyObject *tuple = PyTuple_Pack(2, self, param);
3795-
if (tuple == NULL) {
3796-
return NULL;
3797-
}
3798-
PyObject *new_union = _Py_Union(tuple);
3799-
Py_DECREF(tuple);
3800-
return new_union;
3801-
}
38023792

38033793
static PyNumberMethods type_as_number = {
3804-
.nb_or = (binaryfunc)type_or, // Add __or__ function
3794+
.nb_or = _Py_union_type_or, // Add __or__ function
38053795
};
38063796

38073797
PyTypeObject PyType_Type = {

Objects/unionobject.c

+15-5
Original file line numberDiff line numberDiff line change
@@ -237,9 +237,19 @@ dedup_and_flatten_args(PyObject* args)
237237
PyObject* i_element = PyTuple_GET_ITEM(args, i);
238238
for (Py_ssize_t j = i + 1; j < arg_length; j++) {
239239
PyObject* j_element = PyTuple_GET_ITEM(args, j);
240-
if (i_element == j_element) {
241-
is_duplicate = 1;
240+
int is_ga = Py_TYPE(i_element) == &Py_GenericAliasType &&
241+
Py_TYPE(j_element) == &Py_GenericAliasType;
242+
// RichCompare to also deduplicate GenericAlias types (slower)
243+
is_duplicate = is_ga ? PyObject_RichCompareBool(i_element, j_element, Py_EQ)
244+
: i_element == j_element;
245+
// Should only happen if RichCompare fails
246+
if (is_duplicate < 0) {
247+
Py_DECREF(args);
248+
Py_DECREF(new_args);
249+
return NULL;
242250
}
251+
if (is_duplicate)
252+
break;
243253
}
244254
if (!is_duplicate) {
245255
Py_INCREF(i_element);
@@ -290,8 +300,8 @@ is_unionable(PyObject *obj)
290300
type == &_Py_UnionType);
291301
}
292302

293-
static PyObject *
294-
type_or(PyTypeObject* self, PyObject* param)
303+
PyObject *
304+
_Py_union_type_or(PyObject* self, PyObject* param)
295305
{
296306
PyObject *tuple = PyTuple_Pack(2, self, param);
297307
if (tuple == NULL) {
@@ -404,7 +414,7 @@ static PyMethodDef union_methods[] = {
404414
{0}};
405415

406416
static PyNumberMethods union_as_number = {
407-
.nb_or = (binaryfunc)type_or, // Add __or__ function
417+
.nb_or = _Py_union_type_or, // Add __or__ function
408418
};
409419

410420
PyTypeObject _Py_UnionType = {

0 commit comments

Comments
 (0)