diff --git a/doc/source/whatsnew/v1.3.0.rst b/doc/source/whatsnew/v1.3.0.rst index ea9017da8a2f9..4655968eb07b5 100644 --- a/doc/source/whatsnew/v1.3.0.rst +++ b/doc/source/whatsnew/v1.3.0.rst @@ -233,6 +233,7 @@ Other enhancements - Add keyword ``sort`` to :func:`pivot_table` to allow non-sorting of the result (:issue:`39143`) - Add keyword ``dropna`` to :meth:`DataFrame.value_counts` to allow counting rows that include ``NA`` values (:issue:`41325`) - :meth:`Series.replace` will now cast results to ``PeriodDtype`` where possible instead of ``object`` dtype (:issue:`41526`) +- Improved error message in ``corr` and ``cov`` methods on :class:`.Rolling`, :class:`.Expanding`, and :class:`.ExponentialMovingWindow` when ``other`` is not a :class:`DataFrame` or :class:`Series` (:issue:`41741`) .. --------------------------------------------------------------------------- diff --git a/pandas/core/window/common.py b/pandas/core/window/common.py index d85aa20de5ab4..e0720c5d86df1 100644 --- a/pandas/core/window/common.py +++ b/pandas/core/window/common.py @@ -1,7 +1,6 @@ """Common utility functions for rolling operations""" from collections import defaultdict from typing import cast -import warnings import numpy as np @@ -15,17 +14,7 @@ def flex_binary_moment(arg1, arg2, f, pairwise=False): - if not ( - isinstance(arg1, (np.ndarray, ABCSeries, ABCDataFrame)) - and isinstance(arg2, (np.ndarray, ABCSeries, ABCDataFrame)) - ): - raise TypeError( - "arguments to moment function must be of type np.ndarray/Series/DataFrame" - ) - - if isinstance(arg1, (np.ndarray, ABCSeries)) and isinstance( - arg2, (np.ndarray, ABCSeries) - ): + if isinstance(arg1, ABCSeries) and isinstance(arg2, ABCSeries): X, Y = prep_binary(arg1, arg2) return f(X, Y) @@ -43,7 +32,7 @@ def dataframe_from_int_dict(data, frame_template): if pairwise is False: if arg1 is arg2: # special case in order to handle duplicate column names - for i, col in enumerate(arg1.columns): + for i in range(len(arg1.columns)): results[i] = f(arg1.iloc[:, i], arg2.iloc[:, i]) return dataframe_from_int_dict(results, arg1) else: @@ -51,23 +40,17 @@ def dataframe_from_int_dict(data, frame_template): raise ValueError("'arg1' columns are not unique") if not arg2.columns.is_unique: raise ValueError("'arg2' columns are not unique") - with warnings.catch_warnings(record=True): - warnings.simplefilter("ignore", RuntimeWarning) - X, Y = arg1.align(arg2, join="outer") - X = X + 0 * Y - Y = Y + 0 * X - - with warnings.catch_warnings(record=True): - warnings.simplefilter("ignore", RuntimeWarning) - res_columns = arg1.columns.union(arg2.columns) + X, Y = arg1.align(arg2, join="outer") + X, Y = prep_binary(X, Y) + res_columns = arg1.columns.union(arg2.columns) for col in res_columns: if col in X and col in Y: results[col] = f(X[col], Y[col]) return DataFrame(results, index=X.index, columns=res_columns) elif pairwise is True: results = defaultdict(dict) - for i, k1 in enumerate(arg1.columns): - for j, k2 in enumerate(arg2.columns): + for i in range(len(arg1.columns)): + for j in range(len(arg2.columns)): if j < i and arg2 is arg1: # Symmetric case results[i][j] = results[j][i] @@ -85,10 +68,10 @@ def dataframe_from_int_dict(data, frame_template): result = concat( [ concat( - [results[i][j] for j, c in enumerate(arg2.columns)], + [results[i][j] for j in range(len(arg2.columns))], ignore_index=True, ) - for i, c in enumerate(arg1.columns) + for i in range(len(arg1.columns)) ], ignore_index=True, axis=1, @@ -135,13 +118,10 @@ def dataframe_from_int_dict(data, frame_template): ) return result - - else: - raise ValueError("'pairwise' is not True/False") else: results = { i: f(*prep_binary(arg1.iloc[:, i], arg2)) - for i, col in enumerate(arg1.columns) + for i in range(len(arg1.columns)) } return dataframe_from_int_dict(results, arg1) @@ -165,11 +145,7 @@ def zsqrt(x): def prep_binary(arg1, arg2): - if not isinstance(arg2, type(arg1)): - raise Exception("Input arrays must be of the same type!") - # mask out values, this also makes a common index... X = arg1 + 0 * arg2 Y = arg2 + 0 * arg1 - return X, Y diff --git a/pandas/core/window/rolling.py b/pandas/core/window/rolling.py index dfb74b38cd9cf..2d5f148a6437a 100644 --- a/pandas/core/window/rolling.py +++ b/pandas/core/window/rolling.py @@ -472,6 +472,8 @@ def _apply_pairwise( other = target # only default unset pairwise = True if pairwise is None else pairwise + elif not isinstance(other, (ABCDataFrame, ABCSeries)): + raise ValueError("other must be a DataFrame or Series") return flex_binary_moment(target, other, func, pairwise=bool(pairwise)) diff --git a/pandas/tests/window/moments/test_moments_consistency_ewm.py b/pandas/tests/window/moments/test_moments_consistency_ewm.py index a36091ab8934e..c79d02fd3237e 100644 --- a/pandas/tests/window/moments/test_moments_consistency_ewm.py +++ b/pandas/tests/window/moments/test_moments_consistency_ewm.py @@ -64,9 +64,9 @@ def test_different_input_array_raise_exception(name): A = Series(np.random.randn(50), index=np.arange(50)) A[:10] = np.NaN - msg = "Input arrays must be of the same type!" + msg = "other must be a DataFrame or Series" # exception raised is Exception - with pytest.raises(Exception, match=msg): + with pytest.raises(ValueError, match=msg): getattr(A.ewm(com=20, min_periods=5), name)(np.random.randn(50)) diff --git a/pandas/tests/window/moments/test_moments_consistency_rolling.py b/pandas/tests/window/moments/test_moments_consistency_rolling.py index 28fd5633de02e..7ec5846ef4acf 100644 --- a/pandas/tests/window/moments/test_moments_consistency_rolling.py +++ b/pandas/tests/window/moments/test_moments_consistency_rolling.py @@ -13,7 +13,6 @@ Series, ) import pandas._testing as tm -from pandas.core.window.common import flex_binary_moment def _rolling_consistency_cases(): @@ -133,14 +132,6 @@ def test_rolling_corr_with_zero_variance(window): assert s.rolling(window=window).corr(other=other).isna().all() -def test_flex_binary_moment(): - # GH3155 - # don't blow the stack - msg = "arguments to moment function must be of type np.ndarray/Series/DataFrame" - with pytest.raises(TypeError, match=msg): - flex_binary_moment(5, 6, None) - - def test_corr_sanity(): # GH 3155 df = DataFrame(