Skip to content

Commit ff435a1

Browse files
erikwrightilevkivskyi
authored andcommitted
Fix instance/subclass checks of functions against runtime protocols. (python#580)
Fixes python#579 . When performing an `issubclass` check of a type against a protocol, the `__annotations__` member of the type is accessed and assumed to be iterable. `__annotations__` is a descriptor in the case of `types.FunctionType`, so while it is iterable when accessed on a function instance it is not iterable when accessed on the type of a function. This causes the `issubclass` check to fail with an exception. In some cases (AFAICT, non-data protocols), an `isinstance` check of an object will use, internally, a subclass check of the object's type. As a result, `isinstance` will also fail with an exception in these conditions. The above only seemed to occur in Python 3. This PR fixes the issue in the Python 3 implementation while adding test coverage for both Python 2 and 3 that ensures that functions (and `types.FunctionType`) can be correctly compared both against protocols that they legitimately implement as well as those that they do not implement.
1 parent 0eb1ce3 commit ff435a1

File tree

3 files changed

+43
-7
lines changed

3 files changed

+43
-7
lines changed

src_py2/test_typing_extensions.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import collections
66
import pickle
77
import subprocess
8+
import types
89
from unittest import TestCase, main, skipUnless
910

1011
from typing_extensions import NoReturn, ClassVar, Final
@@ -274,20 +275,35 @@ class C(object): pass
274275
class D(object):
275276
def meth(self):
276277
pass
278+
def f():
279+
pass
277280
self.assertIsSubclass(D, P)
278281
self.assertIsInstance(D(), P)
279282
self.assertNotIsSubclass(C, P)
280283
self.assertNotIsInstance(C(), P)
284+
self.assertNotIsSubclass(types.FunctionType, P)
285+
self.assertNotIsInstance(f, P)
281286

282287
def test_everything_implements_empty_protocol(self):
283288
@runtime
284289
class Empty(Protocol): pass
285290
class C(object): pass
286-
for thing in (object, type, tuple, C):
291+
def f():
292+
pass
293+
for thing in (object, type, tuple, C, types.FunctionType):
287294
self.assertIsSubclass(thing, Empty)
288-
for thing in (object(), 1, (), typing):
295+
for thing in (object(), 1, (), typing, f):
289296
self.assertIsInstance(thing, Empty)
290297

298+
def test_function_implements_protocol(self):
299+
@runtime
300+
class Function(Protocol):
301+
def __call__(self, *args, **kwargs):
302+
pass
303+
def f():
304+
pass
305+
self.assertIsInstance(f, Function)
306+
291307
def test_no_inheritance_from_nominal(self):
292308
class C(object): pass
293309
class BP(Protocol): pass

src_py3/test_typing_extensions.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import collections
66
import pickle
77
import subprocess
8+
import types
89
from unittest import TestCase, main, skipUnless
910
from typing import TypeVar, Optional
1011
from typing import T, KT, VT # Not in __all__.
@@ -318,6 +319,11 @@ def __str__(self):
318319
def __add__(self, other):
319320
return 0
320321
322+
@runtime
323+
class HasCallProtocol(Protocol):
324+
__call__: typing.Callable
325+
326+
321327
async def g_with(am: AsyncContextManager[int]):
322328
x: int
323329
async with am as x:
@@ -335,7 +341,7 @@ async def g_with(am: AsyncContextManager[int]):
335341
# fake names for the sake of static analysis
336342
ann_module = ann_module2 = ann_module3 = None
337343
A = B = CSub = G = CoolEmployee = CoolEmployeeWithDefault = object
338-
XMeth = XRepr = NoneAndForward = Loop = object
344+
XMeth = XRepr = HasCallProtocol = NoneAndForward = Loop = object
339345

340346
gth = get_type_hints
341347

@@ -739,20 +745,32 @@ class C: pass
739745
class D:
740746
def meth(self):
741747
pass
748+
def f():
749+
pass
742750
self.assertIsSubclass(D, P)
743751
self.assertIsInstance(D(), P)
744752
self.assertNotIsSubclass(C, P)
745753
self.assertNotIsInstance(C(), P)
754+
self.assertNotIsSubclass(types.FunctionType, P)
755+
self.assertNotIsInstance(f, P)
746756

747757
def test_everything_implements_empty_protocol(self):
748758
@runtime
749759
class Empty(Protocol): pass
750760
class C: pass
751-
for thing in (object, type, tuple, C):
761+
def f():
762+
pass
763+
for thing in (object, type, tuple, C, types.FunctionType):
752764
self.assertIsSubclass(thing, Empty)
753-
for thing in (object(), 1, (), typing):
765+
for thing in (object(), 1, (), typing, f):
754766
self.assertIsInstance(thing, Empty)
755767

768+
@skipUnless(PY36, 'Python 3.6 required')
769+
def test_function_implements_protocol(self):
770+
def f():
771+
pass
772+
self.assertIsInstance(f, HasCallProtocol)
773+
756774
def test_no_inheritance_from_nominal(self):
757775
class C: pass
758776
class BP(Protocol): pass

src_py3/typing_extensions.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,7 +1043,8 @@ def _proto_hook(other):
10431043
if base.__dict__[attr] is None:
10441044
return NotImplemented
10451045
break
1046-
if (attr in getattr(base, '__annotations__', {}) and
1046+
annotations = getattr(base, '__annotations__', {})
1047+
if (isinstance(annotations, typing.Mapping) and attr in annotations and
10471048
isinstance(other, _ProtocolMeta) and other._is_protocol):
10481049
break
10491050
else:
@@ -1328,7 +1329,8 @@ def _proto_hook(other):
13281329
if base.__dict__[attr] is None:
13291330
return NotImplemented
13301331
break
1331-
if (attr in getattr(base, '__annotations__', {}) and
1332+
annotations = getattr(base, '__annotations__', {})
1333+
if (isinstance(annotations, typing.Mapping) and attr in annotations and
13321334
isinstance(other, _ProtocolMeta) and other._is_protocol):
13331335
break
13341336
else:

0 commit comments

Comments
 (0)