Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 134 additions & 0 deletions bigframes/core/compile/sqlglot/aggregations/unary_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,140 @@ def _(
return apply_window_if_present(sge.func("COUNT", column.expr), window)


@UNARY_OP_REGISTRATION.register(agg_ops.CutOp)
def _(
op: agg_ops.CutOp,
column: typed_expr.TypedExpr,
window: typing.Optional[window_spec.WindowSpec] = None,
) -> sge.Expression:
if isinstance(op.bins, int):
case_expr = _cut_ops_w_int_bins(op, column, op.bins, window)
else: # Interpret as intervals
case_expr = _cut_ops_w_intervals(op, column, op.bins, window)
return case_expr


def _cut_ops_w_int_bins(
op: agg_ops.CutOp,
column: typed_expr.TypedExpr,
bins: int,
window: typing.Optional[window_spec.WindowSpec] = None,
) -> sge.Case:
case_expr = sge.Case()
col_min = apply_window_if_present(
sge.func("MIN", column.expr), window or window_spec.WindowSpec()
)
col_max = apply_window_if_present(
sge.func("MAX", column.expr), window or window_spec.WindowSpec()
)
adj: sge.Expression = sge.Sub(this=col_max, expression=col_min) * sge.convert(0.001)
bin_width: sge.Expression = sge.func(
"IEEE_DIVIDE",
sge.Sub(this=col_max, expression=col_min),
sge.convert(bins),
)

for this_bin in range(bins):
value: sge.Expression
if op.labels is False:
value = ir._literal(this_bin, dtypes.INT_DTYPE)
elif isinstance(op.labels, typing.Iterable):
value = ir._literal(list(op.labels)[this_bin], dtypes.STRING_DTYPE)
else:
left_adj: sge.Expression = (
adj if this_bin == 0 and op.right else sge.convert(0)
)
right_adj: sge.Expression = (
adj if this_bin == bins - 1 and not op.right else sge.convert(0)
)

left: sge.Expression = (
col_min + sge.convert(this_bin) * bin_width - left_adj
)
right: sge.Expression = (
col_min + sge.convert(this_bin + 1) * bin_width + right_adj
)
if op.right:
left_identifier = sge.Identifier(this="left_exclusive", quoted=True)
right_identifier = sge.Identifier(this="right_inclusive", quoted=True)
else:
left_identifier = sge.Identifier(this="left_inclusive", quoted=True)
right_identifier = sge.Identifier(this="right_exclusive", quoted=True)

value = sge.Struct(
expressions=[
sge.PropertyEQ(this=left_identifier, expression=left),
sge.PropertyEQ(this=right_identifier, expression=right),
]
)

condition: sge.Expression
if this_bin == bins - 1:
condition = sge.Is(this=column.expr, expression=sge.Not(this=sge.Null()))
else:
if op.right:
condition = sge.LTE(
this=column.expr,
expression=(col_min + sge.convert(this_bin + 1) * bin_width),
)
else:
condition = sge.LT(
this=column.expr,
expression=(col_min + sge.convert(this_bin + 1) * bin_width),
)
case_expr = case_expr.when(condition, value)
return case_expr


def _cut_ops_w_intervals(
op: agg_ops.CutOp,
column: typed_expr.TypedExpr,
bins: typing.Iterable[typing.Tuple[typing.Any, typing.Any]],
window: typing.Optional[window_spec.WindowSpec] = None,
) -> sge.Case:
case_expr = sge.Case()
for this_bin, interval in enumerate(bins):
left: sge.Expression = ir._literal(
interval[0], dtypes.infer_literal_type(interval[0])
)
right: sge.Expression = ir._literal(
interval[1], dtypes.infer_literal_type(interval[1])
)
condition: sge.Expression
if op.right:
condition = sge.And(
this=sge.GT(this=column.expr, expression=left),
expression=sge.LTE(this=column.expr, expression=right),
)
else:
condition = sge.And(
this=sge.GTE(this=column.expr, expression=left),
expression=sge.LT(this=column.expr, expression=right),
)

value: sge.Expression
if op.labels is False:
value = ir._literal(this_bin, dtypes.INT_DTYPE)
elif isinstance(op.labels, typing.Iterable):
value = ir._literal(list(op.labels)[this_bin], dtypes.STRING_DTYPE)
else:
if op.right:
left_identifier = sge.Identifier(this="left_exclusive", quoted=True)
right_identifier = sge.Identifier(this="right_inclusive", quoted=True)
else:
left_identifier = sge.Identifier(this="left_inclusive", quoted=True)
right_identifier = sge.Identifier(this="right_exclusive", quoted=True)

value = sge.Struct(
expressions=[
sge.PropertyEQ(this=left_identifier, expression=left),
sge.PropertyEQ(this=right_identifier, expression=right),
]
)
case_expr = case_expr.when(condition, value)
return case_expr


@UNARY_OP_REGISTRATION.register(agg_ops.DateSeriesDiffOp)
def _(
op: agg_ops.DateSeriesDiffOp,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
WITH `bfcte_0` AS (
SELECT
`int64_col`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
CASE
WHEN `int64_col` <= MIN(`int64_col`) OVER () + (
1 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3)
)
THEN STRUCT(
(
MIN(`int64_col`) OVER () + (
0 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3)
)
) - (
(
MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER ()
) * 0.001
) AS `left_exclusive`,
MIN(`int64_col`) OVER () + (
1 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3)
) + 0 AS `right_inclusive`
)
WHEN `int64_col` <= MIN(`int64_col`) OVER () + (
2 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3)
)
THEN STRUCT(
(
MIN(`int64_col`) OVER () + (
1 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3)
)
) - 0 AS `left_exclusive`,
MIN(`int64_col`) OVER () + (
2 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3)
) + 0 AS `right_inclusive`
)
WHEN `int64_col` IS NOT NULL
THEN STRUCT(
(
MIN(`int64_col`) OVER () + (
2 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3)
)
) - 0 AS `left_exclusive`,
MIN(`int64_col`) OVER () + (
3 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3)
) + 0 AS `right_inclusive`
)
END AS `bfcol_1`
FROM `bfcte_0`
)
SELECT
`bfcol_1` AS `int_bins`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
WITH `bfcte_0` AS (
SELECT
`int64_col`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
CASE
WHEN `int64_col` < MIN(`int64_col`) OVER () + (
1 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3)
)
THEN 'a'
WHEN `int64_col` < MIN(`int64_col`) OVER () + (
2 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3)
)
THEN 'b'
WHEN `int64_col` IS NOT NULL
THEN 'c'
END AS `bfcol_1`
FROM `bfcte_0`
)
SELECT
`bfcol_1` AS `int_bins_labels`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
WITH `bfcte_0` AS (
SELECT
`int64_col`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
CASE
WHEN `int64_col` > 0 AND `int64_col` <= 1
THEN STRUCT(0 AS `left_exclusive`, 1 AS `right_inclusive`)
WHEN `int64_col` > 1 AND `int64_col` <= 2
THEN STRUCT(1 AS `left_exclusive`, 2 AS `right_inclusive`)
END AS `bfcol_1`
FROM `bfcte_0`
)
SELECT
`bfcol_1` AS `interval_bins`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
WITH `bfcte_0` AS (
SELECT
`int64_col`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
CASE
WHEN `int64_col` > 0 AND `int64_col` <= 1
THEN 0
WHEN `int64_col` > 1 AND `int64_col` <= 2
THEN 1
END AS `bfcol_1`
FROM `bfcte_0`
)
SELECT
`bfcol_1` AS `interval_bins_labels`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,35 @@ def test_count(scalar_types_df: bpd.DataFrame, snapshot):
snapshot.assert_match(sql_window_partition, "window_partition_out.sql")


