diff --git a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py index e2bd6b8382..14da8dd555 100644 --- a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py +++ b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py @@ -400,6 +400,43 @@ def _( return apply_window_if_present(expr, window) +@UNARY_OP_REGISTRATION.register(agg_ops.QcutOp) +def _( + op: agg_ops.QcutOp, + column: typed_expr.TypedExpr, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Expression: + percent_ranks_order_by = sge.Ordered(this=column.expr, desc=False) + percent_ranks = apply_window_if_present( + sge.func("PERCENT_RANK"), + window, + include_framing_clauses=False, + order_by_override=[percent_ranks_order_by], + ) + if isinstance(op.quantiles, int): + scaled_rank = percent_ranks * sge.convert(op.quantiles) + # Calculate the 0-based bucket index. + bucket_index = sge.func("CEIL", scaled_rank) - sge.convert(1) + safe_bucket_index = sge.func("GREATEST", bucket_index, 0) + + return sge.If( + this=sge.Is(this=column.expr, expression=sge.Null()), + true=sge.Null(), + false=sge.Cast(this=safe_bucket_index, to="INT64"), + ) + else: + case = sge.Case() + first_quantile = sge.convert(op.quantiles[0]) + case = case.when( + sge.LT(this=percent_ranks, expression=first_quantile), sge.Null() + ) + for bucket_n in range(len(op.quantiles) - 1): + quantile = sge.convert(op.quantiles[bucket_n + 1]) + bucket = sge.convert(bucket_n) + case = case.when(sge.LTE(this=percent_ranks, expression=quantile), bucket) + return case.else_(sge.Null()) + + @UNARY_OP_REGISTRATION.register(agg_ops.QuantileOp) def _( op: agg_ops.QuantileOp, diff --git a/bigframes/core/compile/sqlglot/aggregations/windows.py b/bigframes/core/compile/sqlglot/aggregations/windows.py index 099f5832da..b775d6666a 100644 --- a/bigframes/core/compile/sqlglot/aggregations/windows.py +++ b/bigframes/core/compile/sqlglot/aggregations/windows.py @@ -26,6 +26,7 @@ def apply_window_if_present( value: sge.Expression, window: typing.Optional[window_spec.WindowSpec] = None, include_framing_clauses: bool = True, + order_by_override: typing.Optional[typing.List[sge.Ordered]] = None, ) -> sge.Expression: if window is None: return value @@ -44,7 +45,11 @@ def apply_window_if_present( else: order_by = get_window_order_by(window.ordering) - order = sge.Order(expressions=order_by) if order_by else None + order = None + if order_by_override is not None and len(order_by_override) > 0: + order = sge.Order(expressions=order_by_override) + elif order_by: + order = sge.Order(expressions=order_by) group_by = ( [ diff --git a/bigframes/core/compile/sqlglot/sqlglot_ir.py b/bigframes/core/compile/sqlglot/sqlglot_ir.py index 3473968450..fd3bdd532f 100644 --- a/bigframes/core/compile/sqlglot/sqlglot_ir.py +++ b/bigframes/core/compile/sqlglot/sqlglot_ir.py @@ -637,7 +637,12 @@ def _select_to_cte(expr: sge.Select, cte_name: sge.Identifier) -> sge.Select: def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression: - sqlglot_type = sgt.from_bigframes_dtype(dtype) + sqlglot_type = sgt.from_bigframes_dtype(dtype) if dtype else None + if sqlglot_type is None: + if value is not None: + raise ValueError("Cannot infer SQLGlot type from None dtype.") + return sge.Null() + if value is None: return _cast(sge.Null(), sqlglot_type) elif dtype == dtypes.BYTES_DTYPE: diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_qcut/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_qcut/out.sql new file mode 100644 index 0000000000..1aa2e436ca --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_qcut/out.sql @@ -0,0 +1,61 @@ +WITH `bfcte_0` AS ( + SELECT + `int64_col`, + `rowindex` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + NOT `int64_col` IS NULL AS `bfcol_4` + FROM `bfcte_0` +), `bfcte_2` AS ( + SELECT + *, + IF( + `int64_col` IS NULL, + NULL, + CAST(GREATEST( + CEIL(PERCENT_RANK() OVER (PARTITION BY `bfcol_4` ORDER BY `int64_col` ASC) * 4) - 1, + 0 + ) AS INT64) + ) AS `bfcol_5` + FROM `bfcte_1` +), `bfcte_3` AS ( + SELECT + *, + IF(`bfcol_4`, `bfcol_5`, NULL) AS `bfcol_6` + FROM `bfcte_2` +), `bfcte_4` AS ( + SELECT + *, + NOT `int64_col` IS NULL AS `bfcol_10` + FROM `bfcte_3` +), `bfcte_5` AS ( + SELECT + *, + CASE + WHEN PERCENT_RANK() OVER (PARTITION BY `bfcol_10` ORDER BY `int64_col` ASC) < 0 + THEN NULL + WHEN PERCENT_RANK() OVER (PARTITION BY `bfcol_10` ORDER BY `int64_col` ASC) <= 0.25 + THEN 0 + WHEN PERCENT_RANK() OVER (PARTITION BY `bfcol_10` ORDER BY `int64_col` ASC) <= 0.5 + THEN 1 + WHEN PERCENT_RANK() OVER (PARTITION BY `bfcol_10` ORDER BY `int64_col` ASC) <= 0.75 + THEN 2 + WHEN PERCENT_RANK() OVER (PARTITION BY `bfcol_10` ORDER BY `int64_col` ASC) <= 1 + THEN 3 + ELSE NULL + END AS `bfcol_11` + FROM `bfcte_4` +), `bfcte_6` AS ( + SELECT + *, + IF(`bfcol_10`, `bfcol_11`, NULL) AS `bfcol_12` + FROM `bfcte_5` +) +SELECT + `rowindex`, + `int64_col`, + `bfcol_6` AS `qcut_w_int`, + `bfcol_12` AS `qcut_w_list` +FROM `bfcte_6` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py index ab9f7febbf..184cb3925f 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py +++ b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py @@ -435,6 +435,27 @@ def test_pop_var(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(sql_window, "window_out.sql") +def test_qcut(scalar_types_df: bpd.DataFrame, snapshot): + if sys.version_info < (3, 12): + pytest.skip( + "Skipping test due to inconsistent SQL formatting on Python < 3.12.", + ) + + col_name = "int64_col" + bf = scalar_types_df[[col_name]] + bf["qcut_w_int"] = bpd.qcut(bf[col_name], q=4, labels=False, duplicates="drop") + + q_list = tuple([0, 0.25, 0.5, 0.75, 1]) + bf["qcut_w_list"] = bpd.qcut( + scalar_types_df[col_name], + q=q_list, + labels=False, + duplicates="drop", + ) + + snapshot.assert_match(bf.sql, "out.sql") + + def test_quantile(scalar_types_df: bpd.DataFrame, snapshot): col_name = "int64_col" bf_df = scalar_types_df[[col_name]]