diff --git a/python2/test_typing.py b/python2/test_typing.py index 8dd1acb0..dca9a4f5 100644 --- a/python2/test_typing.py +++ b/python2/test_typing.py @@ -598,6 +598,62 @@ class MM1(MutableMapping[str, str], collections_abc.MutableMapping): class MM2(collections_abc.MutableMapping, MutableMapping[str, str]): pass + def test_orig_bases(self): + T = TypeVar('T') + class C(typing.Dict[str, T]): pass + self.assertEqual(C.__orig_bases__, (typing.Dict[str, T],)) + + def test_naive_runtime_checks(self): + def naive_dict_check(obj, tp): + # Check if a dictionary conforms to Dict type + if len(tp.__parameters__) > 0: + raise NotImplementedError + if tp.__args__: + KT, VT = tp.__args__ + return all(isinstance(k, KT) and isinstance(v, VT) + for k, v in obj.items()) + self.assertTrue(naive_dict_check({'x': 1}, typing.Dict[typing.Text, int])) + self.assertFalse(naive_dict_check({1: 'x'}, typing.Dict[typing.Text, int])) + with self.assertRaises(NotImplementedError): + naive_dict_check({1: 'x'}, typing.Dict[typing.Text, T]) + + def naive_generic_check(obj, tp): + # Check if an instance conforms to the generic class + if not hasattr(obj, '__orig_class__'): + raise NotImplementedError + return obj.__orig_class__ == tp + class Node(Generic[T]): pass + self.assertTrue(naive_generic_check(Node[int](), Node[int])) + self.assertFalse(naive_generic_check(Node[str](), Node[int])) + self.assertFalse(naive_generic_check(Node[str](), List)) + with self.assertRaises(NotImplementedError): + naive_generic_check([1,2,3], Node[int]) + + def naive_list_base_check(obj, tp): + # Check if list conforms to a List subclass + return all(isinstance(x, tp.__orig_bases__[0].__args__[0]) + for x in obj) + class C(List[int]): pass + self.assertTrue(naive_list_base_check([1, 2, 3], C)) + self.assertFalse(naive_list_base_check(['a', 'b'], C)) + + def test_multi_subscr_base(self): + T = TypeVar('T') + U = TypeVar('U') + V = TypeVar('V') + class C(List[T][U][V]): pass + class D(C, List[T][U][V]): pass + self.assertEqual(C.__parameters__, (V,)) + self.assertEqual(D.__parameters__, (V,)) + self.assertEqual(C[int].__parameters__, ()) + self.assertEqual(D[int].__parameters__, ()) + self.assertEqual(C[int].__args__, (int,)) + self.assertEqual(D[int].__args__, (int,)) + self.assertEqual(C.__bases__, (List,)) + self.assertEqual(D.__bases__, (C, List)) + self.assertEqual(C.__orig_bases__, (List[T][U][V],)) + self.assertEqual(D.__orig_bases__, (C, List[T][U][V])) + def test_pickle(self): global C # pickle wants to reference the class by name T = TypeVar('T') diff --git a/python2/typing.py b/python2/typing.py index cb5b3edd..0bec764a 100644 --- a/python2/typing.py +++ b/python2/typing.py @@ -1045,13 +1045,7 @@ class GenericMeta(TypingMeta, abc.ABCMeta): """Metaclass for generic types.""" def __new__(cls, name, bases, namespace, - tvars=None, args=None, origin=None, extra=None): - if extra is None: - extra = namespace.get('__extra__') - if extra is not None and type(extra) is abc.ABCMeta and extra not in bases: - bases = (extra,) + bases - self = super(GenericMeta, cls).__new__(cls, name, bases, namespace) - + tvars=None, args=None, origin=None, extra=None, orig_bases=None): if tvars is not None: # Called from __getitem__() below. assert origin is not None @@ -1092,12 +1086,27 @@ def __new__(cls, name, bases, namespace, ", ".join(str(g) for g in gvars))) tvars = gvars + initial_bases = bases + if extra is None: + extra = namespace.get('__extra__') + if extra is not None and type(extra) is abc.ABCMeta and extra not in bases: + bases = (extra,) + bases + bases = tuple(_gorg(b) if isinstance(b, GenericMeta) else b for b in bases) + + # remove bare Generic from bases if there are other generic bases + if any(isinstance(b, GenericMeta) and b is not Generic for b in bases): + bases = tuple(b for b in bases if b is not Generic) + self = super(GenericMeta, cls).__new__(cls, name, bases, namespace) + self.__parameters__ = tvars self.__args__ = args self.__origin__ = origin self.__extra__ = extra # Speed hack (https://github.com/python/typing/issues/196). self.__next_in_mro__ = _next_in_mro(self) + # Preserve base classes on subclassing (__bases__ are type erased now). + if orig_bases is None: + self.__orig_bases__ = initial_bases # This allows unparameterized generic collections to be used # with issubclass() and isinstance() in the same way as their @@ -1180,12 +1189,13 @@ def __getitem__(self, params): tvars = _type_vars(params) args = params return self.__class__(self.__name__, - (self,) + self.__bases__, + self.__bases__, dict(self.__dict__), tvars=tvars, args=args, origin=self, - extra=self.__extra__) + extra=self.__extra__, + orig_bases=self.__orig_bases__) def __instancecheck__(self, instance): # Since we extend ABC.__subclasscheck__ and @@ -1232,6 +1242,10 @@ def __new__(cls, *args, **kwds): else: origin = _gorg(cls) obj = cls.__next_in_mro__.__new__(origin) + try: + obj.__orig_class__ = cls + except AttributeError: + pass obj.__init__(*args, **kwds) return obj @@ -1402,6 +1416,7 @@ def _get_protocol_attrs(self): attr != '__next_in_mro__' and attr != '__parameters__' and attr != '__origin__' and + attr != '__orig_bases__' and attr != '__extra__' and attr != '__module__'): attrs.add(attr) diff --git a/src/test_typing.py b/src/test_typing.py index dff737ae..bf7053b9 100644 --- a/src/test_typing.py +++ b/src/test_typing.py @@ -625,6 +625,63 @@ class MM1(MutableMapping[str, str], collections_abc.MutableMapping): class MM2(collections_abc.MutableMapping, MutableMapping[str, str]): pass + def test_orig_bases(self): + T = TypeVar('T') + class C(typing.Dict[str, T]): ... + self.assertEqual(C.__orig_bases__, (typing.Dict[str, T],)) + + def test_naive_runtime_checks(self): + def naive_dict_check(obj, tp): + # Check if a dictionary conforms to Dict type + if len(tp.__parameters__) > 0: + raise NotImplementedError + if tp.__args__: + KT, VT = tp.__args__ + return all(isinstance(k, KT) and isinstance(v, VT) + for k, v in obj.items()) + self.assertTrue(naive_dict_check({'x': 1}, typing.Dict[str, int])) + self.assertFalse(naive_dict_check({1: 'x'}, typing.Dict[str, int])) + with self.assertRaises(NotImplementedError): + naive_dict_check({1: 'x'}, typing.Dict[str, T]) + + def naive_generic_check(obj, tp): + # Check if an instance conforms to the generic class + if not hasattr(obj, '__orig_class__'): + raise NotImplementedError + return obj.__orig_class__ == tp + class Node(Generic[T]): ... + self.assertTrue(naive_generic_check(Node[int](), Node[int])) + self.assertFalse(naive_generic_check(Node[str](), Node[int])) + self.assertFalse(naive_generic_check(Node[str](), List)) + with self.assertRaises(NotImplementedError): + naive_generic_check([1,2,3], Node[int]) + + def naive_list_base_check(obj, tp): + # Check if list conforms to a List subclass + return all(isinstance(x, tp.__orig_bases__[0].__args__[0]) + for x in obj) + class C(List[int]): ... + self.assertTrue(naive_list_base_check([1, 2, 3], C)) + self.assertFalse(naive_list_base_check(['a', 'b'], C)) + + def test_multi_subscr_base(self): + T = TypeVar('T') + U = TypeVar('U') + V = TypeVar('V') + class C(List[T][U][V]): ... + class D(C, List[T][U][V]): ... + self.assertEqual(C.__parameters__, (V,)) + self.assertEqual(D.__parameters__, (V,)) + self.assertEqual(C[int].__parameters__, ()) + self.assertEqual(D[int].__parameters__, ()) + self.assertEqual(C[int].__args__, (int,)) + self.assertEqual(D[int].__args__, (int,)) + self.assertEqual(C.__bases__, (List,)) + self.assertEqual(D.__bases__, (C, List)) + self.assertEqual(C.__orig_bases__, (List[T][U][V],)) + self.assertEqual(D.__orig_bases__, (C, List[T][U][V])) + + def test_pickle(self): global C # pickle wants to reference the class by name T = TypeVar('T') diff --git a/src/typing.py b/src/typing.py index 35d562e0..930ba0c6 100644 --- a/src/typing.py +++ b/src/typing.py @@ -938,11 +938,7 @@ class GenericMeta(TypingMeta, abc.ABCMeta): """Metaclass for generic types.""" def __new__(cls, name, bases, namespace, - tvars=None, args=None, origin=None, extra=None): - if extra is not None and type(extra) is abc.ABCMeta and extra not in bases: - bases = (extra,) + bases - self = super().__new__(cls, name, bases, namespace, _root=True) - + tvars=None, args=None, origin=None, extra=None, orig_bases=None): if tvars is not None: # Called from __getitem__() below. assert origin is not None @@ -983,12 +979,25 @@ def __new__(cls, name, bases, namespace, ", ".join(str(g) for g in gvars))) tvars = gvars + initial_bases = bases + if extra is not None and type(extra) is abc.ABCMeta and extra not in bases: + bases = (extra,) + bases + bases = tuple(_gorg(b) if isinstance(b, GenericMeta) else b for b in bases) + + # remove bare Generic from bases if there are other generic bases + if any(isinstance(b, GenericMeta) and b is not Generic for b in bases): + bases = tuple(b for b in bases if b is not Generic) + self = super().__new__(cls, name, bases, namespace, _root=True) + self.__parameters__ = tvars self.__args__ = args self.__origin__ = origin self.__extra__ = extra # Speed hack (https://github.com/python/typing/issues/196). self.__next_in_mro__ = _next_in_mro(self) + # Preserve base classes on subclassing (__bases__ are type erased now). + if orig_bases is None: + self.__orig_bases__ = initial_bases # This allows unparameterized generic collections to be used # with issubclass() and isinstance() in the same way as their @@ -1071,12 +1080,13 @@ def __getitem__(self, params): tvars = _type_vars(params) args = params return self.__class__(self.__name__, - (self,) + self.__bases__, + self.__bases__, dict(self.__dict__), tvars=tvars, args=args, origin=self, - extra=self.__extra__) + extra=self.__extra__, + orig_bases=self.__orig_bases__) def __instancecheck__(self, instance): # Since we extend ABC.__subclasscheck__ and @@ -1120,6 +1130,10 @@ def __new__(cls, *args, **kwds): else: origin = _gorg(cls) obj = cls.__next_in_mro__.__new__(origin) + try: + obj.__orig_class__ = cls + except AttributeError: + pass obj.__init__(*args, **kwds) return obj @@ -1485,6 +1499,7 @@ def _get_protocol_attrs(self): attr != '__next_in_mro__' and attr != '__parameters__' and attr != '__origin__' and + attr != '__orig_bases__' and attr != '__extra__' and attr != '__module__'): attrs.add(attr)