Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions bigframes/core/compile/sqlglot/aggregations/unary_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,43 @@ def _(
return apply_window_if_present(expr, window)


@UNARY_OP_REGISTRATION.register(agg_ops.QcutOp)
def _(
op: agg_ops.QcutOp,
column: typed_expr.TypedExpr,
window: typing.Optional[window_spec.WindowSpec] = None,
) -> sge.Expression:
percent_ranks_order_by = sge.Ordered(this=column.expr, desc=False)
percent_ranks = apply_window_if_present(
sge.func("PERCENT_RANK"),
window,
include_framing_clauses=False,
order_by_override=[percent_ranks_order_by],
)
if isinstance(op.quantiles, int):
scaled_rank = percent_ranks * sge.convert(op.quantiles)
# Calculate the 0-based bucket index.
bucket_index = sge.func("CEIL", scaled_rank) - sge.convert(1)
safe_bucket_index = sge.func("GREATEST", bucket_index, 0)

return sge.If(
this=sge.Is(this=column.expr, expression=sge.Null()),
true=sge.Null(),
false=sge.Cast(this=safe_bucket_index, to="INT64"),
)
else:
case = sge.Case()
first_quantile = sge.convert(op.quantiles[0])
case = case.when(
sge.LT(this=percent_ranks, expression=first_quantile), sge.Null()
)
for bucket_n in range(len(op.quantiles) - 1):
quantile = sge.convert(op.quantiles[bucket_n + 1])
bucket = sge.convert(bucket_n)
case = case.when(sge.LTE(this=percent_ranks, expression=quantile), bucket)
return case.else_(sge.Null())


@UNARY_OP_REGISTRATION.register(agg_ops.QuantileOp)
def _(
op: agg_ops.QuantileOp,
Expand Down
7 changes: 6 additions & 1 deletion bigframes/core/compile/sqlglot/aggregations/windows.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def apply_window_if_present(
value: sge.Expression,
window: typing.Optional[window_spec.WindowSpec] = None,
include_framing_clauses: bool = True,
order_by_override: typing.Optional[typing.List[sge.Ordered]] = None,
) -> sge.Expression:
if window is None:
return value
Expand All @@ -44,7 +45,11 @@ def apply_window_if_present(
else:
order_by = get_window_order_by(window.ordering)

order = sge.Order(expressions=order_by) if order_by else None
order = None
if order_by_override is not None and len(order_by_override) > 0:
order = sge.Order(expressions=order_by_override)
elif order_by:
order = sge.Order(expressions=order_by)

group_by = (
[
Expand Down
7 changes: 6 additions & 1 deletion bigframes/core/compile/sqlglot/sqlglot_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,12 @@ def _select_to_cte(expr: sge.Select, cte_name: sge.Identifier) -> sge.Select:


def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression:
sqlglot_type = sgt.from_bigframes_dtype(dtype)
sqlglot_type = sgt.from_bigframes_dtype(dtype) if dtype else None
if sqlglot_type is None:
if value is not None:
raise ValueError("Cannot infer SQLGlot type from None dtype.")
return sge.Null()

if value is None:
return _cast(sge.Null(), sqlglot_type)
elif dtype == dtypes.BYTES_DTYPE:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
WITH `bfcte_0` AS (
SELECT
`int64_col`,
`rowindex`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
NOT `int64_col` IS NULL AS `bfcol_4`
FROM `bfcte_0`
), `bfcte_2` AS (
SELECT
*,
IF(
`int64_col` IS NULL,
NULL,
CAST(GREATEST(
CEIL(PERCENT_RANK() OVER (PARTITION BY `bfcol_4` ORDER BY `int64_col` ASC) * 4) - 1,
0
) AS INT64)
) AS `bfcol_5`
FROM `bfcte_1`
), `bfcte_3` AS (
SELECT
*,
IF(`bfcol_4`, `bfcol_5`, NULL) AS `bfcol_6`
FROM `bfcte_2`
), `bfcte_4` AS (
SELECT
*,
NOT `int64_col` IS NULL AS `bfcol_10`
FROM `bfcte_3`
), `bfcte_5` AS (
SELECT
*,
CASE
WHEN PERCENT_RANK() OVER (PARTITION BY `bfcol_10` ORDER BY `int64_col` ASC) < 0
THEN NULL
WHEN PERCENT_RANK() OVER (PARTITION BY `bfcol_10` ORDER BY `int64_col` ASC) <= 0.25
THEN 0
WHEN PERCENT_RANK() OVER (PARTITION BY `bfcol_10` ORDER BY `int64_col` ASC) <= 0.5
THEN 1
WHEN PERCENT_RANK() OVER (PARTITION BY `bfcol_10` ORDER BY `int64_col` ASC) <= 0.75
THEN 2
WHEN PERCENT_RANK() OVER (PARTITION BY `bfcol_10` ORDER BY `int64_col` ASC) <= 1
THEN 3
ELSE NULL
END AS `bfcol_11`
FROM `bfcte_4`
), `bfcte_6` AS (
SELECT
*,
IF(`bfcol_10`, `bfcol_11`, NULL) AS `bfcol_12`
FROM `bfcte_5`
)
SELECT
`rowindex`,
`int64_col`,
`bfcol_6` AS `qcut_w_int`,
`bfcol_12` AS `qcut_w_list`
FROM `bfcte_6`
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,27 @@ def test_pop_var(scalar_types_df: bpd.DataFrame, snapshot):
snapshot.assert_match(sql_window, "window_out.sql")


def test_qcut(scalar_types_df: bpd.DataFrame, snapshot):
if sys.version_info < (3, 12):
pytest.skip(
"Skipping test due to inconsistent SQL formatting on Python < 3.12.",
)

col_name = "int64_col"
bf = scalar_types_df[[col_name]]
bf["qcut_w_int"] = bpd.qcut(bf[col_name], q=4, labels=False, duplicates="drop")

q_list = tuple([0, 0.25, 0.5, 0.75, 1])
bf["qcut_w_list"] = bpd.qcut(
scalar_types_df[col_name],
q=q_list,
labels=False,
duplicates="drop",
)

snapshot.assert_match(bf.sql, "out.sql")


def test_quantile(scalar_types_df: bpd.DataFrame, snapshot):
col_name = "int64_col"
bf_df = scalar_types_df[[col_name]]
Expand Down