From 89afa5b81981da74afb2f820b496128ddd104237 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Thu, 13 Nov 2025 23:23:10 +0000 Subject: [PATCH 1/3] refactor: add agg_ops.CutOp to the sqlglot compiler --- .../sqlglot/aggregations/unary_compiler.py | 158 ++++++++++++++++++ .../test_unary_compiler/test_cut/out.sql | 81 +++++++++ .../aggregations/test_unary_compiler.py | 27 +++ 3 files changed, 266 insertions(+) create mode 100644 tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/out.sql diff --git a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py index 603e8a096c..4e12374bf3 100644 --- a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py +++ b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py @@ -111,6 +111,164 @@ def _( return apply_window_if_present(sge.func("COUNT", column.expr), window) +@UNARY_OP_REGISTRATION.register(agg_ops.CutOp) +def _( + op: agg_ops.CutOp, + column: typed_expr.TypedExpr, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Expression: + if isinstance(op.bins, int): + case_expr = _cut_ops_w_int_bins(op, column, op.bins, window) + else: # Interpret as intervals + case_expr = _cut_ops_w_intervals(op, column, op.bins, window) + return apply_window_if_present(case_expr, window) + + +def _cut_ops_w_int_bins( + op: agg_ops.CutOp, + column: typed_expr.TypedExpr, + bins: int, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Case: + case_expr = sge.Case() + col_min = apply_window_if_present( + sge.func("MIN", column.expr), window or window_spec.WindowSpec() + ) + col_max = apply_window_if_present( + sge.func("MAX", column.expr), window or window_spec.WindowSpec() + ) + adj: sge.Expression = sge.Sub(this=col_max, expression=col_min) * sge.convert(0.001) + bin_width: sge.Expression = sge.func( + "IEEE_DIVIDE", + sge.Sub(this=col_max, expression=col_min), + sge.convert(bins), + ) + + for this_bin in range(bins): + value: sge.Expression + if op.labels is False: + value = ir._literal(this_bin, dtypes.INT_DTYPE) + elif isinstance(op.labels, typing.Iterable): + value = ir._literal(list(op.labels)[this_bin], dtypes.STRING_DTYPE) + else: + left_adj: sge.Expression = ( + adj if this_bin == 0 and op.right else sge.convert(0) + ) + right_adj: sge.Expression = ( + adj if this_bin == bins - 1 and not op.right else sge.convert(0) + ) + + left: sge.Expression = ( + col_min + sge.convert(this_bin) * bin_width - left_adj + ) + right: sge.Expression = ( + col_min + sge.convert(this_bin + 1) * bin_width + right_adj + ) + if op.right: + value = sge.Struct( + expressions=[ + sge.PropertyEQ( + this=sge.Identifier(this="left_exclusive", quoted=True), + expression=left, + ), + sge.PropertyEQ( + this=sge.Identifier(this="right_inclusive", quoted=True), + expression=right, + ), + ] + ) + else: + value = sge.Struct( + expressions=[ + sge.PropertyEQ( + this=sge.Identifier(this="left_inclusive", quoted=True), + expression=left, + ), + sge.PropertyEQ( + this=sge.Identifier(this="right_exclusive", quoted=True), + expression=right, + ), + ] + ) + + condition: sge.Expression + if this_bin == bins - 1: + condition = sge.Is(this=column.expr, expression=sge.Not(this=sge.Null())) + else: + if op.right: + condition = sge.LTE( + this=column.expr, + expression=(col_min + sge.convert(this_bin + 1) * bin_width), + ) + else: + condition = sge.LT( + this=column.expr, + expression=(col_min + sge.convert(this_bin + 1) * bin_width), + ) + case_expr = case_expr.when(condition, value) + return case_expr + + +def _cut_ops_w_intervals( + op: agg_ops.CutOp, + column: typed_expr.TypedExpr, + bins: typing.Iterable[typing.Tuple[typing.Any, typing.Any]], + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Case: + case_expr = sge.Case() + for this_bin, interval in enumerate(bins): + left: sge.Expression = ir._literal( + interval[0], dtypes.infer_literal_type(interval[0]) + ) + right: sge.Expression = ir._literal( + interval[1], dtypes.infer_literal_type(interval[1]) + ) + condition: sge.Expression + if op.right: + condition = sge.And( + this=sge.GT(this=column.expr, expression=left), + expression=sge.LTE(this=column.expr, expression=right), + ) + else: + condition = sge.And( + this=sge.GTE(this=column.expr, expression=left), + expression=sge.LT(this=column.expr, expression=right), + ) + + value: sge.Expression + if op.labels is False: + value = ir._literal(this_bin, dtypes.INT_DTYPE) + elif isinstance(op.labels, typing.Iterable): + value = ir._literal(list(op.labels)[this_bin], dtypes.STRING_DTYPE) + else: + if op.right: + value = sge.Struct( + expressions=[ + sge.PropertyEQ( + this=sge.Identifier(this="left_exclusive"), expression=left + ), + sge.PropertyEQ( + this=sge.Identifier(this="right_inclusive"), + expression=right, + ), + ] + ) + else: + value = sge.Struct( + expressions=[ + sge.PropertyEQ( + this=sge.Identifier(this="left_inclusive"), expression=left + ), + sge.PropertyEQ( + this=sge.Identifier(this="right_exclusive"), + expression=right, + ), + ] + ) + case_expr = case_expr.when(condition, value) + return case_expr + + @UNARY_OP_REGISTRATION.register(agg_ops.DateSeriesDiffOp) def _( op: agg_ops.DateSeriesDiffOp, diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/out.sql new file mode 100644 index 0000000000..6eb91d356b --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/out.sql @@ -0,0 +1,81 @@ +WITH `bfcte_0` AS ( + SELECT + `int64_col` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + CASE + WHEN `int64_col` <= MIN(`int64_col`) OVER () + ( + 1 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3) + ) + THEN STRUCT( + ( + MIN(`int64_col`) OVER () + ( + 0 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3) + ) + ) - ( + ( + MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER () + ) * 0.001 + ) AS `left_exclusive`, + MIN(`int64_col`) OVER () + ( + 1 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3) + ) + 0 AS `right_inclusive` + ) + WHEN `int64_col` <= MIN(`int64_col`) OVER () + ( + 2 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3) + ) + THEN STRUCT( + ( + MIN(`int64_col`) OVER () + ( + 1 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3) + ) + ) - 0 AS `left_exclusive`, + MIN(`int64_col`) OVER () + ( + 2 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3) + ) + 0 AS `right_inclusive` + ) + WHEN `int64_col` IS NOT NULL + THEN STRUCT( + ( + MIN(`int64_col`) OVER () + ( + 2 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3) + ) + ) - 0 AS `left_exclusive`, + MIN(`int64_col`) OVER () + ( + 3 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3) + ) + 0 AS `right_inclusive` + ) + END AS `bfcol_1`, + CASE + WHEN `int64_col` > 0 AND `int64_col` <= 1 + THEN STRUCT(0 AS left_exclusive, 1 AS right_inclusive) + WHEN `int64_col` > 1 AND `int64_col` <= 2 + THEN STRUCT(1 AS left_exclusive, 2 AS right_inclusive) + END AS `bfcol_2`, + CASE + WHEN `int64_col` < MIN(`int64_col`) OVER () + ( + 1 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3) + ) + THEN 'a' + WHEN `int64_col` < MIN(`int64_col`) OVER () + ( + 2 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3) + ) + THEN 'b' + WHEN `int64_col` IS NOT NULL + THEN 'c' + END AS `bfcol_3`, + CASE + WHEN `int64_col` > 0 AND `int64_col` <= 1 + THEN 0 + WHEN `int64_col` > 1 AND `int64_col` <= 2 + THEN 1 + END AS `bfcol_4` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `int_bins`, + `bfcol_2` AS `interval_bins`, + `bfcol_3` AS `int_bins_labels`, + `bfcol_4` AS `interval_bins_labels` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py index a21c753896..92c13b0c99 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py +++ b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py @@ -174,6 +174,33 @@ def test_count(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(sql_window_partition, "window_partition_out.sql") +def test_cut(scalar_types_df: bpd.DataFrame, snapshot): + col_name = "int64_col" + bf_df = scalar_types_df[[col_name]] + agg_ops_map = { + "int_bins": agg_exprs.UnaryAggregation( + agg_ops.CutOp(bins=3, right=True, labels=None), expression.deref(col_name) + ), + "interval_bins": agg_exprs.UnaryAggregation( + agg_ops.CutOp(bins=((0, 1), (1, 2)), right=True, labels=None), + expression.deref(col_name), + ), + "int_bins_labels": agg_exprs.UnaryAggregation( + agg_ops.CutOp(bins=3, labels=("a", "b", "c"), right=False), + expression.deref(col_name), + ), + "interval_bins_labels": agg_exprs.UnaryAggregation( + agg_ops.CutOp(bins=((0, 1), (1, 2)), labels=False, right=True), + expression.deref(col_name), + ), + } + sql = _apply_unary_agg_ops( + bf_df, list(agg_ops_map.values()), list(agg_ops_map.keys()) + ) + + snapshot.assert_match(sql, "out.sql") + + def test_dense_rank(scalar_types_df: bpd.DataFrame, snapshot): col_name = "int64_col" bf_df = scalar_types_df[[col_name]] From 55fca3d98ead531d7d361e4788f2e91d0fe6bd9f Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Fri, 14 Nov 2025 00:11:28 +0000 Subject: [PATCH 2/3] make ibis runable --- .../sqlglot/aggregations/unary_compiler.py | 10 +++--- .../test_cut/{out.sql => int_bins.sql} | 32 ++----------------- .../test_cut/int_bins_labels.sql | 24 ++++++++++++++ .../test_cut/interval_bins.sql | 18 +++++++++++ .../test_cut/interval_bins_labels.sql | 18 +++++++++++ .../aggregations/test_unary_compiler.py | 10 +++--- 6 files changed, 75 insertions(+), 37 deletions(-) rename tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/{out.sql => int_bins.sql} (63%) create mode 100644 tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/int_bins_labels.sql create mode 100644 tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/interval_bins.sql create mode 100644 tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/interval_bins_labels.sql diff --git a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py index 4e12374bf3..c18f73017e 100644 --- a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py +++ b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py @@ -245,10 +245,11 @@ def _cut_ops_w_intervals( value = sge.Struct( expressions=[ sge.PropertyEQ( - this=sge.Identifier(this="left_exclusive"), expression=left + this=sge.Identifier(this="left_exclusive", quoted=True), + expression=left, ), sge.PropertyEQ( - this=sge.Identifier(this="right_inclusive"), + this=sge.Identifier(this="right_inclusive", quoted=True), expression=right, ), ] @@ -257,10 +258,11 @@ def _cut_ops_w_intervals( value = sge.Struct( expressions=[ sge.PropertyEQ( - this=sge.Identifier(this="left_inclusive"), expression=left + this=sge.Identifier(this="left_inclusive", quoted=True), + expression=left, ), sge.PropertyEQ( - this=sge.Identifier(this="right_exclusive"), + this=sge.Identifier(this="right_exclusive", quoted=True), expression=right, ), ] diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/int_bins.sql similarity index 63% rename from tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/out.sql rename to tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/int_bins.sql index 6eb91d356b..cca7929c22 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/out.sql +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/int_bins.sql @@ -4,6 +4,7 @@ WITH `bfcte_0` AS ( FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` ), `bfcte_1` AS ( SELECT + *, CASE WHEN `int64_col` <= MIN(`int64_col`) OVER () + ( 1 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3) @@ -46,36 +47,9 @@ WITH `bfcte_0` AS ( 3 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3) ) + 0 AS `right_inclusive` ) - END AS `bfcol_1`, - CASE - WHEN `int64_col` > 0 AND `int64_col` <= 1 - THEN STRUCT(0 AS left_exclusive, 1 AS right_inclusive) - WHEN `int64_col` > 1 AND `int64_col` <= 2 - THEN STRUCT(1 AS left_exclusive, 2 AS right_inclusive) - END AS `bfcol_2`, - CASE - WHEN `int64_col` < MIN(`int64_col`) OVER () + ( - 1 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3) - ) - THEN 'a' - WHEN `int64_col` < MIN(`int64_col`) OVER () + ( - 2 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3) - ) - THEN 'b' - WHEN `int64_col` IS NOT NULL - THEN 'c' - END AS `bfcol_3`, - CASE - WHEN `int64_col` > 0 AND `int64_col` <= 1 - THEN 0 - WHEN `int64_col` > 1 AND `int64_col` <= 2 - THEN 1 - END AS `bfcol_4` + END OVER () AS `bfcol_1` FROM `bfcte_0` ) SELECT - `bfcol_1` AS `int_bins`, - `bfcol_2` AS `interval_bins`, - `bfcol_3` AS `int_bins_labels`, - `bfcol_4` AS `interval_bins_labels` + `bfcol_1` AS `int_bins` FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/int_bins_labels.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/int_bins_labels.sql new file mode 100644 index 0000000000..2a670dd872 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/int_bins_labels.sql @@ -0,0 +1,24 @@ +WITH `bfcte_0` AS ( + SELECT + `int64_col` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + CASE + WHEN `int64_col` < MIN(`int64_col`) OVER () + ( + 1 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3) + ) + THEN 'a' + WHEN `int64_col` < MIN(`int64_col`) OVER () + ( + 2 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3) + ) + THEN 'b' + WHEN `int64_col` IS NOT NULL + THEN 'c' + END OVER () AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `int_bins_labels` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/interval_bins.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/interval_bins.sql new file mode 100644 index 0000000000..285533fb87 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/interval_bins.sql @@ -0,0 +1,18 @@ +WITH `bfcte_0` AS ( + SELECT + `int64_col` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + CASE + WHEN `int64_col` > 0 AND `int64_col` <= 1 + THEN STRUCT(0 AS `left_exclusive`, 1 AS `right_inclusive`) + WHEN `int64_col` > 1 AND `int64_col` <= 2 + THEN STRUCT(1 AS `left_exclusive`, 2 AS `right_inclusive`) + END OVER () AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `interval_bins` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/interval_bins_labels.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/interval_bins_labels.sql new file mode 100644 index 0000000000..2f11a93f8d --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/interval_bins_labels.sql @@ -0,0 +1,18 @@ +WITH `bfcte_0` AS ( + SELECT + `int64_col` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + CASE + WHEN `int64_col` > 0 AND `int64_col` <= 1 + THEN 0 + WHEN `int64_col` > 1 AND `int64_col` <= 2 + THEN 1 + END OVER () AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `interval_bins_labels` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py index 92c13b0c99..d44889527a 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py +++ b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py @@ -194,11 +194,13 @@ def test_cut(scalar_types_df: bpd.DataFrame, snapshot): expression.deref(col_name), ), } - sql = _apply_unary_agg_ops( - bf_df, list(agg_ops_map.values()), list(agg_ops_map.keys()) - ) + window = window_spec.WindowSpec() - snapshot.assert_match(sql, "out.sql") + # Loop through the aggregation map items + for test_name, agg_expr in agg_ops_map.items(): + sql = _apply_unary_window_op(bf_df, agg_expr, window, test_name) + + snapshot.assert_match(sql, f"{test_name}.sql") def test_dense_rank(scalar_types_df: bpd.DataFrame, snapshot): From 92ad3d48e6ead456e5b38b080e6348e02b8dcdbd Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Mon, 17 Nov 2025 21:48:11 +0000 Subject: [PATCH 3/3] fix window errors --- .../sqlglot/aggregations/unary_compiler.py | 72 ++++++------------- .../test_unary_compiler/test_cut/int_bins.sql | 2 +- .../test_cut/int_bins_labels.sql | 2 +- .../test_cut/interval_bins.sql | 2 +- .../test_cut/interval_bins_labels.sql | 2 +- .../aggregations/test_unary_compiler.py | 2 +- 6 files changed, 28 insertions(+), 54 deletions(-) diff --git a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py index c18f73017e..e2bd6b8382 100644 --- a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py +++ b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py @@ -121,7 +121,7 @@ def _( case_expr = _cut_ops_w_int_bins(op, column, op.bins, window) else: # Interpret as intervals case_expr = _cut_ops_w_intervals(op, column, op.bins, window) - return apply_window_if_present(case_expr, window) + return case_expr def _cut_ops_w_int_bins( @@ -165,31 +165,18 @@ def _cut_ops_w_int_bins( col_min + sge.convert(this_bin + 1) * bin_width + right_adj ) if op.right: - value = sge.Struct( - expressions=[ - sge.PropertyEQ( - this=sge.Identifier(this="left_exclusive", quoted=True), - expression=left, - ), - sge.PropertyEQ( - this=sge.Identifier(this="right_inclusive", quoted=True), - expression=right, - ), - ] - ) + left_identifier = sge.Identifier(this="left_exclusive", quoted=True) + right_identifier = sge.Identifier(this="right_inclusive", quoted=True) else: - value = sge.Struct( - expressions=[ - sge.PropertyEQ( - this=sge.Identifier(this="left_inclusive", quoted=True), - expression=left, - ), - sge.PropertyEQ( - this=sge.Identifier(this="right_exclusive", quoted=True), - expression=right, - ), - ] - ) + left_identifier = sge.Identifier(this="left_inclusive", quoted=True) + right_identifier = sge.Identifier(this="right_exclusive", quoted=True) + + value = sge.Struct( + expressions=[ + sge.PropertyEQ(this=left_identifier, expression=left), + sge.PropertyEQ(this=right_identifier, expression=right), + ] + ) condition: sge.Expression if this_bin == bins - 1: @@ -242,31 +229,18 @@ def _cut_ops_w_intervals( value = ir._literal(list(op.labels)[this_bin], dtypes.STRING_DTYPE) else: if op.right: - value = sge.Struct( - expressions=[ - sge.PropertyEQ( - this=sge.Identifier(this="left_exclusive", quoted=True), - expression=left, - ), - sge.PropertyEQ( - this=sge.Identifier(this="right_inclusive", quoted=True), - expression=right, - ), - ] - ) + left_identifier = sge.Identifier(this="left_exclusive", quoted=True) + right_identifier = sge.Identifier(this="right_inclusive", quoted=True) else: - value = sge.Struct( - expressions=[ - sge.PropertyEQ( - this=sge.Identifier(this="left_inclusive", quoted=True), - expression=left, - ), - sge.PropertyEQ( - this=sge.Identifier(this="right_exclusive", quoted=True), - expression=right, - ), - ] - ) + left_identifier = sge.Identifier(this="left_inclusive", quoted=True) + right_identifier = sge.Identifier(this="right_exclusive", quoted=True) + + value = sge.Struct( + expressions=[ + sge.PropertyEQ(this=left_identifier, expression=left), + sge.PropertyEQ(this=right_identifier, expression=right), + ] + ) case_expr = case_expr.when(condition, value) return case_expr diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/int_bins.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/int_bins.sql index cca7929c22..015ac32799 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/int_bins.sql +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/int_bins.sql @@ -47,7 +47,7 @@ WITH `bfcte_0` AS ( 3 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3) ) + 0 AS `right_inclusive` ) - END OVER () AS `bfcol_1` + END AS `bfcol_1` FROM `bfcte_0` ) SELECT diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/int_bins_labels.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/int_bins_labels.sql index 2a670dd872..c98682f2b8 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/int_bins_labels.sql +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/int_bins_labels.sql @@ -16,7 +16,7 @@ WITH `bfcte_0` AS ( THEN 'b' WHEN `int64_col` IS NOT NULL THEN 'c' - END OVER () AS `bfcol_1` + END AS `bfcol_1` FROM `bfcte_0` ) SELECT diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/interval_bins.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/interval_bins.sql index 285533fb87..a3e689b11e 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/interval_bins.sql +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/interval_bins.sql @@ -10,7 +10,7 @@ WITH `bfcte_0` AS ( THEN STRUCT(0 AS `left_exclusive`, 1 AS `right_inclusive`) WHEN `int64_col` > 1 AND `int64_col` <= 2 THEN STRUCT(1 AS `left_exclusive`, 2 AS `right_inclusive`) - END OVER () AS `bfcol_1` + END AS `bfcol_1` FROM `bfcte_0` ) SELECT diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/interval_bins_labels.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/interval_bins_labels.sql index 2f11a93f8d..1a8a92e38e 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/interval_bins_labels.sql +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/interval_bins_labels.sql @@ -10,7 +10,7 @@ WITH `bfcte_0` AS ( THEN 0 WHEN `int64_col` > 1 AND `int64_col` <= 2 THEN 1 - END OVER () AS `bfcol_1` + END AS `bfcol_1` FROM `bfcte_0` ) SELECT diff --git a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py index d44889527a..ab9f7febbf 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py +++ b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py @@ -199,7 +199,7 @@ def test_cut(scalar_types_df: bpd.DataFrame, snapshot): # Loop through the aggregation map items for test_name, agg_expr in agg_ops_map.items(): sql = _apply_unary_window_op(bf_df, agg_expr, window, test_name) - + snapshot.assert_match(sql, f"{test_name}.sql")