diff --git a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py index 603e8a096c..e2bd6b8382 100644 --- a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py +++ b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py @@ -111,6 +111,140 @@ def _( return apply_window_if_present(sge.func("COUNT", column.expr), window) +@UNARY_OP_REGISTRATION.register(agg_ops.CutOp) +def _( + op: agg_ops.CutOp, + column: typed_expr.TypedExpr, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Expression: + if isinstance(op.bins, int): + case_expr = _cut_ops_w_int_bins(op, column, op.bins, window) + else: # Interpret as intervals + case_expr = _cut_ops_w_intervals(op, column, op.bins, window) + return case_expr + + +def _cut_ops_w_int_bins( + op: agg_ops.CutOp, + column: typed_expr.TypedExpr, + bins: int, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Case: + case_expr = sge.Case() + col_min = apply_window_if_present( + sge.func("MIN", column.expr), window or window_spec.WindowSpec() + ) + col_max = apply_window_if_present( + sge.func("MAX", column.expr), window or window_spec.WindowSpec() + ) + adj: sge.Expression = sge.Sub(this=col_max, expression=col_min) * sge.convert(0.001) + bin_width: sge.Expression = sge.func( + "IEEE_DIVIDE", + sge.Sub(this=col_max, expression=col_min), + sge.convert(bins), + ) + + for this_bin in range(bins): + value: sge.Expression + if op.labels is False: + value = ir._literal(this_bin, dtypes.INT_DTYPE) + elif isinstance(op.labels, typing.Iterable): + value = ir._literal(list(op.labels)[this_bin], dtypes.STRING_DTYPE) + else: + left_adj: sge.Expression = ( + adj if this_bin == 0 and op.right else sge.convert(0) + ) + right_adj: sge.Expression = ( + adj if this_bin == bins - 1 and not op.right else sge.convert(0) + ) + + left: sge.Expression = ( + col_min + sge.convert(this_bin) * bin_width - left_adj + ) + right: sge.Expression = ( + col_min + sge.convert(this_bin + 1) * bin_width + right_adj + ) + if op.right: + left_identifier = sge.Identifier(this="left_exclusive", quoted=True) + right_identifier = sge.Identifier(this="right_inclusive", quoted=True) + else: + left_identifier = sge.Identifier(this="left_inclusive", quoted=True) + right_identifier = sge.Identifier(this="right_exclusive", quoted=True) + + value = sge.Struct( + expressions=[ + sge.PropertyEQ(this=left_identifier, expression=left), + sge.PropertyEQ(this=right_identifier, expression=right), + ] + ) + + condition: sge.Expression + if this_bin == bins - 1: + condition = sge.Is(this=column.expr, expression=sge.Not(this=sge.Null())) + else: + if op.right: + condition = sge.LTE( + this=column.expr, + expression=(col_min + sge.convert(this_bin + 1) * bin_width), + ) + else: + condition = sge.LT( + this=column.expr, + expression=(col_min + sge.convert(this_bin + 1) * bin_width), + ) + case_expr = case_expr.when(condition, value) + return case_expr + + +def _cut_ops_w_intervals( + op: agg_ops.CutOp, + column: typed_expr.TypedExpr, + bins: typing.Iterable[typing.Tuple[typing.Any, typing.Any]], + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Case: + case_expr = sge.Case() + for this_bin, interval in enumerate(bins): + left: sge.Expression = ir._literal( + interval[0], dtypes.infer_literal_type(interval[0]) + ) + right: sge.Expression = ir._literal( + interval[1], dtypes.infer_literal_type(interval[1]) + ) + condition: sge.Expression + if op.right: + condition = sge.And( + this=sge.GT(this=column.expr, expression=left), + expression=sge.LTE(this=column.expr, expression=right), + ) + else: + condition = sge.And( + this=sge.GTE(this=column.expr, expression=left), + expression=sge.LT(this=column.expr, expression=right), + ) + + value: sge.Expression + if op.labels is False: + value = ir._literal(this_bin, dtypes.INT_DTYPE) + elif isinstance(op.labels, typing.Iterable): + value = ir._literal(list(op.labels)[this_bin], dtypes.STRING_DTYPE) + else: + if op.right: + left_identifier = sge.Identifier(this="left_exclusive", quoted=True) + right_identifier = sge.Identifier(this="right_inclusive", quoted=True) + else: + left_identifier = sge.Identifier(this="left_inclusive", quoted=True) + right_identifier = sge.Identifier(this="right_exclusive", quoted=True) + + value = sge.Struct( + expressions=[ + sge.PropertyEQ(this=left_identifier, expression=left), + sge.PropertyEQ(this=right_identifier, expression=right), + ] + ) + case_expr = case_expr.when(condition, value) + return case_expr + + @UNARY_OP_REGISTRATION.register(agg_ops.DateSeriesDiffOp) def _( op: agg_ops.DateSeriesDiffOp, diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/int_bins.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/int_bins.sql new file mode 100644 index 0000000000..015ac32799 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/int_bins.sql @@ -0,0 +1,55 @@ +WITH `bfcte_0` AS ( + SELECT + `int64_col` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + CASE + WHEN `int64_col` <= MIN(`int64_col`) OVER () + ( + 1 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3) + ) + THEN STRUCT( + ( + MIN(`int64_col`) OVER () + ( + 0 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3) + ) + ) - ( + ( + MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER () + ) * 0.001 + ) AS `left_exclusive`, + MIN(`int64_col`) OVER () + ( + 1 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3) + ) + 0 AS `right_inclusive` + ) + WHEN `int64_col` <= MIN(`int64_col`) OVER () + ( + 2 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3) + ) + THEN STRUCT( + ( + MIN(`int64_col`) OVER () + ( + 1 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3) + ) + ) - 0 AS `left_exclusive`, + MIN(`int64_col`) OVER () + ( + 2 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3) + ) + 0 AS `right_inclusive` + ) + WHEN `int64_col` IS NOT NULL + THEN STRUCT( + ( + MIN(`int64_col`) OVER () + ( + 2 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3) + ) + ) - 0 AS `left_exclusive`, + MIN(`int64_col`) OVER () + ( + 3 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3) + ) + 0 AS `right_inclusive` + ) + END AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `int_bins` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/int_bins_labels.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/int_bins_labels.sql new file mode 100644 index 0000000000..c98682f2b8 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/int_bins_labels.sql @@ -0,0 +1,24 @@ +WITH `bfcte_0` AS ( + SELECT + `int64_col` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + CASE + WHEN `int64_col` < MIN(`int64_col`) OVER () + ( + 1 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3) + ) + THEN 'a' + WHEN `int64_col` < MIN(`int64_col`) OVER () + ( + 2 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3) + ) + THEN 'b' + WHEN `int64_col` IS NOT NULL + THEN 'c' + END AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `int_bins_labels` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/interval_bins.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/interval_bins.sql new file mode 100644 index 0000000000..a3e689b11e --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/interval_bins.sql @@ -0,0 +1,18 @@ +WITH `bfcte_0` AS ( + SELECT + `int64_col` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + CASE + WHEN `int64_col` > 0 AND `int64_col` <= 1 + THEN STRUCT(0 AS `left_exclusive`, 1 AS `right_inclusive`) + WHEN `int64_col` > 1 AND `int64_col` <= 2 + THEN STRUCT(1 AS `left_exclusive`, 2 AS `right_inclusive`) + END AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `interval_bins` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/interval_bins_labels.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/interval_bins_labels.sql new file mode 100644 index 0000000000..1a8a92e38e --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/interval_bins_labels.sql @@ -0,0 +1,18 @@ +WITH `bfcte_0` AS ( + SELECT + `int64_col` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + CASE + WHEN `int64_col` > 0 AND `int64_col` <= 1 + THEN 0 + WHEN `int64_col` > 1 AND `int64_col` <= 2 + THEN 1 + END AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `interval_bins_labels` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py index a21c753896..ab9f7febbf 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py +++ b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py @@ -174,6 +174,35 @@ def test_count(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(sql_window_partition, "window_partition_out.sql") +def test_cut(scalar_types_df: bpd.DataFrame, snapshot): + col_name = "int64_col" + bf_df = scalar_types_df[[col_name]] + agg_ops_map = { + "int_bins": agg_exprs.UnaryAggregation( + agg_ops.CutOp(bins=3, right=True, labels=None), expression.deref(col_name) + ), + "interval_bins": agg_exprs.UnaryAggregation( + agg_ops.CutOp(bins=((0, 1), (1, 2)), right=True, labels=None), + expression.deref(col_name), + ), + "int_bins_labels": agg_exprs.UnaryAggregation( + agg_ops.CutOp(bins=3, labels=("a", "b", "c"), right=False), + expression.deref(col_name), + ), + "interval_bins_labels": agg_exprs.UnaryAggregation( + agg_ops.CutOp(bins=((0, 1), (1, 2)), labels=False, right=True), + expression.deref(col_name), + ), + } + window = window_spec.WindowSpec() + + # Loop through the aggregation map items + for test_name, agg_expr in agg_ops_map.items(): + sql = _apply_unary_window_op(bf_df, agg_expr, window, test_name) + + snapshot.assert_match(sql, f"{test_name}.sql") + + def test_dense_rank(scalar_types_df: bpd.DataFrame, snapshot): col_name = "int64_col" bf_df = scalar_types_df[[col_name]]