Skip to content

Commit 1113779

Browse files
authored
BUG: IntervalArray.__cmp__(pd.NA) (#44830)
1 parent f08f574 commit 1113779

File tree

3 files changed

+50
-19
lines changed

3 files changed

+50
-19
lines changed

pandas/_testing/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
is_unsigned_integer_dtype,
3939
pandas_dtype,
4040
)
41+
from pandas.core.dtypes.dtypes import IntervalDtype
4142

4243
import pandas as pd
4344
from pandas import (
@@ -282,6 +283,10 @@ def to_array(obj):
282283
return DatetimeArray._from_sequence(obj)
283284
elif is_timedelta64_dtype(dtype):
284285
return TimedeltaArray._from_sequence(obj)
286+
elif isinstance(obj, pd.core.arrays.BooleanArray):
287+
return obj
288+
elif isinstance(dtype, IntervalDtype):
289+
return pd.core.arrays.IntervalArray(obj)
285290
else:
286291
return np.array(obj)
287292

pandas/core/arrays/interval.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -685,6 +685,13 @@ def _cmp_method(self, other, op):
685685
other = pd_array(other)
686686
elif not isinstance(other, Interval):
687687
# non-interval scalar -> no matches
688+
if other is NA:
689+
# GH#31882
690+
from pandas.core.arrays import BooleanArray
691+
692+
arr = np.empty(self.shape, dtype=bool)
693+
mask = np.ones(self.shape, dtype=bool)
694+
return BooleanArray(arr, mask)
688695
return invalid_comparison(self, other, op)
689696

690697
# determine the dtype of the elements we want to compare
@@ -743,7 +750,8 @@ def _cmp_method(self, other, op):
743750
if obj is NA:
744751
# comparison with np.nan returns NA
745752
# github.com/pandas-dev/pandas/pull/37124#discussion_r509095092
746-
result[i] = op is operator.ne
753+
result = result.astype(object)
754+
result[i] = NA
747755
else:
748756
raise
749757
return result

pandas/tests/arithmetic/test_interval.py

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@
2020
timedelta_range,
2121
)
2222
import pandas._testing as tm
23-
from pandas.core.arrays import IntervalArray
23+
from pandas.core.arrays import (
24+
BooleanArray,
25+
IntervalArray,
26+
)
27+
from pandas.tests.arithmetic.common import get_upcast_box
2428

2529

2630
@pytest.fixture(
@@ -129,18 +133,37 @@ def test_compare_scalar_interval_mixed_closed(self, op, closed, other_closed):
129133
expected = self.elementwise_comparison(op, interval_array, other)
130134
tm.assert_numpy_array_equal(result, expected)
131135

132-
def test_compare_scalar_na(self, op, interval_array, nulls_fixture, request):
133-
result = op(interval_array, nulls_fixture)
134-
expected = self.elementwise_comparison(op, interval_array, nulls_fixture)
136+
def test_compare_scalar_na(
137+
self, op, interval_array, nulls_fixture, box_with_array, request
138+
):
139+
box = box_with_array
140+
141+
if box is pd.DataFrame:
142+
if interval_array.dtype.subtype.kind not in "iuf":
143+
mark = pytest.mark.xfail(
144+
reason="raises on DataFrame.transpose (would be fixed by EA2D)"
145+
)
146+
request.node.add_marker(mark)
147+
148+
obj = tm.box_expected(interval_array, box)
149+
result = op(obj, nulls_fixture)
150+
151+
if nulls_fixture is pd.NA:
152+
# GH#31882
153+
exp = np.ones(interval_array.shape, dtype=bool)
154+
expected = BooleanArray(exp, exp)
155+
else:
156+
expected = self.elementwise_comparison(op, interval_array, nulls_fixture)
157+
158+
if not (box is Index and nulls_fixture is pd.NA):
159+
# don't cast expected from BooleanArray to ndarray[object]
160+
xbox = get_upcast_box(obj, nulls_fixture, True)
161+
expected = tm.box_expected(expected, xbox)
135162

136-
if nulls_fixture is pd.NA and interval_array.dtype.subtype != "int64":
137-
mark = pytest.mark.xfail(
138-
raises=AssertionError,
139-
reason="broken for non-integer IntervalArray; see GH 31882",
140-
)
141-
request.node.add_marker(mark)
163+
tm.assert_equal(result, expected)
142164

143-
tm.assert_numpy_array_equal(result, expected)
165+
rev = op(nulls_fixture, obj)
166+
tm.assert_equal(rev, expected)
144167

145168
@pytest.mark.parametrize(
146169
"other",
@@ -214,17 +237,12 @@ def test_compare_list_like_object(self, op, interval_array, other):
214237
expected = self.elementwise_comparison(op, interval_array, other)
215238
tm.assert_numpy_array_equal(result, expected)
216239

217-
def test_compare_list_like_nan(self, op, interval_array, nulls_fixture, request):
240+
def test_compare_list_like_nan(self, op, interval_array, nulls_fixture):
218241
other = [nulls_fixture] * 4
219242
result = op(interval_array, other)
220243
expected = self.elementwise_comparison(op, interval_array, other)
221244

222-
if nulls_fixture is pd.NA and interval_array.dtype.subtype != "i8":
223-
reason = "broken for non-integer IntervalArray; see GH 31882"
224-
mark = pytest.mark.xfail(raises=AssertionError, reason=reason)
225-
request.node.add_marker(mark)
226-
227-
tm.assert_numpy_array_equal(result, expected)
245+
tm.assert_equal(result, expected)
228246

229247
@pytest.mark.parametrize(
230248
"other",

0 commit comments

Comments
 (0)