diff --git a/doc/source/whatsnew/v0.24.0.txt b/doc/source/whatsnew/v0.24.0.txt index d7feb6e547b22..a0f12179970c4 100644 --- a/doc/source/whatsnew/v0.24.0.txt +++ b/doc/source/whatsnew/v0.24.0.txt @@ -445,6 +445,7 @@ ExtensionType Changes - Added ``ExtensionDtype._is_numeric`` for controlling whether an extension dtype is considered numeric (:issue:`22290`). - The ``ExtensionArray`` constructor, ``_from_sequence`` now take the keyword arg ``copy=False`` (:issue:`21185`) - Bug in :meth:`Series.get` for ``Series`` using ``ExtensionArray`` and integer index (:issue:`21257`) +- :meth:`~Series.shift` now dispatches to :meth:`ExtensionArray.shift` (:issue:`22386`) - :meth:`Series.combine()` works correctly with :class:`~pandas.api.extensions.ExtensionArray` inside of :class:`Series` (:issue:`20825`) - :meth:`Series.combine()` with scalar argument now works for any function type (:issue:`21248`) - :meth:`Series.astype` and :meth:`DataFrame.astype` now dispatch to :meth:`ExtensionArray.astype` (:issue:`21185:`). diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index cb82625e818a1..7bf13fb2fecc0 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -59,6 +59,10 @@ class ExtensionArray(object): * factorize / _values_for_factorize * argsort / _values_for_argsort + The remaining methods implemented on this class should be performant, + as they only compose abstract methods. Still, a more efficient + implementation may be available, and these methods can be overridden. + This class does not inherit from 'abc.ABCMeta' for performance reasons. Methods and properties required by the interface raise ``pandas.errors.AbstractMethodError`` and no ``register`` method is @@ -400,6 +404,40 @@ def dropna(self): return self[~self.isna()] + def shift(self, periods=1): + # type: (int) -> ExtensionArray + """ + Shift values by desired number. + + Newly introduced missing values are filled with + ``self.dtype.na_value``. + + .. versionadded:: 0.24.0 + + Parameters + ---------- + periods : int, default 1 + The number of periods to shift. Negative values are allowed + for shifting backwards. + + Returns + ------- + shifted : ExtensionArray + """ + # Note: this implementation assumes that `self.dtype.na_value` can be + # stored in an instance of your ExtensionArray with `self.dtype`. + if periods == 0: + return self.copy() + empty = self._from_sequence([self.dtype.na_value] * abs(periods), + dtype=self.dtype) + if periods > 0: + a = empty + b = self[:-periods] + else: + a = self[abs(periods):] + b = empty + return self._concat_same_type([a, b]) + def unique(self): """Compute the ExtensionArray of unique values. diff --git a/pandas/core/internals/blocks.py b/pandas/core/internals/blocks.py index 57d09ff33d8b4..e735b35653cd4 100644 --- a/pandas/core/internals/blocks.py +++ b/pandas/core/internals/blocks.py @@ -2068,6 +2068,18 @@ def interpolate(self, method='pad', axis=0, inplace=False, limit=None, limit=limit), placement=self.mgr_locs) + def shift(self, periods, axis=0, mgr=None): + """ + Shift the block by `periods`. + + Dispatches to underlying ExtensionArray and re-boxes in an + ExtensionBlock. + """ + # type: (int, Optional[BlockPlacement]) -> List[ExtensionBlock] + return [self.make_block_same_class(self.values.shift(periods=periods), + placement=self.mgr_locs, + ndim=self.ndim)] + class NumericBlock(Block): __slots__ = () @@ -2691,10 +2703,6 @@ def _try_coerce_result(self, result): return result - def shift(self, periods, axis=0, mgr=None): - return self.make_block_same_class(values=self.values.shift(periods), - placement=self.mgr_locs) - def to_dense(self): # Categorical.get_values returns a DatetimeIndex for datetime # categories, so we can't simply use `np.asarray(self.values)` like diff --git a/pandas/tests/extension/base/methods.py b/pandas/tests/extension/base/methods.py index c660687f16590..c8656808739c4 100644 --- a/pandas/tests/extension/base/methods.py +++ b/pandas/tests/extension/base/methods.py @@ -138,3 +138,28 @@ def test_combine_add(self, data_repeated): expected = pd.Series( orig_data1._from_sequence([a + val for a in list(orig_data1)])) self.assert_series_equal(result, expected) + + @pytest.mark.parametrize('frame', [True, False]) + @pytest.mark.parametrize('periods, indices', [ + (-2, [2, 3, 4, -1, -1]), + (0, [0, 1, 2, 3, 4]), + (2, [-1, -1, 0, 1, 2]), + ]) + def test_container_shift(self, data, frame, periods, indices): + # https://github.com/pandas-dev/pandas/issues/22386 + subset = data[:5] + data = pd.Series(subset, name='A') + expected = pd.Series(subset.take(indices, allow_fill=True), name='A') + + if frame: + result = data.to_frame(name='A').assign(B=1).shift(periods) + expected = pd.concat([ + expected, + pd.Series([1] * 5, name='B').shift(periods) + ], axis=1) + compare = self.assert_frame_equal + else: + result = data.shift(periods) + compare = self.assert_series_equal + + compare(result, expected)