Skip to content

Commit 0b43290

Browse files
committed
feat: add DataFrames.corr() method
1 parent ffb0d15 commit 0b43290

File tree

5 files changed

+105
-1
lines changed

5 files changed

+105
-1
lines changed

bigframes/core/blocks.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1089,6 +1089,39 @@ def summarize(
10891089
labels = self._get_labels_for_columns(column_ids)
10901090
return Block(expr, column_labels=labels, index_columns=[label_col_id])
10911091

1092+
def corr(self):
1093+
"""Returns a block object to compute the self-correlation on this block."""
1094+
aggregations = [
1095+
(
1096+
ex.BinaryAggregation(
1097+
agg_ops.CorrOp(), ex.free_var(left_col), ex.free_var(right_col)
1098+
),
1099+
f"{left_col}-{right_col}",
1100+
)
1101+
for left_col in self.value_columns
1102+
for right_col in self.value_columns
1103+
]
1104+
expr = self.expr.aggregate(aggregations)
1105+
1106+
label_col_id = guid.generate_guid()
1107+
input_count = len(self.value_columns)
1108+
unpivot_columns = tuple(
1109+
(
1110+
guid.generate_guid(),
1111+
tuple(expr.column_ids[input_count * i : input_count * (i + 1)]),
1112+
)
1113+
for i in range(input_count)
1114+
)
1115+
labels = self._get_labels_for_columns(self.value_columns)
1116+
1117+
expr = expr.unpivot(
1118+
row_labels=labels,
1119+
index_col_ids=[label_col_id],
1120+
unpivot_columns=unpivot_columns,
1121+
)
1122+
1123+
return Block(expr, column_labels=labels, index_columns=[label_col_id])
1124+
10921125
def _standard_stats(self, column_id) -> typing.Sequence[agg_ops.UnaryAggregateOp]:
10931126
"""
10941127
Gets a standard set of stats to preemptively fetch for a column if

bigframes/dataframe.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,6 +1017,27 @@ def combine(
10171017
def combine_first(self, other: DataFrame):
10181018
return self._apply_dataframe_binop(other, ops.fillna_op)
10191019

1020+
def corr(self, method="pearson", min_periods=None, numeric_only=False) -> DataFrame:
1021+
if method != "pearson":
1022+
raise NotImplementedError(
1023+
f"Only Pearson correlation is currently supported. {constants.FEEDBACK_LINK}"
1024+
)
1025+
if min_periods:
1026+
raise NotImplementedError(
1027+
f"min_periods not yet supported. {constants.FEEDBACK_LINK}"
1028+
)
1029+
# TODO(chelsealin): Support non-numeric columns correlation.
1030+
if not numeric_only:
1031+
raise NotImplementedError(
1032+
f"Only numeric columns' correlation is currently supported. {constants.FEEDBACK_LINK}"
1033+
)
1034+
if len(self.columns) > 30:
1035+
raise NotImplementedError(
1036+
f"Only work with dataframes containing fewer than 30 columns. Current: {self.columns}. {constants.FEEDBACK_LINK}"
1037+
)
1038+
# TODO(chelsealin): Support multi-index dataframes' correlation.
1039+
return DataFrame(self._block.corr())
1040+
10201041
def to_pandas(
10211042
self,
10221043
max_download_size: Optional[int] = None,

tests/system/small/test_dataframe.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1783,6 +1783,21 @@ def test_combine_first(
17831783
pd.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)
17841784

17851785

1786+
def test_corr_w_numeric_only(scalars_dfs):
1787+
columns = ["int64_too", "int64_col", "float64_col"]
1788+
scalars_df, scalars_pandas_df = scalars_dfs
1789+
1790+
bf_result = scalars_df[columns].corr(numeric_only=True).to_pandas()
1791+
pd_result = scalars_pandas_df[columns].corr(numeric_only=True)
1792+
1793+
# BigFrames and Pandas differ in their data type handling:
1794+
# - Column types: BigFrames uses Float64, Pandas uses float64.
1795+
# - Index types: BigFrames uses strign, Pandas uses object.
1796+
pd.testing.assert_frame_equal(
1797+
bf_result, pd_result, check_dtype=False, check_index_type=False
1798+
)
1799+
1800+
17861801
@pytest.mark.parametrize(
17871802
("op"),
17881803
[

third_party/bigframes_vendored/pandas/core/frame.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2805,6 +2805,41 @@ def combine_first(self, other) -> DataFrame:
28052805
"""
28062806
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)
28072807

2808+
def corr(self, method, min_periods, numeric_only) -> DataFrame:
2809+
"""
2810+
Compute pairwise correlation of columns, excluding NA/null values.
2811+
2812+
**Examples:**
2813+
2814+
>>> import bigframes.pandas as bpd
2815+
>>> bpd.options.display.progress_bar = None
2816+
2817+
>>> df = bpd.DataFrame({'A': [1, 2, 3],
2818+
... 'B': [400, 500, 600],
2819+
... 'C': [0.8, 0.4, 0.9]})
2820+
>>> df.corr(numeric_only=True)
2821+
>>> df
2822+
A B C
2823+
A 1.0 1.0 0.188982
2824+
B 1.0 1.0 0.188982
2825+
C 0.188982 0.188982 1.0
2826+
<BLANKLINE>
2827+
[3 rows x 3 columns]
2828+
2829+
Args:
2830+
method (string, default "pearson"):
2831+
Correlation method to use - currently only "pearson" is supported.
2832+
min_periods (int, default None):
2833+
The minimum number of observations needed to return a result. Non-default values
2834+
are not yet supported, so a result will be returned for at least two observations.
2835+
numeric_only(bool, default False):
2836+
Include only float, int or boolean data. - currently numeric only is supported
2837+
2838+
Returns:
2839+
DataFrame: Correlation matrix.
2840+
"""
2841+
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)
2842+
28082843
def update(
28092844
self, other, join: str = "left", overwrite: bool = True, filter_func=None
28102845
) -> DataFrame:

third_party/bigframes_vendored/pandas/core/series.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -840,7 +840,7 @@ def corr(self, other, method="pearson", min_periods=None) -> float:
840840
float: Will return NaN if there are fewer than two numeric pairs, either series has a
841841
variance or covariance of zero, or any input value is infinite.
842842
"""
843-
raise NotImplementedError("abstract method")
843+
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)
844844

845845
def cov(
846846
self,

0 commit comments

Comments
 (0)