Skip to content

Commit 596164b

Browse files
committed
Add initial array_api interface
1 parent 1831b3c commit 596164b

File tree

6 files changed

+828
-0
lines changed

6 files changed

+828
-0
lines changed
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# Copyright 2023 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
__array_api_version__ = '2022.12'
16+
17+
from jax.experimental.array_api._constants import (
18+
e as e,
19+
inf as inf,
20+
nan as nan,
21+
newaxis as newaxis,
22+
pi as pi,
23+
)
24+
25+
from jax.experimental.array_api._creation_functions import (
26+
arange as arange,
27+
asarray as asarray,
28+
empty as empty,
29+
empty_like as empty_like,
30+
eye as eye,
31+
from_dlpack as from_dlpack,
32+
full as full,
33+
full_like as full_like,
34+
linspace as linspace,
35+
meshgrid as meshgrid,
36+
ones as ones,
37+
ones_like as ones_like,
38+
tril as tril,
39+
triu as triu,
40+
zeros as zeros,
41+
zeros_like as zeros_like,
42+
)
43+
44+
from jax.experimental.array_api._data_type_functions import (
45+
astype as astype,
46+
can_cast as can_cast,
47+
finfo as finfo,
48+
iinfo as iinfo,
49+
isdtype as isdtype,
50+
result_type as result_type,
51+
)
52+
53+
from jax.experimental.array_api._dtypes import (
54+
bool as bool,
55+
int8 as int8,
56+
int16 as int16,
57+
int32 as int32,
58+
int64 as int64,
59+
uint8 as uint8,
60+
uint16 as uint16,
61+
uint32 as uint32,
62+
uint64 as uint64,
63+
float32 as float32,
64+
float64 as float64,
65+
complex64 as complex64,
66+
complex128 as complex128,
67+
)
68+
69+
from jax.experimental.array_api._elementwise_functions import (
70+
abs as abs,
71+
acos as acos,
72+
acosh as acosh,
73+
add as add,
74+
asin as asin,
75+
asinh as asinh,
76+
atan as atan,
77+
atan2 as atan2,
78+
atanh as atanh,
79+
bitwise_and as bitwise_and,
80+
bitwise_invert as bitwise_invert,
81+
bitwise_left_shift as bitwise_left_shift,
82+
bitwise_or as bitwise_or,
83+
bitwise_right_shift as bitwise_right_shift,
84+
bitwise_xor as bitwise_xor,
85+
ceil as ceil,
86+
conj as conj,
87+
cos as cos,
88+
cosh as cosh,
89+
divide as divide,
90+
equal as equal,
91+
exp as exp,
92+
expm1 as expm1,
93+
floor as floor,
94+
floor_divide as floor_divide,
95+
greater as greater,
96+
greater_equal as greater_equal,
97+
imag as imag,
98+
isfinite as isfinite,
99+
isinf as isinf,
100+
isnan as isnan,
101+
jax as jax,
102+
less as less,
103+
less_equal as less_equal,
104+
log as log,
105+
log10 as log10,
106+
log1p as log1p,
107+
log2 as log2,
108+
logaddexp as logaddexp,
109+
logical_and as logical_and,
110+
logical_not as logical_not,
111+
logical_or as logical_or,
112+
logical_xor as logical_xor,
113+
multiply as multiply,
114+
negative as negative,
115+
not_equal as not_equal,
116+
np as np,
117+
positive as positive,
118+
pow as pow,
119+
real as real,
120+
remainder as remainder,
121+
round as round,
122+
sign as sign,
123+
sin as sin,
124+
sinh as sinh,
125+
sqrt as sqrt,
126+
square as square,
127+
subtract as subtract,
128+
tan as tan,
129+
tanh as tanh,
130+
trunc as trunc,
131+
)
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import numpy as np
2+
3+
e = np.e
4+
inf = np.inf
5+
nan = np.nan
6+
newaxis = np.newaxis
7+
pi = np.pi
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright 2023 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import jax
16+
import jax.numpy as jnp
17+
18+
19+
def arange(start, /, stop=None, step=1, *, dtype=None, device=None):
20+
return jax.device_put(jnp.arange(start, stop, step, dtype=dtype), device=device)
21+
22+
def asarray(obj, /, *, dtype=None, device=None, copy=None):
23+
return jax.device_put(jnp.array(obj, dtype=dtype, copy=copy), device=device)
24+
25+
def empty(shape, *, dtype=None, device=None):
26+
return jax.device_put(jnp.empty(shape, dtype=dtype), device=device)
27+
28+
def empty_like(x, /, *, dtype=None, device=None):
29+
return jax.device_put(jnp.empty_like(x, dtype=dtype), device=device)
30+
31+
def eye(n_rows, n_cols=None, /, *, k=0, dtype=None, device=None):
32+
return jax.device_put(jnp.eye(n_rows, n_cols, k=k, dtype=dtype), device=device)
33+
34+
def from_dlpack(x, /):
35+
return jnp.from_dlpack(x)
36+
37+
def full(shape, fill_value, *, dtype=None, device=None):
38+
return jax.device_put(jnp.full(shape, fill_value, dtype=dtype), device=device)
39+
40+
def full_like(x, /, fill_value, *, dtype=None, device=None):
41+
return jax.device_put(jnp.full_like(x, fill_value=fill_value, dtype=dtype), device=device)
42+
43+
def linspace(start, stop, /, num, *, dtype=None, device=None, endpoint=True):
44+
return jax.device_put(jnp.linspace(start, stop, num=num, dtype=dtype, endpoint=endpoint), device=device)
45+
46+
def meshgrid(*arrays, indexing='xy'):
47+
return jnp.meshgrid(*arrays, indexing=indexing)
48+
49+
def ones(shape, *, dtype=None, device=None):
50+
return jax.device_put(jnp.ones(shape, dtype=dtype), device=device)
51+
52+
def ones_like(x, /, *, dtype=None, device=None):
53+
return jax.device_put(jnp.ones_like(x, dtype=dtype), device=device)
54+
55+
def tril(x, /, *, k=0):
56+
return jnp.tril(x, k=k)
57+
58+
def triu(x, /, *, k=0):
59+
return jnp.triu(x, k=k)
60+
61+
def zeros(shape, *, dtype=None, device=None):
62+
return jax.device_put(jnp.zeros(shape, dtype=dtype), device=device)
63+
64+
def zeros_like(x, /, *, dtype=None, device=None):
65+
return jax.device_put(jnp.zeros_like(x, dtype=dtype), device=device)

0 commit comments

Comments
 (0)