Skip to content

Commit 67fd434

Browse files
authored
feat: add DataFrames.corr() method (#379)
* feat: add `DataFrames.corr()` method * support multi-indices * fix mypy * support non-numeric col * fix doc * fix system 3.9 * fix doctest
1 parent 234b61c commit 67fd434

File tree

6 files changed

+160
-4
lines changed

6 files changed

+160
-4
lines changed

bigframes/core/blocks.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,11 +102,11 @@ def __init__(
102102
):
103103
"""Construct a block object, will create default index if no index columns specified."""
104104
index_columns = list(index_columns)
105-
if index_labels:
105+
if index_labels is not None:
106106
index_labels = list(index_labels)
107107
if len(index_labels) != len(index_columns):
108108
raise ValueError(
109-
"'index_columns' and 'index_labels' must have equal length"
109+
f"'index_columns' (size {len(index_columns)}) and 'index_labels' (size {len(index_labels)}) must have equal length"
110110
)
111111
if len(index_columns) == 0:
112112
new_index_col_id = guid.generate_guid()
@@ -1089,6 +1089,46 @@ def summarize(
10891089
labels = self._get_labels_for_columns(column_ids)
10901090
return Block(expr, column_labels=labels, index_columns=[label_col_id])
10911091

1092+
def corr(self):
1093+
"""Returns a block object to compute the self-correlation on this block."""
1094+
aggregations = [
1095+
(
1096+
ex.BinaryAggregation(
1097+
agg_ops.CorrOp(), ex.free_var(left_col), ex.free_var(right_col)
1098+
),
1099+
f"{left_col}-{right_col}",
1100+
)
1101+
for left_col in self.value_columns
1102+
for right_col in self.value_columns
1103+
]
1104+
expr = self.expr.aggregate(aggregations)
1105+
1106+
index_col_ids = [
1107+
guid.generate_guid() for i in range(self.column_labels.nlevels)
1108+
]
1109+
input_count = len(self.value_columns)
1110+
unpivot_columns = tuple(
1111+
(
1112+
guid.generate_guid(),
1113+
tuple(expr.column_ids[input_count * i : input_count * (i + 1)]),
1114+
)
1115+
for i in range(input_count)
1116+
)
1117+
labels = self._get_labels_for_columns(self.value_columns)
1118+
1119+
expr = expr.unpivot(
1120+
row_labels=labels,
1121+
index_col_ids=index_col_ids,
1122+
unpivot_columns=unpivot_columns,
1123+
)
1124+
1125+
return Block(
1126+
expr,
1127+
column_labels=self.column_labels,
1128+
index_columns=index_col_ids,
1129+
index_labels=self.column_labels.names,
1130+
)
1131+
10921132
def _standard_stats(self, column_id) -> typing.Sequence[agg_ops.UnaryAggregateOp]:
10931133
"""
10941134
Gets a standard set of stats to preemptively fetch for a column if
@@ -1889,7 +1929,7 @@ def to_pandas(self) -> pd.Index:
18891929
df = expr.session._rows_to_dataframe(results, dtypes)
18901930
df = df.set_index(index_columns)
18911931
index = df.index
1892-
index.names = list(self._block._index_labels)
1932+
index.names = list(self._block._index_labels) # type:ignore
18931933
return index
18941934

18951935
def resolve_level(self, level: LevelsType) -> typing.Sequence[str]:

bigframes/dataframe.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,6 +1017,27 @@ def combine(
10171017
def combine_first(self, other: DataFrame):
10181018
return self._apply_dataframe_binop(other, ops.fillna_op)
10191019

1020+
def corr(self, method="pearson", min_periods=None, numeric_only=False) -> DataFrame:
1021+
if method != "pearson":
1022+
raise NotImplementedError(
1023+
f"Only Pearson correlation is currently supported. {constants.FEEDBACK_LINK}"
1024+
)
1025+
if min_periods:
1026+
raise NotImplementedError(
1027+
f"min_periods not yet supported. {constants.FEEDBACK_LINK}"
1028+
)
1029+
if len(self.columns) > 30:
1030+
raise NotImplementedError(
1031+
f"Only work with dataframes containing fewer than 30 columns. Current: {len(self.columns)}. {constants.FEEDBACK_LINK}"
1032+
)
1033+
1034+
if not numeric_only:
1035+
frame = self._raise_on_non_numeric("corr")
1036+
else:
1037+
frame = self._drop_non_numeric()
1038+
1039+
return DataFrame(frame._block.corr())
1040+
10201041
def to_pandas(
10211042
self,
10221043
max_download_size: Optional[int] = None,

tests/system/small/test_dataframe.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1783,6 +1783,46 @@ def test_combine_first(
17831783
pd.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)
17841784

17851785

1786+
@pytest.mark.parametrize(
1787+
("columns", "numeric_only"),
1788+
[
1789+
(["bool_col", "int64_col", "float64_col"], True),
1790+
(["bool_col", "int64_col", "float64_col"], False),
1791+
(["bool_col", "int64_col", "float64_col", "string_col"], True),
1792+
pytest.param(
1793+
["bool_col", "int64_col", "float64_col", "string_col"],
1794+
False,
1795+
marks=pytest.mark.xfail(
1796+
raises=NotImplementedError,
1797+
),
1798+
),
1799+
],
1800+
)
1801+
def test_corr_w_numeric_only(scalars_dfs, columns, numeric_only):
1802+
scalars_df, scalars_pandas_df = scalars_dfs
1803+
1804+
bf_result = scalars_df[columns].corr(numeric_only=numeric_only).to_pandas()
1805+
pd_result = scalars_pandas_df[columns].corr(numeric_only=numeric_only)
1806+
1807+
# BigFrames and Pandas differ in their data type handling:
1808+
# - Column types: BigFrames uses Float64, Pandas uses float64.
1809+
# - Index types: BigFrames uses strign, Pandas uses object.
1810+
pd.testing.assert_frame_equal(
1811+
bf_result, pd_result, check_dtype=False, check_index_type=False
1812+
)
1813+
1814+
1815+
def test_corr_w_invalid_parameters(scalars_dfs):
1816+
columns = ["int64_too", "int64_col", "float64_col"]
1817+
scalars_df, _ = scalars_dfs
1818+
1819+
with pytest.raises(NotImplementedError):
1820+
scalars_df[columns].corr(method="kendall")
1821+
1822+
with pytest.raises(NotImplementedError):
1823+
scalars_df[columns].corr(min_periods=1)
1824+
1825+
17861826
@pytest.mark.parametrize(
17871827
("op"),
17881828
[

tests/system/small/test_multiindex.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -880,6 +880,27 @@ def test_column_multi_index_w_na_stack(scalars_df_index, scalars_pandas_df_index
880880
pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)
881881

882882

883+
def test_corr_w_multi_index(scalars_df_index, scalars_pandas_df_index):
884+
columns = ["int64_too", "float64_col", "int64_col"]
885+
multi_columns = pandas.MultiIndex.from_tuples(zip(["a", "b", "b"], [1, 2, 2]))
886+
887+
bf = scalars_df_index[columns].copy()
888+
bf.columns = multi_columns
889+
890+
pd_df = scalars_pandas_df_index[columns].copy()
891+
pd_df.columns = multi_columns
892+
893+
bf_result = bf.corr(numeric_only=True).to_pandas()
894+
pd_result = pd_df.corr(numeric_only=True)
895+
896+
# BigFrames and Pandas differ in their data type handling:
897+
# - Column types: BigFrames uses Float64, Pandas uses float64.
898+
# - Index types: BigFrames uses strign, Pandas uses object.
899+
pandas.testing.assert_frame_equal(
900+
bf_result, pd_result, check_dtype=False, check_index_type=False
901+
)
902+
903+
883904
@pytest.mark.parametrize(
884905
("index_names",),
885906
[

third_party/bigframes_vendored/pandas/core/frame.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2805,6 +2805,40 @@ def combine_first(self, other) -> DataFrame:
28052805
"""
28062806
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)
28072807

2808+
def corr(self, method, min_periods, numeric_only) -> DataFrame:
2809+
"""
2810+
Compute pairwise correlation of columns, excluding NA/null values.
2811+
2812+
**Examples:**
2813+
2814+
>>> import bigframes.pandas as bpd
2815+
>>> bpd.options.display.progress_bar = None
2816+
2817+
>>> df = bpd.DataFrame({'A': [1, 2, 3],
2818+
... 'B': [400, 500, 600],
2819+
... 'C': [0.8, 0.4, 0.9]})
2820+
>>> df.corr(numeric_only=True)
2821+
A B C
2822+
A 1.0 1.0 0.188982
2823+
B 1.0 1.0 0.188982
2824+
C 0.188982 0.188982 1.0
2825+
<BLANKLINE>
2826+
[3 rows x 3 columns]
2827+
2828+
Args:
2829+
method (string, default "pearson"):
2830+
Correlation method to use - currently only "pearson" is supported.
2831+
min_periods (int, default None):
2832+
The minimum number of observations needed to return a result. Non-default values
2833+
are not yet supported, so a result will be returned for at least two observations.
2834+
numeric_only(bool, default False):
2835+
Include only float, int, boolean, decimal data.
2836+
2837+
Returns:
2838+
DataFrame: Correlation matrix.
2839+
"""
2840+
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)
2841+
28082842
def update(
28092843
self, other, join: str = "left", overwrite: bool = True, filter_func=None
28102844
) -> DataFrame:

third_party/bigframes_vendored/pandas/core/series.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -840,7 +840,7 @@ def corr(self, other, method="pearson", min_periods=None) -> float:
840840
float: Will return NaN if there are fewer than two numeric pairs, either series has a
841841
variance or covariance of zero, or any input value is infinite.
842842
"""
843-
raise NotImplementedError("abstract method")
843+
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)
844844

845845
def cov(
846846
self,

0 commit comments

Comments
 (0)