@@ -860,7 +860,7 @@ def filter(self, column_id: str, keep_null: bool = False):
860
860
861
861
def aggregate_all_and_stack (
862
862
self ,
863
- operation : agg_ops .AggregateOp ,
863
+ operation : agg_ops .UnaryAggregateOp ,
864
864
* ,
865
865
axis : int | str = 0 ,
866
866
value_col_id : str = "values" ,
@@ -872,7 +872,8 @@ def aggregate_all_and_stack(
872
872
axis_n = utils .get_axis_number (axis )
873
873
if axis_n == 0 :
874
874
aggregations = [
875
- (col_id , operation , col_id ) for col_id in self .value_columns
875
+ (ex .UnaryAggregation (operation , ex .free_var (col_id )), col_id )
876
+ for col_id in self .value_columns
876
877
]
877
878
index_col_ids = [
878
879
guid .generate_guid () for i in range (self .column_labels .nlevels )
@@ -902,10 +903,13 @@ def aggregate_all_and_stack(
902
903
dtype = dtype ,
903
904
)
904
905
index_aggregations = [
905
- (col_id , agg_ops .AnyValueOp (), col_id )
906
+ (ex . UnaryAggregation ( agg_ops .AnyValueOp (), ex . free_var ( col_id ) ), col_id )
906
907
for col_id in [* self .index_columns ]
907
908
]
908
- main_aggregation = (value_col_id , operation , value_col_id )
909
+ main_aggregation = (
910
+ ex .UnaryAggregation (operation , ex .free_var (value_col_id )),
911
+ value_col_id ,
912
+ )
909
913
result_expr = stacked_expr .aggregate (
910
914
[* index_aggregations , main_aggregation ],
911
915
by_column_ids = [offset_col ],
@@ -966,7 +970,7 @@ def remap_f(x):
966
970
def aggregate (
967
971
self ,
968
972
by_column_ids : typing .Sequence [str ] = (),
969
- aggregations : typing .Sequence [typing .Tuple [str , agg_ops .AggregateOp ]] = (),
973
+ aggregations : typing .Sequence [typing .Tuple [str , agg_ops .UnaryAggregateOp ]] = (),
970
974
* ,
971
975
dropna : bool = True ,
972
976
) -> typing .Tuple [Block , typing .Sequence [str ]]:
@@ -979,10 +983,13 @@ def aggregate(
979
983
dropna: whether null keys should be dropped
980
984
"""
981
985
agg_specs = [
982
- (input_id , operation , guid .generate_guid ())
986
+ (
987
+ ex .UnaryAggregation (operation , ex .free_var (input_id )),
988
+ guid .generate_guid (),
989
+ )
983
990
for input_id , operation in aggregations
984
991
]
985
- output_col_ids = [agg_spec [2 ] for agg_spec in agg_specs ]
992
+ output_col_ids = [agg_spec [1 ] for agg_spec in agg_specs ]
986
993
result_expr = self .expr .aggregate (agg_specs , by_column_ids , dropna = dropna )
987
994
988
995
aggregate_labels = self ._get_labels_for_columns (
@@ -1004,7 +1011,7 @@ def aggregate(
1004
1011
output_col_ids ,
1005
1012
)
1006
1013
1007
- def get_stat (self , column_id : str , stat : agg_ops .AggregateOp ):
1014
+ def get_stat (self , column_id : str , stat : agg_ops .UnaryAggregateOp ):
1008
1015
"""Gets aggregates immediately, and caches it"""
1009
1016
if stat .name in self ._stats_cache [column_id ]:
1010
1017
return self ._stats_cache [column_id ][stat .name ]
@@ -1014,7 +1021,10 @@ def get_stat(self, column_id: str, stat: agg_ops.AggregateOp):
1014
1021
standard_stats = self ._standard_stats (column_id )
1015
1022
stats_to_fetch = standard_stats if stat in standard_stats else [stat ]
1016
1023
1017
- aggregations = [(column_id , stat , stat .name ) for stat in stats_to_fetch ]
1024
+ aggregations = [
1025
+ (ex .UnaryAggregation (stat , ex .free_var (column_id )), stat .name )
1026
+ for stat in stats_to_fetch
1027
+ ]
1018
1028
expr = self .expr .aggregate (aggregations )
1019
1029
offset_index_id = guid .generate_guid ()
1020
1030
expr = expr .promote_offsets (offset_index_id )
@@ -1054,13 +1064,13 @@ def get_corr_stat(self, column_id_left: str, column_id_right: str):
1054
1064
def summarize (
1055
1065
self ,
1056
1066
column_ids : typing .Sequence [str ],
1057
- stats : typing .Sequence [agg_ops .AggregateOp ],
1067
+ stats : typing .Sequence [agg_ops .UnaryAggregateOp ],
1058
1068
):
1059
1069
"""Get a list of stats as a deferred block object."""
1060
1070
label_col_id = guid .generate_guid ()
1061
1071
labels = [stat .name for stat in stats ]
1062
1072
aggregations = [
1063
- (col_id , stat , f"{ col_id } -{ stat .name } " )
1073
+ (ex . UnaryAggregation ( stat , ex . free_var ( col_id )) , f"{ col_id } -{ stat .name } " )
1064
1074
for stat in stats
1065
1075
for col_id in column_ids
1066
1076
]
@@ -1076,7 +1086,7 @@ def summarize(
1076
1086
labels = self ._get_labels_for_columns (column_ids )
1077
1087
return Block (expr , column_labels = labels , index_columns = [label_col_id ])
1078
1088
1079
- def _standard_stats (self , column_id ) -> typing .Sequence [agg_ops .AggregateOp ]:
1089
+ def _standard_stats (self , column_id ) -> typing .Sequence [agg_ops .UnaryAggregateOp ]:
1080
1090
"""
1081
1091
Gets a standard set of stats to preemptively fetch for a column if
1082
1092
any other stat is fetched.
@@ -1087,7 +1097,7 @@ def _standard_stats(self, column_id) -> typing.Sequence[agg_ops.AggregateOp]:
1087
1097
"""
1088
1098
# TODO: annotate aggregations themself with this information
1089
1099
dtype = self .expr .get_column_type (column_id )
1090
- stats : list [agg_ops .AggregateOp ] = [agg_ops .count_op ]
1100
+ stats : list [agg_ops .UnaryAggregateOp ] = [agg_ops .count_op ]
1091
1101
if dtype not in bigframes .dtypes .UNORDERED_DTYPES :
1092
1102
stats += [agg_ops .min_op , agg_ops .max_op ]
1093
1103
if dtype in bigframes .dtypes .NUMERIC_BIGFRAMES_TYPES_PERMISSIVE :
0 commit comments