diff --git a/pandas/core/reshape/merge.py b/pandas/core/reshape/merge.py index 524b26ff07769..d9b6cab518164 100644 --- a/pandas/core/reshape/merge.py +++ b/pandas/core/reshape/merge.py @@ -30,6 +30,7 @@ ArrayLike, DtypeObj, IndexLabel, + Shape, Suffixes, npt, ) @@ -625,6 +626,9 @@ class _MergeOperation: copy: bool indicator: bool validate: str | None + join_names: list[Hashable] + right_join_keys: list[AnyArrayLike] + left_join_keys: list[AnyArrayLike] def __init__( self, @@ -960,9 +964,9 @@ def _maybe_add_join_keys( rvals = result[name]._values else: # TODO: can we pin down take_right's type earlier? - take_right = extract_array(take_right, extract_numpy=True) - rfill = na_value_for_dtype(take_right.dtype) - rvals = algos.take_nd(take_right, right_indexer, fill_value=rfill) + taker = extract_array(take_right, extract_numpy=True) + rfill = na_value_for_dtype(taker.dtype) + rvals = algos.take_nd(taker, right_indexer, fill_value=rfill) # if we have an all missing left_indexer # make sure to just use the right values or vice-versa @@ -1098,7 +1102,9 @@ def _create_join_index( index = index.append(Index([fill_value])) return index.take(indexer) - def _get_merge_keys(self): + def _get_merge_keys( + self, + ) -> tuple[list[AnyArrayLike], list[AnyArrayLike], list[Hashable]]: """ Note: has side effects (copy/delete key columns) @@ -1117,8 +1123,8 @@ def _get_merge_keys(self): left_keys: list[AnyArrayLike] = [] right_keys: list[AnyArrayLike] = [] join_names: list[Hashable] = [] - right_drop = [] - left_drop = [] + right_drop: list[Hashable] = [] + left_drop: list[Hashable] = [] left, right = self.left, self.right @@ -1168,6 +1174,7 @@ def _get_merge_keys(self): right_keys.append(right.index) if lk is not None and lk == rk: # FIXME: what about other NAs? # avoid key upcast in corner case (length-0) + lk = cast(Hashable, lk) if len(left) > 0: right_drop.append(rk) else: @@ -1260,6 +1267,8 @@ def _maybe_coerce_merge_keys(self) -> None: # if either left or right is a categorical # then the must match exactly in categories & ordered if lk_is_cat and rk_is_cat: + lk = cast(Categorical, lk) + rk = cast(Categorical, rk) if lk._categories_match_up_to_permutation(rk): continue @@ -1286,7 +1295,22 @@ def _maybe_coerce_merge_keys(self) -> None: elif is_integer_dtype(rk.dtype) and is_float_dtype(lk.dtype): # GH 47391 numpy > 1.24 will raise a RuntimeError for nan -> int with np.errstate(invalid="ignore"): - if not (lk == lk.astype(rk.dtype))[~np.isnan(lk)].all(): + # error: Argument 1 to "astype" of "ndarray" has incompatible + # type "Union[ExtensionDtype, Any, dtype[Any]]"; expected + # "Union[dtype[Any], Type[Any], _SupportsDType[dtype[Any]]]" + casted = lk.astype(rk.dtype) # type: ignore[arg-type] + + # Argument 1 to "__call__" of "_UFunc_Nin1_Nout1" has + # incompatible type "Union[ExtensionArray, ndarray[Any, Any], + # Index, Series]"; expected "Union[_SupportsArray[dtype[Any]], + # _NestedSequence[_SupportsArray[dtype[Any]]], bool, int, + # float, complex, str, bytes, _NestedSequence[Union[bool, + # int, float, complex, str, bytes]]]" + mask = ~np.isnan(lk) # type: ignore[arg-type] + match = lk == casted + # error: Item "ExtensionArray" of "Union[ExtensionArray, + # ndarray[Any, Any], Any]" has no attribute "all" + if not match[mask].all(): # type: ignore[union-attr] warnings.warn( "You are merging on int and float " "columns where the float values " @@ -1299,7 +1323,22 @@ def _maybe_coerce_merge_keys(self) -> None: elif is_float_dtype(rk.dtype) and is_integer_dtype(lk.dtype): # GH 47391 numpy > 1.24 will raise a RuntimeError for nan -> int with np.errstate(invalid="ignore"): - if not (rk == rk.astype(lk.dtype))[~np.isnan(rk)].all(): + # error: Argument 1 to "astype" of "ndarray" has incompatible + # type "Union[ExtensionDtype, Any, dtype[Any]]"; expected + # "Union[dtype[Any], Type[Any], _SupportsDType[dtype[Any]]]" + casted = rk.astype(lk.dtype) # type: ignore[arg-type] + + # Argument 1 to "__call__" of "_UFunc_Nin1_Nout1" has + # incompatible type "Union[ExtensionArray, ndarray[Any, Any], + # Index, Series]"; expected "Union[_SupportsArray[dtype[Any]], + # _NestedSequence[_SupportsArray[dtype[Any]]], bool, int, + # float, complex, str, bytes, _NestedSequence[Union[bool, + # int, float, complex, str, bytes]]]" + mask = ~np.isnan(rk) # type: ignore[arg-type] + match = rk == casted + # error: Item "ExtensionArray" of "Union[ExtensionArray, + # ndarray[Any, Any], Any]" has no attribute "all" + if not match[mask].all(): # type: ignore[union-attr] warnings.warn( "You are merging on int and float " "columns where the float values " @@ -1370,11 +1409,11 @@ def _maybe_coerce_merge_keys(self) -> None: # columns, and end up trying to merge # incompatible dtypes. See GH 16900. if name in self.left.columns: - typ = lk.categories.dtype if lk_is_cat else object + typ = cast(Categorical, lk).categories.dtype if lk_is_cat else object self.left = self.left.copy() self.left[name] = self.left[name].astype(typ) if name in self.right.columns: - typ = rk.categories.dtype if rk_is_cat else object + typ = cast(Categorical, rk).categories.dtype if rk_is_cat else object self.right = self.right.copy() self.right[name] = self.right[name].astype(typ) @@ -1592,7 +1631,7 @@ def get_join_indexers( llab, rlab, shape = (list(x) for x in zipped) # get flat i8 keys from label lists - lkey, rkey = _get_join_keys(llab, rlab, shape, sort) + lkey, rkey = _get_join_keys(llab, rlab, tuple(shape), sort) # factorize keys to a dense i8 space # `count` is the num. of unique keys @@ -1922,7 +1961,9 @@ def _validate_left_right_on(self, left_on, right_on): return left_on, right_on - def _get_merge_keys(self): + def _get_merge_keys( + self, + ) -> tuple[list[AnyArrayLike], list[AnyArrayLike], list[Hashable]]: # note this function has side effects (left_join_keys, right_join_keys, join_names) = super()._get_merge_keys() @@ -1954,7 +1995,8 @@ def _get_merge_keys(self): if self.tolerance is not None: if self.left_index: - lt = self.left.index + # Actually more specifically an Index + lt = cast(AnyArrayLike, self.left.index) else: lt = left_join_keys[-1] @@ -2069,21 +2111,21 @@ def injection(obj): # get tuple representation of values if more than one if len(left_by_values) == 1: - left_by_values = left_by_values[0] - right_by_values = right_by_values[0] + lbv = left_by_values[0] + rbv = right_by_values[0] else: # We get here with non-ndarrays in test_merge_by_col_tz_aware # and test_merge_groupby_multiple_column_with_categorical_column - left_by_values = flip(left_by_values) - right_by_values = flip(right_by_values) + lbv = flip(left_by_values) + rbv = flip(right_by_values) # upcast 'by' parameter because HashTable is limited - by_type = _get_cython_type_upcast(left_by_values.dtype) + by_type = _get_cython_type_upcast(lbv.dtype) by_type_caster = _type_casters[by_type] # error: Cannot call function of unknown type - left_by_values = by_type_caster(left_by_values) # type: ignore[operator] + left_by_values = by_type_caster(lbv) # type: ignore[operator] # error: Cannot call function of unknown type - right_by_values = by_type_caster(right_by_values) # type: ignore[operator] + right_by_values = by_type_caster(rbv) # type: ignore[operator] # choose appropriate function by type func = _asof_by_function(self.direction) @@ -2139,7 +2181,7 @@ def _get_multiindex_indexer( rcodes[i][mask] = shape[i] - 1 # get flat i8 join keys - lkey, rkey = _get_join_keys(lcodes, rcodes, shape, sort) + lkey, rkey = _get_join_keys(lcodes, rcodes, tuple(shape), sort) # factorize keys to a dense i8 space lkey, rkey, count = _factorize_keys(lkey, rkey, sort=sort) @@ -2377,7 +2419,12 @@ def _sort_labels( return new_left, new_right -def _get_join_keys(llab, rlab, shape, sort: bool): +def _get_join_keys( + llab: list[npt.NDArray[np.int64 | np.intp]], + rlab: list[npt.NDArray[np.int64 | np.intp]], + shape: Shape, + sort: bool, +) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.int64]]: # how many levels can be done without overflow nlev = next( @@ -2405,7 +2452,7 @@ def _get_join_keys(llab, rlab, shape, sort: bool): llab = [lkey] + llab[nlev:] rlab = [rkey] + rlab[nlev:] - shape = [count] + shape[nlev:] + shape = (count,) + shape[nlev:] return _get_join_keys(llab, rlab, shape, sort)