Skip to content

TYP: core.reshape #52531

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pandas/core/reshape/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,7 +1004,7 @@ def _maybe_add_join_keys(
result_dtype = find_common_type([lvals.dtype, rvals.dtype])

if result._is_label_reference(name):
result[name] = Series(
result[name] = result._constructor_sliced(
key_col, dtype=result_dtype, index=result.index
)
elif result._is_level_reference(name):
Expand Down
27 changes: 19 additions & 8 deletions pandas/core/reshape/pivot.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def _add_margins(
rows,
cols,
aggfunc,
observed=None,
observed: bool,
margins_name: Hashable = "All",
fill_value=None,
):
Expand Down Expand Up @@ -292,7 +292,7 @@ def _add_margins(
if not values and isinstance(table, ABCSeries):
# If there are no values and the table is a series, then there is only
# one column in the data. Compute grand margin and return it.
return table._append(Series({key: grand_margin[margins_name]}))
return table._append(table._constructor({key: grand_margin[margins_name]}))

elif values:
marginal_result_set = _generate_marginal_results(
Expand Down Expand Up @@ -364,8 +364,16 @@ def _compute_grand_margin(


def _generate_marginal_results(
table, data, values, rows, cols, aggfunc, observed, margins_name: Hashable = "All"
table,
data: DataFrame,
values,
rows,
cols,
aggfunc,
observed: bool,
margins_name: Hashable = "All",
):
margin_keys: list | Index
if len(cols) > 0:
# need to "interleave" the margins
table_pieces = []
Expand Down Expand Up @@ -433,23 +441,24 @@ def _all_key(key):
new_order = [len(cols)] + list(range(len(cols)))
row_margin.index = row_margin.index.reorder_levels(new_order)
else:
row_margin = Series(np.nan, index=result.columns)
row_margin = data._constructor_sliced(np.nan, index=result.columns)

return result, margin_keys, row_margin


def _generate_marginal_results_without_values(
table: DataFrame,
data,
data: DataFrame,
rows,
cols,
aggfunc,
observed,
observed: bool,
margins_name: Hashable = "All",
):
margin_keys: list | Index
if len(cols) > 0:
# need to "interleave" the margins
margin_keys: list | Index = []
margin_keys = []

def _all_key():
if len(cols) == 1:
Expand Down Expand Up @@ -535,7 +544,9 @@ def pivot(
data.index.get_level_values(i) for i in range(data.index.nlevels)
]
else:
index_list = [Series(data.index, name=data.index.name)]
index_list = [
data._constructor_sliced(data.index, name=data.index.name)
]
else:
index_list = [data[idx] for idx in com.convert_to_list_like(index)]

Expand Down
22 changes: 12 additions & 10 deletions pandas/core/reshape/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@
)

if TYPE_CHECKING:
from pandas._typing import npt
from pandas._typing import (
Level,
npt,
)

from pandas.core.arrays import ExtensionArray
from pandas.core.indexes.frozen import FrozenList
Expand Down Expand Up @@ -98,9 +101,7 @@ class _Unstacker:
unstacked : DataFrame
"""

def __init__(self, index: MultiIndex, level=-1, constructor=None) -> None:
if constructor is None:
constructor = DataFrame
def __init__(self, index: MultiIndex, level: Level, constructor) -> None:
self.constructor = constructor

self.index = index.remove_unused_levels()
Expand Down Expand Up @@ -374,13 +375,14 @@ def new_index(self) -> MultiIndex:
)


def _unstack_multiple(data, clocs, fill_value=None):
def _unstack_multiple(data: Series | DataFrame, clocs, fill_value=None):
if len(clocs) == 0:
return data

# NOTE: This doesn't deal with hierarchical columns yet

index = data.index
index = cast(MultiIndex, index) # caller is responsible for checking

# GH 19966 Make sure if MultiIndexed index has tuple name, they will be
# recognised as a whole
Expand Down Expand Up @@ -433,10 +435,10 @@ def _unstack_multiple(data, clocs, fill_value=None):
return result

# GH#42579 deep=False to avoid consolidating
dummy = data.copy(deep=False)
dummy.index = dummy_index
dummy_df = data.copy(deep=False)
dummy_df.index = dummy_index

unstacked = dummy.unstack("__placeholder__", fill_value=fill_value)
unstacked = dummy_df.unstack("__placeholder__", fill_value=fill_value)
if isinstance(unstacked, Series):
unstcols = unstacked.index
else:
Expand Down Expand Up @@ -497,7 +499,7 @@ def unstack(obj: Series | DataFrame, level, fill_value=None):
)


def _unstack_frame(obj: DataFrame, level, fill_value=None):
def _unstack_frame(obj: DataFrame, level, fill_value=None) -> DataFrame:
assert isinstance(obj.index, MultiIndex) # checked by caller
unstacker = _Unstacker(obj.index, level=level, constructor=obj._constructor)

Expand Down Expand Up @@ -617,7 +619,7 @@ def factorize(index):
return frame._constructor_sliced(new_values, index=new_index)


def stack_multiple(frame, level, dropna: bool = True):
def stack_multiple(frame: DataFrame, level, dropna: bool = True):
# If all passed levels match up to column names, no
# ambiguity about what to do
if all(lev in frame.columns.names for lev in level):
Expand Down
28 changes: 19 additions & 9 deletions pandas/core/reshape/tile.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@
is_scalar,
is_timedelta64_dtype,
)
from pandas.core.dtypes.dtypes import ExtensionDtype
from pandas.core.dtypes.dtypes import (
DatetimeTZDtype,
ExtensionDtype,
)
from pandas.core.dtypes.generic import ABCSeries
from pandas.core.dtypes.missing import isna

Expand All @@ -47,7 +50,10 @@
import pandas.core.algorithms as algos

if TYPE_CHECKING:
from pandas._typing import IntervalLeftRight
from pandas._typing import (
DtypeObj,
IntervalLeftRight,
)


def cut(
Expand Down Expand Up @@ -399,7 +405,7 @@ def _bins_to_cuts(
labels=None,
precision: int = 3,
include_lowest: bool = False,
dtype=None,
dtype: DtypeObj | None = None,
duplicates: str = "raise",
ordered: bool = True,
):
Expand Down Expand Up @@ -481,7 +487,7 @@ def _coerce_to_type(x):
this method converts it to numeric so that cut or qcut method can
handle it
"""
dtype = None
dtype: DtypeObj | None = None

if is_datetime64tz_dtype(x.dtype):
dtype = x.dtype
Expand All @@ -508,7 +514,7 @@ def _coerce_to_type(x):
return x, dtype


def _convert_bin_to_numeric_type(bins, dtype):
def _convert_bin_to_numeric_type(bins, dtype: DtypeObj | None):
"""
if the passed bin is of datetime/timedelta type,
this method converts it to integer
Expand Down Expand Up @@ -542,7 +548,7 @@ def _convert_bin_to_numeric_type(bins, dtype):
return bins


def _convert_bin_to_datelike_type(bins, dtype):
def _convert_bin_to_datelike_type(bins, dtype: DtypeObj | None):
"""
Convert bins to a DatetimeIndex or TimedeltaIndex if the original dtype is
datelike
Expand All @@ -557,22 +563,26 @@ def _convert_bin_to_datelike_type(bins, dtype):
bins : Array-like of bins, DatetimeIndex or TimedeltaIndex if dtype is
datelike
"""
if is_datetime64tz_dtype(dtype):
if isinstance(dtype, DatetimeTZDtype):
bins = to_datetime(bins.astype(np.int64), utc=True).tz_convert(dtype.tz)
elif is_datetime_or_timedelta_dtype(dtype):
bins = Index(bins.astype(np.int64), dtype=dtype)
return bins


def _format_labels(
bins, precision: int, right: bool = True, include_lowest: bool = False, dtype=None
bins,
precision: int,
right: bool = True,
include_lowest: bool = False,
dtype: DtypeObj | None = None,
):
"""based on the dtype, return our labels"""
closed: IntervalLeftRight = "right" if right else "left"

formatter: Callable[[Any], Timestamp] | Callable[[Any], Timedelta]

if is_datetime64tz_dtype(dtype):
if isinstance(dtype, DatetimeTZDtype):
formatter = lambda x: Timestamp(x, tz=dtype.tz)
adjust = lambda x: x - Timedelta("1ns")
elif is_datetime64_dtype(dtype):
Expand Down