diff --git a/pandas/core/frame.py b/pandas/core/frame.py index a1582a57e9a71..396108bab47b7 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -4572,20 +4572,23 @@ def shift( if axis == 1 and periods != 0 and fill_value is lib.no_default and ncols > 0: # We will infer fill_value to match the closest column + # Use a column that we know is valid for our column's dtype GH#38434 + label = self.columns[0] + if periods > 0: result = self.iloc[:, :-periods] for col in range(min(ncols, abs(periods))): # TODO(EA2D): doing this in a loop unnecessary with 2D EAs # Define filler inside loop so we get a copy filler = self.iloc[:, 0].shift(len(self)) - result.insert(0, col, filler, allow_duplicates=True) + result.insert(0, label, filler, allow_duplicates=True) else: result = self.iloc[:, -periods:] for col in range(min(ncols, abs(periods))): # Define filler inside loop so we get a copy filler = self.iloc[:, -1].shift(len(self)) result.insert( - len(result.columns), col, filler, allow_duplicates=True + len(result.columns), label, filler, allow_duplicates=True ) result.columns = self.columns.copy() diff --git a/pandas/tests/frame/methods/test_shift.py b/pandas/tests/frame/methods/test_shift.py index 2e21ce8ec2256..40b3f1e89c015 100644 --- a/pandas/tests/frame/methods/test_shift.py +++ b/pandas/tests/frame/methods/test_shift.py @@ -2,7 +2,7 @@ import pytest import pandas as pd -from pandas import DataFrame, Index, Series, date_range, offsets +from pandas import CategoricalIndex, DataFrame, Index, Series, date_range, offsets import pandas._testing as tm @@ -292,3 +292,25 @@ def test_shift_dt64values_int_fill_deprecated(self): expected = DataFrame({"A": [pd.Timestamp(0), pd.Timestamp(0)], "B": df2["A"]}) tm.assert_frame_equal(result, expected) + + def test_shift_axis1_categorical_columns(self): + # GH#38434 + ci = CategoricalIndex(["a", "b", "c"]) + df = DataFrame( + {"a": [1, 3], "b": [2, 4], "c": [5, 6]}, index=ci[:-1], columns=ci + ) + result = df.shift(axis=1) + + expected = DataFrame( + {"a": [np.nan, np.nan], "b": [1, 3], "c": [2, 4]}, index=ci[:-1], columns=ci + ) + tm.assert_frame_equal(result, expected) + + # periods != 1 + result = df.shift(2, axis=1) + expected = DataFrame( + {"a": [np.nan, np.nan], "b": [np.nan, np.nan], "c": [1, 3]}, + index=ci[:-1], + columns=ci, + ) + tm.assert_frame_equal(result, expected)