Skip to content

bpo-42233: Add union type expression support for GenericAlias and fix de-duplicating of GenericAlias #23077

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Nov 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Include/internal/pycore_unionobject.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ extern "C" {

PyAPI_FUNC(PyObject *) _Py_Union(PyObject *args);
PyAPI_DATA(PyTypeObject) _Py_UnionType;
PyAPI_FUNC(PyObject *) _Py_union_type_or(PyObject* self, PyObject* param);

#ifdef __cplusplus
}
Expand Down
22 changes: 22 additions & 0 deletions Lib/test/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,28 @@ def test_or_type_repr(self):
assert repr(int | None) == "int | None"
assert repr(int | typing.GenericAlias(list, int)) == "int | list[int]"

def test_or_type_operator_with_genericalias(self):
a = list[int]
b = list[str]
c = dict[float, str]
# equivalence with typing.Union
self.assertEqual(a | b | c, typing.Union[a, b, c])
# de-duplicate
self.assertEqual(a | c | b | b | a | c, a | b | c)
# order shouldn't matter
self.assertEqual(a | b, b | a)
self.assertEqual(repr(a | b | c),
"list[int] | list[str] | dict[float, str]")

class BadType(type):
def __eq__(self, other):
return 1 / 0

bt = BadType('bt', (), {})
# Comparison should fail and errors should propagate out for bad types.
with self.assertRaises(ZeroDivisionError):
list[int] | list[bt]

def test_ellipsis_type(self):
self.assertIsInstance(Ellipsis, types.EllipsisType)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Allow ``GenericAlias`` objects to use :ref:`union type expressions <types-union>`.
This allows expressions like ``list[int] | dict[float, str]`` where previously a
``TypeError`` would have been thrown. This also fixes union type expressions
not de-duplicating ``GenericAlias`` objects. (Contributed by Ken Jin in
:issue:`42233`.)
6 changes: 6 additions & 0 deletions Objects/genericaliasobject.c
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "Python.h"
#include "pycore_object.h"
#include "pycore_unionobject.h" // _Py_union_as_number
#include "structmember.h" // PyMemberDef

typedef struct {
Expand Down Expand Up @@ -573,6 +574,10 @@ ga_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
return Py_GenericAlias(origin, arguments);
}

static PyNumberMethods ga_as_number = {
.nb_or = (binaryfunc)_Py_union_type_or, // Add __or__ function
};

// TODO:
// - argument clinic?
// - __doc__?
Expand All @@ -586,6 +591,7 @@ PyTypeObject Py_GenericAliasType = {
.tp_basicsize = sizeof(gaobject),
.tp_dealloc = ga_dealloc,
.tp_repr = ga_repr,
.tp_as_number = &ga_as_number, // allow X | Y of GenericAlias objs
.tp_as_mapping = &ga_as_mapping,
.tp_hash = ga_hash,
.tp_call = ga_call,
Expand Down
14 changes: 2 additions & 12 deletions Objects/typeobject.c
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include "pycore_object.h"
#include "pycore_pyerrors.h"
#include "pycore_pystate.h" // _PyThreadState_GET()
#include "pycore_unionobject.h" // _Py_Union()
#include "pycore_unionobject.h" // _Py_Union(), _Py_union_type_or
#include "frameobject.h"
#include "structmember.h" // PyMemberDef

Expand Down Expand Up @@ -3747,19 +3747,9 @@ type_is_gc(PyTypeObject *type)
return type->tp_flags & Py_TPFLAGS_HEAPTYPE;
}

static PyObject *
type_or(PyTypeObject* self, PyObject* param) {
PyObject *tuple = PyTuple_Pack(2, self, param);
if (tuple == NULL) {
return NULL;
}
PyObject *new_union = _Py_Union(tuple);
Py_DECREF(tuple);
return new_union;
}

static PyNumberMethods type_as_number = {
.nb_or = (binaryfunc)type_or, // Add __or__ function
.nb_or = _Py_union_type_or, // Add __or__ function
};

PyTypeObject PyType_Type = {
Expand Down
20 changes: 15 additions & 5 deletions Objects/unionobject.c
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,19 @@ dedup_and_flatten_args(PyObject* args)
PyObject* i_element = PyTuple_GET_ITEM(args, i);
for (Py_ssize_t j = i + 1; j < arg_length; j++) {
PyObject* j_element = PyTuple_GET_ITEM(args, j);
if (i_element == j_element) {
is_duplicate = 1;
int is_ga = Py_TYPE(i_element) == &Py_GenericAliasType &&
Py_TYPE(j_element) == &Py_GenericAliasType;
// RichCompare to also deduplicate GenericAlias types (slower)
is_duplicate = is_ga ? PyObject_RichCompareBool(i_element, j_element, Py_EQ)
: i_element == j_element;
// Should only happen if RichCompare fails
if (is_duplicate < 0) {
Py_DECREF(args);
Py_DECREF(new_args);
return NULL;
}
if (is_duplicate)
break;
}
if (!is_duplicate) {
Py_INCREF(i_element);
Expand Down Expand Up @@ -290,8 +300,8 @@ is_unionable(PyObject *obj)
type == &_Py_UnionType);
}

static PyObject *
type_or(PyTypeObject* self, PyObject* param)
PyObject *
_Py_union_type_or(PyObject* self, PyObject* param)
{
PyObject *tuple = PyTuple_Pack(2, self, param);
if (tuple == NULL) {
Expand Down Expand Up @@ -404,7 +414,7 @@ static PyMethodDef union_methods[] = {
{0}};

static PyNumberMethods union_as_number = {
.nb_or = (binaryfunc)type_or, // Add __or__ function
.nb_or = _Py_union_type_or, // Add __or__ function
};

PyTypeObject _Py_UnionType = {
Expand Down