Skip to content

Commit b35c628

Browse files
committed
more functions
1 parent 37f9901 commit b35c628

File tree

3 files changed

+104
-0
lines changed

3 files changed

+104
-0
lines changed

jax/experimental/array_api/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,18 @@
127127
tanh as tanh,
128128
trunc as trunc,
129129
)
130+
131+
from jax.experimental.array_api._indexing_functions import take as take
132+
133+
from jax.experimental.array_api._manipulation_functions import (
134+
broadcast_arrays as broadcast_arrays,
135+
broadcast_to as broadcast_to,
136+
concat as concat,
137+
expand_dims as expand_dims,
138+
flip as flip,
139+
permute_dims as permute_dims,
140+
reshape as reshape,
141+
roll as roll,
142+
squeeze as squeeze,
143+
stack as stack,
144+
)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright 2023 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import jax
16+
17+
def take(x, indices, /, *, axis):
18+
return jax.numpy.take(x, indices, axis=axis)
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright 2023 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import jax
16+
from jax.experimental.array_api._data_type_functions import result_type as _result_type
17+
18+
19+
def broadcast_arrays(*arrays):
20+
"""Broadcasts one or more arrays against one another."""
21+
return jax.numpy.broadcast_arrays(*arrays)
22+
23+
24+
def broadcast_to(x, /, shape):
25+
"""Broadcasts an array to a specified shape."""
26+
return jax.numpy.broadcast_to(x, shape=shape)
27+
28+
29+
def concat(arrays, /, *, axis=0):
30+
"""Joins a sequence of arrays along an existing axis."""
31+
dtype = _result_type(*arrays)
32+
return jax.numpy.concatenate(arrays, axis=axis, dtype=dtype)
33+
34+
35+
def expand_dims(x, /, *, axis=0):
36+
"""Expands the shape of an array by inserting a new axis (dimension) of size one at the position specified by axis."""
37+
return jax.lax.expand_dims(x, dimensions=[axis])
38+
39+
40+
def flip(x, /, *, axis=None):
41+
"""Reverses the order of elements in an array along the given axis."""
42+
dimensions = list(axis) if isinstance(axis, tuple) else [axis]
43+
return jax.lax.rev(x, dimensions=dimensions)
44+
45+
46+
def permute_dims(x, /, axes):
47+
"""Permutes the axes (dimensions) of an array x."""
48+
return jax.lax.transpose(x, axes)
49+
50+
51+
def reshape(x, /, shape, *, copy=None):
52+
"""Reshapes an array without changing its data."""
53+
del copy # unused
54+
return jax.lax.reshape(x, shape)
55+
56+
57+
def roll(x, /, shift, *, axis=None):
58+
"""Rolls array elements along a specified axis."""
59+
return jax.numpy.roll(x, shift=shift, axis=axis)
60+
61+
62+
def squeeze(x, /, axis):
63+
"""Removes singleton dimensions (axes) from x."""
64+
dimensions = list(axis) if isinstance(axis, tuple) else [axis]
65+
return jax.lax.squeeze(x, dimensions=dimensions)
66+
67+
68+
def stack(arrays, /, *, axis=0):
69+
"""Joins a sequence of arrays along a new axis."""
70+
dtype = _result_type(*arrays)
71+
return jax.numpy.stack(arrays, axis=axis, dtype=dtype)

0 commit comments

Comments
 (0)