def test_cut(scalar_types_df: bpd.DataFrame, snapshot):
col_name = "int64_col"
bf_df = scalar_types_df[[col_name]]
agg_ops_map = {
"int_bins": agg_exprs.UnaryAggregation(
agg_ops.CutOp(bins=3, right=True, labels=None), expression.deref(col_name)
),
"interval_bins": agg_exprs.UnaryAggregation(
agg_ops.CutOp(bins=((0, 1), (1, 2)), right=True, labels=None),
expression.deref(col_name),
),
"int_bins_labels": agg_exprs.UnaryAggregation(
agg_ops.CutOp(bins=3, labels=("a", "b", "c"), right=False),
expression.deref(col_name),
),
"interval_bins_labels": agg_exprs.UnaryAggregation(
agg_ops.CutOp(bins=((0, 1), (1, 2)), labels=False, right=True),
expression.deref(col_name),
),
}
window = window_spec.WindowSpec()

# Loop through the aggregation map items
for test_name, agg_expr in agg_ops_map.items():
sql = _apply_unary_window_op(bf_df, agg_expr, window, test_name)

snapshot.assert_match(sql, f"{test_name}.sql")


def test_dense_rank(scalar_types_df: bpd.DataFrame, snapshot):
col_name = "int64_col"
bf_df = scalar_types_df[[col_name]]
Expand Down