From ae90f5682c986ffab1496eb690d0c650d4e582a5 Mon Sep 17 00:00:00 2001 From: Kalle Tuure Date: Wed, 18 May 2016 08:30:26 +0300 Subject: [PATCH] Stop generic subclasses from inheriting __extra__ Fixes several issues related to subclass checks against custom subclasses of generic collections. --- python2/test_typing.py | 25 ++++++++++++++++++++++--- python2/typing.py | 14 ++++++++------ src/test_typing.py | 25 ++++++++++++++++++++++--- src/typing.py | 20 +++++++++----------- 4 files changed, 61 insertions(+), 23 deletions(-) diff --git a/python2/test_typing.py b/python2/test_typing.py index 0994662f..c11ba221 100644 --- a/python2/test_typing.py +++ b/python2/test_typing.py @@ -1,5 +1,6 @@ from __future__ import absolute_import, unicode_literals +import collections import pickle import re import sys @@ -970,13 +971,17 @@ def test_no_list_instantiation(self): with self.assertRaises(TypeError): typing.List[int]() - def test_list_subclass_instantiation(self): + def test_list_subclass(self): class MyList(typing.List[int]): pass a = MyList() self.assertIsInstance(a, MyList) + self.assertIsInstance(a, typing.Sequence) + + self.assertIsSubclass(MyList, list) + self.assertNotIsSubclass(list, MyList) def test_no_dict_instantiation(self): with self.assertRaises(TypeError): @@ -986,13 +991,17 @@ def test_no_dict_instantiation(self): with self.assertRaises(TypeError): typing.Dict[str, int]() - def test_dict_subclass_instantiation(self): + def test_dict_subclass(self): class MyDict(typing.Dict[str, int]): pass d = MyDict() self.assertIsInstance(d, MyDict) + self.assertIsInstance(d, typing.MutableMapping) + + self.assertIsSubclass(MyDict, dict) + self.assertNotIsSubclass(dict, MyDict) def test_no_defaultdict_instantiation(self): with self.assertRaises(TypeError): @@ -1002,7 +1011,7 @@ def test_no_defaultdict_instantiation(self): with self.assertRaises(TypeError): typing.DefaultDict[str, int]() - def test_defaultdict_subclass_instantiation(self): + def test_defaultdict_subclass(self): class MyDefDict(typing.DefaultDict[str, int]): pass @@ -1010,6 +1019,9 @@ class MyDefDict(typing.DefaultDict[str, int]): dd = MyDefDict() self.assertIsInstance(dd, MyDefDict) + self.assertIsSubclass(MyDefDict, collections.defaultdict) + self.assertNotIsSubclass(collections.defaultdict, MyDefDict) + def test_no_set_instantiation(self): with self.assertRaises(TypeError): typing.Set() @@ -1090,6 +1102,13 @@ def __len__(self): self.assertEqual(len(MMB[str, str]()), 0) self.assertEqual(len(MMB[KT, VT]()), 0) + self.assertNotIsSubclass(dict, MMA) + self.assertNotIsSubclass(dict, MMB) + + self.assertIsSubclass(MMA, typing.Mapping) + self.assertIsSubclass(MMB, typing.Mapping) + self.assertIsSubclass(MMC, typing.Mapping) + class NamedTupleTests(BaseTestCase): diff --git a/python2/typing.py b/python2/typing.py index 57af77c5..44eb9f2e 100644 --- a/python2/typing.py +++ b/python2/typing.py @@ -911,8 +911,6 @@ def _next_in_mro(cls): class GenericMeta(TypingMeta, abc.ABCMeta): """Metaclass for generic types.""" - __extra__ = None - def __new__(cls, name, bases, namespace, tvars=None, args=None, origin=None, extra=None): self = super(GenericMeta, cls).__new__(cls, name, bases, namespace) @@ -960,10 +958,7 @@ def __new__(cls, name, bases, namespace, self.__parameters__ = tvars self.__args__ = args self.__origin__ = origin - if extra is not None: - self.__extra__ = extra - # Else __extra__ is inherited, eventually from the - # (meta-)class default above. + self.__extra__ = namespace.get('__extra__') # Speed hack (https://github.com/python/typing/issues/196). self.__next_in_mro__ = _next_in_mro(self) return self @@ -1289,6 +1284,7 @@ def _get_protocol_attrs(self): attr != '__next_in_mro__' and attr != '__parameters__' and attr != '__origin__' and + attr != '__extra__' and attr != '__module__'): attrs.add(attr) @@ -1414,10 +1410,12 @@ class ByteString(Sequence[int]): pass +ByteString.register(str) ByteString.register(bytearray) class List(list, MutableSequence[T]): + __extra__ = list def __new__(cls, *args, **kwds): if _geqv(cls, List): @@ -1427,6 +1425,7 @@ def __new__(cls, *args, **kwds): class Set(set, MutableSet[T]): + __extra__ = set def __new__(cls, *args, **kwds): if _geqv(cls, Set): @@ -1452,6 +1451,7 @@ def __subclasscheck__(self, cls): class FrozenSet(frozenset, AbstractSet[T_co]): __metaclass__ = _FrozenSetMeta __slots__ = () + __extra__ = frozenset def __new__(cls, *args, **kwds): if _geqv(cls, FrozenSet): @@ -1479,6 +1479,7 @@ class ValuesView(MappingView[VT_co]): class Dict(dict, MutableMapping[KT, VT]): + __extra__ = dict def __new__(cls, *args, **kwds): if _geqv(cls, Dict): @@ -1488,6 +1489,7 @@ def __new__(cls, *args, **kwds): class DefaultDict(collections.defaultdict, MutableMapping[KT, VT]): + __extra__ = collections.defaultdict def __new__(cls, *args, **kwds): if _geqv(cls, DefaultDict): diff --git a/src/test_typing.py b/src/test_typing.py index 90bad775..f9e54b27 100644 --- a/src/test_typing.py +++ b/src/test_typing.py @@ -1,4 +1,5 @@ import contextlib +import collections import pickle import re import sys @@ -1218,13 +1219,17 @@ def test_no_list_instantiation(self): with self.assertRaises(TypeError): typing.List[int]() - def test_list_subclass_instantiation(self): + def test_list_subclass(self): class MyList(typing.List[int]): pass a = MyList() self.assertIsInstance(a, MyList) + self.assertIsInstance(a, typing.Sequence) + + self.assertIsSubclass(MyList, list) + self.assertNotIsSubclass(list, MyList) def test_no_dict_instantiation(self): with self.assertRaises(TypeError): @@ -1234,13 +1239,17 @@ def test_no_dict_instantiation(self): with self.assertRaises(TypeError): typing.Dict[str, int]() - def test_dict_subclass_instantiation(self): + def test_dict_subclass(self): class MyDict(typing.Dict[str, int]): pass d = MyDict() self.assertIsInstance(d, MyDict) + self.assertIsInstance(d, typing.MutableMapping) + + self.assertIsSubclass(MyDict, dict) + self.assertNotIsSubclass(dict, MyDict) def test_no_defaultdict_instantiation(self): with self.assertRaises(TypeError): @@ -1250,7 +1259,7 @@ def test_no_defaultdict_instantiation(self): with self.assertRaises(TypeError): typing.DefaultDict[str, int]() - def test_defaultdict_subclass_instantiation(self): + def test_defaultdict_subclass(self): class MyDefDict(typing.DefaultDict[str, int]): pass @@ -1258,6 +1267,9 @@ class MyDefDict(typing.DefaultDict[str, int]): dd = MyDefDict() self.assertIsInstance(dd, MyDefDict) + self.assertIsSubclass(MyDefDict, collections.defaultdict) + self.assertNotIsSubclass(collections.defaultdict, MyDefDict) + def test_no_set_instantiation(self): with self.assertRaises(TypeError): typing.Set() @@ -1338,6 +1350,13 @@ def __len__(self): self.assertEqual(len(MMB[str, str]()), 0) self.assertEqual(len(MMB[KT, VT]()), 0) + self.assertNotIsSubclass(dict, MMA) + self.assertNotIsSubclass(dict, MMB) + + self.assertIsSubclass(MMA, typing.Mapping) + self.assertIsSubclass(MMB, typing.Mapping) + self.assertIsSubclass(MMC, typing.Mapping) + class OtherABCTests(BaseTestCase): diff --git a/src/typing.py b/src/typing.py index d2750111..841e7786 100644 --- a/src/typing.py +++ b/src/typing.py @@ -894,8 +894,6 @@ def _next_in_mro(cls): class GenericMeta(TypingMeta, abc.ABCMeta): """Metaclass for generic types.""" - __extra__ = None - def __new__(cls, name, bases, namespace, tvars=None, args=None, origin=None, extra=None): self = super().__new__(cls, name, bases, namespace, _root=True) @@ -943,10 +941,7 @@ def __new__(cls, name, bases, namespace, self.__parameters__ = tvars self.__args__ = args self.__origin__ = origin - if extra is not None: - self.__extra__ = extra - # Else __extra__ is inherited, eventually from the - # (meta-)class default above. + self.__extra__ = extra # Speed hack (https://github.com/python/typing/issues/196). self.__next_in_mro__ = _next_in_mro(self) return self @@ -1307,6 +1302,7 @@ def _get_protocol_attrs(self): attr != '__next_in_mro__' and attr != '__parameters__' and attr != '__origin__' and + attr != '__extra__' and attr != '__module__'): attrs.add(attr) @@ -1470,7 +1466,7 @@ class ByteString(Sequence[int], extra=collections_abc.ByteString): ByteString.register(type(memoryview(b''))) -class List(list, MutableSequence[T]): +class List(list, MutableSequence[T], extra=list): def __new__(cls, *args, **kwds): if _geqv(cls, List): @@ -1479,7 +1475,7 @@ def __new__(cls, *args, **kwds): return list.__new__(cls, *args, **kwds) -class Set(set, MutableSet[T]): +class Set(set, MutableSet[T], extra=set): def __new__(cls, *args, **kwds): if _geqv(cls, Set): @@ -1502,7 +1498,8 @@ def __subclasscheck__(self, cls): return super().__subclasscheck__(cls) -class FrozenSet(frozenset, AbstractSet[T_co], metaclass=_FrozenSetMeta): +class FrozenSet(frozenset, AbstractSet[T_co], metaclass=_FrozenSetMeta, + extra=frozenset): __slots__ = () def __new__(cls, *args, **kwds): @@ -1538,7 +1535,7 @@ class ContextManager(Generic[T_co], extra=contextlib.AbstractContextManager): __all__.append('ContextManager') -class Dict(dict, MutableMapping[KT, VT]): +class Dict(dict, MutableMapping[KT, VT], extra=dict): def __new__(cls, *args, **kwds): if _geqv(cls, Dict): @@ -1546,7 +1543,8 @@ def __new__(cls, *args, **kwds): "use dict() instead") return dict.__new__(cls, *args, **kwds) -class DefaultDict(collections.defaultdict, MutableMapping[KT, VT]): +class DefaultDict(collections.defaultdict, MutableMapping[KT, VT], + extra=collections.defaultdict): def __new__(cls, *args, **kwds): if _geqv(cls, DefaultDict):