Skip to content

Commit 6db33da

Browse files
authored
TST: port Dim2CompatTests (#39880)
1 parent 316f5ac commit 6db33da

File tree

6 files changed

+231
-1
lines changed

6 files changed

+231
-1
lines changed

pandas/core/ops/mask_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,6 @@ def kleene_and(
179179
return result, mask
180180

181181

182-
def raise_for_nan(value, method):
182+
def raise_for_nan(value, method: str):
183183
if lib.is_float(value) and np.isnan(value):
184184
raise ValueError(f"Cannot perform logical '{method}' with floating NaN")

pandas/tests/extension/base/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class TestMyDtype(BaseDtypeTests):
4343
"""
4444
from pandas.tests.extension.base.casting import BaseCastingTests # noqa
4545
from pandas.tests.extension.base.constructors import BaseConstructorsTests # noqa
46+
from pandas.tests.extension.base.dim2 import Dim2CompatTests # noqa
4647
from pandas.tests.extension.base.dtype import BaseDtypeTests # noqa
4748
from pandas.tests.extension.base.getitem import BaseGetitemTests # noqa
4849
from pandas.tests.extension.base.groupby import BaseGroupbyTests # noqa

pandas/tests/extension/base/dim2.py

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
"""
2+
Tests for 2D compatibility.
3+
"""
4+
import numpy as np
5+
import pytest
6+
7+
from pandas.compat import np_version_under1p17
8+
9+
import pandas as pd
10+
from pandas.core.arrays import (
11+
FloatingArray,
12+
IntegerArray,
13+
)
14+
from pandas.tests.extension.base.base import BaseExtensionTests
15+
16+
17+
def maybe_xfail_masked_reductions(arr, request):
18+
if (
19+
isinstance(arr, (FloatingArray, IntegerArray))
20+
and np_version_under1p17
21+
and arr.ndim == 2
22+
):
23+
mark = pytest.mark.xfail(reason="masked_reductions does not implement")
24+
request.node.add_marker(mark)
25+
26+
27+
class Dim2CompatTests(BaseExtensionTests):
28+
def test_take_2d(self, data):
29+
arr2d = data.reshape(-1, 1)
30+
31+
result = arr2d.take([0, 0, -1], axis=0)
32+
33+
expected = data.take([0, 0, -1]).reshape(-1, 1)
34+
self.assert_extension_array_equal(result, expected)
35+
36+
def test_repr_2d(self, data):
37+
# this could fail in a corner case where an element contained the name
38+
res = repr(data.reshape(1, -1))
39+
assert res.count(f"<{type(data).__name__}") == 1
40+
41+
res = repr(data.reshape(-1, 1))
42+
assert res.count(f"<{type(data).__name__}") == 1
43+
44+
def test_reshape(self, data):
45+
arr2d = data.reshape(-1, 1)
46+
assert arr2d.shape == (data.size, 1)
47+
assert len(arr2d) == len(data)
48+
49+
arr2d = data.reshape((-1, 1))
50+
assert arr2d.shape == (data.size, 1)
51+
assert len(arr2d) == len(data)
52+
53+
with pytest.raises(ValueError):
54+
data.reshape((data.size, 2))
55+
with pytest.raises(ValueError):
56+
data.reshape(data.size, 2)
57+
58+
def test_getitem_2d(self, data):
59+
arr2d = data.reshape(1, -1)
60+
61+
result = arr2d[0]
62+
self.assert_extension_array_equal(result, data)
63+
64+
with pytest.raises(IndexError):
65+
arr2d[1]
66+
67+
with pytest.raises(IndexError):
68+
arr2d[-2]
69+
70+
result = arr2d[:]
71+
self.assert_extension_array_equal(result, arr2d)
72+
73+
result = arr2d[:, :]
74+
self.assert_extension_array_equal(result, arr2d)
75+
76+
result = arr2d[:, 0]
77+
expected = data[[0]]
78+
self.assert_extension_array_equal(result, expected)
79+
80+
# dimension-expanding getitem on 1D
81+
result = data[:, np.newaxis]
82+
self.assert_extension_array_equal(result, arr2d.T)
83+
84+
def test_iter_2d(self, data):
85+
arr2d = data.reshape(1, -1)
86+
87+
objs = list(iter(arr2d))
88+
assert len(objs) == arr2d.shape[0]
89+
90+
for obj in objs:
91+
assert isinstance(obj, type(data))
92+
assert obj.dtype == data.dtype
93+
assert obj.ndim == 1
94+
assert len(obj) == arr2d.shape[1]
95+
96+
def test_concat_2d(self, data):
97+
left = data.reshape(-1, 1)
98+
right = left.copy()
99+
100+
# axis=0
101+
result = left._concat_same_type([left, right], axis=0)
102+
expected = data._concat_same_type([data, data]).reshape(-1, 1)
103+
self.assert_extension_array_equal(result, expected)
104+
105+
# axis=1
106+
result = left._concat_same_type([left, right], axis=1)
107+
expected = data.repeat(2).reshape(-1, 2)
108+
self.assert_extension_array_equal(result, expected)
109+
110+
# axis > 1 -> invalid
111+
with pytest.raises(ValueError):
112+
left._concat_same_type([left, right], axis=2)
113+
114+
@pytest.mark.parametrize("method", ["mean", "median", "var", "std", "sum", "prod"])
115+
def test_reductions_2d_axis_none(self, data, method, request):
116+
if not hasattr(data, method):
117+
pytest.skip("test is not applicable for this type/dtype")
118+
119+
arr2d = data.reshape(1, -1)
120+
maybe_xfail_masked_reductions(arr2d, request)
121+
122+
err_expected = None
123+
err_result = None
124+
try:
125+
expected = getattr(data, method)()
126+
except Exception as err:
127+
# if the 1D reduction is invalid, the 2D reduction should be as well
128+
err_expected = err
129+
try:
130+
result = getattr(arr2d, method)(axis=None)
131+
except Exception as err2:
132+
err_result = err2
133+
134+
else:
135+
result = getattr(arr2d, method)(axis=None)
136+
137+
if err_result is not None or err_expected is not None:
138+
assert type(err_result) == type(err_expected)
139+
return
140+
141+
assert result == expected # TODO: or matching NA
142+
143+
@pytest.mark.parametrize("method", ["mean", "median", "var", "std", "sum", "prod"])
144+
def test_reductions_2d_axis0(self, data, method, request):
145+
if not hasattr(data, method):
146+
pytest.skip("test is not applicable for this type/dtype")
147+
148+
arr2d = data.reshape(1, -1)
149+
maybe_xfail_masked_reductions(arr2d, request)
150+
151+
kwargs = {}
152+
if method == "std":
153+
# pass ddof=0 so we get all-zero std instead of all-NA std
154+
kwargs["ddof"] = 0
155+
156+
try:
157+
result = getattr(arr2d, method)(axis=0, **kwargs)
158+
except Exception as err:
159+
try:
160+
getattr(data, method)()
161+
except Exception as err2:
162+
assert type(err) == type(err2)
163+
return
164+
else:
165+
raise AssertionError("Both reductions should raise or neither")
166+
167+
if method in ["mean", "median", "sum", "prod"]:
168+
# std and var are not dtype-preserving
169+
expected = data
170+
if method in ["sum", "prod"] and data.dtype.kind in ["i", "u"]:
171+
# FIXME: kludge
172+
if data.dtype.kind == "i":
173+
dtype = pd.Int64Dtype
174+
else:
175+
dtype = pd.UInt64Dtype
176+
177+
expected = data.astype(dtype)
178+
if type(expected) != type(data):
179+
mark = pytest.mark.xfail(
180+
reason="IntegerArray.astype is broken GH#38983"
181+
)
182+
request.node.add_marker(mark)
183+
assert type(expected) == type(data), type(expected)
184+
assert dtype == expected.dtype
185+
186+
self.assert_extension_array_equal(result, expected)
187+
elif method == "std":
188+
self.assert_extension_array_equal(result, data - data)
189+
# punt on method == "var"
190+
191+
@pytest.mark.parametrize("method", ["mean", "median", "var", "std", "sum", "prod"])
192+
def test_reductions_2d_axis1(self, data, method, request):
193+
if not hasattr(data, method):
194+
pytest.skip("test is not applicable for this type/dtype")
195+
196+
arr2d = data.reshape(1, -1)
197+
maybe_xfail_masked_reductions(arr2d, request)
198+
199+
try:
200+
result = getattr(arr2d, method)(axis=1)
201+
except Exception as err:
202+
try:
203+
getattr(data, method)()
204+
except Exception as err2:
205+
assert type(err) == type(err2)
206+
return
207+
else:
208+
raise AssertionError("Both reductions should raise or neither")
209+
210+
# not necesarrily type/dtype-preserving, so weaker assertions
211+
assert result.shape == (1,)
212+
expected_scalar = getattr(data, method)()
213+
if pd.isna(result[0]):
214+
# TODO: require matching NA
215+
assert pd.isna(expected_scalar), expected_scalar
216+
else:
217+
assert result[0] == expected_scalar

pandas/tests/extension/test_datetime.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,3 +235,7 @@ class TestGroupby(BaseDatetimeTests, base.BaseGroupbyTests):
235235

236236
class TestPrinting(BaseDatetimeTests, base.BasePrintingTests):
237237
pass
238+
239+
240+
class Test2DCompat(BaseDatetimeTests, base.Dim2CompatTests):
241+
pass

pandas/tests/extension/test_numpy.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,3 +415,7 @@ def test_setitem_loc_iloc_slice(self, data):
415415
@skip_nested
416416
class TestParsing(BaseNumPyTests, base.BaseParsingTests):
417417
pass
418+
419+
420+
class Test2DCompat(BaseNumPyTests, base.Dim2CompatTests):
421+
pass

pandas/tests/extension/test_period.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,3 +184,7 @@ class TestParsing(BasePeriodTests, base.BaseParsingTests):
184184
@pytest.mark.parametrize("engine", ["c", "python"])
185185
def test_EA_types(self, engine, data):
186186
super().test_EA_types(engine, data)
187+
188+
189+
class Test2DCompat(BasePeriodTests, base.Dim2CompatTests):
190+
pass

0 commit comments

Comments
 (0)