Skip to content

Commit fc61a0a

Browse files
committed
Add support for arr.to_device()
1 parent eb4e4e6 commit fc61a0a

File tree

4 files changed

+65
-21
lines changed

4 files changed

+65
-21
lines changed

array-api-skips.txt

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,9 @@
33
# JAX doesn't yet support scalar boolean indexing
44
array_api_tests/test_array_object.py::test_getitem_masking
55

6-
# JAX arrays don't have a to_device() method
7-
array_api_tests/test_signatures.py::test_array_method_signature[to_device]
8-
96
# Hypothesis warning
107
array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices
118

12-
# JAX arrays don't yet support to_device
13-
array_api_tests/test_has_names.py::test_has_names[array_method-to_device]
14-
159
# Test suite attempts in-place mutation:
1610
array_api_tests/test_special_cases.py::test_binary
1711
array_api_tests/test_special_cases.py::test_iop
@@ -38,6 +32,7 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[__
3832
array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x, s)]
3933
array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x, s)]
4034
array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[__ge__(x, s)]
35+
array_api_tests/test_operators_and_elementwise_functions.py::test_less[__lt__(x, s)]
4136
array_api_tests/test_operators_and_elementwise_functions.py::test_not_equal[__ne__(x, s)]
4237

4338
# JAX's NaN sorting doesn't match specification

jax/experimental/array_api/__init__.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from __future__ import annotations
1616

17-
__array_api_version__ = '2022.12'
17+
from jax.experimental.array_api._version import __array_api_version__ as __array_api_version__
1818

1919
from jax.experimental.array_api import linalg as linalg
2020

@@ -190,17 +190,6 @@
190190
vecdot as vecdot,
191191
)
192192

193-
def _array_namespace(self, /, *, api_version: None | str = None):
194-
import sys
195-
if api_version is not None and api_version != __array_api_version__:
196-
raise ValueError(f"{api_version=!r} is not available; "
197-
f"available versions are: {[__array_api_version__]}")
198-
return sys.modules[__name__]
199-
200-
def _setup_array_type():
201-
# TODO(jakevdp): set on tracers as well?
202-
from jax._src.array import ArrayImpl
203-
setattr(ArrayImpl, "__array_namespace__", _array_namespace)
204-
205-
_setup_array_type()
206-
del _setup_array_type
193+
from jax.experimental.array_api import _array_methods
194+
_array_methods.add_array_object_methods()
195+
del _array_methods
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright 2023 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from typing import Any, Callable, Optional, Union
18+
19+
import jax
20+
from jax._src.array import ArrayImpl
21+
from jax.experimental.array_api._version import __array_api_version__
22+
23+
from jax._src.lib import xla_extension as xe
24+
25+
26+
def _array_namespace(self, /, *, api_version: None | str = None):
27+
if api_version is not None and api_version != __array_api_version__:
28+
raise ValueError(f"{api_version=!r} is not available; "
29+
f"available versions are: {[__array_api_version__]}")
30+
return jax.experimental.array_api
31+
32+
33+
def _to_device(self, device: xe.Device | Callable[[], xe.Device], /, *,
34+
stream: Optional[Union[int, Any]] = None):
35+
if stream is not None:
36+
raise NotImplementedError("stream argument of array.to_device()")
37+
# The type of device is defined by Array.device. In JAX, this is a callable that
38+
# returns a device, so we must handle this case to satisfy the API spec.
39+
return jax.device_put(self, device() if callable(device) else device)
40+
41+
42+
def add_array_object_methods():
43+
# TODO(jakevdp): set on tracers as well?
44+
setattr(ArrayImpl, "__array_namespace__", _array_namespace)
45+
setattr(ArrayImpl, "to_device", _to_device)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright 2023 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
__array_api_version__ = '2022.12'

0 commit comments

Comments
 (0)