Skip to content

Commit 0bf4a87

Browse files
committed
refactor: add agg_ops.QcutOp to the sqlglot compiler
1 parent f7fd2d2 commit 0bf4a87

File tree

5 files changed

+65
-8
lines changed

5 files changed

+65
-8
lines changed

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,43 @@ def _(
400400
return apply_window_if_present(expr, window)
401401

402402

403+
@UNARY_OP_REGISTRATION.register(agg_ops.QcutOp)
404+
def _(
405+
op: agg_ops.QcutOp,
406+
column: typed_expr.TypedExpr,
407+
window: typing.Optional[window_spec.WindowSpec] = None,
408+
) -> sge.Expression:
409+
percent_ranks_order_by = sge.Ordered(this=column.expr, desc=False)
410+
percent_ranks = apply_window_if_present(
411+
sge.func("PERCENT_RANK"),
412+
window,
413+
include_framing_clauses=False,
414+
order_by_override=[percent_ranks_order_by],
415+
)
416+
if isinstance(op.quantiles, int):
417+
scaled_rank = percent_ranks * sge.convert(op.quantiles)
418+
# Calculate the 0-based bucket index.
419+
bucket_index = sge.func("CEIL", scaled_rank) - sge.convert(1)
420+
safe_bucket_index = sge.func("GREATEST", bucket_index, 0)
421+
422+
return sge.If(
423+
this=sge.Is(this=column.expr, expression=sge.Null()),
424+
true=sge.Null(),
425+
false=sge.Cast(this=safe_bucket_index, to="INT64"),
426+
)
427+
else:
428+
case = sge.Case()
429+
first_quantile = sge.convert(op.quantiles[0])
430+
case = case.when(
431+
sge.LT(this=percent_ranks, expression=first_quantile), sge.Null()
432+
)
433+
for bucket_n in range(len(op.quantiles) - 1):
434+
quantile = sge.convert(op.quantiles[bucket_n + 1])
435+
bucket = sge.convert(bucket_n)
436+
case = case.when(sge.LTE(this=percent_ranks, expression=quantile), bucket)
437+
return case.else_(sge.Null())
438+
439+
403440
@UNARY_OP_REGISTRATION.register(agg_ops.QuantileOp)
404441
def _(
405442
op: agg_ops.QuantileOp,

bigframes/core/compile/sqlglot/aggregations/windows.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def apply_window_if_present(
2626
value: sge.Expression,
2727
window: typing.Optional[window_spec.WindowSpec] = None,
2828
include_framing_clauses: bool = True,
29+
order_by_override: typing.Optional[typing.List[sge.Ordered]] = None,
2930
) -> sge.Expression:
3031
if window is None:
3132
return value
@@ -44,7 +45,11 @@ def apply_window_if_present(
4445
else:
4546
order_by = get_window_order_by(window.ordering)
4647

47-
order = sge.Order(expressions=order_by) if order_by else None
48+
order = None
49+
if order_by_override is not None and len(order_by_override) > 0:
50+
order = sge.Order(expressions=order_by_override)
51+
elif order_by:
52+
order = sge.Order(expressions=order_by)
4853

4954
group_by = (
5055
[

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -637,7 +637,12 @@ def _select_to_cte(expr: sge.Select, cte_name: sge.Identifier) -> sge.Select:
637637

638638

639639
def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression:
640-
sqlglot_type = sgt.from_bigframes_dtype(dtype)
640+
sqlglot_type = sgt.from_bigframes_dtype(dtype) if dtype else None
641+
if sqlglot_type is None:
642+
if value is not None:
643+
raise ValueError("Cannot infer SQLGlot type from None dtype.")
644+
return sge.Null()
645+
641646
if value is None:
642647
return _cast(sge.Null(), sqlglot_type)
643648
elif dtype == dtypes.BYTES_DTYPE:

bigframes/core/reshape/tile.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@
2222
import pandas as pd
2323

2424
import bigframes
25-
import bigframes.constants
26-
import bigframes.core.expression as ex
2725
import bigframes.core.ordering as order
2826
import bigframes.core.utils as utils
2927
import bigframes.core.window_spec as window_specs
@@ -165,7 +163,6 @@ def qcut(
165163
f"Only duplicates='drop' is supported in BigQuery DataFrames so far. {constants.FEEDBACK_LINK}"
166164
)
167165
block = x._block
168-
label = block.col_id_to_label[x._value_column]
169166
block, nullity_id = block.apply_unary_op(x._value_column, ops.notnull_op)
170167
block, result = block.apply_window_op(
171168
x._value_column,
@@ -175,9 +172,6 @@ def qcut(
175172
ordering=(order.ascending_over(x._value_column),),
176173
),
177174
)
178-
block, result = block.project_expr(
179-
ops.where_op.as_expr(result, nullity_id, ex.const(None)), label=label
180-
)
181175
return bigframes.series.Series(block.select_column(result))
182176

183177

tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,22 @@ def test_pop_var(scalar_types_df: bpd.DataFrame, snapshot):
435435
snapshot.assert_match(sql_window, "window_out.sql")
436436

437437

438+
def test_qcut(scalar_types_df: bpd.DataFrame, snapshot):
439+
col_name = "int64_col"
440+
bf = scalar_types_df[[col_name]]
441+
bf["qcut_w_int"] = bpd.qcut(bf[col_name], q=4, labels=False, duplicates="drop")
442+
443+
q_list = tuple([0, 0.25, 0.5, 0.75, 1])
444+
bf["qcut_w_list"] = bpd.qcut(
445+
scalar_types_df[col_name],
446+
q=q_list,
447+
labels=False,
448+
duplicates="drop",
449+
)
450+
451+
snapshot.assert_match(bf.sql, "out.sql")
452+
453+
438454
def test_quantile(scalar_types_df: bpd.DataFrame, snapshot):
439455
col_name = "int64_col"
440456
bf_df = scalar_types_df[[col_name]]

0 commit comments

Comments
 (0)