Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# Unreleased

- Backport `Protocol.__init__` behaviour from python 3.11. (see
python/cpython#31628, by Adrian Garcia Badaracco). Patch by James
Hilton-Balfe (@Gobot1234).

# Release 4.5.0 (February 14, 2023)

- Runtime support for PEP 702, adding `typing_extensions.deprecated`. Patch
Expand Down
28 changes: 27 additions & 1 deletion src/test_typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1427,6 +1427,32 @@ class PG(Protocol[T]): pass
class CG(PG[T]): pass
self.assertIsInstance(CG[int](), CG)

def test_protocol_defining_init_does_not_get_overridden(self):
# check that P.__init__ doesn't get clobbered
# see https://bugs.python.org/issue44807

class P(Protocol):
x: int
def __init__(self, x: int) -> None:
self.x = x
class C: pass

c = C()
P.__init__(c, 1)
self.assertEqual(c.x, 1)

def test_concrete_class_inheriting_init_from_protocol(self):
class P(Protocol):
x: int
def __init__(self, x: int) -> None:
self.x = x

class C(P): pass

c = C(1)
self.assertIsInstance(c, C)
self.assertEqual(c.x, 1)

def test_cannot_instantiate_abstract(self):
@runtime
class P(Protocol):
Expand Down Expand Up @@ -3302,7 +3328,7 @@ def test_typing_extensions_defers_when_possible(self):
if sys.version_info < (3, 10):
exclude |= {'get_args', 'get_origin'}
if sys.version_info < (3, 11):
exclude |= {'final', 'NamedTuple', 'Any'}
exclude |= {'final', 'NamedTuple', 'Any', 'Protocol'}
for item in typing_extensions.__all__:
if item not in exclude and hasattr(typing, item):
self.assertIs(
Expand Down
215 changes: 130 additions & 85 deletions src/typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def _check_generic(cls, parameters, elen=_marker):
"""Check correct count for parameters of a generic cls (internal helper).
This gives a nice error message in case of count mismatch.
"""
if cls is Protocol and not elen:
return
if not elen:
raise TypeError(f"{cls} is not a generic class")
if elen is _marker:
Expand Down Expand Up @@ -143,6 +145,11 @@ def _collect_type_vars(types, typevar_types=None):
tvars.extend([t for t in t.__parameters__ if t not in tvars])
return tuple(tvars)

def _caller(depth=1, default='__main__'):
try:
return sys._getframe(depth + 1).f_globals.get('__name__', default)
except (AttributeError, ValueError): # For platforms without _getframe()
return None

NoReturn = typing.NoReturn

Expand Down Expand Up @@ -457,162 +464,206 @@ def _maybe_adjust_parameters(cls):
cls.__parameters__ = tuple(tvars)


# 3.8+
if hasattr(typing, 'Protocol'):
# 3.11+
if sys.version_info >= (3, 11): # 3.8 has Protocol but it doesn't preserve __init__
Protocol = typing.Protocol
# 3.7

else:
_TYPING_INTERNALS = ['__parameters__', '__orig_bases__', '__orig_class__',
'_is_protocol', '_is_runtime_protocol']

_SPECIAL_NAMES = ['__abstractmethods__', '__annotations__', '__dict__', '__doc__',
'__init__', '__module__', '__new__', '__slots__',
'__subclasshook__', '__weakref__', '__class_getitem__']

# These special attributes will be not collected as protocol members.
_EXCLUDED_ATTRIBUTES = _TYPING_INTERNALS + _SPECIAL_NAMES + ['_MutableMapping__marker']


def _get_protocol_attrs(cls):
"""Collect protocol members from a protocol class objects.
This includes names actually defined in the class dictionary, as well
as names that appear in annotations. Special names (above) are skipped.
"""
attrs = set()
for base in cls.__mro__[:-1]: # without object
if base.__name__ in ('Protocol', 'Generic'):
continue
annotations = getattr(base, '__annotations__', {})
for attr in list(base.__dict__.keys()) + list(annotations.keys()):
if not attr.startswith('_abc_') and attr not in _EXCLUDED_ATTRIBUTES:
attrs.add(attr)
return attrs


def _is_callable_members_only(cls):
# PEP 544 prohibits using issubclass() with protocols that have non-method members.
return all(callable(getattr(cls, attr, None)) for attr in _get_protocol_attrs(cls))

def _no_init(self, *args, **kwargs):
if type(self)._is_protocol:

def _no_init_or_replace_init(self, *args, **kwargs):
cls = type(self)

if cls._is_protocol:
raise TypeError('Protocols cannot be instantiated')

class _ProtocolMeta(abc.ABCMeta): # noqa: B024
# This metaclass is a bit unfortunate and exists only because of the lack
# of __instancehook__.
# Already using a custom `__init__`. No need to calculate correct
# `__init__` to call. This can lead to RecursionError. See bpo-45121.
if cls.__init__ is not _no_init_or_replace_init:
return

# Initially, `__init__` of a protocol subclass is set to `_no_init_or_replace_init`.
# The first instantiation of the subclass will call `_no_init_or_replace_init` which
# searches for a proper new `__init__` in the MRO. The new `__init__`
# replaces the subclass' old `__init__` (ie `_no_init_or_replace_init`). Subsequent
# instantiation of the protocol subclass will thus use the new
# `__init__` and no longer call `_no_init_or_replace_init`.
for base in cls.__mro__:
init = base.__dict__.get('__init__', _no_init_or_replace_init)
if init is not _no_init_or_replace_init:
cls.__init__ = init
break
else:
# should not happen
cls.__init__ = object.__init__

cls.__init__(self, *args, **kwargs)


def _allow_reckless_class_checks(depth=3):
"""Allow instance and class checks for special stdlib modules.
The abc and functools modules indiscriminately call isinstance() and
issubclass() on the whole MRO of a user class, which may contain protocols.
"""
return _caller(depth) in {'abc', 'functools', None}


_PROTO_ALLOWLIST = {
'collections.abc': [
'Callable', 'Awaitable', 'Iterable', 'Iterator', 'AsyncIterable',
'Hashable', 'Sized', 'Container', 'Collection', 'Reversible',
],
'contextlib': ['AbstractContextManager', 'AbstractAsyncContextManager'],
}


class _ProtocolMeta(abc.ABCMeta):
# This metaclass is really unfortunate and exists only because of
# the lack of __instancehook__.
def __instancecheck__(cls, instance):
# We need this method for situations where attributes are
# assigned in __init__.
if (
getattr(cls, '_is_protocol', False) and
not getattr(cls, '_is_runtime_protocol', False) and
not _allow_reckless_class_checks(depth=2)
):
raise TypeError("Instance and class checks can only be used with"
" @runtime_checkable protocols")

if ((not getattr(cls, '_is_protocol', False) or
_is_callable_members_only(cls)) and
_is_callable_members_only(cls)) and
issubclass(instance.__class__, cls)):
return True
if cls._is_protocol:
if all(hasattr(instance, attr) and
(not callable(getattr(cls, attr, None)) or
# All *methods* can be blocked by setting them to None.
(not callable(getattr(cls, attr, None)) or
getattr(instance, attr) is not None)
for attr in _get_protocol_attrs(cls)):
for attr in _get_protocol_attrs(cls)):
return True
return super().__instancecheck__(instance)

class Protocol(metaclass=_ProtocolMeta):
# There is quite a lot of overlapping code with typing.Generic.
# Unfortunately it is hard to avoid this while these live in two different
# modules. The duplicated code will be removed when Protocol is moved to typing.
"""Base class for protocol classes. Protocol classes are defined as::

class Protocol(typing.Generic, metaclass=_ProtocolMeta):
"""Base class for protocol classes.
Protocol classes are defined as::
class Proto(Protocol):
def meth(self) -> int:
...

Such classes are primarily used with static type checkers that recognize
structural subtyping (static duck-typing), for example::

class C:
def meth(self) -> int:
return 0

def func(x: Proto) -> int:
return x.meth()

func(C()) # Passes static type check

See PEP 544 for details. Protocol classes decorated with
@typing_extensions.runtime act as simple-minded runtime protocol that checks
@typing.runtime_checkable act as simple-minded runtime protocols that check
only the presence of given attributes, ignoring their type signatures.

Protocol classes can be generic, they are defined as::

class GenProto(Protocol[T]):
def meth(self) -> T:
...
"""
__slots__ = ()
_is_protocol = True

def __new__(cls, *args, **kwds):
if cls is Protocol:
raise TypeError("Type Protocol cannot be instantiated; "
"it can only be used as a base class")
return super().__new__(cls)

@typing._tp_cache
def __class_getitem__(cls, params):
if not isinstance(params, tuple):
params = (params,)
if not params and cls is not typing.Tuple:
raise TypeError(
f"Parameter list to {cls.__qualname__}[...] cannot be empty")
msg = "Parameters to generic types must be types."
params = tuple(typing._type_check(p, msg) for p in params) # noqa
if cls is Protocol:
# Generic can only be subscripted with unique type variables.
if not all(isinstance(p, typing.TypeVar) for p in params):
i = 0
while isinstance(params[i], typing.TypeVar):
i += 1
raise TypeError(
"Parameters to Protocol[...] must all be type variables."
f" Parameter {i + 1} is {params[i]}")
if len(set(params)) != len(params):
raise TypeError(
"Parameters to Protocol[...] must all be unique")
else:
# Subscripting a regular Generic subclass.
_check_generic(cls, params, len(cls.__parameters__))
return typing._GenericAlias(cls, params)
_is_runtime_protocol = False

def __init_subclass__(cls, *args, **kwargs):
if '__orig_bases__' in cls.__dict__:
error = typing.Generic in cls.__orig_bases__
else:
error = typing.Generic in cls.__bases__
if error:
raise TypeError("Cannot inherit from plain Generic")
_maybe_adjust_parameters(cls)
super().__init_subclass__(*args, **kwargs)

# Determine if this is a protocol or a concrete subclass.
if not cls.__dict__.get('_is_protocol', None):
if not cls.__dict__.get('_is_protocol', False):
cls._is_protocol = any(b is Protocol for b in cls.__bases__)

# Set (or override) the protocol subclass hook.
def _proto_hook(other):
if not cls.__dict__.get('_is_protocol', None):
if not cls.__dict__.get('_is_protocol', False):
return NotImplemented

# First, perform various sanity checks.
if not getattr(cls, '_is_runtime_protocol', False):
if sys._getframe(2).f_globals['__name__'] in ['abc', 'functools']:
if _allow_reckless_class_checks():
return NotImplemented
raise TypeError("Instance and class checks can only be used with"
" @runtime protocols")
" @runtime_checkable protocols")
if not _is_callable_members_only(cls):
if sys._getframe(2).f_globals['__name__'] in ['abc', 'functools']:
if _allow_reckless_class_checks():
return NotImplemented
raise TypeError("Protocols with non-method members"
" don't support issubclass()")
if not isinstance(other, type):
# Same error as for issubclass(1, int)
# Same error message as for issubclass(1, int).
raise TypeError('issubclass() arg 1 must be a class')

# Second, perform the actual structural compatibility check.
for attr in _get_protocol_attrs(cls):
for base in other.__mro__:
# Check if the members appears in the class dictionary...
if attr in base.__dict__:
if base.__dict__[attr] is None:
return NotImplemented
break

# ...or in annotations, if it is a sub-protocol.
annotations = getattr(base, '__annotations__', {})
if (isinstance(annotations, typing.Mapping) and
if (isinstance(annotations, collections.abc.Mapping) and
attr in annotations and
isinstance(other, _ProtocolMeta) and
other._is_protocol):
issubclass(other, typing.Generic) and other._is_protocol):
break
else:
return NotImplemented
return True

if '__subclasshook__' not in cls.__dict__:
cls.__subclasshook__ = _proto_hook

# We have nothing more to do for non-protocols.
# We have nothing more to do for non-protocols...
if not cls._is_protocol:
return

# Check consistency of bases.
# ... otherwise check consistency of bases, and prohibit instantiation.
for base in cls.__bases__:
if not (base in (object, typing.Generic) or
base.__module__ == 'collections.abc' and
base.__name__ in _PROTO_WHITELIST or
isinstance(base, _ProtocolMeta) and base._is_protocol):
base.__module__ in _PROTO_ALLOWLIST and
base.__name__ in _PROTO_ALLOWLIST[base.__module__] or
issubclass(base, typing.Generic) and base._is_protocol):
raise TypeError('Protocols can only inherit from other'
f' protocols, got {repr(base)}')
cls.__init__ = _no_init
' protocols, got %r' % base)
if cls.__init__ is Protocol.__init__:
cls.__init__ = _no_init_or_replace_init


# 3.8+
Expand Down Expand Up @@ -2229,12 +2280,6 @@ def wrapper(*args, **kwargs):
if sys.version_info >= (3, 11):
NamedTuple = typing.NamedTuple
else:
def _caller():
try:
return sys._getframe(2).f_globals.get('__name__', '__main__')
except (AttributeError, ValueError): # For platforms without _getframe()
return None

def _make_nmtuple(name, types, module, defaults=()):
fields = [n for n, t in types]
annotations = {n: typing._type_check(t, f"field {n} annotation must be a type")
Expand Down