Skip to content

Commit cc3099a

Browse files
authored
TST: fixturize in test_coercion (#39471)
1 parent 097ff0c commit cc3099a

File tree

1 file changed

+65
-71
lines changed

1 file changed

+65
-71
lines changed

pandas/tests/indexing/test_coercion.py

Lines changed: 65 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -65,17 +65,6 @@ class CoercionBase:
6565
def method(self):
6666
raise NotImplementedError(self)
6767

68-
def _assert(self, left, right, dtype):
69-
# explicitly check dtype to avoid any unexpected result
70-
if isinstance(left, pd.Series):
71-
tm.assert_series_equal(left, right)
72-
elif isinstance(left, pd.Index):
73-
tm.assert_index_equal(left, right)
74-
else:
75-
raise NotImplementedError
76-
assert left.dtype == dtype
77-
assert right.dtype == dtype
78-
7968

8069
class TestSetitemCoercion(CoercionBase):
8170

@@ -91,6 +80,7 @@ def _assert_setitem_series_conversion(
9180
# check dtype explicitly for sure
9281
assert temp.dtype == expected_dtype
9382

83+
# FIXME: dont leave commented-out
9484
# .loc works different rule, temporary disable
9585
# temp = original_series.copy()
9686
# temp.loc[1] = loc_value
@@ -565,7 +555,8 @@ def _assert_where_conversion(
565555
""" test coercion triggered by where """
566556
target = original.copy()
567557
res = target.where(cond, values)
568-
self._assert(res, expected, expected_dtype)
558+
tm.assert_equal(res, expected)
559+
assert res.dtype == expected_dtype
569560

570561
@pytest.mark.parametrize(
571562
"fill_val,exp_dtype",
@@ -588,7 +579,7 @@ def test_where_object(self, index_or_series, fill_val, exp_dtype):
588579
if fill_val is True:
589580
values = klass([True, False, True, True])
590581
else:
591-
values = klass(fill_val * x for x in [5, 6, 7, 8])
582+
values = klass(x * fill_val for x in [5, 6, 7, 8])
592583

593584
exp = klass(["a", values[1], "c", values[3]])
594585
self._assert_where_conversion(obj, cond, values, exp, exp_dtype)
@@ -647,38 +638,40 @@ def test_where_float64(self, index_or_series, fill_val, exp_dtype):
647638
],
648639
)
649640
def test_where_series_complex128(self, fill_val, exp_dtype):
650-
obj = pd.Series([1 + 1j, 2 + 2j, 3 + 3j, 4 + 4j])
641+
klass = pd.Series
642+
obj = klass([1 + 1j, 2 + 2j, 3 + 3j, 4 + 4j])
651643
assert obj.dtype == np.complex128
652-
cond = pd.Series([True, False, True, False])
644+
cond = klass([True, False, True, False])
653645

654-
exp = pd.Series([1 + 1j, fill_val, 3 + 3j, fill_val])
646+
exp = klass([1 + 1j, fill_val, 3 + 3j, fill_val])
655647
self._assert_where_conversion(obj, cond, fill_val, exp, exp_dtype)
656648

657649
if fill_val is True:
658-
values = pd.Series([True, False, True, True])
650+
values = klass([True, False, True, True])
659651
else:
660-
values = pd.Series(x * fill_val for x in [5, 6, 7, 8])
661-
exp = pd.Series([1 + 1j, values[1], 3 + 3j, values[3]])
652+
values = klass(x * fill_val for x in [5, 6, 7, 8])
653+
exp = klass([1 + 1j, values[1], 3 + 3j, values[3]])
662654
self._assert_where_conversion(obj, cond, values, exp, exp_dtype)
663655

664656
@pytest.mark.parametrize(
665657
"fill_val,exp_dtype",
666658
[(1, object), (1.1, object), (1 + 1j, object), (True, np.bool_)],
667659
)
668660
def test_where_series_bool(self, fill_val, exp_dtype):
661+
klass = pd.Series
669662

670-
obj = pd.Series([True, False, True, False])
663+
obj = klass([True, False, True, False])
671664
assert obj.dtype == np.bool_
672-
cond = pd.Series([True, False, True, False])
665+
cond = klass([True, False, True, False])
673666

674-
exp = pd.Series([True, fill_val, True, fill_val])
667+
exp = klass([True, fill_val, True, fill_val])
675668
self._assert_where_conversion(obj, cond, fill_val, exp, exp_dtype)
676669

677670
if fill_val is True:
678-
values = pd.Series([True, False, True, True])
671+
values = klass([True, False, True, True])
679672
else:
680-
values = pd.Series(x * fill_val for x in [5, 6, 7, 8])
681-
exp = pd.Series([True, values[1], True, values[3]])
673+
values = klass(x * fill_val for x in [5, 6, 7, 8])
674+
exp = klass([True, values[1], True, values[3]])
682675
self._assert_where_conversion(obj, cond, values, exp, exp_dtype)
683676

684677
@pytest.mark.parametrize(
@@ -871,7 +864,8 @@ def _assert_fillna_conversion(self, original, value, expected, expected_dtype):
871864
""" test coercion triggered by fillna """
872865
target = original.copy()
873866
res = target.fillna(value)
874-
self._assert(res, expected, expected_dtype)
867+
tm.assert_equal(res, expected)
868+
assert res.dtype == expected_dtype
875869

876870
@pytest.mark.parametrize(
877871
"fill_val, fill_dtype",
@@ -1040,10 +1034,12 @@ class TestReplaceSeriesCoercion(CoercionBase):
10401034

10411035
rep["timedelta64[ns]"] = [pd.Timedelta("1 day"), pd.Timedelta("2 day")]
10421036

1043-
@pytest.mark.parametrize("how", ["dict", "series"])
1044-
@pytest.mark.parametrize(
1045-
"to_key",
1046-
[
1037+
@pytest.fixture(params=["dict", "series"])
1038+
def how(self, request):
1039+
return request.param
1040+
1041+
@pytest.fixture(
1042+
params=[
10471043
"object",
10481044
"int64",
10491045
"float64",
@@ -1053,34 +1049,52 @@ class TestReplaceSeriesCoercion(CoercionBase):
10531049
"datetime64[ns, UTC]",
10541050
"datetime64[ns, US/Eastern]",
10551051
"timedelta64[ns]",
1056-
],
1057-
ids=[
1052+
]
1053+
)
1054+
def from_key(self, request):
1055+
return request.param
1056+
1057+
@pytest.fixture(
1058+
params=[
10581059
"object",
10591060
"int64",
10601061
"float64",
10611062
"complex128",
10621063
"bool",
1063-
"datetime64",
1064-
"datetime64tz",
1065-
"datetime64tz",
1066-
"timedelta64",
1064+
"datetime64[ns]",
1065+
"datetime64[ns, UTC]",
1066+
"datetime64[ns, US/Eastern]",
1067+
"timedelta64[ns]",
10671068
],
1068-
)
1069-
@pytest.mark.parametrize(
1070-
"from_key",
1071-
[
1069+
ids=[
10721070
"object",
10731071
"int64",
10741072
"float64",
10751073
"complex128",
10761074
"bool",
1077-
"datetime64[ns]",
1078-
"datetime64[ns, UTC]",
1079-
"datetime64[ns, US/Eastern]",
1080-
"timedelta64[ns]",
1075+
"datetime64",
1076+
"datetime64tz",
1077+
"datetime64tz",
1078+
"timedelta64",
10811079
],
10821080
)
1083-
def test_replace_series(self, how, to_key, from_key):
1081+
def to_key(self, request):
1082+
return request.param
1083+
1084+
@pytest.fixture
1085+
def replacer(self, how, from_key, to_key):
1086+
"""
1087+
Object we will pass to `Series.replace`
1088+
"""
1089+
if how == "dict":
1090+
replacer = dict(zip(self.rep[from_key], self.rep[to_key]))
1091+
elif how == "series":
1092+
replacer = pd.Series(self.rep[to_key], index=self.rep[from_key])
1093+
else:
1094+
raise ValueError
1095+
return replacer
1096+
1097+
def test_replace_series(self, how, to_key, from_key, replacer):
10841098
index = pd.Index([3, 4], name="xxx")
10851099
obj = pd.Series(self.rep[from_key], index=index, name="yyy")
10861100
assert obj.dtype == from_key
@@ -1092,13 +1106,6 @@ def test_replace_series(self, how, to_key, from_key):
10921106
# tested below
10931107
return
10941108

1095-
if how == "dict":
1096-
replacer = dict(zip(self.rep[from_key], self.rep[to_key]))
1097-
elif how == "series":
1098-
replacer = pd.Series(self.rep[to_key], index=self.rep[from_key])
1099-
else:
1100-
raise ValueError
1101-
11021109
result = obj.replace(replacer)
11031110

11041111
if (from_key == "float64" and to_key in ("int64")) or (
@@ -1117,53 +1124,40 @@ def test_replace_series(self, how, to_key, from_key):
11171124

11181125
tm.assert_series_equal(result, exp)
11191126

1120-
@pytest.mark.parametrize("how", ["dict", "series"])
11211127
@pytest.mark.parametrize(
11221128
"to_key",
11231129
["timedelta64[ns]", "bool", "object", "complex128", "float64", "int64"],
1130+
indirect=True,
11241131
)
11251132
@pytest.mark.parametrize(
1126-
"from_key", ["datetime64[ns, UTC]", "datetime64[ns, US/Eastern]"]
1133+
"from_key", ["datetime64[ns, UTC]", "datetime64[ns, US/Eastern]"], indirect=True
11271134
)
1128-
def test_replace_series_datetime_tz(self, how, to_key, from_key):
1135+
def test_replace_series_datetime_tz(self, how, to_key, from_key, replacer):
11291136
index = pd.Index([3, 4], name="xyz")
11301137
obj = pd.Series(self.rep[from_key], index=index, name="yyy")
11311138
assert obj.dtype == from_key
11321139

1133-
if how == "dict":
1134-
replacer = dict(zip(self.rep[from_key], self.rep[to_key]))
1135-
elif how == "series":
1136-
replacer = pd.Series(self.rep[to_key], index=self.rep[from_key])
1137-
else:
1138-
raise ValueError
1139-
11401140
result = obj.replace(replacer)
11411141
exp = pd.Series(self.rep[to_key], index=index, name="yyy")
11421142
assert exp.dtype == to_key
11431143

11441144
tm.assert_series_equal(result, exp)
11451145

1146-
@pytest.mark.parametrize("how", ["dict", "series"])
11471146
@pytest.mark.parametrize(
11481147
"to_key",
11491148
["datetime64[ns]", "datetime64[ns, UTC]", "datetime64[ns, US/Eastern]"],
1149+
indirect=True,
11501150
)
11511151
@pytest.mark.parametrize(
11521152
"from_key",
11531153
["datetime64[ns]", "datetime64[ns, UTC]", "datetime64[ns, US/Eastern]"],
1154+
indirect=True,
11541155
)
1155-
def test_replace_series_datetime_datetime(self, how, to_key, from_key):
1156+
def test_replace_series_datetime_datetime(self, how, to_key, from_key, replacer):
11561157
index = pd.Index([3, 4], name="xyz")
11571158
obj = pd.Series(self.rep[from_key], index=index, name="yyy")
11581159
assert obj.dtype == from_key
11591160

1160-
if how == "dict":
1161-
replacer = dict(zip(self.rep[from_key], self.rep[to_key]))
1162-
elif how == "series":
1163-
replacer = pd.Series(self.rep[to_key], index=self.rep[from_key])
1164-
else:
1165-
raise ValueError
1166-
11671161
result = obj.replace(replacer)
11681162
exp = pd.Series(self.rep[to_key], index=index, name="yyy")
11691163
assert exp.dtype == to_key

0 commit comments

Comments
 (0)