From 12df773dfb17100c6db171dbb9daa87f22e424e1 Mon Sep 17 00:00:00 2001 From: Felix Seele <3756270+citruz@users.noreply.github.com> Date: Sat, 5 Feb 2022 11:09:25 +0100 Subject: [PATCH 1/2] stubgen: fixed handling of Protocol and added testcase --- mypy/stubgen.py | 3 +++ test-data/unit/stubgen.test | 15 +++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/mypy/stubgen.py b/mypy/stubgen.py index b0a2c4e64587..4e1a0fe083f6 100755 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -849,6 +849,9 @@ def visit_class_def(self, o: ClassDef) -> None: base_types.append('metaclass=abc.ABCMeta') self.import_tracker.add_import('abc') self.import_tracker.require_name('abc') + elif self.analyzed and o.info.is_protocol: + base_types.append('Protocol') + self.add_typing_import('Protocol') if base_types: self.add('(%s)' % ', '.join(base_types)) self.add(':\n') diff --git a/test-data/unit/stubgen.test b/test-data/unit/stubgen.test index e20073027db2..2ff04f79db9d 100644 --- a/test-data/unit/stubgen.test +++ b/test-data/unit/stubgen.test @@ -2580,3 +2580,18 @@ class A: def f(x: int, y: int) -> int: ... @t.overload def f(x: t.Tuple[int, int]) -> int: ... + + +[case testProtocol_semanal] +from typing import Protocol + +class P(Protocol): + def f(self, x: int, y: int) -> str: + ... + + +[out] +from typing import Protocol + +class P(Protocol): + def f(self, x: int, y: int) -> str: ... From 62ceaa3c747d4c7f2eff66463e8aca483eb011bd Mon Sep 17 00:00:00 2001 From: Felix Seele <3756270+citruz@users.noreply.github.com> Date: Mon, 7 Feb 2022 21:40:44 +0100 Subject: [PATCH 2/2] added support for generic protocols --- mypy/stubgen.py | 5 ++++- test-data/unit/stubgen.test | 15 +++++++++++++-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/mypy/stubgen.py b/mypy/stubgen.py index 4e1a0fe083f6..6a604da57ba7 100755 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -850,7 +850,10 @@ def visit_class_def(self, o: ClassDef) -> None: self.import_tracker.add_import('abc') self.import_tracker.require_name('abc') elif self.analyzed and o.info.is_protocol: - base_types.append('Protocol') + type_str = 'Protocol' + if o.info.type_vars: + type_str += f'[{", ".join(o.info.type_vars)}]' + base_types.append(type_str) self.add_typing_import('Protocol') if base_types: self.add('(%s)' % ', '.join(base_types)) diff --git a/test-data/unit/stubgen.test b/test-data/unit/stubgen.test index 2ff04f79db9d..fb67bb2893e4 100644 --- a/test-data/unit/stubgen.test +++ b/test-data/unit/stubgen.test @@ -2583,15 +2583,26 @@ def f(x: t.Tuple[int, int]) -> int: ... [case testProtocol_semanal] -from typing import Protocol +from typing import Protocol, TypeVar class P(Protocol): def f(self, x: int, y: int) -> str: ... +T = TypeVar('T') +T2 = TypeVar('T2') +class PT(Protocol[T, T2]): + def f(self, x: T) -> T2: + ... + [out] -from typing import Protocol +from typing import Protocol, TypeVar class P(Protocol): def f(self, x: int, y: int) -> str: ... +T = TypeVar('T') +T2 = TypeVar('T2') + +class PT(Protocol[T, T2]): + def f(self, x: T) -> T2: ...