diff --git a/pyproject.toml b/pyproject.toml index fa8442a..3677e09 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,9 @@ ] dependencies = [ "typing-extensions>=4.14.0", + "optype>=0.9.3; python_version < '3.11'", + "optype>=0.12.0; python_version >= '3.11'", + "tomli>=1.2.0 ; python_full_version < '3.11'", ] [project.urls] @@ -127,9 +130,12 @@ version_tuple = {version_tuple!r} "D107", # Missing docstring in __init__ "D203", # 1 blank line required before class docstring "D213", # Multi-line docstring summary should start at the second line + "D401", # First line of docstring should be in imperative mood "FBT", # flake8-boolean-trap "FIX", # flake8-fixme "ISC001", # Conflicts with formatter + "PLW1641", # Object does not implement `__hash__` method + "PYI041", # Use `float` instead of `int | float` ] [tool.ruff.lint.pylint] @@ -143,10 +149,13 @@ version_tuple = {version_tuple!r} ] [tool.ruff.lint.flake8-import-conventions] - banned-from = ["array_api_typing"] + banned-from = ["array_api_typing", "optype", "optype.numpy", "optype.numpy.compat"] [tool.ruff.lint.flake8-import-conventions.extend-aliases] array_api_typing = "xpt" + optype = "op" + "optype.numpy" = "onp" + "optype.numpy.compat" = "npc" [tool.ruff.lint.isort] combine-as-imports = true diff --git a/src/array_api_typing/__init__.py b/src/array_api_typing/__init__.py index 3532743..5c541b3 100644 --- a/src/array_api_typing/__init__.py +++ b/src/array_api_typing/__init__.py @@ -1,10 +1,11 @@ """Static typing support for the array API standard.""" __all__ = ( + "Array", "HasArrayNamespace", "__version__", "__version_tuple__", ) -from ._namespace import HasArrayNamespace +from ._array import Array, HasArrayNamespace from ._version import version as __version__, version_tuple as __version_tuple__ diff --git a/src/array_api_typing/_array.py b/src/array_api_typing/_array.py new file mode 100644 index 0000000..d2d08a7 --- /dev/null +++ b/src/array_api_typing/_array.py @@ -0,0 +1,101 @@ +__all__ = ( + "Array", + "BoolArray", + "HasArrayNamespace", + "NumericArray", +) + +from pathlib import Path +from types import ModuleType +from typing import Literal, Protocol, TypeAlias +from typing_extensions import TypeVar + +import optype as op + +from ._utils import docstring_setter + +# Load docstrings from TOML file +try: + import tomllib +except ImportError: + import tomli as tomllib # type: ignore[import-not-found, no-redef] + +_docstrings_path = Path(__file__).parent / "_array_docstrings.toml" +with _docstrings_path.open("rb") as f: + _array_docstrings = tomllib.load(f)["docstrings"] + +NS_co = TypeVar("NS_co", covariant=True, default=ModuleType) +T_contra = TypeVar("T_contra", contravariant=True) + + +class HasArrayNamespace(Protocol[NS_co]): + """Protocol for classes that have an `__array_namespace__` method. + + Example: + >>> import array_api_typing as xpt + >>> + >>> class MyArray: + ... def __array_namespace__(self): + ... return object() + >>> + >>> x = MyArray() + >>> def has_array_namespace(x: xpt.HasArrayNamespace) -> bool: + ... return hasattr(x, "__array_namespace__") + >>> has_array_namespace(x) + True + + """ + + def __array_namespace__( + self, /, *, api_version: Literal["2021.12"] | None = None + ) -> NS_co: ... + + +@docstring_setter(**_array_docstrings) +class Array( + HasArrayNamespace[NS_co], + op.CanPosSelf, + op.CanNegSelf, + op.CanAddSame[T_contra], + op.CanIAddSelf[T_contra], + op.CanRAddSelf[T_contra], + op.CanSubSame[T_contra], + op.CanISubSelf[T_contra], + op.CanRSubSelf[T_contra], + op.CanMulSame[T_contra], + op.CanIMulSelf[T_contra], + op.CanRMulSelf[T_contra], + op.CanTruedivSame[T_contra], + op.CanITruedivSelf[T_contra], + op.CanRTruedivSelf[T_contra], + op.CanFloordivSame[T_contra], + op.CanIFloordivSelf[T_contra], + op.CanRFloordivSelf[T_contra], + op.CanModSame[T_contra], + op.CanIModSelf[T_contra], + op.CanRModSelf[T_contra], + op.CanPowSame[T_contra], + op.CanIPowSelf[T_contra], + op.CanRPowSelf[T_contra], + Protocol[T_contra, NS_co], +): + """Array API specification for array object attributes and methods.""" + + +BoolArray: TypeAlias = Array[bool, NS_co] +"""Array API specification for boolean array object attributes and methods. + +Specifically, this type alias fills the `T_contra` type variable with `bool`, +allowing for `bool` objects to be added, subtracted, multiplied, etc. to the +array object. + +""" + +NumericArray: TypeAlias = Array[float | int, NS_co] +"""Array API specification for numeric array object attributes and methods. + +Specifically, this type alias fills the `T_contra` type variable with `float | +int`, allowing for `float | int` objects to be added, subtracted, multiplied, +etc. to the array object. + +""" diff --git a/src/array_api_typing/_array_docstrings.toml b/src/array_api_typing/_array_docstrings.toml new file mode 100644 index 0000000..a7c0f25 --- /dev/null +++ b/src/array_api_typing/_array_docstrings.toml @@ -0,0 +1,398 @@ +[docstrings] + +__pos__ = ''' +Evaluates `+self_i` for each element of an array instance. + +Returns: + Self: An array containing the evaluated result for each element. + The returned array must have the same data type as `self`. + +See Also: + array_api_typing.Positive + +''' + +__neg__ = ''' +Evaluates `-self_i` for each element of an array instance. + +Returns: + Self: an array containing the evaluated result for each element in + `self`. The returned array must have a data type determined by Type + Promotion Rules. + +See Also: + array_api_typing.Negative + +''' + +__add__ = ''' +Calculates the sum for each element of an array instance with the respective +element of the array `other`. + +Args: + other: addend array. Must be compatible with `self` (see + Broadcasting). Should have a numeric data type. + +Returns: + Self: an array containing the element-wise sums. The returned array + must have a data type determined by Type Promotion Rules. + +See Also: + array_api_typing.Add + +''' + +__iadd__ = ''' +Calculates the in-place sum for each element of an array instance with the +respective element of the array `other`. + +Args: + other: addend array. Must be compatible with `self` (see Broadcasting). + Should have a numeric data type. + +Returns: + Self: `self`, after performing the in-place addition. The returned + array must have a data type determined by Type Promotion Rules. + +See Also: + array_api_typing.Add + +''' + +__radd__ = ''' +Calculates the sum for each element of the array `other` with the respective +element of an array instance. + +Args: + other: addend array. Must be compatible with `self` (see Broadcasting). + Should have a numeric data type. + +Returns: + Self: an array containing the element-wise sums. The returned array + must have a data type determined by Type Promotion Rules. + +See Also: + array_api_typing.Add + +''' + +__sub__ = ''' +Calculates the difference for each element of an array instance with the +respective element of the array other. + +The result of `self_i - other_i` must be the same as `self_i + +(-other_i)` and must be governed by the same floating-point rules as +addition (see `CanArrayAdd`). + +Args: + other: subtrahend array. Must be compatible with `self` (see + Broadcasting). Should have a numeric data type. + +Returns: + Self: an array containing the element-wise differences. The returned + array must have a data type determined by Type Promotion Rules. + +See Also: + array_api_typing.Subtract + +''' + +__isub__ = ''' +Calculates the in-place difference for each element of an array instance +with the respective element of the array `other`. + +Args: + other: subtrahend array. Must be compatible with `self` (see + Broadcasting). Should have a numeric data type. + +Returns: + Self: `self`, after performing the in-place subtraction. The returned + array must have a data type determined by Type Promotion Rules. + +See Also: + array_api_typing.Subtract + +''' + +__rsub__ = ''' +Calculates the difference for each element of the array `other` with the +respective element of an array instance. + +The result of `other_i - self_i` must be the same as `other_i + (-self_i)` +and must be governed by the same floating-point rules as addition (see +`CanArrayAdd`). + +Args: + other: minuend array. Must be compatible with `self` (see Broadcasting). + Should have a numeric data type. + +Returns: + Self: an array containing the element-wise differences. The returned + array must have a data type determined by Type Promotion Rules. + +See Also: + array_api_typing.Subtract + +''' + +__mul__ = ''' +Calculates the product for each element of an array instance with the +respective element of the array `other`. + +Args: + other: multiplicand array. Must be compatible with `self` (see + Broadcasting). Should have a numeric data type. + +Returns: + Self: an array containing the element-wise products. The returned + array must have a data type determined by Type Promotion Rules. + +See Also: + array_api_typing.Multiply + +''' + +__imul__ = ''' +Calculates the in-place product for each element of an array instance with +the respective element of the array `other`. + +Args: + other: multiplicand array. Must be compatible with `self` (see + Broadcasting). Should have a numeric data type. + +Returns: + Self: `self`, after performing the in-place multiplication. The returned + array must have a data type determined by Type Promotion Rules. + +See Also: + array_api_typing.Multiply + +''' + +__rmul__ = ''' +Calculates the product for each element of the array `other` with the +respective element of an array instance. + +Args: + other: multiplicand array. Must be compatible with `self` (see + Broadcasting). Should have a numeric data type. + +Returns: + Self: an array containing the element-wise products. The returned array + must have a data type determined by Type Promotion Rules. + +See Also: + array_api_typing.Multiply + +''' + +__truediv__ = ''' +Evaluates `self_i / other_i` for each element of an array instance with the +respective element of the array `other`. + +Args: + other: Must be compatible with `self` (see Broadcasting). Should have a + numeric data type. + +Returns: + Self: an array containing the element-wise results. The returned array + should have a floating-point data type determined by Type Promotion + Rules. + +See Also: + array_api_typing.TrueDiv + +''' + +__itruediv__ = ''' +Calculates the in-place quotient for each element of an array instance with +the respective element of the array `other`. + +Args: + other: divisor array. Must be compatible with `self` (see Broadcasting). + Should have a numeric data type. + +Returns: + Self: `self`, after performing the in-place true division. The returned + array must have a data type determined by Type Promotion Rules. + +See Also: + array_api_typing.TrueDiv + +''' + +__rtruediv__ = ''' +Calculates the quotient for each element of the array `other` with the +respective element of an array instance. + +Args: + other: dividend array. Must be compatible with `self` (see Broadcasting). + Should have a numeric data type. + +Returns: + Self: an array containing the element-wise quotients. The returned + array must have a data type determined by Type Promotion Rules. + +See Also: + array_api_typing.TrueDiv + +''' + +__floordiv__ = ''' +Evaluates `self_i // other_i` for each element of an array instance with the +respective element of the array `other`. + +Args: + other: Must be compatible with `self` (see Broadcasting). Should have a + numeric data type. + +Returns: + Self: an array containing the element-wise results. The returned array + must have a data type determined by Type Promotion Rules. + +See Also: + array_api_typing.FloorDiv + +''' + +__ifloordiv__ = ''' +Calculates the in-place floor division for each element of an array instance +with the respective element of the array `other`. + +Args: + other: divisor array. Must be compatible with `self` (see Broadcasting). + Should have a numeric data type. + +Returns: + Self: `self`, after performing the in-place floor division. The + returned array must have a data type determined by Type Promotion + Rules. + +See Also: + array_api_typing.FloorDiv + +''' + +__rfloordiv__ = ''' +Calculates the floor division for each element of the array `other` with the +respective element of an array instance. + +Args: + other: dividend array. Must be compatible with `self` (see Broadcasting). + Should have a numeric data type. + +Returns: + Self: an array containing the element-wise floor division results. The + returned array must have a data type determined by Type Promotion + Rules. + +See Also: + array_api_typing.FloorDiv + +''' + +__imod__ = ''' +Calculates the in-place remainder for each element of an array instance with +the respective element of the array `other`. + +Args: + other: divisor array. Must be compatible with `self` (see Broadcasting). + Should have a numeric data type. + +Returns: + Self: `self`, after performing the in-place modulo operation. The + returned array must have a data type determined by Type Promotion + Rules. + +See Also: + array_api_typing.Remainder + +''' + +__mod__ = ''' +Evaluates `self_i % other_i` for each element of an array instance with the +respective element of the array `other`. + +Args: + other: Must be compatible with `self` (see Broadcasting). Should have a + numeric data type. + +Returns: + Self: an array containing the element-wise results. Each element-wise + result must have the same sign as the respective element `other_i`. + The returned array must have a floating-point data type determined + by Type Promotion Rules. + +See Also: + array_api_typing.Remainder + +''' + +__rmod__ = ''' +Calculates the remainder for each element of the array `other` with the +respective element of an array instance. + +Args: + other: dividend array. Must be compatible with `self` (see Broadcasting). + Should have a numeric data type. + +Returns: + Self: an array containing the element-wise remainders. The returned + array must have a data type determined by Type Promotion Rules. + +See Also: + array_api_typing.Remainder + +''' + +__pow__ = ''' +Calculates an implementation-dependent approximation of exponentiation by +raising each element (the base) of an array instance to the power of +`other_i` (the exponent), where `other_i` is the corresponding element of +the array `other`. + +Args: + other: array whose elements correspond to the exponentiation exponent. + Must be compatible with `self` (see Broadcasting). Should have a + numeric data type. + +Returns: + Self: an array containing the element-wise results. The returned array + must have a data type determined by Type Promotion Rules. + +''' + +__ipow__ = ''' +Calculates the in-place power for each element of an array instance with the +respective element of the array `other`. + +Args: + other: exponent array. Must be compatible with `self` (see Broadcasting). + Should have a numeric data type. + +Returns: + Self: `self`, after performing the in-place power operation. The + returned array must have a data type determined by Type Promotion + Rules. + +See Also: + array_api_typing.Power + +''' + +__rpow__ = ''' +Calculates the power for each element of the array `other` raised to the +respective element of an array instance. + +Args: + other: base array. Must be compatible with `self` (see Broadcasting). + Should have a numeric data type. + +Returns: + Self: an array containing the element-wise powers. The returned array + must have a data type determined by Type Promotion Rules. + +See Also: + array_api_typing.Power + +''' diff --git a/src/array_api_typing/_namespace.py b/src/array_api_typing/_namespace.py index 2074f4e..e69de29 100644 --- a/src/array_api_typing/_namespace.py +++ b/src/array_api_typing/_namespace.py @@ -1,30 +0,0 @@ -__all__ = ("HasArrayNamespace",) - -from types import ModuleType -from typing import Literal, Protocol -from typing_extensions import TypeVar - -T_co = TypeVar("T_co", covariant=True, default=ModuleType) - - -class HasArrayNamespace(Protocol[T_co]): - """Protocol for classes that have an `__array_namespace__` method. - - Example: - >>> import array_api_typing as xpt - >>> - >>> class MyArray: - ... def __array_namespace__(self): - ... return object() - >>> - >>> x = MyArray() - >>> def has_array_namespace(x: xpt.HasArrayNamespace) -> bool: - ... return hasattr(x, "__array_namespace__") - >>> has_array_namespace(x) - True - - """ - - def __array_namespace__( - self, /, *, api_version: Literal["2021.12"] | None = None - ) -> T_co: ... diff --git a/src/array_api_typing/_utils.py b/src/array_api_typing/_utils.py new file mode 100644 index 0000000..ea07370 --- /dev/null +++ b/src/array_api_typing/_utils.py @@ -0,0 +1,64 @@ +"""Utility functions.""" + +from collections.abc import Callable +from enum import Enum, auto +from typing import Literal, TypeVar + +ClassT = TypeVar("ClassT") +DocstringTypes = str | None + + +class _Sentinel(Enum): + SKIP = auto() + + +def set_docstrings( + obj: type[ClassT], + main: DocstringTypes | Literal[_Sentinel.SKIP] = _Sentinel.SKIP, + /, + **method_docs: DocstringTypes, +) -> type[ClassT]: + """Set the docstring for a class and its methods. + + Args: + obj: The class to set the docstring for. + main: The main docstring for the class. If not provided, the + class docstring will not be modified. + method_docs: A mapping of method names to their docstrings. If a method + is not provided, its docstring will not be modified. + + Returns: + The class with updated docstrings. + + """ + if main is not _Sentinel.SKIP: + obj.__doc__ = main + + for name, doc in method_docs.items(): + method = getattr(obj, name) + method.__doc__ = doc + return obj + + +def docstring_setter( + main: DocstringTypes | Literal[_Sentinel.SKIP] = _Sentinel.SKIP, + /, + **method_docs: DocstringTypes, +) -> Callable[[type[ClassT]], type[ClassT]]: + """Decorator to set docstrings for a class and its methods. + + Args: + main: The main docstring for the class. If not provided, the + class docstring will not be modified. + method_docs: A mapping of method names to their docstrings. If a method + is not provided, its docstring will not be modified. + + Returns: + A decorator that sets the docstrings for the class and its methods. + + """ + + def decorator(cls: type[ClassT]) -> type[ClassT]: + return set_docstrings(cls, main, **method_docs) + + return decorator diff --git a/tests/integration/test_numpy1.pyi b/tests/integration/test_numpy1.pyi index 9367379..a933f2d 100644 --- a/tests/integration/test_numpy1.pyi +++ b/tests/integration/test_numpy1.pyi @@ -5,8 +5,16 @@ import numpy.array_api as np import array_api_typing as xpt +# Define an NDArray against which we can test the protocols +arr = np.eye(2) + ### # Ensure that `np.ndarray` instances are assignable to `xpt.HasArrayNamespace`. -arr = np.eye(2) arr_namespace: xpt.HasArrayNamespace[Any] = arr + +### +# Ensure that `np.ndarray` instances are assignable to `xpt.Array`. + +arr_array: xpt.Array[Any, Any] = arr +arr_floatarray: xpt.Array[float, Any] = arr diff --git a/tests/integration/test_numpy2.pyi b/tests/integration/test_numpy2.pyi index 64bed8c..affba67 100644 --- a/tests/integration/test_numpy2.pyi +++ b/tests/integration/test_numpy2.pyi @@ -4,8 +4,17 @@ import numpy.typing as npt import array_api_typing as xpt +# Define an NDArray against which we can test the protocols +arr: npt.NDArray[Any] + ### # Ensure that `np.ndarray` instances are assignable to `xpt.HasArrayNamespace`. -arr: npt.NDArray[Any] arr_namespace: xpt.HasArrayNamespace[Any] = arr + +### +# Ensure that `np.ndarray` instances are assignable to `xpt.Array`. + +arr_array: xpt.Array[Any, Any] = arr +arr_floatarray: xpt.Array[float, Any] = arr +arr_boolarray: xpt.Array[bool, Any] = arr diff --git a/uv.lock b/uv.lock index 276ae47..f9e9a5c 100644 --- a/uv.lock +++ b/uv.lock @@ -10,6 +10,9 @@ resolution-markers = [ name = "array-api-typing" source = { editable = "." } dependencies = [ + { name = "optype", version = "0.9.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "optype", version = "0.12.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "tomli", marker = "python_full_version < '3.11'" }, { name = "typing-extensions" }, ] @@ -44,7 +47,12 @@ test-runtime = [ ] [package.metadata] -requires-dist = [{ name = "typing-extensions", specifier = ">=4.14.0" }] +requires-dist = [ + { name = "optype", marker = "python_full_version < '3.11'", specifier = ">=0.9.3" }, + { name = "optype", marker = "python_full_version >= '3.11'", specifier = ">=0.12.0" }, + { name = "tomli", marker = "python_full_version < '3.11'", specifier = ">=1.2.0" }, + { name = "typing-extensions", specifier = ">=4.14.0" }, +] [package.metadata.requires-dev] dev = [ @@ -350,6 +358,36 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/48/6b/1c6b515a83d5564b1698a61efa245727c8feecf308f4091f565988519d20/numpy-2.3.1-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:e610832418a2bc09d974cc9fecebfa51e9532d6190223bc5ef6a7402ebf3b5cb", size = 12927246, upload-time = "2025-06-21T12:27:38.618Z" }, ] +[[package]] +name = "optype" +version = "0.9.3" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version < '3.11'", +] +dependencies = [ + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/88/3c/9d59b0167458b839273ad0c4fc5f62f787058d8f5aed7f71294963a99471/optype-0.9.3.tar.gz", hash = "sha256:5f09d74127d316053b26971ce441a4df01f3a01943601d3712dd6f34cdfbaf48", size = 96143, upload-time = "2025-03-31T17:00:08.392Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/73/d8/ac50e2982bdc2d3595dc2bfe3c7e5a0574b5e407ad82d70b5f3707009671/optype-0.9.3-py3-none-any.whl", hash = "sha256:2935c033265938d66cc4198b0aca865572e635094e60e6e79522852f029d9e8d", size = 84357, upload-time = "2025-03-31T17:00:06.464Z" }, +] + +[[package]] +name = "optype" +version = "0.12.0" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.11'", +] +dependencies = [ + { name = "typing-extensions", marker = "python_full_version >= '3.11' and python_full_version < '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d1/a5/f8faedc8bd43cff9f1d846ce9d1d6d6162886b7221c2c53b1b8d2c9fff4a/optype-0.12.0.tar.gz", hash = "sha256:d1314f486028bc8d53b8c3e6b65f493d999d983df11978272ff67c1876f8ce53", size = 98419, upload-time = "2025-07-16T16:27:29.666Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ff/0b/87427c7b4ea6480e18fa5f933f0407ac4ce0fd9b667568ff2c82b8d66069/optype-0.12.0-py3-none-any.whl", hash = "sha256:245163b14cb78a83f4bf862d2278a5f9aef001242ac8ce577d1891dd1e0b104f", size = 86090, upload-time = "2025-07-16T16:27:27.751Z" }, +] + [[package]] name = "orjson" version = "3.10.18"