Skip to content

BUG/PERF: MultiIndex.difference losing EA dtype #48722

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions asv_bench/benchmarks/multiindex_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ class SetOperations:
params = [
("monotonic", "non_monotonic"),
("datetime", "int", "string", "ea_int"),
("intersection", "union", "symmetric_difference"),
("intersection", "union", "difference", "symmetric_difference"),
]
param_names = ["index_structure", "dtype", "method"]

Expand Down Expand Up @@ -268,7 +268,11 @@ def setup(self, index_structure, dtype, method):
if index_structure == "non_monotonic":
data = {k: mi[::-1] for k, mi in data.items()}

data = {k: {"left": mi, "right": mi[:-1]} for k, mi in data.items()}
# left and right each have values not in the other
for k, mi in data.items():
n = len(mi) // 10
data[k] = {"left": mi[:-n], "right": mi[n:]}

self.left = data[dtype]["left"]
self.right = data[dtype]["right"]

Expand Down
6 changes: 4 additions & 2 deletions doc/source/whatsnew/v1.6.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ Performance improvements
- Performance improvement for :meth:`MultiIndex.unique` (:issue:`48335`)
- Performance improvement in :meth:`DataFrame.join` when joining on a subset of a :class:`MultiIndex` (:issue:`48611`)
- Performance improvement for :meth:`MultiIndex.intersection` (:issue:`48604`)
- Performance improvement for :meth:`MultiIndex.difference` (:issue:`48722`)
- Performance improvement in ``var`` for nullable dtypes (:issue:`48379`).
- Performance improvement to :func:`read_sas` with ``blank_missing=True`` (:issue:`48502`)
-
Expand Down Expand Up @@ -211,10 +212,11 @@ MultiIndex
^^^^^^^^^^
- Bug in :class:`MultiIndex.set_levels` raising ``IndexError`` when setting empty level (:issue:`48636`)
- Bug in :meth:`MultiIndex.unique` losing extension array dtype (:issue:`48335`)
- Bug in :meth:`MultiIndex.intersection` losing extension array (:issue:`48604`)
- Bug in :meth:`MultiIndex.intersection` losing extension array dtype (:issue:`48604`)
- Bug in :meth:`MultiIndex.union` losing extension array (:issue:`48498`, :issue:`48505`)
- Bug in :meth:`MultiIndex.append` not checking names for equality (:issue:`48288`)
- Bug in :meth:`MultiIndex.symmetric_difference` losing extension array (:issue:`48607`)
- Bug in :meth:`MultiIndex.difference` losing extension array dtype (:issue:`48722`)
- Bug in :meth:`MultiIndex.symmetric_difference` losing extension array dtype (:issue:`48607`)
-

I/O
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3677,7 +3677,7 @@ def difference(self, other, sort=None):
return self._wrap_difference_result(other, result)

def _difference(self, other, sort):
# overridden by RangeIndex
# overridden by MultiIndex, RangeIndex

this = self.unique()

Expand Down
49 changes: 26 additions & 23 deletions pandas/core/indexes/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -3654,6 +3654,15 @@ def equal_levels(self, other: MultiIndex) -> bool:
# --------------------------------------------------------------------
# Set Methods

def _difference(self, other, sort) -> MultiIndex:
other, result_names = self._convert_can_do_setop(other)
this = self.unique()
indexer = this.get_indexer_for(other)
indexer = indexer.take((indexer != -1).nonzero()[0])
label_diff = np.setdiff1d(np.arange(len(this)), indexer, assume_unique=True)
result = this.take(label_diff)
return _maybe_try_sort(result, sort)

def _union(self, other, sort) -> MultiIndex:
other, result_names = self._convert_can_do_setop(other)
if (
Expand All @@ -3678,19 +3687,7 @@ def _union(self, other, sort) -> MultiIndex:
result = self.append(other.take(right_missing))
else:
result = self._get_reconciled_name_object(other)

if sort is None:
try:
result = result.sort_values()
except TypeError:
warnings.warn(
"The values in the array are unorderable. "
"Pass `sort=False` to suppress this warning.",
RuntimeWarning,
stacklevel=find_stack_level(inspect.currentframe()),
)
pass
return result
return _maybe_try_sort(result, sort)

def _is_comparable_dtype(self, dtype: DtypeObj) -> bool:
return is_object_dtype(dtype)
Expand Down Expand Up @@ -3730,16 +3727,7 @@ def _wrap_intersection_result(self, other, result) -> MultiIndex:

def _wrap_difference_result(self, other, result) -> MultiIndex:
_, result_names = self._convert_can_do_setop(other)

if len(result) == 0:
return MultiIndex(
levels=[[]] * self.nlevels,
codes=[[]] * self.nlevels,
names=result_names,
verify_integrity=False,
)
else:
return MultiIndex.from_tuples(result, sortorder=0, names=result_names)
return result.set_names(result_names)

def _convert_can_do_setop(self, other):
result_names = self.names
Expand Down Expand Up @@ -4023,3 +4011,18 @@ def _require_listlike(level, arr, arrname: str):
if not is_list_like(arr) or not is_list_like(arr[0]):
raise TypeError(f"{arrname} must be list of lists-like")
return level, arr


def _maybe_try_sort(result: MultiIndex, sort) -> MultiIndex:
if sort is None:
try:
result = result.sort_values()
except TypeError:
warnings.warn(
"The values in the array are unorderable. "
"Pass `sort=False` to suppress this warning.",
RuntimeWarning,
stacklevel=find_stack_level(inspect.currentframe()),
)
pass
return result
18 changes: 17 additions & 1 deletion pandas/tests/indexes/multi/test_setops.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,10 @@ def test_difference_sort_incomparable():

other = MultiIndex.from_product([[3, pd.Timestamp("2000"), 4], ["c", "d"]])
# sort=None, the default
msg = "sort order is undefined for incomparable objects"
msg = (
"The values in the array are unorderable. "
"Pass `sort=False` to suppress this warning."
)
with tm.assert_produces_warning(RuntimeWarning, match=msg):
result = idx.difference(other)
tm.assert_index_equal(result, idx)
Expand Down Expand Up @@ -646,3 +649,16 @@ def test_intersection_keep_ea_dtypes(val, any_numeric_ea_dtype):
result = midx.intersection(midx2)
expected = MultiIndex.from_arrays([Series([2], dtype=any_numeric_ea_dtype), [1]])
tm.assert_index_equal(result, expected)


@pytest.mark.parametrize("val", [pd.NA, 100])
def test_difference_keep_ea_dtypes(val, any_numeric_ea_dtype):
midx = MultiIndex.from_arrays(
[Series([1, 2], dtype=any_numeric_ea_dtype), [2, 1]], names=["a", None]
)
midx2 = MultiIndex.from_arrays(
[Series([1, 2, val], dtype=any_numeric_ea_dtype), [1, 1, 3]]
)
result = midx.difference(midx2)
expected = MultiIndex.from_arrays([Series([1], dtype=any_numeric_ea_dtype), [2]])
tm.assert_index_equal(result, expected)