Skip to content

Commit ab421e9

Browse files
committed
Add one_hot
1 parent dd7650a commit ab421e9

File tree

5 files changed

+196
-4
lines changed

5 files changed

+196
-4
lines changed

docs/api-reference.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
expand_dims
1616
isclose
1717
kron
18+
one_hot
1819
nunique
1920
pad
2021
setdiff1d

src/array_api_extra/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Extra array functions built on top of the array API standard."""
22

3-
from ._delegation import isclose, pad
3+
from ._delegation import isclose, one_hot, pad
44
from ._lib._at import at
55
from ._lib._funcs import (
66
apply_where,
@@ -32,6 +32,7 @@
3232
"kron",
3333
"lazy_apply",
3434
"nunique",
35+
"one_hot",
3536
"pad",
3637
"setdiff1d",
3738
"sinc",

src/array_api_extra/_delegation.py

Lines changed: 92 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,18 @@
99
array_namespace,
1010
is_cupy_namespace,
1111
is_dask_namespace,
12+
is_jax_array,
1213
is_jax_namespace,
1314
is_numpy_namespace,
1415
is_pydata_sparse_namespace,
16+
is_torch_array,
1517
is_torch_namespace,
1618
)
19+
from ._lib._utils._compat import device as get_device
1720
from ._lib._utils._helpers import asarrays
18-
from ._lib._utils._typing import Array
21+
from ._lib._utils._typing import Array, DType
1922

20-
__all__ = ["isclose", "pad"]
23+
__all__ = ["isclose", "one_hot", "pad"]
2124

2225

2326
def isclose(
@@ -112,6 +115,93 @@ def isclose(
112115
return _funcs.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan, xp=xp)
113116

114117

118+
def one_hot(
119+
x: Array,
120+
/,
121+
num_classes: int,
122+
*,
123+
dtype: DType | None = None,
124+
axis: int = -1,
125+
xp: ModuleType | None = None,
126+
) -> Array:
127+
"""
128+
One-hot encode the given indices.
129+
130+
Each index in the input `x` is encoded as a vector of zeros of length `num_classes`
131+
with the element at the given index set to one.
132+
133+
Parameters
134+
----------
135+
x : array
136+
An array with integral dtype and concrete size (``x.size`` cannot be `None`).
137+
num_classes : int
138+
Number of classes in the one-hot dimension.
139+
dtype : DType, optional
140+
The dtype of the return value. Defaults to the default float dtype (usually
141+
float64).
142+
axis : int or tuple of ints, optional
143+
Position(s) in the expanded axes where the new axis is placed.
144+
xp : array_namespace, optional
145+
The standard-compatible namespace for `x`. Default: infer.
146+
147+
Returns
148+
-------
149+
array
150+
An array having the same shape as `x` except for a new axis at the position
151+
given by `axis` having size `num_classes`. If `axis` is unspecified, it
152+
defaults to -1, which appends a new axis.
153+
154+
If ``x < 0`` or ``x >= num_classes``, then the result is undefined, may raise
155+
an exception, or may even cause a bad state. `x` is not checked.
156+
157+
Examples
158+
--------
159+
>>> import array_api_extra as xpx
160+
>>> import array-api-strict as xp
161+
>>> xpx.one_hot(xp.asarray([1, 2, 0]), 3)
162+
Array([[0., 1., 0.],
163+
[0., 0., 1.],
164+
[1., 0., 0.]], dtype=array_api_strict.float64)
165+
"""
166+
# Validate inputs.
167+
if xp is None:
168+
xp = array_namespace(x)
169+
if not xp.isdtype(x.dtype, "integral"):
170+
msg = "x must have an integral dtype."
171+
raise TypeError(msg)
172+
if dtype is None:
173+
dtype = xp.__array_namespace_info__().default_dtypes(device=get_device(x))[
174+
"real floating"
175+
]
176+
# Delegate where possible.
177+
if is_jax_namespace(xp):
178+
assert is_jax_array(x)
179+
from jax.nn import one_hot as jax_one_hot
180+
181+
return jax_one_hot(x, num_classes, dtype=dtype, axis=axis)
182+
if is_torch_namespace(xp):
183+
assert is_torch_array(x)
184+
from torch.nn.functional import one_hot as torch_one_hot
185+
186+
x = xp.astype(x, xp.int64) # PyTorch only supports int64 here.
187+
try:
188+
out = torch_one_hot(x, num_classes)
189+
except RuntimeError as e:
190+
raise IndexError from e
191+
out = xp.astype(out, dtype)
192+
else:
193+
out = _funcs.one_hot(
194+
x,
195+
num_classes,
196+
dtype=dtype,
197+
xp=xp,
198+
)
199+
200+
if axis != -1:
201+
out = xp.moveaxis(out, -1, axis)
202+
return out
203+
204+
115205
def pad(
116206
x: Array,
117207
pad_width: int | tuple[int, int] | Sequence[tuple[int, int]],

src/array_api_extra/_lib/_funcs.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
meta_namespace,
1717
ndindex,
1818
)
19-
from ._utils._typing import Array
19+
from ._utils._typing import Array, DType
2020

2121
__all__ = [
2222
"apply_where",
@@ -375,6 +375,35 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
375375
return xp.squeeze(c, axis=axes)
376376

377377

378+
def one_hot(
379+
x: Array,
380+
/,
381+
num_classes: int,
382+
*,
383+
dtype: DType,
384+
xp: ModuleType,
385+
) -> Array: # numpydoc ignore=PR01,RT01
386+
"""See docstring in `array_api_extra._delegation.py`."""
387+
x_size = _compat.size(x)
388+
if x_size is None: # pragma: no cover
389+
# This cannot be tested because there is no way to create an array with abstract
390+
# size today. However, it is blocked for the sake of type-checking and
391+
# future-proofing since x.size is allowed to be None according to the
392+
# specification.
393+
msg = "x must have a concrete size."
394+
raise TypeError(msg)
395+
# TODO: Benchmark whether this is faster on the numpy backend:
396+
# x_flattened = xp.reshape(x, (-1,))
397+
# out = xp.zeros((x.size, num_classes), dtype=dtype, device=_compat.device(x))
398+
# at(out)[xp.arange(x_size), x_flattened].set(1)
399+
# if x.ndim != 1:
400+
# out = xp.reshape(out, (*x.shape, num_classes))
401+
out = x[..., None] == xp.arange(
402+
num_classes, dtype=x.dtype, device=_compat.device(x)
403+
)
404+
return xp.astype(out, dtype)
405+
406+
378407
def create_diagonal(
379408
x: Array, /, *, offset: int = 0, xp: ModuleType | None = None
380409
) -> Array:

tests/test_funcs.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
isclose,
2222
kron,
2323
nunique,
24+
one_hot,
2425
pad,
2526
setdiff1d,
2627
sinc,
@@ -44,6 +45,7 @@
4445
lazy_xp_function(expand_dims)
4546
lazy_xp_function(kron)
4647
lazy_xp_function(nunique)
48+
lazy_xp_function(one_hot)
4749
lazy_xp_function(pad)
4850
# FIXME calls in1d which calls xp.unique_values without size
4951
lazy_xp_function(setdiff1d, jax_jit=False)
@@ -448,6 +450,75 @@ def test_xp(self, xp: ModuleType):
448450
)
449451

450452

453+
@pytest.mark.skip_xp_backend(
454+
Backend.SPARSE, reason="read-only backend without .at support"
455+
)
456+
class TestOneHot:
457+
@pytest.mark.parametrize("n_dim", range(4))
458+
@pytest.mark.parametrize("num_classes", [1, 3, 10])
459+
def test_dims_and_classes(self, xp: ModuleType, n_dim: int, num_classes: int):
460+
shape = tuple(range(2, 2 + n_dim))
461+
rng = np.random.default_rng(2347823)
462+
np_x = rng.integers(num_classes, size=shape)
463+
x = xp.asarray(np_x)
464+
y = one_hot(x, num_classes)
465+
assert y.shape == (*x.shape, num_classes)
466+
for *i_list, j in ndindex(*shape, num_classes):
467+
i = tuple(i_list)
468+
assert float(y[(*i, j)]) == (int(x[i]) == j)
469+
470+
def test_basic(self, xp: ModuleType):
471+
actual = one_hot(xp.asarray([0, 1, 2]), 3)
472+
expected = xp.asarray([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
473+
xp_assert_equal(actual, expected)
474+
475+
actual = one_hot(xp.asarray([1, 2, 0]), 3)
476+
expected = xp.asarray([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]])
477+
xp_assert_equal(actual, expected)
478+
479+
@pytest.mark.skip_xp_backend(
480+
Backend.TORCH_GPU, reason="Puts Pytorch into a bad state."
481+
)
482+
def test_out_of_bound(self, xp: ModuleType):
483+
# Undefined behavior. Either return zero, or raise.
484+
try:
485+
actual = one_hot(xp.asarray([-1, 3]), 3)
486+
except IndexError:
487+
return
488+
expected = xp.asarray([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
489+
xp_assert_equal(actual, expected)
490+
491+
@pytest.mark.parametrize(
492+
"int_dtype",
493+
["int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"],
494+
)
495+
def test_int_types(self, xp: ModuleType, int_dtype: str):
496+
dtype = getattr(xp, int_dtype)
497+
x = xp.asarray([0, 1, 2], dtype=dtype)
498+
actual = one_hot(x, 3)
499+
expected = xp.asarray([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
500+
xp_assert_equal(actual, expected)
501+
502+
def test_custom_dtype(self, xp: ModuleType):
503+
actual = one_hot(xp.asarray([0, 1, 2], dtype=xp.int32), 3, dtype=xp.bool)
504+
expected = xp.asarray(
505+
[[True, False, False], [False, True, False], [False, False, True]]
506+
)
507+
xp_assert_equal(actual, expected)
508+
509+
def test_axis(self, xp: ModuleType):
510+
expected = xp.asarray([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]).T
511+
actual = one_hot(xp.asarray([1, 2, 0]), 3, axis=0)
512+
xp_assert_equal(actual, expected)
513+
514+
actual = one_hot(xp.asarray([1, 2, 0]), 3, axis=-2)
515+
xp_assert_equal(actual, expected)
516+
517+
def test_non_integer(self, xp: ModuleType):
518+
with pytest.raises(TypeError):
519+
_ = one_hot(xp.asarray([1.0]), 3)
520+
521+
451522
@pytest.mark.skip_xp_backend(
452523
Backend.SPARSE, reason="read-only backend without .at support"
453524
)

0 commit comments

Comments
 (0)