@@ -174,7 +174,7 @@ def from_query_string(
174174 alias = cte_name ,
175175 )
176176 select_expr = sge .Select ().select (sge .Star ()).from_ (sge .Table (this = cte_name ))
177- select_expr . set ( "with" , sge . With ( expressions = [cte ]) )
177+ select_expr = _set_query_ctes ( select_expr , [cte ])
178178 return cls (expr = select_expr , uid_gen = uid_gen )
179179
180180 @classmethod
@@ -197,7 +197,8 @@ def from_union(
197197 ), f"All provided expressions must be of type sge.Select, but got { type (select )} "
198198
199199 select_expr = select .copy ()
200- existing_ctes = [* existing_ctes , * select_expr .args .pop ("with" , [])]
200+ select_expr , select_ctes = _pop_query_ctes (select_expr )
201+ existing_ctes = [* existing_ctes , * select_ctes ]
201202
202203 new_cte_name = sge .to_identifier (
203204 next (uid_gen .get_uid_stream ("bfcte_" )), quoted = cls .quoted
@@ -229,7 +230,7 @@ def from_union(
229230 ),
230231 )
231232 final_select_expr = sge .Select ().select (sge .Star ()).from_ (union_expr .subquery ())
232- final_select_expr . set ( "with" , sge . With ( expressions = existing_ctes ) )
233+ final_select_expr = _set_query_ctes ( final_select_expr , existing_ctes )
233234 return cls (expr = final_select_expr , uid_gen = uid_gen )
234235
235236 def select (
@@ -336,8 +337,8 @@ def join(
336337 left_select = _select_to_cte (self .expr , left_cte_name )
337338 right_select = _select_to_cte (right .expr , right_cte_name )
338339
339- left_ctes = left_select . args . pop ( "with" , [] )
340- right_ctes = right_select . args . pop ( "with" , [] )
340+ left_select , left_ctes = _pop_query_ctes ( left_select )
341+ right_select , right_ctes = _pop_query_ctes ( right_select )
341342 merged_ctes = [* left_ctes , * right_ctes ]
342343
343344 join_on = _and (
@@ -353,7 +354,7 @@ def join(
353354 .from_ (sge .Table (this = left_cte_name ))
354355 .join (sge .Table (this = right_cte_name ), on = join_on , join_type = join_type_str )
355356 )
356- new_expr . set ( "with" , sge . With ( expressions = merged_ctes ) )
357+ new_expr = _set_query_ctes ( new_expr , merged_ctes )
357358
358359 return SQLGlotIR (expr = new_expr , uid_gen = self .uid_gen )
359360
@@ -373,8 +374,8 @@ def isin_join(
373374 # Prefer subquery over CTE for the IN clause's right side to improve SQL readability.
374375 right_select = right .expr
375376
376- left_ctes = left_select . args . pop ( "with" , [] )
377- right_ctes = right_select . args . pop ( "with" , [] )
377+ left_select , left_ctes = _pop_query_ctes ( left_select )
378+ right_select , right_ctes = _pop_query_ctes ( right_select )
378379 merged_ctes = [* left_ctes , * right_ctes ]
379380
380381 left_condition = typed_expr .TypedExpr (
@@ -415,7 +416,7 @@ def isin_join(
415416 .select (sge .Column (this = sge .Star (), table = left_cte_name ), new_column )
416417 .from_ (sge .Table (this = left_cte_name ))
417418 )
418- new_expr . set ( "with" , sge . With ( expressions = merged_ctes ) )
419+ new_expr = _set_query_ctes ( new_expr , merged_ctes )
419420
420421 return SQLGlotIR (expr = new_expr , uid_gen = self .uid_gen )
421422
@@ -625,14 +626,13 @@ def _select_to_cte(expr: sge.Select, cte_name: sge.Identifier) -> sge.Select:
625626 into a new CTE and then generates a 'SELECT * FROM new_cte_name'
626627 for the new query."""
627628 select_expr = expr .copy ()
628- existing_ctes = select_expr . args . pop ( "with" , [] )
629+ select_expr , existing_ctes = _pop_query_ctes ( select_expr )
629630 new_cte = sge .CTE (
630631 this = select_expr ,
631632 alias = cte_name ,
632633 )
633- new_with_clause = sge .With (expressions = [* existing_ctes , new_cte ])
634634 new_select_expr = sge .Select ().select (sge .Star ()).from_ (sge .Table (this = cte_name ))
635- new_select_expr . set ( "with" , new_with_clause )
635+ new_select_expr = _set_query_ctes ( new_select_expr , [ * existing_ctes , new_cte ] )
636636 return new_select_expr
637637
638638
@@ -788,3 +788,34 @@ def _join_condition_for_numeric(
788788 this = sge .EQ (this = left_2 , expression = right_2 ),
789789 expression = sge .EQ (this = left_3 , expression = right_3 ),
790790 )
791+
792+
793+ def _set_query_ctes (
794+ expr : sge .Select ,
795+ ctes : list [sge .CTE ],
796+ ) -> sge .Select :
797+ """Sets the CTEs of a given sge.Select expression."""
798+ new_expr = expr .copy ()
799+ with_expr = sge .With (expressions = ctes ) if len (ctes ) > 0 else None
800+
801+ if "with" in new_expr .arg_types .keys ():
802+ new_expr .set ("with" , with_expr )
803+ elif "with_" in new_expr .arg_types .keys ():
804+ new_expr .set ("with_" , with_expr )
805+ else :
806+ raise ValueError ("The expression does not support CTEs." )
807+ return new_expr
808+
809+
810+ def _pop_query_ctes (
811+ expr : sge .Select ,
812+ ) -> tuple [sge .Select , list [sge .CTE ]]:
813+ """Pops the CTEs of a given sge.Select expression."""
814+ if "with" in expr .arg_types .keys ():
815+ expr_ctes = expr .args .pop ("with" , [])
816+ return expr , expr_ctes
817+ elif "with_" in expr .arg_types .keys ():
818+ expr_ctes = expr .args .pop ("with_" , [])
819+ return expr , expr_ctes
820+ else :
821+ raise ValueError ("The expression does not support CTEs." )
0 commit comments