Skip to content

Commit 501a00e

Browse files
authored
Backport the ability to define __init__ methods on Protocol classes (#142)
1 parent 90c866b commit 501a00e

File tree

3 files changed

+32
-1
lines changed

3 files changed

+32
-1
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@
2929
`typing_extensions` may no longer be considered instances of that protocol
3030
using the new release, and vice versa. Most users are unlikely to be affected
3131
by this change. Patch by Alex Waygood.
32+
- Backport the ability to define `__init__` methods on Protocol classes, a
33+
change made in Python 3.11 (originally implemented in
34+
https://github.com/python/cpython/pull/31628 by Adrian Garcia Badaracco).
35+
Patch by Alex Waygood.
3236
- Speedup `isinstance(3, typing_extensions.SupportsIndex)` by >10x on Python
3337
<3.12. Patch by Alex Waygood.
3438

src/test_typing_extensions.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1454,6 +1454,32 @@ class PG(Protocol[T]): pass
14541454
class CG(PG[T]): pass
14551455
self.assertIsInstance(CG[int](), CG)
14561456

1457+
def test_protocol_defining_init_does_not_get_overridden(self):
1458+
# check that P.__init__ doesn't get clobbered
1459+
# see https://bugs.python.org/issue44807
1460+
1461+
class P(Protocol):
1462+
x: int
1463+
def __init__(self, x: int) -> None:
1464+
self.x = x
1465+
class C: pass
1466+
1467+
c = C()
1468+
P.__init__(c, 1)
1469+
self.assertEqual(c.x, 1)
1470+
1471+
def test_concrete_class_inheriting_init_from_protocol(self):
1472+
class P(Protocol):
1473+
x: int
1474+
def __init__(self, x: int) -> None:
1475+
self.x = x
1476+
1477+
class C(P): pass
1478+
1479+
c = C(1)
1480+
self.assertIsInstance(c, C)
1481+
self.assertEqual(c.x, 1)
1482+
14571483
def test_cannot_instantiate_abstract(self):
14581484
@runtime_checkable
14591485
class P(Protocol):

src/typing_extensions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -662,7 +662,8 @@ def _proto_hook(other):
662662
isinstance(base, _ProtocolMeta) and base._is_protocol):
663663
raise TypeError('Protocols can only inherit from other'
664664
f' protocols, got {repr(base)}')
665-
cls.__init__ = _no_init
665+
if cls.__init__ is Protocol.__init__:
666+
cls.__init__ = _no_init
666667

667668
def runtime_checkable(cls):
668669
"""Mark a protocol class as a runtime protocol, so that it

0 commit comments

Comments
 (0)