diff --git a/pandas/_typing.py b/pandas/_typing.py index 9d64842373573..3d9872c55ca2d 100644 --- a/pandas/_typing.py +++ b/pandas/_typing.py @@ -131,9 +131,6 @@ # Series is passed into a function, a Series is always returned and if a DataFrame is # passed in, a DataFrame is always returned. NDFrameT = TypeVar("NDFrameT", bound="NDFrame") -# same as NDFrameT, needed when binding two pairs of parameters to potentially -# separate NDFrame-subclasses (see NDFrame.align) -NDFrameTb = TypeVar("NDFrameTb", bound="NDFrame") NumpyIndexT = TypeVar("NumpyIndexT", np.ndarray, "Index") diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 17e4a4c142f66..9b5cc4bb83649 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -68,10 +68,10 @@ Manager, NaPosition, NDFrameT, - NDFrameTb, RandomState, Renamer, Scalar, + Self, SortKind, StorageOptions, Suffixes, @@ -9317,8 +9317,8 @@ def compare( @doc(**_shared_doc_kwargs) def align( - self: NDFrameT, - other: NDFrameTb, + self, + other: NDFrameT, join: AlignJoin = "outer", axis: Axis | None = None, level: Level = None, @@ -9328,7 +9328,7 @@ def align( limit: int | None = None, fill_axis: Axis = 0, broadcast_axis: Axis | None = None, - ) -> tuple[NDFrameT, NDFrameTb]: + ) -> tuple[Self, NDFrameT]: """ Align two objects on their axes with the specified join method. @@ -9450,7 +9450,7 @@ def align( {c: self for c in other.columns}, **other._construct_axes_dict() ) # error: Incompatible return value type (got "Tuple[DataFrame, - # DataFrame]", expected "Tuple[NDFrameT, NDFrameTb]") + # DataFrame]", expected "Tuple[Self, NDFrameT]") return df._align_frame( # type: ignore[return-value] other, # type: ignore[arg-type] join=join, @@ -9461,7 +9461,7 @@ def align( method=method, limit=limit, fill_axis=fill_axis, - ) + )[:2] elif isinstance(other, ABCSeries): # this means self is a DataFrame, and we need to broadcast # other @@ -9470,7 +9470,7 @@ def align( {c: other for c in self.columns}, **self._construct_axes_dict() ) # error: Incompatible return value type (got "Tuple[NDFrameT, - # DataFrame]", expected "Tuple[NDFrameT, NDFrameTb]") + # DataFrame]", expected "Tuple[Self, NDFrameT]") return self._align_frame( # type: ignore[return-value] df, join=join, @@ -9481,14 +9481,13 @@ def align( method=method, limit=limit, fill_axis=fill_axis, - ) + )[:2] + _right: DataFrame | Series if axis is not None: axis = self._get_axis_number(axis) if isinstance(other, ABCDataFrame): - # error: Incompatible return value type (got "Tuple[NDFrameT, DataFrame]", - # expected "Tuple[NDFrameT, NDFrameTb]") - return self._align_frame( # type: ignore[return-value] + left, _right, join_index = self._align_frame( other, join=join, axis=axis, @@ -9499,10 +9498,9 @@ def align( limit=limit, fill_axis=fill_axis, ) + elif isinstance(other, ABCSeries): - # error: Incompatible return value type (got "Tuple[NDFrameT, Series]", - # expected "Tuple[NDFrameT, NDFrameTb]") - return self._align_series( # type: ignore[return-value] + left, _right, join_index = self._align_series( other, join=join, axis=axis, @@ -9516,9 +9514,27 @@ def align( else: # pragma: no cover raise TypeError(f"unsupported type: {type(other)}") + right = cast(NDFrameT, _right) + if self.ndim == 1 or axis == 0: + # If we are aligning timezone-aware DatetimeIndexes and the timezones + # do not match, convert both to UTC. + if is_datetime64tz_dtype(left.index.dtype): + if left.index.tz != right.index.tz: + if join_index is not None: + # GH#33671 copy to ensure we don't change the index on + # our original Series + left = left.copy(deep=False) + right = right.copy(deep=False) + left.index = join_index + right.index = join_index + + left = left.__finalize__(self) + right = right.__finalize__(other) + return left, right + @final def _align_frame( - self: NDFrameT, + self, other: DataFrame, join: AlignJoin = "outer", axis: Axis | None = None, @@ -9528,7 +9544,7 @@ def _align_frame( method=None, limit=None, fill_axis: Axis = 0, - ) -> tuple[NDFrameT, DataFrame]: + ) -> tuple[Self, DataFrame, Index | None]: # defaults join_index, join_columns = None, None ilidx, iridx = None, None @@ -9567,22 +9583,14 @@ def _align_frame( ) if method is not None: - _left = left.fillna(method=method, axis=fill_axis, limit=limit) - assert _left is not None # needed for mypy - left = _left + left = left.fillna(method=method, axis=fill_axis, limit=limit) right = right.fillna(method=method, axis=fill_axis, limit=limit) - # if DatetimeIndex have different tz, convert to UTC - left, right = _align_as_utc(left, right, join_index) - - return ( - left.__finalize__(self), - right.__finalize__(other), - ) + return left, right, join_index @final def _align_series( - self: NDFrameT, + self, other: Series, join: AlignJoin = "outer", axis: Axis | None = None, @@ -9592,7 +9600,7 @@ def _align_series( method=None, limit=None, fill_axis: Axis = 0, - ) -> tuple[NDFrameT, Series]: + ) -> tuple[Self, Series, Index | None]: is_series = isinstance(self, ABCSeries) if copy and using_copy_on_write(): copy = False @@ -9654,14 +9662,7 @@ def _align_series( left = left.fillna(fill_value, method=method, limit=limit, axis=fill_axis) right = right.fillna(fill_value, method=method, limit=limit) - # if DatetimeIndex have different tz, convert to UTC - if is_series or (not is_series and axis == 0): - left, right = _align_as_utc(left, right, join_index) - - return ( - left.__finalize__(self), - right.__finalize__(other), - ) + return left, right, join_index @final def _where( @@ -12824,23 +12825,3 @@ def _doc_params(cls): The required number of valid values to perform the operation. If fewer than ``min_count`` non-NA values are present the result will be NA. """ - - -def _align_as_utc( - left: NDFrameT, right: NDFrameTb, join_index: Index | None -) -> tuple[NDFrameT, NDFrameTb]: - """ - If we are aligning timezone-aware DatetimeIndexes and the timezones - do not match, convert both to UTC. - """ - if is_datetime64tz_dtype(left.index.dtype): - if left.index.tz != right.index.tz: - if join_index is not None: - # GH#33671 ensure we don't change the index on - # our original Series (NB: by default deep=False) - left = left.copy() - right = right.copy() - left.index = join_index - right.index = join_index - - return left, right diff --git a/pandas/tests/frame/methods/test_align.py b/pandas/tests/frame/methods/test_align.py index ec7d75ef4debb..a50aed401e82c 100644 --- a/pandas/tests/frame/methods/test_align.py +++ b/pandas/tests/frame/methods/test_align.py @@ -101,6 +101,7 @@ def test_align_float(self, float_frame, using_copy_on_write): with pytest.raises(ValueError, match=msg): float_frame.align(af.iloc[0, :3], join="inner", axis=2) + def test_align_frame_with_series(self, float_frame): # align dataframe to series with broadcast or not idx = float_frame.index s = Series(range(len(idx)), index=idx) @@ -118,6 +119,7 @@ def test_align_float(self, float_frame, using_copy_on_write): ) tm.assert_frame_equal(right, expected) + def test_align_series_condition(self): # see gh-9558 df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) result = df[df["a"] == 2]