diff --git a/pandas/core/arrays/_mixins.py b/pandas/core/arrays/_mixins.py index 284dd31ffcb59..a947ab64f7380 100644 --- a/pandas/core/arrays/_mixins.py +++ b/pandas/core/arrays/_mixins.py @@ -2,6 +2,7 @@ import numpy as np +from pandas._libs import lib from pandas.compat.numpy import function as nv from pandas.errors import AbstractMethodError from pandas.util._decorators import cache_readonly, doc @@ -30,6 +31,12 @@ def _from_backing_data(self: _T, arr: np.ndarray) -> _T: """ raise AbstractMethodError(self) + def _box_func(self, x): + """ + Wrap numpy type in our dtype.type if necessary. + """ + return x + # ------------------------------------------------------------------------ def take( @@ -168,3 +175,22 @@ def _validate_setitem_key(self, key): def _validate_setitem_value(self, value): return value + + def __getitem__(self, key): + if lib.is_integer(key): + # fast-path + result = self._ndarray[key] + if self.ndim == 1: + return self._box_func(result) + return self._from_backing_data(result) + + key = self._validate_getitem_key(key) + result = self._ndarray[key] + if lib.is_scalar(result): + return self._box_func(result) + + result = self._from_backing_data(result) + return result + + def _validate_getitem_key(self, key): + return check_array_indexer(self, key) diff --git a/pandas/core/arrays/categorical.py b/pandas/core/arrays/categorical.py index 25073282ec0f6..347232d30bda8 100644 --- a/pandas/core/arrays/categorical.py +++ b/pandas/core/arrays/categorical.py @@ -1882,17 +1882,11 @@ def __getitem__(self, key): """ Return an item. """ - if isinstance(key, (int, np.integer)): - i = self._codes[key] - return self._box_func(i) - - key = check_array_indexer(self, key) - - result = self._codes[key] - if result.ndim > 1: + result = super().__getitem__(key) + if getattr(result, "ndim", 0) > 1: + result = result._ndarray deprecate_ndim_indexing(result) - return result - return self._from_backing_data(result) + return result def _validate_setitem_value(self, value): value = extract_array(value, extract_numpy=True) diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index e8b1c12687584..f42d3bee6bbea 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -539,23 +539,11 @@ def __getitem__(self, key): This getitem defers to the underlying array, which by-definition can only handle list-likes, slices, and integer scalars """ - - if lib.is_integer(key): - # fast-path - result = self._ndarray[key] - if self.ndim == 1: - return self._box_func(result) - return self._from_backing_data(result) - - key = self._validate_getitem_key(key) - result = self._ndarray[key] + result = super().__getitem__(key) if lib.is_scalar(result): - return self._box_func(result) - - result = self._from_backing_data(result) + return result - freq = self._get_getitem_freq(key) - result._freq = freq + result._freq = self._get_getitem_freq(key) return result def _validate_getitem_key(self, key): @@ -572,7 +560,7 @@ def _validate_getitem_key(self, key): # this for now (would otherwise raise in check_array_indexer) pass else: - key = check_array_indexer(self, key) + key = super()._validate_getitem_key(key) return key def _get_getitem_freq(self, key): @@ -582,7 +570,10 @@ def _get_getitem_freq(self, key): is_period = is_period_dtype(self.dtype) if is_period: freq = self.freq + elif self.ndim != 1: + freq = None else: + key = self._validate_getitem_key(key) # maybe ndarray[bool] -> slice freq = None if isinstance(key, slice): if self.freq is not None and key.step is not None: diff --git a/pandas/core/arrays/numpy_.py b/pandas/core/arrays/numpy_.py index 61ffa28d31ba0..afcae2c5c8b43 100644 --- a/pandas/core/arrays/numpy_.py +++ b/pandas/core/arrays/numpy_.py @@ -19,7 +19,6 @@ from pandas.core.arrays._mixins import NDArrayBackedExtensionArray from pandas.core.arrays.base import ExtensionOpsMixin from pandas.core.construction import extract_array -from pandas.core.indexers import check_array_indexer from pandas.core.missing import backfill_1d, pad_1d @@ -248,16 +247,11 @@ def __array_ufunc__(self, ufunc, method: str, *inputs, **kwargs): # ------------------------------------------------------------------------ # Pandas ExtensionArray Interface - def __getitem__(self, item): - if isinstance(item, type(self)): - item = item._ndarray + def _validate_getitem_key(self, key): + if isinstance(key, type(self)): + key = key._ndarray - item = check_array_indexer(self, item) - - result = self._ndarray[item] - if not lib.is_scalar(item): - result = type(self)(result) - return result + return super()._validate_getitem_key(key) def _validate_setitem_value(self, value): value = extract_array(value, extract_numpy=True)