diff --git a/bigframes/core/compile/sqlglot/expressions/generic_ops.py b/bigframes/core/compile/sqlglot/expressions/generic_ops.py index 9782ef11d4..7572a1e801 100644 --- a/bigframes/core/compile/sqlglot/expressions/generic_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/generic_ops.py @@ -14,6 +14,7 @@ from __future__ import annotations +import sqlglot as sg import sqlglot.expressions as sge from bigframes import dtypes @@ -80,6 +81,16 @@ def _(expr: TypedExpr) -> sge.Expression: return sge.BitwiseNot(this=sge.paren(expr.expr)) +@register_nary_op(ops.SqlScalarOp, pass_op=True) +def _(*operands: TypedExpr, op: ops.SqlScalarOp) -> sge.Expression: + return sg.parse_one( + op.sql_template.format( + *[operand.expr.sql(dialect="bigquery") for operand in operands] + ), + dialect="bigquery", + ) + + @register_unary_op(ops.isnull_op) def _(expr: TypedExpr) -> sge.Expression: return sge.Is(this=expr.expr, expression=sge.Null()) diff --git a/bigframes/core/compile/sqlglot/expressions/struct_ops.py b/bigframes/core/compile/sqlglot/expressions/struct_ops.py index ebd3a38397..b6ec101eb1 100644 --- a/bigframes/core/compile/sqlglot/expressions/struct_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/struct_ops.py @@ -24,6 +24,7 @@ from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler +register_nary_op = scalar_compiler.scalar_op_compiler.register_nary_op register_unary_op = scalar_compiler.scalar_op_compiler.register_unary_op @@ -40,3 +41,13 @@ def _(expr: TypedExpr, op: ops.StructFieldOp) -> sge.Expression: this=sge.to_identifier(name, quoted=True), catalog=expr.expr, ) + + +@register_nary_op(ops.StructOp, pass_op=True) +def _(*exprs: TypedExpr, op: ops.StructOp) -> sge.Struct: + return sge.Struct( + expressions=[ + sge.PropertyEQ(this=sge.to_identifier(col), expression=expr.expr) + for col, expr in zip(op.column_names, exprs) + ] + ) diff --git a/bigframes/testing/utils.py b/bigframes/testing/utils.py index b4daab7aad..a0bfc9e648 100644 --- a/bigframes/testing/utils.py +++ b/bigframes/testing/utils.py @@ -475,13 +475,23 @@ def _apply_binary_op( ) -> str: """Applies a binary op to the given DataFrame and return the SQL representing the resulting DataFrame.""" + return _apply_nary_op(obj, op, l_arg, r_arg) + + +def _apply_nary_op( + obj: bpd.DataFrame, + op: Union[ops.BinaryOp, ops.NaryOp], + *args: Union[str, ex.Expression], +) -> str: + """Applies a nary op to the given DataFrame and return the SQL representing + the resulting DataFrame.""" array_value = obj._block.expr - op_expr = op.as_expr(l_arg, r_arg) + op_expr = op.as_expr(*args) result, col_ids = array_value.compute_values([op_expr]) # Rename columns for deterministic golden SQL results. assert len(col_ids) == 1 - result = result.rename_columns({col_ids[0]: l_arg}).select_columns([l_arg]) + result = result.rename_columns({col_ids[0]: args[0]}).select_columns([args[0]]) sql = result.session._executor.to_sql(result, enable_cache=False) return sql diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_sql_scalar_op/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_sql_scalar_op/out.sql new file mode 100644 index 0000000000..a79e006885 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_sql_scalar_op/out.sql @@ -0,0 +1,14 @@ +WITH `bfcte_0` AS ( + SELECT + `bool_col` AS `bfcol_0`, + `bytes_col` AS `bfcol_1` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + CAST(`bfcol_0` AS INT64) + BYTE_LENGTH(`bfcol_1`) AS `bfcol_2` + FROM `bfcte_0` +) +SELECT + `bfcol_2` AS `bool_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_struct_ops/test_struct_op/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_struct_ops/test_struct_op/out.sql new file mode 100644 index 0000000000..f7f741a523 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_struct_ops/test_struct_op/out.sql @@ -0,0 +1,21 @@ +WITH `bfcte_0` AS ( + SELECT + `bool_col` AS `bfcol_0`, + `int64_col` AS `bfcol_1`, + `float64_col` AS `bfcol_2`, + `string_col` AS `bfcol_3` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + STRUCT( + `bfcol_0` AS bool_col, + `bfcol_1` AS int64_col, + `bfcol_2` AS float64_col, + `bfcol_3` AS string_col + ) AS `bfcol_4` + FROM `bfcte_0` +) +SELECT + `bfcol_4` AS `result_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py b/tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py index b7abc63213..075416d664 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py @@ -261,6 +261,17 @@ def test_notnull(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(sql, "out.sql") +def test_sql_scalar_op(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["bool_col", "bytes_col"]] + sql = utils._apply_nary_op( + bf_df, + ops.SqlScalarOp(dtypes.INT_DTYPE, "CAST({0} AS INT64) + BYTE_LENGTH({1})"), + "bool_col", + "bytes_col", + ) + snapshot.assert_match(sql, "out.sql") + + def test_map(scalar_types_df: bpd.DataFrame, snapshot): col_name = "string_col" bf_df = scalar_types_df[[col_name]] diff --git a/tests/unit/core/compile/sqlglot/expressions/test_struct_ops.py b/tests/unit/core/compile/sqlglot/expressions/test_struct_ops.py index 19156ead99..7e67e44cd3 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_struct_ops.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_struct_ops.py @@ -12,15 +12,39 @@ # See the License for the specific language governing permissions and # limitations under the License. +import typing + import pytest from bigframes import operations as ops +from bigframes.core import expression as ex import bigframes.pandas as bpd from bigframes.testing import utils pytest.importorskip("pytest_snapshot") +def _apply_nary_op( + obj: bpd.DataFrame, + op: ops.NaryOp, + *args: typing.Union[str, ex.Expression], +) -> str: + """Applies a nary op to the given DataFrame and return the SQL representing + the resulting DataFrame.""" + array_value = obj._block.expr + op_expr = op.as_expr(*args) + result, col_ids = array_value.compute_values([op_expr]) + + # Rename columns for deterministic golden SQL results. + assert len(col_ids) == 1 + result = result.rename_columns({col_ids[0]: "result_col"}).select_columns( + ["result_col"] + ) + + sql = result.session._executor.to_sql(result, enable_cache=False) + return sql + + def test_struct_field(nested_structs_types_df: bpd.DataFrame, snapshot): col_name = "people" bf_df = nested_structs_types_df[[col_name]] @@ -34,3 +58,11 @@ def test_struct_field(nested_structs_types_df: bpd.DataFrame, snapshot): sql = utils._apply_unary_ops(bf_df, list(ops_map.values()), list(ops_map.keys())) snapshot.assert_match(sql, "out.sql") + + +def test_struct_op(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["bool_col", "int64_col", "float64_col", "string_col"]] + op = ops.StructOp(column_names=tuple(bf_df.columns.tolist())) + sql = _apply_nary_op(bf_df, op, *bf_df.columns.tolist()) + + snapshot.assert_match(sql, "out.sql")