Skip to content

TST: fixturize in test_coercion #39471

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 29, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 65 additions & 71 deletions pandas/tests/indexing/test_coercion.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,6 @@ class CoercionBase:
def method(self):
raise NotImplementedError(self)

def _assert(self, left, right, dtype):
# explicitly check dtype to avoid any unexpected result
if isinstance(left, pd.Series):
tm.assert_series_equal(left, right)
elif isinstance(left, pd.Index):
tm.assert_index_equal(left, right)
else:
raise NotImplementedError
assert left.dtype == dtype
assert right.dtype == dtype


class TestSetitemCoercion(CoercionBase):

Expand All @@ -91,6 +80,7 @@ def _assert_setitem_series_conversion(
# check dtype explicitly for sure
assert temp.dtype == expected_dtype

# FIXME: dont leave commented-out
# .loc works different rule, temporary disable
# temp = original_series.copy()
# temp.loc[1] = loc_value
Expand Down Expand Up @@ -565,7 +555,8 @@ def _assert_where_conversion(
""" test coercion triggered by where """
target = original.copy()
res = target.where(cond, values)
self._assert(res, expected, expected_dtype)
tm.assert_equal(res, expected)
assert res.dtype == expected_dtype

@pytest.mark.parametrize(
"fill_val,exp_dtype",
Expand All @@ -588,7 +579,7 @@ def test_where_object(self, index_or_series, fill_val, exp_dtype):
if fill_val is True:
values = klass([True, False, True, True])
else:
values = klass(fill_val * x for x in [5, 6, 7, 8])
values = klass(x * fill_val for x in [5, 6, 7, 8])

exp = klass(["a", values[1], "c", values[3]])
self._assert_where_conversion(obj, cond, values, exp, exp_dtype)
Expand Down Expand Up @@ -647,38 +638,40 @@ def test_where_float64(self, index_or_series, fill_val, exp_dtype):
],
)
def test_where_series_complex128(self, fill_val, exp_dtype):
obj = pd.Series([1 + 1j, 2 + 2j, 3 + 3j, 4 + 4j])
klass = pd.Series
obj = klass([1 + 1j, 2 + 2j, 3 + 3j, 4 + 4j])
assert obj.dtype == np.complex128
cond = pd.Series([True, False, True, False])
cond = klass([True, False, True, False])

exp = pd.Series([1 + 1j, fill_val, 3 + 3j, fill_val])
exp = klass([1 + 1j, fill_val, 3 + 3j, fill_val])
self._assert_where_conversion(obj, cond, fill_val, exp, exp_dtype)

if fill_val is True:
values = pd.Series([True, False, True, True])
values = klass([True, False, True, True])
else:
values = pd.Series(x * fill_val for x in [5, 6, 7, 8])
exp = pd.Series([1 + 1j, values[1], 3 + 3j, values[3]])
values = klass(x * fill_val for x in [5, 6, 7, 8])
exp = klass([1 + 1j, values[1], 3 + 3j, values[3]])
self._assert_where_conversion(obj, cond, values, exp, exp_dtype)

@pytest.mark.parametrize(
"fill_val,exp_dtype",
[(1, object), (1.1, object), (1 + 1j, object), (True, np.bool_)],
)
def test_where_series_bool(self, fill_val, exp_dtype):
klass = pd.Series

obj = pd.Series([True, False, True, False])
obj = klass([True, False, True, False])
assert obj.dtype == np.bool_
cond = pd.Series([True, False, True, False])
cond = klass([True, False, True, False])

exp = pd.Series([True, fill_val, True, fill_val])
exp = klass([True, fill_val, True, fill_val])
self._assert_where_conversion(obj, cond, fill_val, exp, exp_dtype)

if fill_val is True:
values = pd.Series([True, False, True, True])
values = klass([True, False, True, True])
else:
values = pd.Series(x * fill_val for x in [5, 6, 7, 8])
exp = pd.Series([True, values[1], True, values[3]])
values = klass(x * fill_val for x in [5, 6, 7, 8])
exp = klass([True, values[1], True, values[3]])
self._assert_where_conversion(obj, cond, values, exp, exp_dtype)

@pytest.mark.parametrize(
Expand Down Expand Up @@ -871,7 +864,8 @@ def _assert_fillna_conversion(self, original, value, expected, expected_dtype):
""" test coercion triggered by fillna """
target = original.copy()
res = target.fillna(value)
self._assert(res, expected, expected_dtype)
tm.assert_equal(res, expected)
assert res.dtype == expected_dtype

@pytest.mark.parametrize(
"fill_val, fill_dtype",
Expand Down Expand Up @@ -1040,10 +1034,12 @@ class TestReplaceSeriesCoercion(CoercionBase):

rep["timedelta64[ns]"] = [pd.Timedelta("1 day"), pd.Timedelta("2 day")]

@pytest.mark.parametrize("how", ["dict", "series"])
@pytest.mark.parametrize(
"to_key",
[
@pytest.fixture(params=["dict", "series"])
def how(self, request):
return request.param

@pytest.fixture(
params=[
"object",
"int64",
"float64",
Expand All @@ -1053,34 +1049,52 @@ class TestReplaceSeriesCoercion(CoercionBase):
"datetime64[ns, UTC]",
"datetime64[ns, US/Eastern]",
"timedelta64[ns]",
],
ids=[
]
)
def from_key(self, request):
return request.param

@pytest.fixture(
params=[
"object",
"int64",
"float64",
"complex128",
"bool",
"datetime64",
"datetime64tz",
"datetime64tz",
"timedelta64",
"datetime64[ns]",
"datetime64[ns, UTC]",
"datetime64[ns, US/Eastern]",
"timedelta64[ns]",
],
)
@pytest.mark.parametrize(
"from_key",
[
ids=[
"object",
"int64",
"float64",
"complex128",
"bool",
"datetime64[ns]",
"datetime64[ns, UTC]",
"datetime64[ns, US/Eastern]",
"timedelta64[ns]",
"datetime64",
"datetime64tz",
"datetime64tz",
"timedelta64",
],
)
def test_replace_series(self, how, to_key, from_key):
def to_key(self, request):
return request.param

@pytest.fixture
def replacer(self, how, from_key, to_key):
"""
Object we will pass to `Series.replace`
"""
if how == "dict":
replacer = dict(zip(self.rep[from_key], self.rep[to_key]))
elif how == "series":
replacer = pd.Series(self.rep[to_key], index=self.rep[from_key])
else:
raise ValueError
return replacer

def test_replace_series(self, how, to_key, from_key, replacer):
index = pd.Index([3, 4], name="xxx")
obj = pd.Series(self.rep[from_key], index=index, name="yyy")
assert obj.dtype == from_key
Expand All @@ -1092,13 +1106,6 @@ def test_replace_series(self, how, to_key, from_key):
# tested below
return

if how == "dict":
replacer = dict(zip(self.rep[from_key], self.rep[to_key]))
elif how == "series":
replacer = pd.Series(self.rep[to_key], index=self.rep[from_key])
else:
raise ValueError

result = obj.replace(replacer)

if (from_key == "float64" and to_key in ("int64")) or (
Expand All @@ -1117,53 +1124,40 @@ def test_replace_series(self, how, to_key, from_key):

tm.assert_series_equal(result, exp)

@pytest.mark.parametrize("how", ["dict", "series"])
@pytest.mark.parametrize(
"to_key",
["timedelta64[ns]", "bool", "object", "complex128", "float64", "int64"],
indirect=True,
)
@pytest.mark.parametrize(
"from_key", ["datetime64[ns, UTC]", "datetime64[ns, US/Eastern]"]
"from_key", ["datetime64[ns, UTC]", "datetime64[ns, US/Eastern]"], indirect=True
)
def test_replace_series_datetime_tz(self, how, to_key, from_key):
def test_replace_series_datetime_tz(self, how, to_key, from_key, replacer):
index = pd.Index([3, 4], name="xyz")
obj = pd.Series(self.rep[from_key], index=index, name="yyy")
assert obj.dtype == from_key

if how == "dict":
replacer = dict(zip(self.rep[from_key], self.rep[to_key]))
elif how == "series":
replacer = pd.Series(self.rep[to_key], index=self.rep[from_key])
else:
raise ValueError

result = obj.replace(replacer)
exp = pd.Series(self.rep[to_key], index=index, name="yyy")
assert exp.dtype == to_key

tm.assert_series_equal(result, exp)

@pytest.mark.parametrize("how", ["dict", "series"])
@pytest.mark.parametrize(
"to_key",
["datetime64[ns]", "datetime64[ns, UTC]", "datetime64[ns, US/Eastern]"],
indirect=True,
)
@pytest.mark.parametrize(
"from_key",
["datetime64[ns]", "datetime64[ns, UTC]", "datetime64[ns, US/Eastern]"],
indirect=True,
)
def test_replace_series_datetime_datetime(self, how, to_key, from_key):
def test_replace_series_datetime_datetime(self, how, to_key, from_key, replacer):
index = pd.Index([3, 4], name="xyz")
obj = pd.Series(self.rep[from_key], index=index, name="yyy")
assert obj.dtype == from_key

if how == "dict":
replacer = dict(zip(self.rep[from_key], self.rep[to_key]))
elif how == "series":
replacer = pd.Series(self.rep[to_key], index=self.rep[from_key])
else:
raise ValueError

result = obj.replace(replacer)
exp = pd.Series(self.rep[to_key], index=index, name="yyy")
assert exp.dtype == to_key
Expand Down