|
17 | 17 | import dataclasses
|
18 | 18 | import typing
|
19 | 19 |
|
| 20 | +from google.cloud import bigquery |
20 | 21 | import pyarrow as pa
|
21 | 22 | import sqlglot as sg
|
22 | 23 | import sqlglot.dialects.bigquery
|
@@ -104,6 +105,24 @@ def from_pyarrow(
|
104 | 105 | )
|
105 | 106 | return cls(expr=sg.select(sge.Star()).from_(expr), uid_gen=uid_gen)
|
106 | 107 |
|
| 108 | + @classmethod |
| 109 | + def from_query_string( |
| 110 | + cls, |
| 111 | + query_string: str, |
| 112 | + ) -> SQLGlotIR: |
| 113 | + """Builds SQLGlot expression from a query string""" |
| 114 | + uid_gen: guid.SequentialUIDGenerator = guid.SequentialUIDGenerator() |
| 115 | + cte_name = sge.to_identifier( |
| 116 | + next(uid_gen.get_uid_stream("bfcte_")), quoted=cls.quoted |
| 117 | + ) |
| 118 | + cte = sge.CTE( |
| 119 | + this=query_string, |
| 120 | + alias=cte_name, |
| 121 | + ) |
| 122 | + select_expr = sge.Select().select(sge.Star()).from_(sge.Table(this=cte_name)) |
| 123 | + select_expr.set("with", sge.With(expressions=[cte])) |
| 124 | + return cls(expr=select_expr, uid_gen=uid_gen) |
| 125 | + |
107 | 126 | def select(
|
108 | 127 | self,
|
109 | 128 | selected_cols: tuple[tuple[str, sge.Expression], ...],
|
@@ -133,6 +152,36 @@ def project(
|
133 | 152 | select_expr = self.expr.select(*projected_cols_expr, append=True)
|
134 | 153 | return SQLGlotIR(expr=select_expr)
|
135 | 154 |
|
| 155 | + def insert( |
| 156 | + self, |
| 157 | + destination: bigquery.TableReference, |
| 158 | + ) -> str: |
| 159 | + return sge.insert(self.expr.subquery(), _table(destination)).sql( |
| 160 | + dialect=self.dialect, pretty=self.pretty |
| 161 | + ) |
| 162 | + |
| 163 | + def replace( |
| 164 | + self, |
| 165 | + destination: bigquery.TableReference, |
| 166 | + ) -> str: |
| 167 | + # Workaround for SQLGlot breaking change: |
| 168 | + # https://github.com/tobymao/sqlglot/pull/4495 |
| 169 | + whens_expr = [ |
| 170 | + sge.When(matched=False, source=True, then=sge.Delete()), |
| 171 | + sge.When(matched=False, then=sge.Insert(this=sge.Var(this="ROW"))), |
| 172 | + ] |
| 173 | + whens_str = "\n".join( |
| 174 | + when_expr.sql(dialect=self.dialect, pretty=self.pretty) |
| 175 | + for when_expr in whens_expr |
| 176 | + ) |
| 177 | + |
| 178 | + merge_str = sge.Merge( |
| 179 | + this=_table(destination), |
| 180 | + using=self.expr.subquery(), |
| 181 | + on=_literal(False, dtypes.BOOL_DTYPE), |
| 182 | + ).sql(dialect=self.dialect, pretty=self.pretty) |
| 183 | + return f"{merge_str}\n{whens_str}" |
| 184 | + |
136 | 185 | def _encapsulate_as_cte(
|
137 | 186 | self,
|
138 | 187 | ) -> sge.Select:
|
@@ -190,3 +239,11 @@ def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression:
|
190 | 239 |
|
191 | 240 | def _cast(arg: typing.Any, to: str) -> sge.Cast:
|
192 | 241 | return sge.Cast(this=arg, to=to)
|
| 242 | + |
| 243 | + |
| 244 | +def _table(table: bigquery.TableReference) -> sge.Table: |
| 245 | + return sge.Table( |
| 246 | + this=sg.to_identifier(table.table_id, quoted=True), |
| 247 | + db=sg.to_identifier(table.dataset_id, quoted=True), |
| 248 | + catalog=sg.to_identifier(table.project, quoted=True), |
| 249 | + ) |
0 commit comments