Skip to content

Commit cf2dfa7

Browse files
authored
REF: Share NumericArray/NumericDtype methods (#45997)
1 parent bf97db3 commit cf2dfa7

File tree

7 files changed

+89
-111
lines changed

7 files changed

+89
-111
lines changed

pandas/core/arrays/floating.py

Lines changed: 4 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import numpy as np
44

55
from pandas._typing import DtypeObj
6-
from pandas.util._decorators import cache_readonly
76

7+
from pandas.core.dtypes.common import is_float_dtype
88
from pandas.core.dtypes.dtypes import register_extension_dtype
99

1010
from pandas.core.arrays.numeric import (
@@ -24,13 +24,7 @@ class FloatingDtype(NumericDtype):
2424
"""
2525

2626
_default_np_dtype = np.dtype(np.float64)
27-
28-
def __repr__(self) -> str:
29-
return f"{self.name}Dtype()"
30-
31-
@property
32-
def _is_numeric(self) -> bool:
33-
return True
27+
_checker = is_float_dtype
3428

3529
@classmethod
3630
def construct_array_type(cls) -> type[FloatingArray]:
@@ -58,18 +52,8 @@ def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None:
5852
return None
5953

6054
@classmethod
61-
def _standardize_dtype(cls, dtype) -> FloatingDtype:
62-
if isinstance(dtype, str) and dtype.startswith("Float"):
63-
# Avoid DeprecationWarning from NumPy about np.dtype("Float64")
64-
# https://github.com/numpy/numpy/pull/7476
65-
dtype = dtype.lower()
66-
67-
if not issubclass(type(dtype), FloatingDtype):
68-
try:
69-
dtype = FLOAT_STR_TO_DTYPE[str(np.dtype(dtype))]
70-
except KeyError as err:
71-
raise ValueError(f"invalid dtype specified {dtype}") from err
72-
return dtype
55+
def _str_to_dtype_mapping(cls):
56+
return FLOAT_STR_TO_DTYPE
7357

7458
@classmethod
7559
def _safe_cast(cls, values: np.ndarray, dtype: np.dtype, copy: bool) -> np.ndarray:
@@ -151,22 +135,6 @@ class FloatingArray(NumericArray):
151135
_truthy_value = 1.0
152136
_falsey_value = 0.0
153137

154-
@cache_readonly
155-
def dtype(self) -> FloatingDtype:
156-
return FLOAT_STR_TO_DTYPE[str(self._data.dtype)]
157-
158-
def __init__(self, values: np.ndarray, mask: np.ndarray, copy: bool = False):
159-
if not (isinstance(values, np.ndarray) and values.dtype.kind == "f"):
160-
raise TypeError(
161-
"values should be floating numpy array. Use "
162-
"the 'pd.array' function instead"
163-
)
164-
if values.dtype == np.float16:
165-
# If we don't raise here, then accessing self.dtype would raise
166-
raise TypeError("FloatingArray does not support np.float16 dtype.")
167-
168-
super().__init__(values, mask, copy=copy)
169-
170138

171139
_dtype_docstring = """
172140
An ExtensionDtype for {dtype} data.

pandas/core/arrays/integer.py

Lines changed: 16 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
import numpy as np
44

55
from pandas._typing import DtypeObj
6-
from pandas.util._decorators import cache_readonly
76

87
from pandas.core.dtypes.base import register_extension_dtype
8+
from pandas.core.dtypes.common import is_integer_dtype
99

1010
from pandas.core.arrays.masked import BaseMaskedDtype
1111
from pandas.core.arrays.numeric import (
@@ -14,33 +14,18 @@
1414
)
1515

1616

17-
class _IntegerDtype(NumericDtype):
17+
class IntegerDtype(NumericDtype):
1818
"""
1919
An ExtensionDtype to hold a single size & kind of integer dtype.
2020
2121
These specific implementations are subclasses of the non-public
22-
_IntegerDtype. For example we have Int8Dtype to represent signed int 8s.
22+
IntegerDtype. For example we have Int8Dtype to represent signed int 8s.
2323
2424
The attributes name & type are set when these subclasses are created.
2525
"""
2626

2727
_default_np_dtype = np.dtype(np.int64)
28-
29-
def __repr__(self) -> str:
30-
sign = "U" if self.is_unsigned_integer else ""
31-
return f"{sign}Int{8 * self.itemsize}Dtype()"
32-
33-
@cache_readonly
34-
def is_signed_integer(self) -> bool:
35-
return self.kind == "i"
36-
37-
@cache_readonly
38-
def is_unsigned_integer(self) -> bool:
39-
return self.kind == "u"
40-
41-
@property
42-
def _is_numeric(self) -> bool:
43-
return True
28+
_checker = is_integer_dtype
4429

4530
@classmethod
4631
def construct_array_type(cls) -> type[IntegerArray]:
@@ -86,20 +71,8 @@ def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None:
8671
return None
8772

8873
@classmethod
89-
def _standardize_dtype(cls, dtype) -> _IntegerDtype:
90-
if isinstance(dtype, str) and (
91-
dtype.startswith("Int") or dtype.startswith("UInt")
92-
):
93-
# Avoid DeprecationWarning from NumPy about np.dtype("Int64")
94-
# https://github.com/numpy/numpy/pull/7476
95-
dtype = dtype.lower()
96-
97-
if not issubclass(type(dtype), _IntegerDtype):
98-
try:
99-
dtype = INT_STR_TO_DTYPE[str(np.dtype(dtype))]
100-
except KeyError as err:
101-
raise ValueError(f"invalid dtype specified {dtype}") from err
102-
return dtype
74+
def _str_to_dtype_mapping(cls):
75+
return INT_STR_TO_DTYPE
10376

10477
@classmethod
10578
def _safe_cast(cls, values: np.ndarray, dtype: np.dtype, copy: bool) -> np.ndarray:
@@ -189,26 +162,14 @@ class IntegerArray(NumericArray):
189162
Length: 3, dtype: UInt16
190163
"""
191164

192-
_dtype_cls = _IntegerDtype
165+
_dtype_cls = IntegerDtype
193166

194167
# The value used to fill '_data' to avoid upcasting
195168
_internal_fill_value = 1
196169
# Fill values used for any/all
197170
_truthy_value = 1
198171
_falsey_value = 0
199172

200-
@cache_readonly
201-
def dtype(self) -> _IntegerDtype:
202-
return INT_STR_TO_DTYPE[str(self._data.dtype)]
203-
204-
def __init__(self, values: np.ndarray, mask: np.ndarray, copy: bool = False):
205-
if not (isinstance(values, np.ndarray) and values.dtype.kind in ["i", "u"]):
206-
raise TypeError(
207-
"values should be integer numpy array. Use "
208-
"the 'pd.array' function instead"
209-
)
210-
super().__init__(values, mask, copy=copy)
211-
212173

213174
_dtype_docstring = """
214175
An ExtensionDtype for {dtype} integer data.
@@ -231,62 +192,62 @@ def __init__(self, values: np.ndarray, mask: np.ndarray, copy: bool = False):
231192

232193

233194
@register_extension_dtype
234-
class Int8Dtype(_IntegerDtype):
195+
class Int8Dtype(IntegerDtype):
235196
type = np.int8
236197
name = "Int8"
237198
__doc__ = _dtype_docstring.format(dtype="int8")
238199

239200

240201
@register_extension_dtype
241-
class Int16Dtype(_IntegerDtype):
202+
class Int16Dtype(IntegerDtype):
242203
type = np.int16
243204
name = "Int16"
244205
__doc__ = _dtype_docstring.format(dtype="int16")
245206

246207

247208
@register_extension_dtype
248-
class Int32Dtype(_IntegerDtype):
209+
class Int32Dtype(IntegerDtype):
249210
type = np.int32
250211
name = "Int32"
251212
__doc__ = _dtype_docstring.format(dtype="int32")
252213

253214

254215
@register_extension_dtype
255-
class Int64Dtype(_IntegerDtype):
216+
class Int64Dtype(IntegerDtype):
256217
type = np.int64
257218
name = "Int64"
258219
__doc__ = _dtype_docstring.format(dtype="int64")
259220

260221

261222
@register_extension_dtype
262-
class UInt8Dtype(_IntegerDtype):
223+
class UInt8Dtype(IntegerDtype):
263224
type = np.uint8
264225
name = "UInt8"
265226
__doc__ = _dtype_docstring.format(dtype="uint8")
266227

267228

268229
@register_extension_dtype
269-
class UInt16Dtype(_IntegerDtype):
230+
class UInt16Dtype(IntegerDtype):
270231
type = np.uint16
271232
name = "UInt16"
272233
__doc__ = _dtype_docstring.format(dtype="uint16")
273234

274235

275236
@register_extension_dtype
276-
class UInt32Dtype(_IntegerDtype):
237+
class UInt32Dtype(IntegerDtype):
277238
type = np.uint32
278239
name = "UInt32"
279240
__doc__ = _dtype_docstring.format(dtype="uint32")
280241

281242

282243
@register_extension_dtype
283-
class UInt64Dtype(_IntegerDtype):
244+
class UInt64Dtype(IntegerDtype):
284245
type = np.uint64
285246
name = "UInt64"
286247
__doc__ = _dtype_docstring.format(dtype="uint64")
287248

288249

289-
INT_STR_TO_DTYPE: dict[str, _IntegerDtype] = {
250+
INT_STR_TO_DTYPE: dict[str, IntegerDtype] = {
290251
"int8": Int8Dtype(),
291252
"int16": Int16Dtype(),
292253
"int32": Int32Dtype(),

pandas/core/arrays/numeric.py

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import numbers
44
from typing import (
55
TYPE_CHECKING,
6+
Any,
7+
Callable,
68
TypeVar,
79
)
810

@@ -17,6 +19,7 @@
1719
DtypeObj,
1820
)
1921
from pandas.errors import AbstractMethodError
22+
from pandas.util._decorators import cache_readonly
2023

2124
from pandas.core.dtypes.common import (
2225
is_bool_dtype,
@@ -41,6 +44,22 @@
4144

4245
class NumericDtype(BaseMaskedDtype):
4346
_default_np_dtype: np.dtype
47+
_checker: Callable[[Any], bool] # is_foo_dtype
48+
49+
def __repr__(self) -> str:
50+
return f"{self.name}Dtype()"
51+
52+
@cache_readonly
53+
def is_signed_integer(self) -> bool:
54+
return self.kind == "i"
55+
56+
@cache_readonly
57+
def is_unsigned_integer(self) -> bool:
58+
return self.kind == "u"
59+
60+
@property
61+
def _is_numeric(self) -> bool:
62+
return True
4463

4564
def __from_arrow__(
4665
self, array: pyarrow.Array | pyarrow.ChunkedArray
@@ -90,12 +109,27 @@ def __from_arrow__(
90109
else:
91110
return array_class._concat_same_type(results)
92111

112+
@classmethod
113+
def _str_to_dtype_mapping(cls):
114+
raise AbstractMethodError(cls)
115+
93116
@classmethod
94117
def _standardize_dtype(cls, dtype) -> NumericDtype:
95118
"""
96119
Convert a string representation or a numpy dtype to NumericDtype.
97120
"""
98-
raise AbstractMethodError(cls)
121+
if isinstance(dtype, str) and (dtype.startswith(("Int", "UInt", "Float"))):
122+
# Avoid DeprecationWarning from NumPy about np.dtype("Int64")
123+
# https://github.com/numpy/numpy/pull/7476
124+
dtype = dtype.lower()
125+
126+
if not issubclass(type(dtype), cls):
127+
mapping = cls._str_to_dtype_mapping()
128+
try:
129+
dtype = mapping[str(np.dtype(dtype))]
130+
except KeyError as err:
131+
raise ValueError(f"invalid dtype specified {dtype}") from err
132+
return dtype
99133

100134
@classmethod
101135
def _safe_cast(cls, values: np.ndarray, dtype: np.dtype, copy: bool) -> np.ndarray:
@@ -108,10 +142,7 @@ def _safe_cast(cls, values: np.ndarray, dtype: np.dtype, copy: bool) -> np.ndarr
108142

109143

110144
def _coerce_to_data_and_mask(values, mask, dtype, copy, dtype_cls, default_dtype):
111-
if default_dtype.kind == "f":
112-
checker = is_float_dtype
113-
else:
114-
checker = is_integer_dtype
145+
checker = dtype_cls._checker
115146

116147
inferred_type = None
117148

@@ -188,6 +219,29 @@ class NumericArray(BaseMaskedArray):
188219

189220
_dtype_cls: type[NumericDtype]
190221

222+
def __init__(self, values: np.ndarray, mask: np.ndarray, copy: bool = False):
223+
checker = self._dtype_cls._checker
224+
if not (isinstance(values, np.ndarray) and checker(values.dtype)):
225+
descr = (
226+
"floating"
227+
if self._dtype_cls.kind == "f" # type: ignore[comparison-overlap]
228+
else "integer"
229+
)
230+
raise TypeError(
231+
f"values should be {descr} numpy array. Use "
232+
"the 'pd.array' function instead"
233+
)
234+
if values.dtype == np.float16:
235+
# If we don't raise here, then accessing self.dtype would raise
236+
raise TypeError("FloatingArray does not support np.float16 dtype.")
237+
238+
super().__init__(values, mask, copy=copy)
239+
240+
@cache_readonly
241+
def dtype(self) -> NumericDtype:
242+
mapping = self._dtype_cls._str_to_dtype_mapping()
243+
return mapping[str(self._data.dtype)]
244+
191245
@classmethod
192246
def _coerce_to_array(
193247
cls, value, *, dtype: DtypeObj, copy: bool = False

pandas/core/arrays/string_.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
)
4646
from pandas.core.arrays.base import ExtensionArray
4747
from pandas.core.arrays.floating import FloatingDtype
48-
from pandas.core.arrays.integer import _IntegerDtype
48+
from pandas.core.arrays.integer import IntegerDtype
4949
from pandas.core.construction import extract_array
5050
from pandas.core.indexers import check_array_indexer
5151
from pandas.core.missing import isna
@@ -432,7 +432,7 @@ def astype(self, dtype, copy: bool = True):
432432
return self.copy()
433433
return self
434434

435-
elif isinstance(dtype, _IntegerDtype):
435+
elif isinstance(dtype, IntegerDtype):
436436
arr = self._ndarray.copy()
437437
mask = self.isna()
438438
arr[mask] = 0

0 commit comments

Comments
 (0)