Skip to content

Commit ed79146

Browse files
authored
Chore: Migrate unsafe_pow_op operator to SQLGlot (#2281)
Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [ ] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/python-bigquery-dataframes/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [ ] Ensure the tests and linter pass - [ ] Code coverage does not decrease (if any source code was changed) - [ ] Appropriate docs were updated (if necessary) Fixes b/447388852 🦕
1 parent 3275761 commit ed79146

File tree

3 files changed

+84
-0
lines changed

3 files changed

+84
-0
lines changed

bigframes/core/compile/sqlglot/expressions/numeric_ops.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,14 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
443443
)
444444

445445

446+
@register_binary_op(ops.unsafe_pow_op)
447+
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
448+
"""For internal use only - where domain and overflow checks are not needed."""
449+
left_expr = _coerce_bool_to_int(left)
450+
right_expr = _coerce_bool_to_int(right)
451+
return sge.Pow(this=left_expr, expression=right_expr)
452+
453+
446454
@register_unary_op(numeric_ops.isnan_op)
447455
def isnan(arg: TypedExpr) -> sge.Expression:
448456
return sge.IsNan(this=arg.expr)
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col`,
4+
`float64_col`,
5+
`int64_col`
6+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
7+
), `bfcte_1` AS (
8+
SELECT
9+
*,
10+
`bool_col` AS `bfcol_3`,
11+
`int64_col` AS `bfcol_4`,
12+
`float64_col` AS `bfcol_5`,
13+
(
14+
`int64_col` >= 0
15+
) AND (
16+
`int64_col` <= 10
17+
) AS `bfcol_6`
18+
FROM `bfcte_0`
19+
), `bfcte_2` AS (
20+
SELECT
21+
*
22+
FROM `bfcte_1`
23+
WHERE
24+
`bfcol_6`
25+
), `bfcte_3` AS (
26+
SELECT
27+
*,
28+
POWER(`bfcol_4`, `bfcol_4`) AS `bfcol_14`,
29+
POWER(`bfcol_4`, `bfcol_5`) AS `bfcol_15`,
30+
POWER(`bfcol_5`, `bfcol_4`) AS `bfcol_16`,
31+
POWER(`bfcol_5`, `bfcol_5`) AS `bfcol_17`,
32+
POWER(`bfcol_4`, CAST(`bfcol_3` AS INT64)) AS `bfcol_18`,
33+
POWER(CAST(`bfcol_3` AS INT64), `bfcol_4`) AS `bfcol_19`
34+
FROM `bfcte_2`
35+
)
36+
SELECT
37+
`bfcol_14` AS `int_pow_int`,
38+
`bfcol_15` AS `int_pow_float`,
39+
`bfcol_16` AS `float_pow_int`,
40+
`bfcol_17` AS `float_pow_float`,
41+
`bfcol_18` AS `int_pow_bool`,
42+
`bfcol_19` AS `bool_pow_int`
43+
FROM `bfcte_3`

tests/unit/core/compile/sqlglot/expressions/test_numeric_ops.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,3 +438,36 @@ def test_sub_unsupported_raises(scalar_types_df: bpd.DataFrame):
438438

439439
with pytest.raises(TypeError):
440440
utils._apply_binary_op(scalar_types_df, ops.sub_op, "int64_col", "string_col")
441+
442+
443+
def test_unsafe_pow_op(scalar_types_df: bpd.DataFrame, snapshot):
444+
# Choose certain row so the sql execution won't fail even with unsafe_pow_op.
445+
bf_df = scalar_types_df[
446+
(scalar_types_df["int64_col"] >= 0) & (scalar_types_df["int64_col"] <= 10)
447+
]
448+
bf_df = bf_df[["int64_col", "float64_col", "bool_col"]]
449+
450+
int64_col_id = bf_df["int64_col"]._value_column
451+
float64_col_id = bf_df["float64_col"]._value_column
452+
bool_col_id = bf_df["bool_col"]._value_column
453+
454+
sql = utils._apply_ops_to_sql(
455+
bf_df,
456+
[
457+
ops.unsafe_pow_op.as_expr(int64_col_id, int64_col_id),
458+
ops.unsafe_pow_op.as_expr(int64_col_id, float64_col_id),
459+
ops.unsafe_pow_op.as_expr(float64_col_id, int64_col_id),
460+
ops.unsafe_pow_op.as_expr(float64_col_id, float64_col_id),
461+
ops.unsafe_pow_op.as_expr(int64_col_id, bool_col_id),
462+
ops.unsafe_pow_op.as_expr(bool_col_id, int64_col_id),
463+
],
464+
[
465+
"int_pow_int",
466+
"int_pow_float",
467+
"float_pow_int",
468+
"float_pow_float",
469+
"int_pow_bool",
470+
"bool_pow_int",
471+
],
472+
)
473+
snapshot.assert_match(sql, "out.sql")

0 commit comments

Comments
 (0)