Skip to content

Commit e39dfe2

Browse files
authored
deps: fix missing CTEs in sqlglot v28 and relax version dependency (#2277)
1 parent 6e73d77 commit e39dfe2

File tree

2 files changed

+44
-13
lines changed

2 files changed

+44
-13
lines changed

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def from_query_string(
174174
alias=cte_name,
175175
)
176176
select_expr = sge.Select().select(sge.Star()).from_(sge.Table(this=cte_name))
177-
select_expr.set("with", sge.With(expressions=[cte]))
177+
select_expr = _set_query_ctes(select_expr, [cte])
178178
return cls(expr=select_expr, uid_gen=uid_gen)
179179

180180
@classmethod
@@ -197,7 +197,8 @@ def from_union(
197197
), f"All provided expressions must be of type sge.Select, but got {type(select)}"
198198

199199
select_expr = select.copy()
200-
existing_ctes = [*existing_ctes, *select_expr.args.pop("with", [])]
200+
select_expr, select_ctes = _pop_query_ctes(select_expr)
201+
existing_ctes = [*existing_ctes, *select_ctes]
201202

202203
new_cte_name = sge.to_identifier(
203204
next(uid_gen.get_uid_stream("bfcte_")), quoted=cls.quoted
@@ -229,7 +230,7 @@ def from_union(
229230
),
230231
)
231232
final_select_expr = sge.Select().select(sge.Star()).from_(union_expr.subquery())
232-
final_select_expr.set("with", sge.With(expressions=existing_ctes))
233+
final_select_expr = _set_query_ctes(final_select_expr, existing_ctes)
233234
return cls(expr=final_select_expr, uid_gen=uid_gen)
234235

235236
def select(
@@ -336,8 +337,8 @@ def join(
336337
left_select = _select_to_cte(self.expr, left_cte_name)
337338
right_select = _select_to_cte(right.expr, right_cte_name)
338339

339-
left_ctes = left_select.args.pop("with", [])
340-
right_ctes = right_select.args.pop("with", [])
340+
left_select, left_ctes = _pop_query_ctes(left_select)
341+
right_select, right_ctes = _pop_query_ctes(right_select)
341342
merged_ctes = [*left_ctes, *right_ctes]
342343

343344
join_on = _and(
@@ -353,7 +354,7 @@ def join(
353354
.from_(sge.Table(this=left_cte_name))
354355
.join(sge.Table(this=right_cte_name), on=join_on, join_type=join_type_str)
355356
)
356-
new_expr.set("with", sge.With(expressions=merged_ctes))
357+
new_expr = _set_query_ctes(new_expr, merged_ctes)
357358

358359
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)
359360

@@ -373,8 +374,8 @@ def isin_join(
373374
# Prefer subquery over CTE for the IN clause's right side to improve SQL readability.
374375
right_select = right.expr
375376

376-
left_ctes = left_select.args.pop("with", [])
377-
right_ctes = right_select.args.pop("with", [])
377+
left_select, left_ctes = _pop_query_ctes(left_select)
378+
right_select, right_ctes = _pop_query_ctes(right_select)
378379
merged_ctes = [*left_ctes, *right_ctes]
379380

380381
left_condition = typed_expr.TypedExpr(
@@ -415,7 +416,7 @@ def isin_join(
415416
.select(sge.Column(this=sge.Star(), table=left_cte_name), new_column)
416417
.from_(sge.Table(this=left_cte_name))
417418
)
418-
new_expr.set("with", sge.With(expressions=merged_ctes))
419+
new_expr = _set_query_ctes(new_expr, merged_ctes)
419420

420421
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)
421422

@@ -625,14 +626,13 @@ def _select_to_cte(expr: sge.Select, cte_name: sge.Identifier) -> sge.Select:
625626
into a new CTE and then generates a 'SELECT * FROM new_cte_name'
626627
for the new query."""
627628
select_expr = expr.copy()
628-
existing_ctes = select_expr.args.pop("with", [])
629+
select_expr, existing_ctes = _pop_query_ctes(select_expr)
629630
new_cte = sge.CTE(
630631
this=select_expr,
631632
alias=cte_name,
632633
)
633-
new_with_clause = sge.With(expressions=[*existing_ctes, new_cte])
634634
new_select_expr = sge.Select().select(sge.Star()).from_(sge.Table(this=cte_name))
635-
new_select_expr.set("with", new_with_clause)
635+
new_select_expr = _set_query_ctes(new_select_expr, [*existing_ctes, new_cte])
636636
return new_select_expr
637637

638638

@@ -788,3 +788,34 @@ def _join_condition_for_numeric(
788788
this=sge.EQ(this=left_2, expression=right_2),
789789
expression=sge.EQ(this=left_3, expression=right_3),
790790
)
791+
792+
793+
def _set_query_ctes(
794+
expr: sge.Select,
795+
ctes: list[sge.CTE],
796+
) -> sge.Select:
797+
"""Sets the CTEs of a given sge.Select expression."""
798+
new_expr = expr.copy()
799+
with_expr = sge.With(expressions=ctes) if len(ctes) > 0 else None
800+
801+
if "with" in new_expr.arg_types.keys():
802+
new_expr.set("with", with_expr)
803+
elif "with_" in new_expr.arg_types.keys():
804+
new_expr.set("with_", with_expr)
805+
else:
806+
raise ValueError("The expression does not support CTEs.")
807+
return new_expr
808+
809+
810+
def _pop_query_ctes(
811+
expr: sge.Select,
812+
) -> tuple[sge.Select, list[sge.CTE]]:
813+
"""Pops the CTEs of a given sge.Select expression."""
814+
if "with" in expr.arg_types.keys():
815+
expr_ctes = expr.args.pop("with", [])
816+
return expr, expr_ctes
817+
elif "with_" in expr.arg_types.keys():
818+
expr_ctes = expr.args.pop("with_", [])
819+
return expr, expr_ctes
820+
else:
821+
raise ValueError("The expression does not support CTEs.")

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
"requests >=2.27.1",
5656
"shapely >=1.8.5",
5757
# 25.20.0 introduces this fix https://github.com/TobikoData/sqlmesh/issues/3095 for rtrim/ltrim.
58-
"sqlglot >=25.20.0, <28.0.0",
58+
"sqlglot >=25.20.0",
5959
"tabulate >=0.9",
6060
"ipywidgets >=7.7.1",
6161
"humanize >=4.6.0",

0 commit comments

Comments
 (0)