Skip to content

Commit 93835ef

Browse files
refactor: generalize aggregation to handle 0,1, or 2 inputs (#360)
Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [ ] Make sure to open an issue as a [bug/issue](https://togithub.com/googleapis/python-bigquery-dataframes/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [ ] Ensure the tests and linter pass - [ ] Code coverage does not decrease (if any source code was changed) - [ ] Appropriate docs were updated (if necessary) Fixes #<issue_number_goes_here> 🦕
1 parent 1866a26 commit 93835ef

File tree

15 files changed

+338
-225
lines changed

15 files changed

+338
-225
lines changed

bigframes/core/__init__.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def row_count(self) -> ArrayValue:
118118
# Operations
119119
def filter_by_id(self, predicate_id: str, keep_null: bool = False) -> ArrayValue:
120120
"""Filter the table on a given expression, the predicate must be a boolean series aligned with the table expression."""
121-
predicate = ex.free_var(predicate_id)
121+
predicate: ex.Expression = ex.free_var(predicate_id)
122122
if keep_null:
123123
predicate = ops.fillna_op.as_expr(predicate, ex.const(True))
124124
return self.filter(predicate)
@@ -241,7 +241,7 @@ def drop_columns(self, columns: Iterable[str]) -> ArrayValue:
241241

242242
def aggregate(
243243
self,
244-
aggregations: typing.Sequence[typing.Tuple[str, agg_ops.AggregateOp, str]],
244+
aggregations: typing.Sequence[typing.Tuple[ex.Aggregation, str]],
245245
by_column_ids: typing.Sequence[str] = (),
246246
dropna: bool = True,
247247
) -> ArrayValue:
@@ -270,14 +270,23 @@ def corr_aggregate(
270270
Arguments:
271271
corr_aggregations: left_column_id, right_column_id, output_column_id tuples
272272
"""
273+
aggregations = tuple(
274+
(
275+
ex.BinaryAggregation(
276+
agg_ops.CorrOp(), ex.free_var(agg[0]), ex.free_var(agg[1])
277+
),
278+
agg[2],
279+
)
280+
for agg in corr_aggregations
281+
)
273282
return ArrayValue(
274-
nodes.CorrNode(child=self.node, corr_aggregations=tuple(corr_aggregations))
283+
nodes.AggregateNode(child=self.node, aggregations=aggregations)
275284
)
276285

277286
def project_window_op(
278287
self,
279288
column_name: str,
280-
op: agg_ops.WindowOp,
289+
op: agg_ops.UnaryWindowOp,
281290
window_spec: WindowSpec,
282291
output_name=None,
283292
*,

bigframes/core/blocks.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -860,7 +860,7 @@ def filter(self, column_id: str, keep_null: bool = False):
860860

861861
def aggregate_all_and_stack(
862862
self,
863-
operation: agg_ops.AggregateOp,
863+
operation: agg_ops.UnaryAggregateOp,
864864
*,
865865
axis: int | str = 0,
866866
value_col_id: str = "values",
@@ -872,7 +872,8 @@ def aggregate_all_and_stack(
872872
axis_n = utils.get_axis_number(axis)
873873
if axis_n == 0:
874874
aggregations = [
875-
(col_id, operation, col_id) for col_id in self.value_columns
875+
(ex.UnaryAggregation(operation, ex.free_var(col_id)), col_id)
876+
for col_id in self.value_columns
876877
]
877878
index_col_ids = [
878879
guid.generate_guid() for i in range(self.column_labels.nlevels)
@@ -902,10 +903,13 @@ def aggregate_all_and_stack(
902903
dtype=dtype,
903904
)
904905
index_aggregations = [
905-
(col_id, agg_ops.AnyValueOp(), col_id)
906+
(ex.UnaryAggregation(agg_ops.AnyValueOp(), ex.free_var(col_id)), col_id)
906907
for col_id in [*self.index_columns]
907908
]
908-
main_aggregation = (value_col_id, operation, value_col_id)
909+
main_aggregation = (
910+
ex.UnaryAggregation(operation, ex.free_var(value_col_id)),
911+
value_col_id,
912+
)
909913
result_expr = stacked_expr.aggregate(
910914
[*index_aggregations, main_aggregation],
911915
by_column_ids=[offset_col],
@@ -966,7 +970,7 @@ def remap_f(x):
966970
def aggregate(
967971
self,
968972
by_column_ids: typing.Sequence[str] = (),
969-
aggregations: typing.Sequence[typing.Tuple[str, agg_ops.AggregateOp]] = (),
973+
aggregations: typing.Sequence[typing.Tuple[str, agg_ops.UnaryAggregateOp]] = (),
970974
*,
971975
dropna: bool = True,
972976
) -> typing.Tuple[Block, typing.Sequence[str]]:
@@ -979,10 +983,13 @@ def aggregate(
979983
dropna: whether null keys should be dropped
980984
"""
981985
agg_specs = [
982-
(input_id, operation, guid.generate_guid())
986+
(
987+
ex.UnaryAggregation(operation, ex.free_var(input_id)),
988+
guid.generate_guid(),
989+
)
983990
for input_id, operation in aggregations
984991
]
985-
output_col_ids = [agg_spec[2] for agg_spec in agg_specs]
992+
output_col_ids = [agg_spec[1] for agg_spec in agg_specs]
986993
result_expr = self.expr.aggregate(agg_specs, by_column_ids, dropna=dropna)
987994

988995
aggregate_labels = self._get_labels_for_columns(
@@ -1004,7 +1011,7 @@ def aggregate(
10041011
output_col_ids,
10051012
)
10061013

1007-
def get_stat(self, column_id: str, stat: agg_ops.AggregateOp):
1014+
def get_stat(self, column_id: str, stat: agg_ops.UnaryAggregateOp):
10081015
"""Gets aggregates immediately, and caches it"""
10091016
if stat.name in self._stats_cache[column_id]:
10101017
return self._stats_cache[column_id][stat.name]
@@ -1014,7 +1021,10 @@ def get_stat(self, column_id: str, stat: agg_ops.AggregateOp):
10141021
standard_stats = self._standard_stats(column_id)
10151022
stats_to_fetch = standard_stats if stat in standard_stats else [stat]
10161023

1017-
aggregations = [(column_id, stat, stat.name) for stat in stats_to_fetch]
1024+
aggregations = [
1025+
(ex.UnaryAggregation(stat, ex.free_var(column_id)), stat.name)
1026+
for stat in stats_to_fetch
1027+
]
10181028
expr = self.expr.aggregate(aggregations)
10191029
offset_index_id = guid.generate_guid()
10201030
expr = expr.promote_offsets(offset_index_id)
@@ -1054,13 +1064,13 @@ def get_corr_stat(self, column_id_left: str, column_id_right: str):
10541064
def summarize(
10551065
self,
10561066
column_ids: typing.Sequence[str],
1057-
stats: typing.Sequence[agg_ops.AggregateOp],
1067+
stats: typing.Sequence[agg_ops.UnaryAggregateOp],
10581068
):
10591069
"""Get a list of stats as a deferred block object."""
10601070
label_col_id = guid.generate_guid()
10611071
labels = [stat.name for stat in stats]
10621072
aggregations = [
1063-
(col_id, stat, f"{col_id}-{stat.name}")
1073+
(ex.UnaryAggregation(stat, ex.free_var(col_id)), f"{col_id}-{stat.name}")
10641074
for stat in stats
10651075
for col_id in column_ids
10661076
]
@@ -1076,7 +1086,7 @@ def summarize(
10761086
labels = self._get_labels_for_columns(column_ids)
10771087
return Block(expr, column_labels=labels, index_columns=[label_col_id])
10781088

1079-
def _standard_stats(self, column_id) -> typing.Sequence[agg_ops.AggregateOp]:
1089+
def _standard_stats(self, column_id) -> typing.Sequence[agg_ops.UnaryAggregateOp]:
10801090
"""
10811091
Gets a standard set of stats to preemptively fetch for a column if
10821092
any other stat is fetched.
@@ -1087,7 +1097,7 @@ def _standard_stats(self, column_id) -> typing.Sequence[agg_ops.AggregateOp]:
10871097
"""
10881098
# TODO: annotate aggregations themself with this information
10891099
dtype = self.expr.get_column_type(column_id)
1090-
stats: list[agg_ops.AggregateOp] = [agg_ops.count_op]
1100+
stats: list[agg_ops.UnaryAggregateOp] = [agg_ops.count_op]
10911101
if dtype not in bigframes.dtypes.UNORDERED_DTYPES:
10921102
stats += [agg_ops.min_op, agg_ops.max_op]
10931103
if dtype in bigframes.dtypes.NUMERIC_BIGFRAMES_TYPES_PERMISSIVE:

0 commit comments

Comments
 (0)