Skip to content

Commit e30f3ed

Browse files
Terji PetersenTerji Petersen
Terji Petersen
authored and
Terji Petersen
committed
fail on float16, but allow np.exp(int8_arrays)
1 parent 378589d commit e30f3ed

File tree

4 files changed

+43
-5
lines changed

4 files changed

+43
-5
lines changed

pandas/core/indexes/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -882,6 +882,9 @@ def __array_ufunc__(self, ufunc: np.ufunc, method: str_t, *inputs, **kwargs):
882882
# i.e. np.divmod, np.modf, np.frexp
883883
return tuple(self.__array_wrap__(x) for x in result)
884884

885+
if result.dtype == np.float16:
886+
result = result.astype(np.float32)
887+
885888
return self.__array_wrap__(result)
886889

887890
def __array_wrap__(self, result, context=None):

pandas/core/indexes/numeric.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,10 @@ def _ensure_array(cls, data, dtype, copy: bool):
134134
Ensure we have a valid array to pass to _simple_new.
135135
"""
136136
cls._validate_dtype(dtype)
137+
if dtype == np.float16:
138+
139+
# float16 not supported (no indexing engine)
140+
raise NotImplementedError("float16 indexes are not supported")
137141

138142
if not isinstance(data, (np.ndarray, Index)):
139143
# Coerce to ndarray if not already ndarray or Index

pandas/tests/indexes/numeric/test_numeric.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -474,11 +474,39 @@ def test_coerce_list(self):
474474
class TestFloat16Index:
475475
# float 16 indexes not supported
476476
# GH 49535
477-
def test_array(self):
478-
arr = np.array([1, 2, 3], dtype=np.float16)
479-
msg = "float16 indexes are not implemented"
477+
_index_cls = NumericIndex
478+
479+
def test_constructor(self):
480+
index_cls = self._index_cls
481+
dtype = np.float16
482+
483+
msg = "float16 indexes are not supported"
484+
485+
# explicit construction
486+
with pytest.raises(NotImplementedError, match=msg):
487+
index_cls([1, 2, 3, 4, 5], dtype=dtype)
488+
489+
with pytest.raises(NotImplementedError, match=msg):
490+
index_cls(np.array([1, 2, 3, 4, 5]), dtype=dtype)
491+
492+
with pytest.raises(NotImplementedError, match=msg):
493+
index_cls([1.0, 2, 3, 4, 5], dtype=dtype)
494+
495+
with pytest.raises(NotImplementedError, match=msg):
496+
index_cls(np.array([1.0, 2, 3, 4, 5]), dtype=dtype)
497+
498+
with pytest.raises(NotImplementedError, match=msg):
499+
index_cls([1.0, 2, 3, 4, 5], dtype=dtype)
500+
501+
with pytest.raises(NotImplementedError, match=msg):
502+
index_cls(np.array([1.0, 2, 3, 4, 5]), dtype=dtype)
503+
504+
# nan handling
505+
with pytest.raises(NotImplementedError, match=msg):
506+
index_cls([np.nan, np.nan], dtype=dtype)
507+
480508
with pytest.raises(NotImplementedError, match=msg):
481-
NumericIndex(arr)
509+
index_cls(np.array([np.nan]), dtype=dtype)
482510

483511

484512
class TestUIntNumericIndex(NumericInt):

pandas/tests/indexes/test_numpy_compat.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,10 @@ def test_numpy_ufuncs_basic(index, func):
7777
# coerces to float (e.g. np.sin)
7878
with np.errstate(all="ignore"):
7979
result = func(index)
80-
exp = Index(func(index.values), name=index.name)
80+
arr_result = func(index.values)
81+
if arr_result.dtype == np.float16:
82+
arr_result = arr_result.astype(np.float32)
83+
exp = Index(arr_result, name=index.name)
8184

8285
tm.assert_index_equal(result, exp)
8386
if type(index) is not Index or index.dtype == bool:

0 commit comments

Comments
 (0)