Description
I've tried to tackle static typing and got a vendorable protocol that can be checked statically as well as at runtime for all but one case I'm going to detail below.
Protocol
import enum
from typing import Any, Optional, Protocol, Tuple, TypeVar, Union, runtime_checkable
A = TypeVar("A")
@runtime_checkable
class VendoredArrayProtocol(Protocol[A]):
@property
def dtype(self) -> Any:
...
@property
def device(self) -> Any:
...
@property
def ndim(self) -> int:
...
@property
def shape(self) -> Any:
...
@property
def size(self) -> int:
...
@property
def T(self) -> A:
...
def __abs__(self) -> A:
...
def __add__(self, other: Union[int, float, A], /) -> A:
...
def __and__(self, other: Union[bool, int, A], /) -> A:
...
def __array_namespace__(self, /, *, api_version: Optional[str] = None) -> Any:
...
def __bool__(self) -> bool:
...
def __dlpack__(self, /, *, stream: Optional[Union[int, Any]] = None) -> Any:
...
def __dlpack_device__(self) -> Tuple[enum.IntEnum, int]:
...
# This overrides the input type, since object.__eq__ handles any input
# This overrides the return type, since object.__eq__ returns a bool
def __eq__( # type: ignore[override]
self,
other: Union[bool, int, float, A],
/,
) -> A: # type: ignore[override]
...
def __float__(self) -> float:
...
def __floordiv__(self, other: Union[int, float, A], /) -> A:
...
def __ge__(self, other: Union[int, float, A], /) -> A:
...
def __getitem__(
self,
key: Union[int, slice, Tuple[Union[int, slice], ...], A],
/,
) -> A:
...
def __gt__(self, other: Union[int, float, A], /) -> A:
...
def __int__(self) -> int:
...
def __invert__(self) -> A:
...
def __le__(self, other: Union[int, float, A], /) -> A:
...
def __len__(self) -> int:
...
def __lshift__(self, other: Union[int, A], /) -> A:
...
def __lt__(self, other: Union[int, float, A], /) -> A:
...
def __matmul__(self, other: A) -> A:
...
def __mod__(self, other: Union[int, float, A], /) -> A:
...
def __mul__(self, other: Union[int, float, A], /) -> A:
...
# This overrides the input type, since object.__ne__ handles any input
# This overrides the return type, since object.__ne__ returns a bool
def __ne__( # type: ignore[override]
self, other: Union[bool, int, float, A], /
) -> A: # type: ignore[override]
...
def __neg__(self) -> A:
...
def __or__(self, other: Union[bool, int, A], /) -> A:
...
def __pos__(self) -> A:
...
def __pow__(self, other: Union[int, float, A], /) -> A:
...
def __rshift__(self, other: Union[int, A], /) -> A:
...
def __setitem__(
self,
key: Union[int, slice, Tuple[Union[int, slice], ...], A],
value: Union[bool, int, float, A],
/,
) -> None:
...
def __sub__(self, other: Union[int, float, A], /) -> A:
...
def __truediv__(self, other: Union[int, float, A], /) -> A:
...
def __xor__(self, other: Union[bool, int, A], /) -> A:
...
To test everything yourself you can use this playground repo.
Current blocker
It is currently impossible to use Ellipsis
in type annotations, since its alias ...
has a different meaning there. Thus, it is currently impossible to correctly annotate the __getitem__
and __setitem__
methods. There is a fix for this in python/cpython/#22336, but it will only be shipped with Python 3.10. If we leave it out of the annotation, accessing the array with something like Array()[..., 0]
will be flagged by mypy
although it should be supported according to the specification.
Suggestes improvements
While working on the protocol I found a few issues that could be addressed:
-
Array.dtype
,Array.device
,Array.__array_namespace__()
, andArray.__dlpack__()
should return custom objects, but it is not specified how these objects "look like". In the current state of the protocol I've typed them asAny
, but the specification should be more precise. -
Array.shape
should returnTuple[int, ...]
, but explicitly check for Python scalars in constant tests array-api-tests#15 (comment) implies that custom objects might also be possible. Maybe we can useSequence[int]
? -
The type annotation of the
stream
parameter fromArray.__dlpack__()
readsOptional[Union[int, Any]]
which is equivalent toAny
but more concise. -
The binary dunder methods take a specific input types for the
other
parameter. For example__add__
takesUnion[int, float, Array]
. IMO they should takeAny
and returnNotImplemented
in case they cannot work with the type. For example:class Array: def __add__(self, other: Any, /) -> "Array": if not isinstance(other, (int, float, Array)): return NotImplemented # perform addition
This makes it harder for static type checkers to catch bugs, because statically something like
Array() + None
would be allowed, but it gives theother
object a chance to work with theArray
object by implementing the reflected dunder (here__radd__
). If both objects do not know how to deal with the addition, Python will automatically raise aTypeError
.Since the
object
class defines a__eq__
and__neq__
method according to the proposed scheme above, I needed to put# type: ignore[override]
directives in the protocol for the input types.