Skip to content

Support decorators that modify parameter types and preserve keywords #1505

@osandov

Description

@osandov

For osandov/drgn#364, I'd like to be able to define a decorator that modifies the parameter types of the functions it wraps, but also preserves keywords.

For a toy example, imagine a decorator that wraps a function taking an int as the first parameter (possibly as a keyword argument) so that it can also take a str that is automatically converted to an int. I can do this manually with overloads:

import inspect
import functools
from typing import Any, Callable, TypeVar, overload


R = TypeVar("R")


def takes_int_or_str(f: Callable[..., R]) -> Callable[..., R]:
    param = next(iter(inspect.signature(f).parameters))

    @functools.wraps(f)
    def wrapper(*args: Any, **kwds: Any) -> R:
        if param in kwds:
            if isinstance(kwds[param], str):
                kwds[param] = int(kwds[param])
        elif isinstance(args[0], str):
            return f(int(args[0]), *args[1:], **kwds)
        return f(*args, **kwds)

    return wrapper


@overload
def f(x: int) -> int: ...
@overload
def f(x: str) -> int: ...
@takes_int_or_str
def f(x: int) -> int:
    return x * x


f(1)
f("2")
f(x="3")

But this is tedious and error-prone if you have a lot of functions using the decorator. ParamSpec almost gets me what I want, but it doesn't preserve keywords:

import inspect
import functools
from typing import Any, Callable, Concatenate, ParamSpec, TypeVar


P = ParamSpec("P")
R = TypeVar("R")


def takes_int_or_str(
    f: Callable[Concatenate[int, P], R]
) -> Callable[Concatenate[int | str, P], R]:
    param = next(iter(inspect.signature(f).parameters))

    @functools.wraps(f)
    def wrapper(*args: Any, **kwds: Any) -> R:
        if param in kwds:
            if isinstance(kwds[param], str):
                kwds[param] = int(kwds[param])
        elif isinstance(args[0], str):
            return f(int(args[0]), *args[1:], **kwds)
        return f(*args, **kwds)

    return wrapper


@takes_int_or_str
def f(x: int) -> int:
    return x * x


f(1)
f("2")
# error: Unexpected keyword argument "x" for "f"  [call-arg]
f(x="3")

This seems somewhat related to #1273, although I want to specify the transformation by parameter, not by input type.

P.S. I used Callable here, but my real use cases need a Protocol with __call__ because they have more complicated signatures.

Metadata

Metadata

Assignees

No one assigned

    Labels

    topic: featureDiscussions about new features for Python's type annotations

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions