-
Notifications
You must be signed in to change notification settings - Fork 266
Description
For osandov/drgn#364, I'd like to be able to define a decorator that modifies the parameter types of the functions it wraps, but also preserves keywords.
For a toy example, imagine a decorator that wraps a function taking an int
as the first parameter (possibly as a keyword argument) so that it can also take a str
that is automatically converted to an int
. I can do this manually with overloads:
import inspect
import functools
from typing import Any, Callable, TypeVar, overload
R = TypeVar("R")
def takes_int_or_str(f: Callable[..., R]) -> Callable[..., R]:
param = next(iter(inspect.signature(f).parameters))
@functools.wraps(f)
def wrapper(*args: Any, **kwds: Any) -> R:
if param in kwds:
if isinstance(kwds[param], str):
kwds[param] = int(kwds[param])
elif isinstance(args[0], str):
return f(int(args[0]), *args[1:], **kwds)
return f(*args, **kwds)
return wrapper
@overload
def f(x: int) -> int: ...
@overload
def f(x: str) -> int: ...
@takes_int_or_str
def f(x: int) -> int:
return x * x
f(1)
f("2")
f(x="3")
But this is tedious and error-prone if you have a lot of functions using the decorator. ParamSpec
almost gets me what I want, but it doesn't preserve keywords:
import inspect
import functools
from typing import Any, Callable, Concatenate, ParamSpec, TypeVar
P = ParamSpec("P")
R = TypeVar("R")
def takes_int_or_str(
f: Callable[Concatenate[int, P], R]
) -> Callable[Concatenate[int | str, P], R]:
param = next(iter(inspect.signature(f).parameters))
@functools.wraps(f)
def wrapper(*args: Any, **kwds: Any) -> R:
if param in kwds:
if isinstance(kwds[param], str):
kwds[param] = int(kwds[param])
elif isinstance(args[0], str):
return f(int(args[0]), *args[1:], **kwds)
return f(*args, **kwds)
return wrapper
@takes_int_or_str
def f(x: int) -> int:
return x * x
f(1)
f("2")
# error: Unexpected keyword argument "x" for "f" [call-arg]
f(x="3")
This seems somewhat related to #1273, although I want to specify the transformation by parameter, not by input type.
P.S. I used Callable
here, but my real use cases need a Protocol
with __call__
because they have more complicated signatures.