Skip to content

Protocol for array objects #229

Open
@pmeier

Description

@pmeier

I've tried to tackle static typing and got a vendorable protocol that can be checked statically as well as at runtime for all but one case I'm going to detail below.

Protocol

import enum
from typing import Any, Optional, Protocol, Tuple, TypeVar, Union, runtime_checkable

A = TypeVar("A")


@runtime_checkable
class VendoredArrayProtocol(Protocol[A]):
    @property
    def dtype(self) -> Any:
        ...

    @property
    def device(self) -> Any:
        ...

    @property
    def ndim(self) -> int:
        ...

    @property
    def shape(self) -> Any:
        ...

    @property
    def size(self) -> int:
        ...

    @property
    def T(self) -> A:
        ...

    def __abs__(self) -> A:
        ...

    def __add__(self, other: Union[int, float, A], /) -> A:
        ...

    def __and__(self, other: Union[bool, int, A], /) -> A:
        ...

    def __array_namespace__(self, /, *, api_version: Optional[str] = None) -> Any:
        ...

    def __bool__(self) -> bool:
        ...

    def __dlpack__(self, /, *, stream: Optional[Union[int, Any]] = None) -> Any:
        ...

    def __dlpack_device__(self) -> Tuple[enum.IntEnum, int]:
        ...

    # This overrides the input type, since object.__eq__ handles any input
    # This overrides the return type, since object.__eq__ returns a bool
    def __eq__(  # type: ignore[override]
        self,
        other: Union[bool, int, float, A],
        /,
    ) -> A:  # type: ignore[override]
        ...

    def __float__(self) -> float:
        ...

    def __floordiv__(self, other: Union[int, float, A], /) -> A:
        ...

    def __ge__(self, other: Union[int, float, A], /) -> A:
        ...

    def __getitem__(
        self,
        key: Union[int, slice, Tuple[Union[int, slice], ...], A],
        /,
    ) -> A:
        ...

    def __gt__(self, other: Union[int, float, A], /) -> A:
        ...

    def __int__(self) -> int:
        ...

    def __invert__(self) -> A:
        ...

    def __le__(self, other: Union[int, float, A], /) -> A:
        ...

    def __len__(self) -> int:
        ...

    def __lshift__(self, other: Union[int, A], /) -> A:
        ...

    def __lt__(self, other: Union[int, float, A], /) -> A:
        ...

    def __matmul__(self, other: A) -> A:
        ...

    def __mod__(self, other: Union[int, float, A], /) -> A:
        ...

    def __mul__(self, other: Union[int, float, A], /) -> A:
        ...

    # This overrides the input type, since object.__ne__ handles any input
    # This overrides the return type, since object.__ne__ returns a bool
    def __ne__(  # type: ignore[override]
        self, other: Union[bool, int, float, A], /
    ) -> A:  # type: ignore[override]
        ...

    def __neg__(self) -> A:
        ...

    def __or__(self, other: Union[bool, int, A], /) -> A:
        ...

    def __pos__(self) -> A:
        ...

    def __pow__(self, other: Union[int, float, A], /) -> A:
        ...

    def __rshift__(self, other: Union[int, A], /) -> A:
        ...

    def __setitem__(
        self,
        key: Union[int, slice, Tuple[Union[int, slice], ...], A],
        value: Union[bool, int, float, A],
        /,
    ) -> None:
        ...

    def __sub__(self, other: Union[int, float, A], /) -> A:
        ...

    def __truediv__(self, other: Union[int, float, A], /) -> A:
        ...

    def __xor__(self, other: Union[bool, int, A], /) -> A:
        ...

To test everything yourself you can use this playground repo.

Current blocker

It is currently impossible to use Ellipsis in type annotations, since its alias ... has a different meaning there. Thus, it is currently impossible to correctly annotate the __getitem__ and __setitem__ methods. There is a fix for this in python/cpython/#22336, but it will only be shipped with Python 3.10. If we leave it out of the annotation, accessing the array with something like Array()[..., 0] will be flagged by mypy although it should be supported according to the specification.

Suggestes improvements

While working on the protocol I found a few issues that could be addressed:

  • Array.dtype, Array.device, Array.__array_namespace__(), and Array.__dlpack__() should return custom objects, but it is not specified how these objects "look like". In the current state of the protocol I've typed them as Any, but the specification should be more precise.

  • Array.shape should return Tuple[int, ...], but explicitly check for Python scalars in constant tests array-api-tests#15 (comment) implies that custom objects might also be possible. Maybe we can use Sequence[int]?

  • The type annotation of the stream parameter from Array.__dlpack__() reads Optional[Union[int, Any]] which is equivalent to Any but more concise.

  • The binary dunder methods take a specific input types for the other parameter. For example __add__ takes Union[int, float, Array]. IMO they should take Any and return NotImplemented in case they cannot work with the type. For example:

    class Array:
        def __add__(self, other: Any, /) -> "Array":
            if not isinstance(other, (int, float, Array)):
                return NotImplemented
    
            # perform addition

    This makes it harder for static type checkers to catch bugs, because statically something like Array() + None would be allowed, but it gives the other object a chance to work with the Array object by implementing the reflected dunder (here __radd__). If both objects do not know how to deal with the addition, Python will automatically raise a TypeError.

    Since the object class defines a __eq__ and __neq__ method according to the proposed scheme above, I needed to put # type: ignore[override] directives in the protocol for the input types.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions