Skip to content

Commit 0b14b17

Browse files
authored
refactor: fix some string ops in the sqlglot compiler (part 2) (#2332)
This change aims to fix some string-related tests failing in #2248. Fixes internal issue 417774347🦕
1 parent 4d5de14 commit 0b14b17

File tree

8 files changed

+128
-56
lines changed

8 files changed

+128
-56
lines changed

bigframes/core/compile/sqlglot/expressions/array_ops.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,12 @@
3030

3131
@register_unary_op(ops.ArrayIndexOp, pass_op=True)
3232
def _(expr: TypedExpr, op: ops.ArrayIndexOp) -> sge.Expression:
33+
if expr.dtype == dtypes.STRING_DTYPE:
34+
return _string_index(expr, op)
35+
3336
return sge.Bracket(
3437
this=expr.expr,
35-
expressions=[sge.Literal.number(op.index)],
38+
expressions=[sge.convert(op.index)],
3639
safe=True,
3740
offset=False,
3841
)
@@ -115,3 +118,16 @@ def _coerce_bool_to_int(typed_expr: TypedExpr) -> sge.Expression:
115118
if typed_expr.dtype == dtypes.BOOL_DTYPE:
116119
return sge.Cast(this=typed_expr.expr, to="INT64")
117120
return typed_expr.expr
121+
122+
123+
def _string_index(expr: TypedExpr, op: ops.ArrayIndexOp) -> sge.Expression:
124+
sub_str = sge.Substring(
125+
this=expr.expr,
126+
start=sge.convert(op.index + 1),
127+
length=sge.convert(1),
128+
)
129+
return sge.If(
130+
this=sge.NEQ(this=sub_str, expression=sge.convert("")),
131+
true=sub_str,
132+
false=sge.Null(),
133+
)

bigframes/core/compile/sqlglot/expressions/string_ops.py

Lines changed: 101 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from __future__ import annotations
1616

1717
import functools
18+
import typing
1819

1920
import sqlglot.expressions as sge
2021

@@ -29,7 +30,7 @@
2930

3031
@register_unary_op(ops.capitalize_op)
3132
def _(expr: TypedExpr) -> sge.Expression:
32-
return sge.Initcap(this=expr.expr)
33+
return sge.Initcap(this=expr.expr, expression=sge.convert(""))
3334

3435

3536
@register_unary_op(ops.StrContainsOp, pass_op=True)
@@ -44,9 +45,17 @@ def _(expr: TypedExpr, op: ops.StrContainsRegexOp) -> sge.Expression:
4445

4546
@register_unary_op(ops.StrExtractOp, pass_op=True)
4647
def _(expr: TypedExpr, op: ops.StrExtractOp) -> sge.Expression:
47-
return sge.RegexpExtract(
48-
this=expr.expr, expression=sge.convert(op.pat), group=sge.convert(op.n)
49-
)
48+
# Cannot use BigQuery's REGEXP_EXTRACT function, which only allows one
49+
# capturing group.
50+
pat_expr = sge.convert(op.pat)
51+
if op.n != 0:
52+
pat_expr = sge.func("CONCAT", sge.convert(".*?"), pat_expr, sge.convert(".*"))
53+
else:
54+
pat_expr = sge.func("CONCAT", sge.convert(".*?("), pat_expr, sge.convert(").*"))
55+
56+
rex_replace = sge.func("REGEXP_REPLACE", expr.expr, pat_expr, sge.convert(r"\1"))
57+
rex_contains = sge.func("REGEXP_CONTAINS", expr.expr, sge.convert(op.pat))
58+
return sge.If(this=rex_contains, true=rex_replace, false=sge.null())
5059

5160

5261
@register_unary_op(ops.StrFindOp, pass_op=True)
@@ -75,47 +84,43 @@ def _(expr: TypedExpr, op: ops.StrFindOp) -> sge.Expression:
7584

7685
@register_unary_op(ops.StrLstripOp, pass_op=True)
7786
def _(expr: TypedExpr, op: ops.StrLstripOp) -> sge.Expression:
78-
return sge.Trim(this=expr.expr, expression=sge.convert(op.to_strip), side="LEFT")
87+
return sge.func("LTRIM", expr.expr, sge.convert(op.to_strip))
88+
89+
90+
@register_unary_op(ops.StrRstripOp, pass_op=True)
91+
def _(expr: TypedExpr, op: ops.StrRstripOp) -> sge.Expression:
92+
return sge.func("RTRIM", expr.expr, sge.convert(op.to_strip))
7993

8094

8195
@register_unary_op(ops.StrPadOp, pass_op=True)
8296
def _(expr: TypedExpr, op: ops.StrPadOp) -> sge.Expression:
83-
pad_length = sge.func(
84-
"GREATEST", sge.Length(this=expr.expr), sge.convert(op.length)
85-
)
97+
expr_length = sge.Length(this=expr.expr)
98+
fillchar = sge.convert(op.fillchar)
99+
pad_length = sge.func("GREATEST", expr_length, sge.convert(op.length))
100+
86101
if op.side == "left":
87-
return sge.func(
88-
"LPAD",
89-
expr.expr,
90-
pad_length,
91-
sge.convert(op.fillchar),
92-
)
102+
return sge.func("LPAD", expr.expr, pad_length, fillchar)
93103
elif op.side == "right":
94-
return sge.func(
95-
"RPAD",
96-
expr.expr,
97-
pad_length,
98-
sge.convert(op.fillchar),
99-
)
104+
return sge.func("RPAD", expr.expr, pad_length, fillchar)
100105
else: # side == both
101-
lpad_amount = sge.Cast(
102-
this=sge.func(
103-
"SAFE_DIVIDE",
104-
sge.Sub(this=pad_length, expression=sge.Length(this=expr.expr)),
105-
sge.convert(2),
106-
),
107-
to="INT64",
108-
) + sge.Length(this=expr.expr)
106+
lpad_amount = (
107+
sge.Cast(
108+
this=sge.Floor(
109+
this=sge.func(
110+
"SAFE_DIVIDE",
111+
sge.Sub(this=pad_length, expression=expr_length),
112+
sge.convert(2),
113+
)
114+
),
115+
to="INT64",
116+
)
117+
+ expr_length
118+
)
109119
return sge.func(
110120
"RPAD",
111-
sge.func(
112-
"LPAD",
113-
expr.expr,
114-
lpad_amount,
115-
sge.convert(op.fillchar),
116-
),
121+
sge.func("LPAD", expr.expr, lpad_amount, fillchar),
117122
pad_length,
118-
sge.convert(op.fillchar),
123+
fillchar,
119124
)
120125

121126

@@ -224,11 +229,6 @@ def _(expr: TypedExpr) -> sge.Expression:
224229
return sge.func("REVERSE", expr.expr)
225230

226231

227-
@register_unary_op(ops.StrRstripOp, pass_op=True)
228-
def _(expr: TypedExpr, op: ops.StrRstripOp) -> sge.Expression:
229-
return sge.Trim(this=expr.expr, expression=sge.convert(op.to_strip), side="RIGHT")
230-
231-
232232
@register_unary_op(ops.StartsWithOp, pass_op=True)
233233
def _(expr: TypedExpr, op: ops.StartsWithOp) -> sge.Expression:
234234
if not op.pat:
@@ -253,26 +253,78 @@ def _(expr: TypedExpr, op: ops.StringSplitOp) -> sge.Expression:
253253

254254
@register_unary_op(ops.StrGetOp, pass_op=True)
255255
def _(expr: TypedExpr, op: ops.StrGetOp) -> sge.Expression:
256-
return sge.Substring(
256+
sub_str = sge.Substring(
257257
this=expr.expr,
258258
start=sge.convert(op.i + 1),
259259
length=sge.convert(1),
260260
)
261261

262+
return sge.If(
263+
this=sge.NEQ(this=sub_str, expression=sge.convert("")),
264+
true=sub_str,
265+
false=sge.Null(),
266+
)
267+
262268

263269
@register_unary_op(ops.StrSliceOp, pass_op=True)
264270
def _(expr: TypedExpr, op: ops.StrSliceOp) -> sge.Expression:
265-
start = op.start + 1 if op.start is not None else None
266-
if op.end is None:
267-
length = None
268-
elif op.start is None:
269-
length = op.end
271+
column_length = sge.Length(this=expr.expr)
272+
if op.start is None:
273+
start = 0
270274
else:
271-
length = op.end - op.start
275+
start = op.start
276+
277+
start_expr = sge.convert(start) if start < 0 else sge.convert(start + 1)
278+
length_expr: typing.Optional[sge.Expression]
279+
if op.end is None:
280+
length_expr = None
281+
elif op.end < 0:
282+
if start < 0:
283+
start_expr = sge.Greatest(
284+
expressions=[
285+
sge.convert(1),
286+
column_length + sge.convert(start + 1),
287+
]
288+
)
289+
length_expr = sge.Greatest(
290+
expressions=[
291+
sge.convert(0),
292+
column_length + sge.convert(op.end),
293+
]
294+
) - sge.Greatest(
295+
expressions=[
296+
sge.convert(0),
297+
column_length + sge.convert(start),
298+
]
299+
)
300+
else:
301+
length_expr = sge.Greatest(
302+
expressions=[
303+
sge.convert(0),
304+
column_length + sge.convert(op.end - start),
305+
]
306+
)
307+
else: # op.end >= 0
308+
if start < 0:
309+
start_expr = sge.Greatest(
310+
expressions=[
311+
sge.convert(1),
312+
column_length + sge.convert(start + 1),
313+
]
314+
)
315+
length_expr = sge.convert(op.end) - sge.Greatest(
316+
expressions=[
317+
sge.convert(0),
318+
column_length + sge.convert(start),
319+
]
320+
)
321+
else:
322+
length_expr = sge.convert(op.end - start)
323+
272324
return sge.Substring(
273325
this=expr.expr,
274-
start=sge.convert(start) if start is not None else None,
275-
length=sge.convert(length) if length is not None else None,
326+
start=start_expr,
327+
length=length_expr,
276328
)
277329

278330

tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_capitalize/out.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ WITH `bfcte_0` AS (
55
), `bfcte_1` AS (
66
SELECT
77
*,
8-
INITCAP(`string_col`) AS `bfcol_1`
8+
INITCAP(`string_col`, '') AS `bfcol_1`
99
FROM `bfcte_0`
1010
)
1111
SELECT

tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_lstrip/out.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ WITH `bfcte_0` AS (
55
), `bfcte_1` AS (
66
SELECT
77
*,
8-
TRIM(`string_col`, ' ') AS `bfcol_1`
8+
LTRIM(`string_col`, ' ') AS `bfcol_1`
99
FROM `bfcte_0`
1010
)
1111
SELECT

tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_rstrip/out.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ WITH `bfcte_0` AS (
55
), `bfcte_1` AS (
66
SELECT
77
*,
8-
TRIM(`string_col`, ' ') AS `bfcol_1`
8+
RTRIM(`string_col`, ' ') AS `bfcol_1`
99
FROM `bfcte_0`
1010
)
1111
SELECT

tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_str_extract/out.sql

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@ WITH `bfcte_0` AS (
55
), `bfcte_1` AS (
66
SELECT
77
*,
8-
REGEXP_EXTRACT(`string_col`, '([a-z]*)') AS `bfcol_1`
8+
IF(
9+
REGEXP_CONTAINS(`string_col`, '([a-z]*)'),
10+
REGEXP_REPLACE(`string_col`, CONCAT('.*?', '([a-z]*)', '.*'), '\\1'),
11+
NULL
12+
) AS `bfcol_1`
913
FROM `bfcte_0`
1014
)
1115
SELECT

tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_str_get/out.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ WITH `bfcte_0` AS (
55
), `bfcte_1` AS (
66
SELECT
77
*,
8-
SUBSTRING(`string_col`, 2, 1) AS `bfcol_1`
8+
IF(SUBSTRING(`string_col`, 2, 1) <> '', SUBSTRING(`string_col`, 2, 1), NULL) AS `bfcol_1`
99
FROM `bfcte_0`
1010
)
1111
SELECT

tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_str_pad/out.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ WITH `bfcte_0` AS (
1010
RPAD(
1111
LPAD(
1212
`string_col`,
13-
CAST(SAFE_DIVIDE(GREATEST(LENGTH(`string_col`), 10) - LENGTH(`string_col`), 2) AS INT64) + LENGTH(`string_col`),
13+
CAST(FLOOR(SAFE_DIVIDE(GREATEST(LENGTH(`string_col`), 10) - LENGTH(`string_col`), 2)) AS INT64) + LENGTH(`string_col`),
1414
'-'
1515
),
1616
GREATEST(LENGTH(`string_col`), 10),

0 commit comments

Comments
 (0)