Skip to content

Commit ecd0d43

Browse files
committed
custom overload decorator to support overloads in the api generate scripts
1 parent a33bec0 commit ecd0d43

File tree

8 files changed

+115
-116
lines changed

8 files changed

+115
-116
lines changed

playwright/_impl/_element_handle.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
parse_result,
3434
serialize_argument,
3535
)
36+
from playwright._impl._overload import api_overload
3637

3738
if sys.version_info >= (3, 8): # pragma: no cover
3839
from typing import Literal
@@ -309,7 +310,8 @@ async def wait_for_element_state(
309310
) -> None:
310311
await self._channel.send("waitForElementState", locals_to_params(locals()))
311312

312-
async def _overload1_wait_for_selector(
313+
@api_overload
314+
async def wait_for_selector(
313315
self,
314316
selector: str,
315317
*,
@@ -318,7 +320,8 @@ async def _overload1_wait_for_selector(
318320
) -> "ElementHandle":
319321
...
320322

321-
async def _overload2_wait_for_selector(
323+
@api_overload # type: ignore[no-redef]
324+
async def wait_for_selector(
322325
self,
323326
selector: str,
324327
*,
@@ -327,7 +330,7 @@ async def _overload2_wait_for_selector(
327330
) -> None:
328331
...
329332

330-
async def wait_for_selector(
333+
async def wait_for_selector( # type: ignore[no-redef]
331334
self,
332335
selector: str,
333336
*,

playwright/_impl/_frame.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
)
4949
from playwright._impl._locator import Locator
5050
from playwright._impl._network import Response
51+
from playwright._impl._overload import api_overload
5152
from playwright._impl._wait_helper import WaitHelper
5253

5354
if sys.version_info >= (3, 8): # pragma: no cover
@@ -270,7 +271,8 @@ async def query_selector_all(self, selector: str) -> List[ElementHandle]:
270271
)
271272
)
272273

273-
async def _overload1_wait_for_selector(
274+
@api_overload
275+
async def wait_for_selector(
274276
self,
275277
selector: str,
276278
*,
@@ -280,7 +282,8 @@ async def _overload1_wait_for_selector(
280282
) -> ElementHandle:
281283
...
282284

283-
async def _overload2_wait_for_selector(
285+
@api_overload # type: ignore[no-redef]
286+
async def wait_for_selector(
284287
self,
285288
selector: str,
286289
*,
@@ -290,7 +293,7 @@ async def _overload2_wait_for_selector(
290293
) -> None:
291294
...
292295

293-
async def wait_for_selector(
296+
async def wait_for_selector( # type: ignore[no-redef]
294297
self,
295298
selector: str,
296299
*,

playwright/_impl/_overload.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from collections import defaultdict
2+
import inspect
3+
from typing import Callable
4+
5+
6+
def api_overload(fn: Callable) -> Callable:
7+
"""
8+
Creates an overload that's visible to the api generate scripts. use this decorator instead of `typing.overload` when exposing overloads from `_impl`.
9+
You will need to suppress mypy errors using a `# type: ignore[no-redef]` comment
10+
"""
11+
dictionary = inspect.getmodule(fn).__dict__
12+
overloads_key = "__overloads__"
13+
if dictionary.get(overloads_key) is None:
14+
dictionary[overloads_key] = defaultdict(list)
15+
dictionary[overloads_key][fn.__name__].append(fn)
16+
return fn

playwright/_impl/_page.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
serialize_argument,
7272
)
7373
from playwright._impl._network import Request, Response, Route, serialize_headers
74+
from playwright._impl._overload import api_overload
7475
from playwright._impl._video import Video
7576
from playwright._impl._wait_helper import WaitHelper
7677

@@ -325,7 +326,8 @@ async def query_selector(
325326
async def query_selector_all(self, selector: str) -> List[ElementHandle]:
326327
return await self._main_frame.query_selector_all(selector)
327328

328-
async def _overload1_wait_for_selector(
329+
@api_overload
330+
async def wait_for_selector(
329331
self,
330332
selector: str,
331333
*,
@@ -335,7 +337,8 @@ async def _overload1_wait_for_selector(
335337
) -> ElementHandle:
336338
...
337339

338-
async def _overload2_wait_for_selector(
340+
@api_overload # type: ignore[no-redef]
341+
async def wait_for_selector(
339342
self,
340343
selector: str,
341344
*,
@@ -345,7 +348,7 @@ async def _overload2_wait_for_selector(
345348
) -> None:
346349
...
347350

348-
async def wait_for_selector(
351+
async def wait_for_selector( # type: ignore[no-redef]
349352
self,
350353
selector: str,
351354
*,

scripts/generate_api.py

Lines changed: 5 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
Callable,
2020
List,
2121
Match,
22-
Optional,
2322
Union,
2423
cast,
2524
get_args,
@@ -93,34 +92,6 @@ def process_type(value: Any, param: bool = False) -> str:
9392
]
9493

9594

96-
class Overload:
97-
def __init__(self, clazz: object, name: str):
98-
matches = re.match(r"_overload(\d+)_([a-zA-Z\d_]+)", name)
99-
if matches is None:
100-
raise ValueError(f"failed to get function name for overload: {name}")
101-
self.number = int(matches[1])
102-
self.name = matches[2]
103-
self.clazz = clazz
104-
105-
def assert_has_implementation(self) -> None:
106-
if self.implementation() is None:
107-
raise Exception(f"implementation for overload '{self.name}' not found")
108-
109-
def implementation(self) -> Optional[Callable]:
110-
for name, method in self.clazz.__dict__.items():
111-
if name == self.name:
112-
return method
113-
return None
114-
115-
116-
def is_overload(name: str) -> bool:
117-
try:
118-
Overload(None, name)
119-
return True
120-
except TypeError:
121-
return False
122-
123-
12495
def is_positional_exception(key: str) -> bool:
12596
for pattern in positional_exceptions:
12697
if re.match(pattern, key):
@@ -234,11 +205,15 @@ def arguments(func: FunctionType, indent: int) -> str:
234205
return split.join(tokens)
235206

236207

237-
def return_type(func: FunctionType) -> str:
208+
def return_type(func: Callable) -> str:
238209
value = get_type_hints(func, globals())["return"]
239210
return process_type(value)
240211

241212

213+
def return_type_value(func: Callable) -> str:
214+
return re.sub(r"\"([^\"]+)Impl\"", r"\1", return_type(func))
215+
216+
242217
def short_name(t: Any) -> str:
243218
match = cast(
244219
Match[str], re.compile(r"playwright\._impl\.[^.]+\.([^']+)").search(str(t))

scripts/generate_async_api.py

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,19 @@
1414
# limitations under the License.
1515

1616
import inspect
17-
import re
1817
from types import FunctionType
19-
from typing import Any, Optional, get_type_hints
18+
from typing import Any, Callable, Dict, List, get_type_hints
2019

2120
from playwright._impl._helper import to_snake_case
2221
from scripts.documentation_provider import DocumentationProvider
2322
from scripts.generate_api import (
24-
Overload,
2523
all_types,
2624
api_globals,
2725
arguments,
2826
header,
29-
is_overload,
3027
process_type,
3128
return_type,
29+
return_type_value,
3230
return_value,
3331
short_name,
3432
signature,
@@ -37,6 +35,12 @@
3735
documentation_provider = DocumentationProvider(True)
3836

3937

38+
def async_return_type_value(func: Callable) -> str:
39+
return return_type_value(func).replace(
40+
"EventContextManager", "AsyncEventContextManager"
41+
)
42+
43+
4044
def generate(t: Any) -> None:
4145
print("")
4246
class_name = short_name(t)
@@ -59,6 +63,9 @@ def generate(t: Any) -> None:
5963
[prefix, suffix] = return_value(type)
6064
prefix = " return " + prefix + f"self._impl_obj.{name}"
6165
print(f"{prefix}{suffix}")
66+
all_overloads: Dict[str, List[Callable]] = (
67+
inspect.getmodule(t).__dict__.get("__overloads__") or {}
68+
)
6269
for [name, value] in t.__dict__.items():
6370
if name.startswith("_"):
6471
continue
@@ -78,55 +85,48 @@ def generate(t: Any) -> None:
7885
prefix = " return " + prefix + f"self._impl_obj.{name}"
7986
print(f"{prefix}{arguments(value, len(prefix))}{suffix}")
8087
for [name, value] in t.__dict__.items():
81-
overload: Optional[Overload] = None
82-
if is_overload(name):
83-
overload = Overload(t, name)
84-
overload.assert_has_implementation()
85-
name = overload.name
88+
overloads = all_overloads.get(name) or []
8689
if (
8790
not name.startswith("_")
8891
and isinstance(value, FunctionType)
8992
and "remove_listener" != name
9093
):
9194
is_async = inspect.iscoroutinefunction(value)
92-
return_type_value = return_type(value)
93-
return_type_value = re.sub(r"\"([^\"]+)Impl\"", r"\1", return_type_value)
94-
return_type_value = return_type_value.replace(
95-
"EventContextManager", "AsyncEventContextManager"
96-
)
97-
print("")
9895
async_prefix = "async " if is_async else ""
99-
if overload is not None:
96+
indent = len(name) + 9
97+
for overload in overloads:
10098
print(" @typing.overload")
99+
print(
100+
f" {async_prefix}def {name}({signature(overload, indent, True)}) -> {async_return_type_value(overload)}:"
101+
)
102+
print(" pass")
103+
print("")
101104
print(
102-
f" {async_prefix}def {name}({signature(value, len(name) + 9, overload is not None)}) -> {return_type_value}:"
105+
f" {async_prefix}def {name}({signature(value, indent)}) -> {async_return_type_value(value)}:"
103106
)
104-
if overload is None:
105-
documentation_provider.print_entry(
106-
class_name, name, get_type_hints(value, api_globals)
107+
documentation_provider.print_entry(
108+
class_name, name, get_type_hints(value, api_globals)
109+
)
110+
if "expect_" in name:
111+
print("")
112+
print(
113+
f" return AsyncEventContextManager(self._impl_obj.{name}({arguments(value, 12)}).future)"
107114
)
108-
if "expect_" in name:
109-
print("")
110-
print(
111-
f" return AsyncEventContextManager(self._impl_obj.{name}({arguments(value, 12)}).future)"
112-
)
113-
else:
114-
[prefix, suffix] = return_value(
115-
get_type_hints(value, api_globals)["return"]
115+
else:
116+
[prefix, suffix] = return_value(
117+
get_type_hints(value, api_globals)["return"]
118+
)
119+
if is_async:
120+
prefix += (
121+
f'await self._async("{to_snake_case(class_name)}.{name}", '
116122
)
117-
if is_async:
118-
prefix += (
119-
f'await self._async("{to_snake_case(class_name)}.{name}", '
120-
)
121-
suffix += ")"
122-
prefix = prefix + f"self._impl_obj.{name}("
123-
suffix = ")" + suffix
124-
print(
125-
f"""
123+
suffix += ")"
124+
prefix = prefix + f"self._impl_obj.{name}("
125+
suffix = ")" + suffix
126+
print(
127+
f"""
126128
return {prefix}{arguments(value, len(prefix))}{suffix}"""
127-
)
128-
else:
129-
print(" pass")
129+
)
130130
if class_name == "Playwright":
131131
print(
132132
"""

0 commit comments

Comments
 (0)