diff --git a/bigframes/core/compile/sqlglot/expressions/generic_ops.py b/bigframes/core/compile/sqlglot/expressions/generic_ops.py index 7ff09ab3f6..e44a1b5c1d 100644 --- a/bigframes/core/compile/sqlglot/expressions/generic_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/generic_ops.py @@ -107,6 +107,7 @@ def _(expr: TypedExpr, op: ops.MapOp) -> sge.Expression: sge.If(this=sge.convert(key), true=sge.convert(value)) for key, value in op.mappings ], + default=expr.expr, ) diff --git a/bigframes/core/compile/sqlglot/expressions/string_ops.py b/bigframes/core/compile/sqlglot/expressions/string_ops.py index bdc4808302..3e19a2fe33 100644 --- a/bigframes/core/compile/sqlglot/expressions/string_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/string_ops.py @@ -18,6 +18,7 @@ import sqlglot.expressions as sge +from bigframes import dtypes from bigframes import operations as ops from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler @@ -195,6 +196,9 @@ def _(expr: TypedExpr) -> sge.Expression: @register_unary_op(ops.len_op) def _(expr: TypedExpr) -> sge.Expression: + if dtypes.is_array_like(expr.dtype): + return sge.func("ARRAY_LENGTH", expr.expr) + return sge.Length(this=expr.expr) @@ -239,7 +243,7 @@ def to_startswith(pat: str) -> sge.Expression: @register_unary_op(ops.StrStripOp, pass_op=True) def _(expr: TypedExpr, op: ops.StrStripOp) -> sge.Expression: - return sge.Trim(this=sge.convert(op.to_strip), expression=expr.expr) + return sge.Trim(this=expr.expr, expression=sge.convert(op.to_strip)) @register_unary_op(ops.StringSplitOp, pass_op=True) @@ -284,27 +288,29 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: @register_unary_op(ops.ZfillOp, pass_op=True) def _(expr: TypedExpr, op: ops.ZfillOp) -> sge.Expression: + length_expr = sge.Greatest( + expressions=[sge.Length(this=expr.expr), sge.convert(op.width)] + ) return sge.Case( ifs=[ sge.If( - this=sge.EQ( - this=sge.Substring( - this=expr.expr, start=sge.convert(1), length=sge.convert(1) - ), - expression=sge.convert("-"), + this=sge.func( + "STARTS_WITH", + expr.expr, + sge.convert("-"), ), true=sge.Concat( expressions=[ sge.convert("-"), sge.func( "LPAD", - sge.Substring(this=expr.expr, start=sge.convert(1)), - sge.convert(op.width - 1), + sge.Substring(this=expr.expr, start=sge.convert(2)), + length_expr - 1, sge.convert("0"), ), ] ), ) ], - default=sge.func("LPAD", expr.expr, sge.convert(op.width), sge.convert("0")), + default=sge.func("LPAD", expr.expr, length_expr, sge.convert("0")), ) diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_map/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_map/out.sql index 52a3174cf9..22628c6a4b 100644 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_map/out.sql +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_map/out.sql @@ -5,7 +5,7 @@ WITH `bfcte_0` AS ( ), `bfcte_1` AS ( SELECT *, - CASE `string_col` WHEN 'value1' THEN 'mapped1' END AS `bfcol_1` + CASE `string_col` WHEN 'value1' THEN 'mapped1' ELSE `string_col` END AS `bfcol_1` FROM `bfcte_0` ) SELECT diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_len_w_array/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_len_w_array/out.sql new file mode 100644 index 0000000000..609c4131e6 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_len_w_array/out.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `int_list_col` + FROM `bigframes-dev`.`sqlglot_test`.`repeated_types` +), `bfcte_1` AS ( + SELECT + *, + ARRAY_LENGTH(`int_list_col`) AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `int_list_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_strip/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_strip/out.sql index 771bb9c49f..ebe4c39bbf 100644 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_strip/out.sql +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_strip/out.sql @@ -5,7 +5,7 @@ WITH `bfcte_0` AS ( ), `bfcte_1` AS ( SELECT *, - TRIM(' ', `string_col`) AS `bfcol_1` + TRIM(`string_col`, ' ') AS `bfcol_1` FROM `bfcte_0` ) SELECT diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_zfill/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_zfill/out.sql index 97651ece49..79c4f695aa 100644 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_zfill/out.sql +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_zfill/out.sql @@ -6,9 +6,9 @@ WITH `bfcte_0` AS ( SELECT *, CASE - WHEN SUBSTRING(`string_col`, 1, 1) = '-' - THEN CONCAT('-', LPAD(SUBSTRING(`string_col`, 1), 9, '0')) - ELSE LPAD(`string_col`, 10, '0') + WHEN STARTS_WITH(`string_col`, '-') + THEN CONCAT('-', LPAD(SUBSTRING(`string_col`, 2), GREATEST(LENGTH(`string_col`), 10) - 1, '0')) + ELSE LPAD(`string_col`, GREATEST(LENGTH(`string_col`), 10), '0') END AS `bfcol_1` FROM `bfcte_0` ) diff --git a/tests/unit/core/compile/sqlglot/expressions/test_string_ops.py b/tests/unit/core/compile/sqlglot/expressions/test_string_ops.py index b20c038ed0..d1856b259d 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_string_ops.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_string_ops.py @@ -120,6 +120,14 @@ def test_len(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(sql, "out.sql") +def test_len_w_array(repeated_types_df: bpd.DataFrame, snapshot): + col_name = "int_list_col" + bf_df = repeated_types_df[[col_name]] + sql = utils._apply_ops_to_sql(bf_df, [ops.len_op.as_expr(col_name)], [col_name]) + + snapshot.assert_match(sql, "out.sql") + + def test_lower(scalar_types_df: bpd.DataFrame, snapshot): col_name = "string_col" bf_df = scalar_types_df[[col_name]]