diff --git a/pandas/tests/indexing/test_coercion.py b/pandas/tests/indexing/test_coercion.py index b342942cd9f5f..6c8b1622e76aa 100644 --- a/pandas/tests/indexing/test_coercion.py +++ b/pandas/tests/indexing/test_coercion.py @@ -65,17 +65,6 @@ class CoercionBase: def method(self): raise NotImplementedError(self) - def _assert(self, left, right, dtype): - # explicitly check dtype to avoid any unexpected result - if isinstance(left, pd.Series): - tm.assert_series_equal(left, right) - elif isinstance(left, pd.Index): - tm.assert_index_equal(left, right) - else: - raise NotImplementedError - assert left.dtype == dtype - assert right.dtype == dtype - class TestSetitemCoercion(CoercionBase): @@ -91,6 +80,7 @@ def _assert_setitem_series_conversion( # check dtype explicitly for sure assert temp.dtype == expected_dtype + # FIXME: dont leave commented-out # .loc works different rule, temporary disable # temp = original_series.copy() # temp.loc[1] = loc_value @@ -565,7 +555,8 @@ def _assert_where_conversion( """ test coercion triggered by where """ target = original.copy() res = target.where(cond, values) - self._assert(res, expected, expected_dtype) + tm.assert_equal(res, expected) + assert res.dtype == expected_dtype @pytest.mark.parametrize( "fill_val,exp_dtype", @@ -588,7 +579,7 @@ def test_where_object(self, index_or_series, fill_val, exp_dtype): if fill_val is True: values = klass([True, False, True, True]) else: - values = klass(fill_val * x for x in [5, 6, 7, 8]) + values = klass(x * fill_val for x in [5, 6, 7, 8]) exp = klass(["a", values[1], "c", values[3]]) self._assert_where_conversion(obj, cond, values, exp, exp_dtype) @@ -647,18 +638,19 @@ def test_where_float64(self, index_or_series, fill_val, exp_dtype): ], ) def test_where_series_complex128(self, fill_val, exp_dtype): - obj = pd.Series([1 + 1j, 2 + 2j, 3 + 3j, 4 + 4j]) + klass = pd.Series + obj = klass([1 + 1j, 2 + 2j, 3 + 3j, 4 + 4j]) assert obj.dtype == np.complex128 - cond = pd.Series([True, False, True, False]) + cond = klass([True, False, True, False]) - exp = pd.Series([1 + 1j, fill_val, 3 + 3j, fill_val]) + exp = klass([1 + 1j, fill_val, 3 + 3j, fill_val]) self._assert_where_conversion(obj, cond, fill_val, exp, exp_dtype) if fill_val is True: - values = pd.Series([True, False, True, True]) + values = klass([True, False, True, True]) else: - values = pd.Series(x * fill_val for x in [5, 6, 7, 8]) - exp = pd.Series([1 + 1j, values[1], 3 + 3j, values[3]]) + values = klass(x * fill_val for x in [5, 6, 7, 8]) + exp = klass([1 + 1j, values[1], 3 + 3j, values[3]]) self._assert_where_conversion(obj, cond, values, exp, exp_dtype) @pytest.mark.parametrize( @@ -666,19 +658,20 @@ def test_where_series_complex128(self, fill_val, exp_dtype): [(1, object), (1.1, object), (1 + 1j, object), (True, np.bool_)], ) def test_where_series_bool(self, fill_val, exp_dtype): + klass = pd.Series - obj = pd.Series([True, False, True, False]) + obj = klass([True, False, True, False]) assert obj.dtype == np.bool_ - cond = pd.Series([True, False, True, False]) + cond = klass([True, False, True, False]) - exp = pd.Series([True, fill_val, True, fill_val]) + exp = klass([True, fill_val, True, fill_val]) self._assert_where_conversion(obj, cond, fill_val, exp, exp_dtype) if fill_val is True: - values = pd.Series([True, False, True, True]) + values = klass([True, False, True, True]) else: - values = pd.Series(x * fill_val for x in [5, 6, 7, 8]) - exp = pd.Series([True, values[1], True, values[3]]) + values = klass(x * fill_val for x in [5, 6, 7, 8]) + exp = klass([True, values[1], True, values[3]]) self._assert_where_conversion(obj, cond, values, exp, exp_dtype) @pytest.mark.parametrize( @@ -871,7 +864,8 @@ def _assert_fillna_conversion(self, original, value, expected, expected_dtype): """ test coercion triggered by fillna """ target = original.copy() res = target.fillna(value) - self._assert(res, expected, expected_dtype) + tm.assert_equal(res, expected) + assert res.dtype == expected_dtype @pytest.mark.parametrize( "fill_val, fill_dtype", @@ -1040,10 +1034,12 @@ class TestReplaceSeriesCoercion(CoercionBase): rep["timedelta64[ns]"] = [pd.Timedelta("1 day"), pd.Timedelta("2 day")] - @pytest.mark.parametrize("how", ["dict", "series"]) - @pytest.mark.parametrize( - "to_key", - [ + @pytest.fixture(params=["dict", "series"]) + def how(self, request): + return request.param + + @pytest.fixture( + params=[ "object", "int64", "float64", @@ -1053,34 +1049,52 @@ class TestReplaceSeriesCoercion(CoercionBase): "datetime64[ns, UTC]", "datetime64[ns, US/Eastern]", "timedelta64[ns]", - ], - ids=[ + ] + ) + def from_key(self, request): + return request.param + + @pytest.fixture( + params=[ "object", "int64", "float64", "complex128", "bool", - "datetime64", - "datetime64tz", - "datetime64tz", - "timedelta64", + "datetime64[ns]", + "datetime64[ns, UTC]", + "datetime64[ns, US/Eastern]", + "timedelta64[ns]", ], - ) - @pytest.mark.parametrize( - "from_key", - [ + ids=[ "object", "int64", "float64", "complex128", "bool", - "datetime64[ns]", - "datetime64[ns, UTC]", - "datetime64[ns, US/Eastern]", - "timedelta64[ns]", + "datetime64", + "datetime64tz", + "datetime64tz", + "timedelta64", ], ) - def test_replace_series(self, how, to_key, from_key): + def to_key(self, request): + return request.param + + @pytest.fixture + def replacer(self, how, from_key, to_key): + """ + Object we will pass to `Series.replace` + """ + if how == "dict": + replacer = dict(zip(self.rep[from_key], self.rep[to_key])) + elif how == "series": + replacer = pd.Series(self.rep[to_key], index=self.rep[from_key]) + else: + raise ValueError + return replacer + + def test_replace_series(self, how, to_key, from_key, replacer): index = pd.Index([3, 4], name="xxx") obj = pd.Series(self.rep[from_key], index=index, name="yyy") assert obj.dtype == from_key @@ -1092,13 +1106,6 @@ def test_replace_series(self, how, to_key, from_key): # tested below return - if how == "dict": - replacer = dict(zip(self.rep[from_key], self.rep[to_key])) - elif how == "series": - replacer = pd.Series(self.rep[to_key], index=self.rep[from_key]) - else: - raise ValueError - result = obj.replace(replacer) if (from_key == "float64" and to_key in ("int64")) or ( @@ -1117,53 +1124,40 @@ def test_replace_series(self, how, to_key, from_key): tm.assert_series_equal(result, exp) - @pytest.mark.parametrize("how", ["dict", "series"]) @pytest.mark.parametrize( "to_key", ["timedelta64[ns]", "bool", "object", "complex128", "float64", "int64"], + indirect=True, ) @pytest.mark.parametrize( - "from_key", ["datetime64[ns, UTC]", "datetime64[ns, US/Eastern]"] + "from_key", ["datetime64[ns, UTC]", "datetime64[ns, US/Eastern]"], indirect=True ) - def test_replace_series_datetime_tz(self, how, to_key, from_key): + def test_replace_series_datetime_tz(self, how, to_key, from_key, replacer): index = pd.Index([3, 4], name="xyz") obj = pd.Series(self.rep[from_key], index=index, name="yyy") assert obj.dtype == from_key - if how == "dict": - replacer = dict(zip(self.rep[from_key], self.rep[to_key])) - elif how == "series": - replacer = pd.Series(self.rep[to_key], index=self.rep[from_key]) - else: - raise ValueError - result = obj.replace(replacer) exp = pd.Series(self.rep[to_key], index=index, name="yyy") assert exp.dtype == to_key tm.assert_series_equal(result, exp) - @pytest.mark.parametrize("how", ["dict", "series"]) @pytest.mark.parametrize( "to_key", ["datetime64[ns]", "datetime64[ns, UTC]", "datetime64[ns, US/Eastern]"], + indirect=True, ) @pytest.mark.parametrize( "from_key", ["datetime64[ns]", "datetime64[ns, UTC]", "datetime64[ns, US/Eastern]"], + indirect=True, ) - def test_replace_series_datetime_datetime(self, how, to_key, from_key): + def test_replace_series_datetime_datetime(self, how, to_key, from_key, replacer): index = pd.Index([3, 4], name="xyz") obj = pd.Series(self.rep[from_key], index=index, name="yyy") assert obj.dtype == from_key - if how == "dict": - replacer = dict(zip(self.rep[from_key], self.rep[to_key])) - elif how == "series": - replacer = pd.Series(self.rep[to_key], index=self.rep[from_key]) - else: - raise ValueError - result = obj.replace(replacer) exp = pd.Series(self.rep[to_key], index=index, name="yyy") assert exp.dtype == to_key