Skip to content

Commit 283cc7e

Browse files
Add basic support for PEP 695 type parameters in functions (#17)
* Add support for basic PEP 695 type parameters in functions * Add tests * Reorder imports * Reformat * Reformat * Add stubs * Add exported names * Do not add typevar if name is a parameter If the name was given a a type parameter a TypeVar variable is not needed.
1 parent 5a5da58 commit 283cc7e

File tree

13 files changed

+219
-16
lines changed

13 files changed

+219
-16
lines changed

pybind11_stubgen/parser/interface.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import abc
44
import types
5-
from typing import Any
5+
from typing import Any, Callable, TypeVar
66

77
from pybind11_stubgen.parser.errors import ParserError
88
from pybind11_stubgen.structs import (
@@ -23,6 +23,8 @@
2323
Value,
2424
)
2525

26+
T = TypeVar("T")
27+
2628

2729
class IParser(abc.ABC):
2830
@abc.abstractmethod
@@ -95,6 +97,14 @@ def handle_type(self, type_: type) -> QualifiedName:
9597
def handle_value(self, value: Any) -> Value:
9698
...
9799

100+
@abc.abstractmethod
101+
def call_with_local_types(self, parameters: list[str], func: Callable[[], T]) -> T:
102+
"""
103+
PEP 695 added template syntax to classes and functions.
104+
This will call the function with these additional local types.
105+
"""
106+
...
107+
98108
@abc.abstractmethod
99109
def parse_args_str(self, args_str: str) -> list[Argument]:
100110
...

pybind11_stubgen/parser/mixins/fix.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import sys
88
import types
99
from logging import getLogger
10-
from typing import Any, Sequence
10+
from typing import Any, Callable, Sequence, TypeVar
1111

1212
from pybind11_stubgen.parser.errors import (
1313
InvalidExpressionError,
@@ -38,6 +38,8 @@
3838

3939
logger = getLogger("pybind11_stubgen")
4040

41+
T = TypeVar("T")
42+
4143

4244
class RemoveSelfAnnotation(IParser):
4345

@@ -88,6 +90,7 @@ def __init__(self):
8890
self.__extra_imports: set[Import] = set()
8991
self.__current_module: types.ModuleType | None = None
9092
self.__current_class: type | None = None
93+
self.__local_types: set[str] = set()
9194

9295
def handle_alias(self, path: QualifiedName, origin: Any) -> Alias | None:
9396
result = super().handle_alias(path, origin)
@@ -144,6 +147,13 @@ def handle_value(self, value: Any) -> Value:
144147
self._add_import(QualifiedName.from_str(result.repr))
145148
return result
146149

150+
def call_with_local_types(self, parameters: list[str], func: Callable[[], T]) -> T:
151+
original_local_types = self.__local_types.copy()
152+
self.__local_types.update(parameters)
153+
result = super().call_with_local_types(parameters, func)
154+
self.__local_types = original_local_types
155+
return result
156+
147157
def parse_annotation_str(
148158
self, annotation_str: str
149159
) -> ResolvedType | InvalidExpression | Value:
@@ -155,7 +165,7 @@ def parse_annotation_str(
155165
def _add_import(self, name: QualifiedName) -> None:
156166
if len(name) == 0:
157167
return
158-
if len(name) == 1 and len(name[0]) == 0:
168+
if len(name) == 1 and (len(name[0]) == 0 or name[0] in self.__local_types):
159169
return
160170
if hasattr(builtins, name[0]):
161171
return
@@ -636,6 +646,7 @@ class FixNumpyArrayDimTypeVar(IParser):
636646
numpy_primitive_types = FixNumpyArrayDimAnnotation.numpy_primitive_types
637647

638648
__DIM_VARS: set[str] = set()
649+
__local_types: set[str] = set()
639650

640651
def handle_module(
641652
self, path: QualifiedName, module: types.ModuleType
@@ -662,6 +673,13 @@ def handle_module(
662673

663674
return result
664675

676+
def call_with_local_types(self, parameters: list[str], func: Callable[[], T]) -> T:
677+
original_local_types = self.__local_types.copy()
678+
self.__local_types.update(parameters)
679+
result = super().call_with_local_types(parameters, func)
680+
self.__local_types = original_local_types
681+
return result
682+
665683
def parse_annotation_str(
666684
self, annotation_str: str
667685
) -> ResolvedType | InvalidExpression | Value:
@@ -675,6 +693,9 @@ def parse_annotation_str(
675693
if not isinstance(result, ResolvedType):
676694
return result
677695

696+
if len(result.name) == 1 and result.name[0] in self.__local_types:
697+
return result
698+
678699
# handle unqualified, single-letter annotation as a TypeVar
679700
if len(result.name) == 1 and len(result.name[0]) == 1:
680701
result.name = QualifiedName.from_str(result.name[0].upper())

pybind11_stubgen/parser/mixins/parse.py

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import inspect
55
import re
66
import types
7-
from typing import Any
7+
from typing import Any, Callable, TypeVar
88

99
from pybind11_stubgen.parser.errors import (
1010
InvalidExpressionError,
@@ -40,6 +40,8 @@
4040
Argument(name=Identifier("kwargs"), kw_variadic=True),
4141
]
4242

43+
T = TypeVar("T")
44+
4345

4446
class ParserDispatchMixin(IParser):
4547
def handle_class(self, path: QualifiedName, class_: type) -> Class | None:
@@ -384,6 +386,9 @@ def handle_type(self, type_: type) -> QualifiedName:
384386
)
385387
)
386388

389+
def call_with_local_types(self, parameters: list[str], func: Callable[[], T]) -> T:
390+
return func()
391+
387392
def parse_value_str(self, value: str) -> Value | InvalidExpression:
388393
return self._parse_expression_str(value)
389394

@@ -624,32 +629,48 @@ def parse_function_docstring(
624629
return []
625630

626631
top_signature_regex = re.compile(
627-
rf"^{func_name}\((?P<args>.*)\)\s*(->\s*(?P<returns>.+))?$"
632+
rf"^{func_name}(\[(?P<type_vars>[\w\s,]*)])?"
633+
rf"\((?P<args>.*)\)\s*(->\s*(?P<returns>.+))?$"
628634
)
629635

630636
match = top_signature_regex.match(doc_lines[0])
631637
if match is None:
632638
return []
633639

634640
if len(doc_lines) < 2 or doc_lines[1] != "Overloaded function.":
641+
# TODO: Update to support more complex formats.
642+
# This only supports bare type parameters.
643+
type_vars: list[str] = list(
644+
filter(
645+
bool, map(str.strip, (match.group("type_vars") or "").split(","))
646+
)
647+
)
648+
args = self.call_with_local_types(
649+
type_vars, lambda: self.parse_args_str(match.group("args"))
650+
)
651+
635652
returns_str = match.group("returns")
636653
if returns_str is not None:
637-
returns = self.parse_annotation_str(returns_str)
654+
returns = self.call_with_local_types(
655+
type_vars, lambda: self.parse_annotation_str(returns_str)
656+
)
638657
else:
639658
returns = None
640659

641660
return [
642661
Function(
643662
name=func_name,
644-
args=self.parse_args_str(match.group("args")),
663+
args=args,
645664
doc=self._strip_empty_lines(doc_lines[1:]),
646665
returns=returns,
666+
type_vars=type_vars,
647667
)
648668
]
649669

650670
overload_signature_regex = re.compile(
651-
rf"^(\s*(?P<overload_number>\d+).\s*)"
652-
rf"{func_name}\((?P<args>.*)\)\s*->\s*(?P<returns>.+)$"
671+
rf"^(\s*(?P<overload_number>\d+)\.\s*)"
672+
rf"{func_name}(\[(?P<type_vars>[\w\s,]*)])?"
673+
rf"\((?P<args>.*)\)\s*->\s*(?P<returns>.+)$"
653674
)
654675

655676
doc_start = 0
@@ -663,16 +684,31 @@ def parse_function_docstring(
663684
continue
664685
overloads[-1].doc = self._strip_empty_lines(doc_lines[doc_start:i])
665686
doc_start = i + 1
687+
# TODO: Update to support more complex formats.
688+
# This only supports bare type parameters.
689+
type_vars: list[str] = list(
690+
filter(
691+
bool,
692+
map(str.strip, (match.group("type_vars") or "").split(",")),
693+
)
694+
)
695+
args = self.call_with_local_types(
696+
type_vars, lambda: self.parse_args_str(match.group("args"))
697+
)
698+
returns = self.call_with_local_types(
699+
type_vars, lambda: self.parse_annotation_str(match.group("returns"))
700+
)
666701
overloads.append(
667702
Function(
668703
name=func_name,
669-
args=self.parse_args_str(match.group("args")),
670-
returns=self.parse_annotation_str(match.group("returns")),
704+
args=args,
705+
returns=returns,
671706
doc=None,
672707
decorators=[
673708
# use `parse_annotation_str()` to trigger typing import
674709
Decorator(str(self.parse_annotation_str("typing.overload")))
675710
],
711+
type_vars=type_vars,
676712
)
677713
)
678714

pybind11_stubgen/printer.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -151,11 +151,18 @@ def print_function(self, func: Function) -> list[str]:
151151
args.append(self.print_argument(arg))
152152
if len(args) > 0 and args[0] == "/":
153153
args = args[1:]
154-
signature = [
155-
f"def {func.name}(",
156-
", ".join(args),
157-
")",
158-
]
154+
signature = [f"def {func.name}"]
155+
156+
if func.type_vars:
157+
signature.extend(["[", ", ".join(func.type_vars), "]"])
158+
159+
signature.extend(
160+
[
161+
"(",
162+
", ".join(args),
163+
")",
164+
]
165+
)
159166

160167
if func.returns is not None:
161168
signature.append(f" -> {self.print_annotation(func.returns)}")

pybind11_stubgen/structs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ class Function:
142142
returns: Annotation | None = field_(default=None)
143143
doc: Docstring | None = field_(default=None)
144144
decorators: list[Decorator] = field_(default_factory=list)
145+
type_vars: list[str] = field_(default_factory=list)
145146

146147
def __str__(self):
147148
return (

tests/py-demo/bindings/src/modules/functions.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,4 +97,34 @@ void bind_functions_module(py::module &&m) {
9797
pyFoo.def(py::init<int>());
9898
m.def("default_custom_arg", [](Foo &foo) {}, py::arg_v("foo", Foo(5), "Foo(5)"));
9999
m.def("pass_callback", [](std::function<Foo(Foo &)> &callback) { return Foo(13); });
100+
101+
#if PY_MAJOR_VERSION > 3 || (PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION >= 12)
102+
py::options options;
103+
options.disable_function_signatures();
104+
m.def(
105+
"passthrough1",
106+
[](py::object obj) { return obj; },
107+
py::doc("passthrough1[T](obj: T) -> T\n"));
108+
m.def(
109+
"passthrough2",
110+
[](py::object obj) { return obj; },
111+
py::doc(
112+
"passthrough2(*args, **kwargs)\n"
113+
"Overloaded function.\n"
114+
"1. passthrough2() -> None\n"
115+
"2. passthrough2[T](obj: T) -> T\n"),
116+
py::arg("obj") = py::none());
117+
m.def(
118+
"passthrough3",
119+
[](py::object obj1, py::object obj2) { return py::make_tuple(obj1, obj2); },
120+
py::doc(
121+
"passthrough3(*args, **kwargs)\n"
122+
"Overloaded function.\n"
123+
"1. passthrough3() -> tuple[None, None]\n"
124+
"2. passthrough3[T](obj: T) -> tuple[T, None]\n"
125+
"3. passthrough3[T1, T2](obj1: T1, obj2: T2) -> tuple[T1, T2]\n"),
126+
py::arg("obj1") = py::none(),
127+
py::arg("obj2") = py::none());
128+
options.enable_function_signatures();
129+
#endif
100130
}

tests/stubs/python-3.12/pybind11-v2.11/numpy-array-wrap-with-annotated/demo/_bindings/functions.pyi

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ __all__: list[str] = [
1919
"generic",
2020
"mul",
2121
"pass_callback",
22+
"passthrough1",
23+
"passthrough2",
24+
"passthrough3",
2225
"pos_kw_only_mix",
2326
"pos_kw_only_variadic_mix",
2427
]
@@ -52,5 +55,16 @@ def mul(p: float, q: float) -> float:
5255
"""
5356

5457
def pass_callback(arg0: typing.Callable[[Foo], Foo]) -> Foo: ...
58+
def passthrough1[T](obj: T) -> T: ...
59+
@typing.overload
60+
def passthrough2() -> None: ...
61+
@typing.overload
62+
def passthrough2[T](obj: T) -> T: ...
63+
@typing.overload
64+
def passthrough3() -> tuple[None, None]: ...
65+
@typing.overload
66+
def passthrough3[T](obj: T) -> tuple[T, None]: ...
67+
@typing.overload
68+
def passthrough3[T1, T2](obj1: T1, obj2: T2) -> tuple[T1, T2]: ...
5569
def pos_kw_only_mix(i: int, /, j: int, *, k: int) -> tuple: ...
5670
def pos_kw_only_variadic_mix(i: int, /, j: int, *args, k: int, **kwargs) -> tuple: ...

tests/stubs/python-3.12/pybind11-v2.12/numpy-array-wrap-with-annotated/demo/_bindings/functions.pyi

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ __all__: list[str] = [
2020
"generic",
2121
"mul",
2222
"pass_callback",
23+
"passthrough1",
24+
"passthrough2",
25+
"passthrough3",
2326
"pos_kw_only_mix",
2427
"pos_kw_only_variadic_mix",
2528
]
@@ -54,5 +57,16 @@ def mul(p: float, q: float) -> float:
5457
"""
5558

5659
def pass_callback(arg0: typing.Callable[[Foo], Foo]) -> Foo: ...
60+
def passthrough1[T](obj: T) -> T: ...
61+
@typing.overload
62+
def passthrough2() -> None: ...
63+
@typing.overload
64+
def passthrough2[T](obj: T) -> T: ...
65+
@typing.overload
66+
def passthrough3() -> tuple[None, None]: ...
67+
@typing.overload
68+
def passthrough3[T](obj: T) -> tuple[T, None]: ...
69+
@typing.overload
70+
def passthrough3[T1, T2](obj1: T1, obj2: T2) -> tuple[T1, T2]: ...
5771
def pos_kw_only_mix(i: int, /, j: int, *, k: int) -> tuple: ...
5872
def pos_kw_only_variadic_mix(i: int, /, j: int, *args, k: int, **kwargs) -> tuple: ...

tests/stubs/python-3.12/pybind11-v2.13/numpy-array-use-type-var/demo/_bindings/functions.pyi

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ __all__: list[str] = [
2020
"generic",
2121
"mul",
2222
"pass_callback",
23+
"passthrough1",
24+
"passthrough2",
25+
"passthrough3",
2326
"pos_kw_only_mix",
2427
"pos_kw_only_variadic_mix",
2528
]
@@ -54,5 +57,16 @@ def mul(p: float, q: float) -> float:
5457
"""
5558

5659
def pass_callback(arg0: typing.Callable[[Foo], Foo]) -> Foo: ...
60+
def passthrough1[T](obj: T) -> T: ...
61+
@typing.overload
62+
def passthrough2() -> None: ...
63+
@typing.overload
64+
def passthrough2[T](obj: T) -> T: ...
65+
@typing.overload
66+
def passthrough3() -> tuple[None, None]: ...
67+
@typing.overload
68+
def passthrough3[T](obj: T) -> tuple[T, None]: ...
69+
@typing.overload
70+
def passthrough3[T1, T2](obj1: T1, obj2: T2) -> tuple[T1, T2]: ...
5771
def pos_kw_only_mix(i: int, /, j: int, *, k: int) -> tuple: ...
5872
def pos_kw_only_variadic_mix(i: int, /, j: int, *args, k: int, **kwargs) -> tuple: ...

0 commit comments

Comments
 (0)