Skip to content

TYP: reshape.merge #48590

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 19, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 70 additions & 23 deletions pandas/core/reshape/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
ArrayLike,
DtypeObj,
IndexLabel,
Shape,
Suffixes,
npt,
)
Expand Down Expand Up @@ -625,6 +626,9 @@ class _MergeOperation:
copy: bool
indicator: bool
validate: str | None
join_names: list[Hashable]
right_join_keys: list[AnyArrayLike]
left_join_keys: list[AnyArrayLike]

def __init__(
self,
Expand Down Expand Up @@ -960,9 +964,9 @@ def _maybe_add_join_keys(
rvals = result[name]._values
else:
# TODO: can we pin down take_right's type earlier?
take_right = extract_array(take_right, extract_numpy=True)
rfill = na_value_for_dtype(take_right.dtype)
rvals = algos.take_nd(take_right, right_indexer, fill_value=rfill)
taker = extract_array(take_right, extract_numpy=True)
rfill = na_value_for_dtype(taker.dtype)
rvals = algos.take_nd(taker, right_indexer, fill_value=rfill)

# if we have an all missing left_indexer
# make sure to just use the right values or vice-versa
Expand Down Expand Up @@ -1098,7 +1102,9 @@ def _create_join_index(
index = index.append(Index([fill_value]))
return index.take(indexer)

def _get_merge_keys(self):
def _get_merge_keys(
self,
) -> tuple[list[AnyArrayLike], list[AnyArrayLike], list[Hashable]]:
"""
Note: has side effects (copy/delete key columns)

Expand All @@ -1117,8 +1123,8 @@ def _get_merge_keys(self):
left_keys: list[AnyArrayLike] = []
right_keys: list[AnyArrayLike] = []
join_names: list[Hashable] = []
right_drop = []
left_drop = []
right_drop: list[Hashable] = []
left_drop: list[Hashable] = []

left, right = self.left, self.right

Expand Down Expand Up @@ -1168,6 +1174,7 @@ def _get_merge_keys(self):
right_keys.append(right.index)
if lk is not None and lk == rk: # FIXME: what about other NAs?
# avoid key upcast in corner case (length-0)
lk = cast(Hashable, lk)
if len(left) > 0:
right_drop.append(rk)
else:
Expand Down Expand Up @@ -1260,6 +1267,8 @@ def _maybe_coerce_merge_keys(self) -> None:
# if either left or right is a categorical
# then the must match exactly in categories & ordered
if lk_is_cat and rk_is_cat:
lk = cast(Categorical, lk)
rk = cast(Categorical, rk)
if lk._categories_match_up_to_permutation(rk):
continue

Expand All @@ -1286,7 +1295,22 @@ def _maybe_coerce_merge_keys(self) -> None:
elif is_integer_dtype(rk.dtype) and is_float_dtype(lk.dtype):
# GH 47391 numpy > 1.24 will raise a RuntimeError for nan -> int
with np.errstate(invalid="ignore"):
if not (lk == lk.astype(rk.dtype))[~np.isnan(lk)].all():
# error: Argument 1 to "astype" of "ndarray" has incompatible
# type "Union[ExtensionDtype, Any, dtype[Any]]"; expected
# "Union[dtype[Any], Type[Any], _SupportsDType[dtype[Any]]]"
casted = lk.astype(rk.dtype) # type: ignore[arg-type]

# Argument 1 to "__call__" of "_UFunc_Nin1_Nout1" has
# incompatible type "Union[ExtensionArray, ndarray[Any, Any],
# Index, Series]"; expected "Union[_SupportsArray[dtype[Any]],
# _NestedSequence[_SupportsArray[dtype[Any]]], bool, int,
# float, complex, str, bytes, _NestedSequence[Union[bool,
# int, float, complex, str, bytes]]]"
mask = ~np.isnan(lk) # type: ignore[arg-type]
match = lk == casted
# error: Item "ExtensionArray" of "Union[ExtensionArray,
# ndarray[Any, Any], Any]" has no attribute "all"
if not match[mask].all(): # type: ignore[union-attr]
warnings.warn(
"You are merging on int and float "
"columns where the float values "
Expand All @@ -1299,7 +1323,22 @@ def _maybe_coerce_merge_keys(self) -> None:
elif is_float_dtype(rk.dtype) and is_integer_dtype(lk.dtype):
# GH 47391 numpy > 1.24 will raise a RuntimeError for nan -> int
with np.errstate(invalid="ignore"):
if not (rk == rk.astype(lk.dtype))[~np.isnan(rk)].all():
# error: Argument 1 to "astype" of "ndarray" has incompatible
# type "Union[ExtensionDtype, Any, dtype[Any]]"; expected
# "Union[dtype[Any], Type[Any], _SupportsDType[dtype[Any]]]"
casted = rk.astype(lk.dtype) # type: ignore[arg-type]

# Argument 1 to "__call__" of "_UFunc_Nin1_Nout1" has
# incompatible type "Union[ExtensionArray, ndarray[Any, Any],
# Index, Series]"; expected "Union[_SupportsArray[dtype[Any]],
# _NestedSequence[_SupportsArray[dtype[Any]]], bool, int,
# float, complex, str, bytes, _NestedSequence[Union[bool,
# int, float, complex, str, bytes]]]"
mask = ~np.isnan(rk) # type: ignore[arg-type]
match = rk == casted
# error: Item "ExtensionArray" of "Union[ExtensionArray,
# ndarray[Any, Any], Any]" has no attribute "all"
if not match[mask].all(): # type: ignore[union-attr]
warnings.warn(
"You are merging on int and float "
"columns where the float values "
Expand Down Expand Up @@ -1370,11 +1409,11 @@ def _maybe_coerce_merge_keys(self) -> None:
# columns, and end up trying to merge
# incompatible dtypes. See GH 16900.
if name in self.left.columns:
typ = lk.categories.dtype if lk_is_cat else object
typ = cast(Categorical, lk).categories.dtype if lk_is_cat else object
self.left = self.left.copy()
self.left[name] = self.left[name].astype(typ)
if name in self.right.columns:
typ = rk.categories.dtype if rk_is_cat else object
typ = cast(Categorical, rk).categories.dtype if rk_is_cat else object
self.right = self.right.copy()
self.right[name] = self.right[name].astype(typ)

Expand Down Expand Up @@ -1592,7 +1631,7 @@ def get_join_indexers(
llab, rlab, shape = (list(x) for x in zipped)

# get flat i8 keys from label lists
lkey, rkey = _get_join_keys(llab, rlab, shape, sort)
lkey, rkey = _get_join_keys(llab, rlab, tuple(shape), sort)

# factorize keys to a dense i8 space
# `count` is the num. of unique keys
Expand Down Expand Up @@ -1922,7 +1961,9 @@ def _validate_left_right_on(self, left_on, right_on):

return left_on, right_on

def _get_merge_keys(self):
def _get_merge_keys(
self,
) -> tuple[list[AnyArrayLike], list[AnyArrayLike], list[Hashable]]:

# note this function has side effects
(left_join_keys, right_join_keys, join_names) = super()._get_merge_keys()
Expand Down Expand Up @@ -1954,7 +1995,8 @@ def _get_merge_keys(self):
if self.tolerance is not None:

if self.left_index:
lt = self.left.index
# Actually more specifically an Index
lt = cast(AnyArrayLike, self.left.index)
else:
lt = left_join_keys[-1]

Expand Down Expand Up @@ -2069,21 +2111,21 @@ def injection(obj):

# get tuple representation of values if more than one
if len(left_by_values) == 1:
left_by_values = left_by_values[0]
right_by_values = right_by_values[0]
lbv = left_by_values[0]
rbv = right_by_values[0]
else:
# We get here with non-ndarrays in test_merge_by_col_tz_aware
# and test_merge_groupby_multiple_column_with_categorical_column
left_by_values = flip(left_by_values)
right_by_values = flip(right_by_values)
lbv = flip(left_by_values)
rbv = flip(right_by_values)

# upcast 'by' parameter because HashTable is limited
by_type = _get_cython_type_upcast(left_by_values.dtype)
by_type = _get_cython_type_upcast(lbv.dtype)
by_type_caster = _type_casters[by_type]
# error: Cannot call function of unknown type
left_by_values = by_type_caster(left_by_values) # type: ignore[operator]
left_by_values = by_type_caster(lbv) # type: ignore[operator]
# error: Cannot call function of unknown type
right_by_values = by_type_caster(right_by_values) # type: ignore[operator]
right_by_values = by_type_caster(rbv) # type: ignore[operator]

# choose appropriate function by type
func = _asof_by_function(self.direction)
Expand Down Expand Up @@ -2139,7 +2181,7 @@ def _get_multiindex_indexer(
rcodes[i][mask] = shape[i] - 1

# get flat i8 join keys
lkey, rkey = _get_join_keys(lcodes, rcodes, shape, sort)
lkey, rkey = _get_join_keys(lcodes, rcodes, tuple(shape), sort)

# factorize keys to a dense i8 space
lkey, rkey, count = _factorize_keys(lkey, rkey, sort=sort)
Expand Down Expand Up @@ -2377,7 +2419,12 @@ def _sort_labels(
return new_left, new_right


def _get_join_keys(llab, rlab, shape, sort: bool):
def _get_join_keys(
llab: list[npt.NDArray[np.int64 | np.intp]],
rlab: list[npt.NDArray[np.int64 | np.intp]],
shape: Shape,
sort: bool,
) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.int64]]:

# how many levels can be done without overflow
nlev = next(
Expand Down Expand Up @@ -2405,7 +2452,7 @@ def _get_join_keys(llab, rlab, shape, sort: bool):

llab = [lkey] + llab[nlev:]
rlab = [rkey] + rlab[nlev:]
shape = [count] + shape[nlev:]
shape = (count,) + shape[nlev:]

return _get_join_keys(llab, rlab, shape, sort)

Expand Down