1515from __future__ import annotations
1616
1717import functools
18+ import typing
1819
1920import sqlglot .expressions as sge
2021
2930
3031@register_unary_op (ops .capitalize_op )
3132def _ (expr : TypedExpr ) -> sge .Expression :
32- return sge .Initcap (this = expr .expr )
33+ return sge .Initcap (this = expr .expr , expression = sge . convert ( "" ) )
3334
3435
3536@register_unary_op (ops .StrContainsOp , pass_op = True )
@@ -44,9 +45,17 @@ def _(expr: TypedExpr, op: ops.StrContainsRegexOp) -> sge.Expression:
4445
4546@register_unary_op (ops .StrExtractOp , pass_op = True )
4647def _ (expr : TypedExpr , op : ops .StrExtractOp ) -> sge .Expression :
47- return sge .RegexpExtract (
48- this = expr .expr , expression = sge .convert (op .pat ), group = sge .convert (op .n )
49- )
48+ # Cannot use BigQuery's REGEXP_EXTRACT function, which only allows one
49+ # capturing group.
50+ pat_expr = sge .convert (op .pat )
51+ if op .n != 0 :
52+ pat_expr = sge .func ("CONCAT" , sge .convert (".*?" ), pat_expr , sge .convert (".*" ))
53+ else :
54+ pat_expr = sge .func ("CONCAT" , sge .convert (".*?(" ), pat_expr , sge .convert (").*" ))
55+
56+ rex_replace = sge .func ("REGEXP_REPLACE" , expr .expr , pat_expr , sge .convert (r"\1" ))
57+ rex_contains = sge .func ("REGEXP_CONTAINS" , expr .expr , sge .convert (op .pat ))
58+ return sge .If (this = rex_contains , true = rex_replace , false = sge .null ())
5059
5160
5261@register_unary_op (ops .StrFindOp , pass_op = True )
@@ -75,47 +84,43 @@ def _(expr: TypedExpr, op: ops.StrFindOp) -> sge.Expression:
7584
7685@register_unary_op (ops .StrLstripOp , pass_op = True )
7786def _ (expr : TypedExpr , op : ops .StrLstripOp ) -> sge .Expression :
78- return sge .Trim (this = expr .expr , expression = sge .convert (op .to_strip ), side = "LEFT" )
87+ return sge .func ("LTRIM" , expr .expr , sge .convert (op .to_strip ))
88+
89+
90+ @register_unary_op (ops .StrRstripOp , pass_op = True )
91+ def _ (expr : TypedExpr , op : ops .StrRstripOp ) -> sge .Expression :
92+ return sge .func ("RTRIM" , expr .expr , sge .convert (op .to_strip ))
7993
8094
8195@register_unary_op (ops .StrPadOp , pass_op = True )
8296def _ (expr : TypedExpr , op : ops .StrPadOp ) -> sge .Expression :
83- pad_length = sge .func (
84- "GREATEST" , sge .Length (this = expr .expr ), sge .convert (op .length )
85- )
97+ expr_length = sge .Length (this = expr .expr )
98+ fillchar = sge .convert (op .fillchar )
99+ pad_length = sge .func ("GREATEST" , expr_length , sge .convert (op .length ))
100+
86101 if op .side == "left" :
87- return sge .func (
88- "LPAD" ,
89- expr .expr ,
90- pad_length ,
91- sge .convert (op .fillchar ),
92- )
102+ return sge .func ("LPAD" , expr .expr , pad_length , fillchar )
93103 elif op .side == "right" :
94- return sge .func (
95- "RPAD" ,
96- expr .expr ,
97- pad_length ,
98- sge .convert (op .fillchar ),
99- )
104+ return sge .func ("RPAD" , expr .expr , pad_length , fillchar )
100105 else : # side == both
101- lpad_amount = sge .Cast (
102- this = sge .func (
103- "SAFE_DIVIDE" ,
104- sge .Sub (this = pad_length , expression = sge .Length (this = expr .expr )),
105- sge .convert (2 ),
106- ),
107- to = "INT64" ,
108- ) + sge .Length (this = expr .expr )
106+ lpad_amount = (
107+ sge .Cast (
108+ this = sge .Floor (
109+ this = sge .func (
110+ "SAFE_DIVIDE" ,
111+ sge .Sub (this = pad_length , expression = expr_length ),
112+ sge .convert (2 ),
113+ )
114+ ),
115+ to = "INT64" ,
116+ )
117+ + expr_length
118+ )
109119 return sge .func (
110120 "RPAD" ,
111- sge .func (
112- "LPAD" ,
113- expr .expr ,
114- lpad_amount ,
115- sge .convert (op .fillchar ),
116- ),
121+ sge .func ("LPAD" , expr .expr , lpad_amount , fillchar ),
117122 pad_length ,
118- sge . convert ( op . fillchar ) ,
123+ fillchar ,
119124 )
120125
121126
@@ -224,11 +229,6 @@ def _(expr: TypedExpr) -> sge.Expression:
224229 return sge .func ("REVERSE" , expr .expr )
225230
226231
227- @register_unary_op (ops .StrRstripOp , pass_op = True )
228- def _ (expr : TypedExpr , op : ops .StrRstripOp ) -> sge .Expression :
229- return sge .Trim (this = expr .expr , expression = sge .convert (op .to_strip ), side = "RIGHT" )
230-
231-
232232@register_unary_op (ops .StartsWithOp , pass_op = True )
233233def _ (expr : TypedExpr , op : ops .StartsWithOp ) -> sge .Expression :
234234 if not op .pat :
@@ -253,26 +253,78 @@ def _(expr: TypedExpr, op: ops.StringSplitOp) -> sge.Expression:
253253
254254@register_unary_op (ops .StrGetOp , pass_op = True )
255255def _ (expr : TypedExpr , op : ops .StrGetOp ) -> sge .Expression :
256- return sge .Substring (
256+ sub_str = sge .Substring (
257257 this = expr .expr ,
258258 start = sge .convert (op .i + 1 ),
259259 length = sge .convert (1 ),
260260 )
261261
262+ return sge .If (
263+ this = sge .NEQ (this = sub_str , expression = sge .convert ("" )),
264+ true = sub_str ,
265+ false = sge .Null (),
266+ )
267+
262268
263269@register_unary_op (ops .StrSliceOp , pass_op = True )
264270def _ (expr : TypedExpr , op : ops .StrSliceOp ) -> sge .Expression :
265- start = op .start + 1 if op .start is not None else None
266- if op .end is None :
267- length = None
268- elif op .start is None :
269- length = op .end
271+ column_length = sge .Length (this = expr .expr )
272+ if op .start is None :
273+ start = 0
270274 else :
271- length = op .end - op .start
275+ start = op .start
276+
277+ start_expr = sge .convert (start ) if start < 0 else sge .convert (start + 1 )
278+ length_expr : typing .Optional [sge .Expression ]
279+ if op .end is None :
280+ length_expr = None
281+ elif op .end < 0 :
282+ if start < 0 :
283+ start_expr = sge .Greatest (
284+ expressions = [
285+ sge .convert (1 ),
286+ column_length + sge .convert (start + 1 ),
287+ ]
288+ )
289+ length_expr = sge .Greatest (
290+ expressions = [
291+ sge .convert (0 ),
292+ column_length + sge .convert (op .end ),
293+ ]
294+ ) - sge .Greatest (
295+ expressions = [
296+ sge .convert (0 ),
297+ column_length + sge .convert (start ),
298+ ]
299+ )
300+ else :
301+ length_expr = sge .Greatest (
302+ expressions = [
303+ sge .convert (0 ),
304+ column_length + sge .convert (op .end - start ),
305+ ]
306+ )
307+ else : # op.end >= 0
308+ if start < 0 :
309+ start_expr = sge .Greatest (
310+ expressions = [
311+ sge .convert (1 ),
312+ column_length + sge .convert (start + 1 ),
313+ ]
314+ )
315+ length_expr = sge .convert (op .end ) - sge .Greatest (
316+ expressions = [
317+ sge .convert (0 ),
318+ column_length + sge .convert (start ),
319+ ]
320+ )
321+ else :
322+ length_expr = sge .convert (op .end - start )
323+
272324 return sge .Substring (
273325 this = expr .expr ,
274- start = sge . convert ( start ) if start is not None else None ,
275- length = sge . convert ( length ) if length is not None else None ,
326+ start = start_expr ,
327+ length = length_expr ,
276328 )
277329
278330
0 commit comments