|
20 | 20 | timedelta_range,
|
21 | 21 | )
|
22 | 22 | import pandas._testing as tm
|
23 |
| -from pandas.core.arrays import IntervalArray |
| 23 | +from pandas.core.arrays import ( |
| 24 | + BooleanArray, |
| 25 | + IntervalArray, |
| 26 | +) |
| 27 | +from pandas.tests.arithmetic.common import get_upcast_box |
24 | 28 |
|
25 | 29 |
|
26 | 30 | @pytest.fixture(
|
@@ -129,18 +133,37 @@ def test_compare_scalar_interval_mixed_closed(self, op, closed, other_closed):
|
129 | 133 | expected = self.elementwise_comparison(op, interval_array, other)
|
130 | 134 | tm.assert_numpy_array_equal(result, expected)
|
131 | 135 |
|
132 |
| - def test_compare_scalar_na(self, op, interval_array, nulls_fixture, request): |
133 |
| - result = op(interval_array, nulls_fixture) |
134 |
| - expected = self.elementwise_comparison(op, interval_array, nulls_fixture) |
| 136 | + def test_compare_scalar_na( |
| 137 | + self, op, interval_array, nulls_fixture, box_with_array, request |
| 138 | + ): |
| 139 | + box = box_with_array |
| 140 | + |
| 141 | + if box is pd.DataFrame: |
| 142 | + if interval_array.dtype.subtype.kind not in "iuf": |
| 143 | + mark = pytest.mark.xfail( |
| 144 | + reason="raises on DataFrame.transpose (would be fixed by EA2D)" |
| 145 | + ) |
| 146 | + request.node.add_marker(mark) |
| 147 | + |
| 148 | + obj = tm.box_expected(interval_array, box) |
| 149 | + result = op(obj, nulls_fixture) |
| 150 | + |
| 151 | + if nulls_fixture is pd.NA: |
| 152 | + # GH#31882 |
| 153 | + exp = np.ones(interval_array.shape, dtype=bool) |
| 154 | + expected = BooleanArray(exp, exp) |
| 155 | + else: |
| 156 | + expected = self.elementwise_comparison(op, interval_array, nulls_fixture) |
| 157 | + |
| 158 | + if not (box is Index and nulls_fixture is pd.NA): |
| 159 | + # don't cast expected from BooleanArray to ndarray[object] |
| 160 | + xbox = get_upcast_box(obj, nulls_fixture, True) |
| 161 | + expected = tm.box_expected(expected, xbox) |
135 | 162 |
|
136 |
| - if nulls_fixture is pd.NA and interval_array.dtype.subtype != "int64": |
137 |
| - mark = pytest.mark.xfail( |
138 |
| - raises=AssertionError, |
139 |
| - reason="broken for non-integer IntervalArray; see GH 31882", |
140 |
| - ) |
141 |
| - request.node.add_marker(mark) |
| 163 | + tm.assert_equal(result, expected) |
142 | 164 |
|
143 |
| - tm.assert_numpy_array_equal(result, expected) |
| 165 | + rev = op(nulls_fixture, obj) |
| 166 | + tm.assert_equal(rev, expected) |
144 | 167 |
|
145 | 168 | @pytest.mark.parametrize(
|
146 | 169 | "other",
|
@@ -214,17 +237,12 @@ def test_compare_list_like_object(self, op, interval_array, other):
|
214 | 237 | expected = self.elementwise_comparison(op, interval_array, other)
|
215 | 238 | tm.assert_numpy_array_equal(result, expected)
|
216 | 239 |
|
217 |
| - def test_compare_list_like_nan(self, op, interval_array, nulls_fixture, request): |
| 240 | + def test_compare_list_like_nan(self, op, interval_array, nulls_fixture): |
218 | 241 | other = [nulls_fixture] * 4
|
219 | 242 | result = op(interval_array, other)
|
220 | 243 | expected = self.elementwise_comparison(op, interval_array, other)
|
221 | 244 |
|
222 |
| - if nulls_fixture is pd.NA and interval_array.dtype.subtype != "i8": |
223 |
| - reason = "broken for non-integer IntervalArray; see GH 31882" |
224 |
| - mark = pytest.mark.xfail(raises=AssertionError, reason=reason) |
225 |
| - request.node.add_marker(mark) |
226 |
| - |
227 |
| - tm.assert_numpy_array_equal(result, expected) |
| 245 | + tm.assert_equal(result, expected) |
228 | 246 |
|
229 | 247 | @pytest.mark.parametrize(
|
230 | 248 | "other",
|
|
0 commit comments