|
| 1 | +# Copyright 2023 The JAX Authors. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import jax |
| 16 | +from jax.experimental.array_api._data_type_functions import result_type as _result_type |
| 17 | + |
| 18 | + |
| 19 | +def broadcast_arrays(*arrays): |
| 20 | + """Broadcasts one or more arrays against one another.""" |
| 21 | + return jax.numpy.broadcast_arrays(*arrays) |
| 22 | + |
| 23 | + |
| 24 | +def broadcast_to(x, /, shape): |
| 25 | + """Broadcasts an array to a specified shape.""" |
| 26 | + return jax.numpy.broadcast_to(x, shape=shape) |
| 27 | + |
| 28 | + |
| 29 | +def concat(arrays, /, *, axis=0): |
| 30 | + """Joins a sequence of arrays along an existing axis.""" |
| 31 | + dtype = _result_type(*arrays) |
| 32 | + return jax.numpy.concatenate(arrays, axis=axis, dtype=dtype) |
| 33 | + |
| 34 | + |
| 35 | +def expand_dims(x, /, *, axis=0): |
| 36 | + """Expands the shape of an array by inserting a new axis (dimension) of size one at the position specified by axis.""" |
| 37 | + return jax.lax.expand_dims(x, dimensions=[axis]) |
| 38 | + |
| 39 | + |
| 40 | +def flip(x, /, *, axis=None): |
| 41 | + """Reverses the order of elements in an array along the given axis.""" |
| 42 | + dimensions = list(axis) if isinstance(axis, tuple) else [axis] |
| 43 | + return jax.lax.rev(x, dimensions=dimensions) |
| 44 | + |
| 45 | + |
| 46 | +def permute_dims(x, /, axes): |
| 47 | + """Permutes the axes (dimensions) of an array x.""" |
| 48 | + return jax.lax.transpose(x, axes) |
| 49 | + |
| 50 | + |
| 51 | +def reshape(x, /, shape, *, copy=None): |
| 52 | + """Reshapes an array without changing its data.""" |
| 53 | + del copy # unused |
| 54 | + return jax.lax.reshape(x, shape) |
| 55 | + |
| 56 | + |
| 57 | +def roll(x, /, shift, *, axis=None): |
| 58 | + """Rolls array elements along a specified axis.""" |
| 59 | + return jax.numpy.roll(x, shift=shift, axis=axis) |
| 60 | + |
| 61 | + |
| 62 | +def squeeze(x, /, axis): |
| 63 | + """Removes singleton dimensions (axes) from x.""" |
| 64 | + dimensions = list(axis) if isinstance(axis, tuple) else [axis] |
| 65 | + return jax.lax.squeeze(x, dimensions=dimensions) |
| 66 | + |
| 67 | + |
| 68 | +def stack(arrays, /, *, axis=0): |
| 69 | + """Joins a sequence of arrays along a new axis.""" |
| 70 | + dtype = _result_type(*arrays) |
| 71 | + return jax.numpy.stack(arrays, axis=axis, dtype=dtype) |
0 commit comments