Skip to content

Commit d9c5cf4

Browse files
committed
BUG: Multiindex.equals not commutative for ea dtype
1 parent 1bd193e commit d9c5cf4

File tree

3 files changed

+25
-16
lines changed

3 files changed

+25
-16
lines changed

doc/source/whatsnew/v1.5.0.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ Missing
347347

348348
MultiIndex
349349
^^^^^^^^^^
350-
-
350+
- Bug in :class:`MultiIndex.equals` not commutative when only one side has extension array dtype (:issue:`46026`)
351351
-
352352

353353
I/O

pandas/core/indexes/multi.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3224,8 +3224,8 @@ def convert_indexer(start, stop, step, indexer=indexer, codes=level_codes):
32243224
return convert_indexer(start, stop + 1, step)
32253225
else:
32263226
# sorted, so can return slice object -> view
3227-
i = algos.searchsorted(level_codes, start, side="left")
3228-
j = algos.searchsorted(level_codes, stop, side="right")
3227+
i = level_codes.searchsorted(start, side="left")
3228+
j = level_codes.searchsorted(stop, side="right")
32293229
return slice(i, j, step)
32303230

32313231
else:
@@ -3248,12 +3248,12 @@ def convert_indexer(start, stop, step, indexer=indexer, codes=level_codes):
32483248

32493249
if isinstance(idx, slice):
32503250
# e.g. test_partial_string_timestamp_multiindex
3251-
start = algos.searchsorted(level_codes, idx.start, side="left")
3251+
start = level_codes.searchsorted(idx.start, side="left")
32523252
# NB: "left" here bc of slice semantics
3253-
end = algos.searchsorted(level_codes, idx.stop, side="left")
3253+
end = level_codes.searchsorted(idx.stop, side="left")
32543254
else:
3255-
start = algos.searchsorted(level_codes, idx, side="left")
3256-
end = algos.searchsorted(level_codes, idx, side="right")
3255+
start = level_codes.searchsorted(idx, side="left")
3256+
end = level_codes.searchsorted(idx, side="right")
32573257

32583258
if start == end:
32593259
# The label is present in self.levels[level] but unused:
@@ -3304,10 +3304,10 @@ def get_locs(self, seq):
33043304
)
33053305

33063306
n = len(self)
3307-
# indexer is the list of all positions that we want to take; it
3308-
# is created on the first entry in seq and narrowed down as we
3309-
# look at remaining entries
3310-
indexer = None
3307+
# indexer is the list of all positions that we want to take; we
3308+
# start with it being everything and narrow it down as we look at each
3309+
# entry in `seq`
3310+
indexer = Index(np.arange(n))
33113311

33123312
if any(x is Ellipsis for x in seq):
33133313
raise NotImplementedError(
@@ -3330,9 +3330,7 @@ def _convert_to_indexer(r) -> Int64Index:
33303330
r = r.nonzero()[0]
33313331
return Int64Index(r)
33323332

3333-
def _update_indexer(idxr: Index, indexer: Index | None) -> Index:
3334-
if indexer is None:
3335-
return idxr
3333+
def _update_indexer(idxr: Index, indexer: Index) -> Index:
33363334
indexer_intersection = indexer.intersection(idxr)
33373335
if indexer_intersection.empty and not idxr.empty and not indexer.empty:
33383336
raise KeyError(seq)
@@ -3417,8 +3415,7 @@ def _update_indexer(idxr: Index, indexer: Index | None) -> Index:
34173415

34183416
elif com.is_null_slice(k):
34193417
# empty slice
3420-
if indexer is None:
3421-
indexer = Index(np.arange(n))
3418+
pass
34223419

34233420
elif isinstance(k, slice):
34243421

@@ -3611,6 +3608,10 @@ def equals(self, other: object) -> bool:
36113608
# i.e. ExtensionArray
36123609
if not self_values.equals(other_values):
36133610
return False
3611+
elif not isinstance(other_values, np.ndarray):
3612+
# i.e. other is ExtensionArray
3613+
if not other_values.equals(self_values):
3614+
return False
36143615
else:
36153616
if not array_equivalent(self_values, other_values):
36163617
return False

pandas/tests/indexes/multi/test_equivalence.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,3 +288,11 @@ def test_multiindex_compare():
288288
expected = Series([False, False])
289289
result = Series(midx > midx)
290290
tm.assert_series_equal(result, expected)
291+
292+
293+
def test_equals_ea_int_regular_int():
294+
# GH#46026
295+
mi1 = MultiIndex.from_arrays([Index([1, 2], dtype="Int64"), [3, 4]])
296+
mi2 = MultiIndex.from_arrays([[1, 2], [3, 4]])
297+
assert not mi1.equals(mi2)
298+
assert not mi2.equals(mi1)

0 commit comments

Comments
 (0)