Skip to content

Commit d8ce12f

Browse files
committed
fix: avoid table modification for to_gbq
1 parent d937be0 commit d8ce12f

File tree

4 files changed

+198
-51
lines changed

4 files changed

+198
-51
lines changed

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import dataclasses
1818
import typing
1919

20+
from google.cloud import bigquery
2021
import pyarrow as pa
2122
import sqlglot as sg
2223
import sqlglot.dialects.bigquery
@@ -104,6 +105,24 @@ def from_pyarrow(
104105
)
105106
return cls(expr=sg.select(sge.Star()).from_(expr), uid_gen=uid_gen)
106107

108+
@classmethod
109+
def from_query_string(
110+
cls,
111+
query_string: str,
112+
) -> SQLGlotIR:
113+
"""Builds SQLGlot expression from a query string"""
114+
uid_gen: guid.SequentialUIDGenerator = guid.SequentialUIDGenerator()
115+
cte_name = sge.to_identifier(
116+
next(uid_gen.get_uid_stream("bfcte_")), quoted=cls.quoted
117+
)
118+
cte = sge.CTE(
119+
this=query_string,
120+
alias=cte_name,
121+
)
122+
select_expr = sge.Select().select(sge.Star()).from_(sge.Table(this=cte_name))
123+
select_expr.set("with", sge.With(expressions=[cte]))
124+
return cls(expr=select_expr, uid_gen=uid_gen)
125+
107126
def select(
108127
self,
109128
selected_cols: tuple[tuple[str, sge.Expression], ...],
@@ -133,6 +152,36 @@ def project(
133152
select_expr = self.expr.select(*projected_cols_expr, append=True)
134153
return SQLGlotIR(expr=select_expr)
135154

155+
def insert(
156+
self,
157+
destination: bigquery.TableReference,
158+
) -> str:
159+
return sge.insert(self.expr.subquery(), _table(destination)).sql(
160+
dialect=self.dialect, pretty=self.pretty
161+
)
162+
163+
def replace(
164+
self,
165+
destination: bigquery.TableReference,
166+
) -> str:
167+
# Workaround for SQLGlot breaking change:
168+
# https://github.com/tobymao/sqlglot/pull/4495
169+
whens_expr = [
170+
sge.When(matched=False, source=True, then=sge.Delete()),
171+
sge.When(matched=False, then=sge.Insert(this=sge.Var(this="ROW"))),
172+
]
173+
whens_str = "\n".join(
174+
when_expr.sql(dialect=self.dialect, pretty=self.pretty)
175+
for when_expr in whens_expr
176+
)
177+
178+
merge_str = sge.Merge(
179+
this=_table(destination),
180+
using=self.expr.subquery(),
181+
on=_literal(False, dtypes.BOOL_DTYPE),
182+
).sql(dialect=self.dialect, pretty=self.pretty)
183+
return f"{merge_str}\n{whens_str}"
184+
136185
def _encapsulate_as_cte(
137186
self,
138187
) -> sge.Select:
@@ -190,3 +239,11 @@ def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression:
190239

191240
def _cast(arg: typing.Any, to: str) -> sge.Cast:
192241
return sge.Cast(this=arg, to=to)
242+
243+
244+
def _table(table: bigquery.TableReference) -> sge.Table:
245+
return sge.Table(
246+
this=sg.to_identifier(table.table_id, quoted=True),
247+
db=sg.to_identifier(table.dataset_id, quoted=True),
248+
catalog=sg.to_identifier(table.project, quoted=True),
249+
)

bigframes/session/bq_caching_executor.py

Lines changed: 55 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,11 @@
2929

3030
import bigframes.core
3131
from bigframes.core import compile, rewrite
32+
import bigframes.core.compile.sqlglot.sqlglot_ir as sqlglot_ir
3233
import bigframes.core.guid
3334
import bigframes.core.nodes as nodes
3435
import bigframes.core.ordering as order
36+
import bigframes.core.schema as schemata
3537
import bigframes.core.tree_properties as tree_properties
3638
import bigframes.dtypes
3739
import bigframes.exceptions as bfe
@@ -206,17 +208,45 @@ def export_gbq(
206208
if bigframes.options.compute.enable_multi_query_execution:
207209
self._simplify_with_caching(array_value)
208210

209-
dispositions = {
210-
"fail": bigquery.WriteDisposition.WRITE_EMPTY,
211-
"replace": bigquery.WriteDisposition.WRITE_TRUNCATE,
212-
"append": bigquery.WriteDisposition.WRITE_APPEND,
213-
}
211+
table_exists = True
212+
try:
213+
table = self.bqclient.get_table(destination)
214+
if if_exists == "fail":
215+
raise ValueError(f"Table already exists: {destination.__str__()}")
216+
except google.api_core.exceptions.NotFound:
217+
table_exists = False
218+
219+
if len(cluster_cols) != 0:
220+
if table_exists and table.clustering_fields != cluster_cols:
221+
raise ValueError(
222+
"Table clustering fields cannot be changed after the table has "
223+
f"been created. Existing clustering fields: {table.clustering_fields}"
224+
)
225+
214226
sql = self.to_sql(array_value, ordered=False)
215-
job_config = bigquery.QueryJobConfig(
216-
write_disposition=dispositions[if_exists],
217-
destination=destination,
218-
clustering_fields=cluster_cols if cluster_cols else None,
219-
)
227+
if table_exists and _if_schama_match(table.schema, array_value.schema):
228+
# b/409086472: Uses DML for table appends and replacements to avoid
229+
# BigQuery `RATE_LIMIT_EXCEEDED` errors, as per quota limits:
230+
# https://cloud.google.com/bigquery/quotas#standard_tables
231+
job_config = bigquery.QueryJobConfig()
232+
ir = sqlglot_ir.SQLGlotIR.from_query_string(sql)
233+
if if_exists == "append":
234+
sql = ir.insert(destination)
235+
else: # for "replace"
236+
assert if_exists == "replace"
237+
sql = ir.replace(destination)
238+
else:
239+
dispositions = {
240+
"fail": bigquery.WriteDisposition.WRITE_EMPTY,
241+
"replace": bigquery.WriteDisposition.WRITE_TRUNCATE,
242+
"append": bigquery.WriteDisposition.WRITE_APPEND,
243+
}
244+
job_config = bigquery.QueryJobConfig(
245+
write_disposition=dispositions[if_exists],
246+
destination=destination,
247+
clustering_fields=cluster_cols if cluster_cols else None,
248+
)
249+
220250
# TODO(swast): plumb through the api_name of the user-facing api that
221251
# caused this query.
222252
_, query_job = self._run_execute_query(
@@ -572,6 +602,21 @@ def _execute_plan(
572602
)
573603

574604

605+
def _if_schama_match(
606+
table_schema: Tuple[bigquery.SchemaField, ...], schema: schemata.ArraySchema
607+
) -> bool:
608+
if len(table_schema) != len(schema.items):
609+
return False
610+
for field in table_schema:
611+
if field.name not in schema.names:
612+
return False
613+
if bigframes.dtypes.convert_schema_field(field)[1] != schema.get_type(
614+
field.name
615+
):
616+
return False
617+
return True
618+
619+
575620
def _sanitize(
576621
schema: Tuple[bigquery.SchemaField, ...]
577622
) -> Tuple[bigquery.SchemaField, ...]:

tests/system/small/test_dataframe_io.py

Lines changed: 81 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,7 @@ def test_to_csv_tabs(
458458
[True, False],
459459
)
460460
@pytest.mark.skipif(pandas_gbq is None, reason="required by pd.read_gbq")
461-
def test_to_gbq_index(scalars_dfs, dataset_id, index):
461+
def test_to_gbq_w_index(scalars_dfs, dataset_id, index):
462462
"""Test the `to_gbq` API with the `index` parameter."""
463463
scalars_df, scalars_pandas_df = scalars_dfs
464464
destination_table = f"{dataset_id}.test_index_df_to_gbq_{index}"
@@ -485,48 +485,67 @@ def test_to_gbq_index(scalars_dfs, dataset_id, index):
485485
pd.testing.assert_frame_equal(df_out, expected, check_index_type=False)
486486

487487

488-
@pytest.mark.parametrize(
489-
("if_exists", "expected_index"),
490-
[
491-
pytest.param("replace", 1),
492-
pytest.param("append", 2),
493-
pytest.param(
494-
"fail",
495-
0,
496-
marks=pytest.mark.xfail(
497-
raises=google.api_core.exceptions.Conflict,
498-
),
499-
),
500-
pytest.param(
501-
"unknown",
502-
0,
503-
marks=pytest.mark.xfail(
504-
raises=ValueError,
505-
),
506-
),
507-
],
508-
)
509-
@pytest.mark.skipif(pandas_gbq is None, reason="required by pd.read_gbq")
510-
def test_to_gbq_if_exists(
511-
scalars_df_default_index,
512-
scalars_pandas_df_default_index,
513-
dataset_id,
514-
if_exists,
515-
expected_index,
516-
):
517-
"""Test the `to_gbq` API with the `if_exists` parameter."""
518-
destination_table = f"{dataset_id}.test_to_gbq_if_exists_{if_exists}"
488+
def test_to_gbq_if_exists_is_fail(scalars_dfs, dataset_id):
489+
scalars_df, scalars_pandas_df = scalars_dfs
490+
destination_table = f"{dataset_id}.test_to_gbq_if_exists_is_fails"
491+
scalars_df.to_gbq(destination_table)
519492

520-
scalars_df_default_index.to_gbq(destination_table)
521-
scalars_df_default_index.to_gbq(destination_table, if_exists=if_exists)
493+
gcs_df = pd.read_gbq(destination_table, index_col="rowindex")
494+
assert len(gcs_df) == len(scalars_pandas_df)
495+
pd.testing.assert_index_equal(gcs_df.columns, scalars_pandas_df.columns)
522496

523-
gcs_df = pd.read_gbq(destination_table)
524-
assert len(gcs_df.index) == expected_index * len(
525-
scalars_pandas_df_default_index.index
526-
)
527-
pd.testing.assert_index_equal(
528-
gcs_df.columns, scalars_pandas_df_default_index.columns
529-
)
497+
# Test default value is "fails"
498+
with pytest.raises(ValueError, match="Table already exists"):
499+
scalars_df.to_gbq(destination_table)
500+
501+
with pytest.raises(ValueError, match="Table already exists"):
502+
scalars_df.to_gbq(destination_table, if_exists="fail")
503+
504+
505+
def test_to_gbq_if_exists_is_replace(scalars_dfs, dataset_id):
506+
scalars_df, scalars_pandas_df = scalars_dfs
507+
destination_table = f"{dataset_id}.test_to_gbq_if_exists_is_replace"
508+
scalars_df.to_gbq(destination_table)
509+
510+
gcs_df = pd.read_gbq(destination_table, index_col="rowindex")
511+
assert len(gcs_df) == len(scalars_pandas_df)
512+
pd.testing.assert_index_equal(gcs_df.columns, scalars_pandas_df.columns)
513+
514+
# When replacing a table with same schema
515+
scalars_df.to_gbq(destination_table, if_exists="replace")
516+
gcs_df = pd.read_gbq(destination_table, index_col="rowindex")
517+
assert len(gcs_df) == len(scalars_pandas_df)
518+
pd.testing.assert_index_equal(gcs_df.columns, scalars_pandas_df.columns)
519+
520+
# When replacing a table with different schema
521+
partitial_scalars_df = scalars_df.drop(columns=["string_col"])
522+
partitial_scalars_df.to_gbq(destination_table, if_exists="replace")
523+
gcs_df = pd.read_gbq(destination_table, index_col="rowindex")
524+
assert len(gcs_df) == len(partitial_scalars_df)
525+
pd.testing.assert_index_equal(gcs_df.columns, partitial_scalars_df.columns)
526+
527+
528+
def test_to_gbq_if_exists_is_append(scalars_dfs, dataset_id):
529+
scalars_df, scalars_pandas_df = scalars_dfs
530+
destination_table = f"{dataset_id}.test_to_gbq_if_exists_is_append"
531+
scalars_df.to_gbq(destination_table)
532+
533+
gcs_df = pd.read_gbq(destination_table, index_col="rowindex")
534+
assert len(gcs_df) == len(scalars_pandas_df)
535+
pd.testing.assert_index_equal(gcs_df.columns, scalars_pandas_df.columns)
536+
537+
# When appending to a table with same schema
538+
scalars_df.to_gbq(destination_table, if_exists="append")
539+
gcs_df = pd.read_gbq(destination_table, index_col="rowindex")
540+
assert len(gcs_df) == 2 * len(scalars_pandas_df)
541+
pd.testing.assert_index_equal(gcs_df.columns, scalars_pandas_df.columns)
542+
543+
# When appending to a table with different schema
544+
partitial_scalars_df = scalars_df.drop(columns=["string_col"])
545+
partitial_scalars_df.to_gbq(destination_table, if_exists="append")
546+
gcs_df = pd.read_gbq(destination_table, index_col="rowindex")
547+
assert len(gcs_df) == 3 * len(partitial_scalars_df)
548+
pd.testing.assert_index_equal(gcs_df.columns, scalars_df.columns)
530549

531550

532551
def test_to_gbq_w_duplicate_column_names(
@@ -773,6 +792,27 @@ def test_to_gbq_w_clustering_no_destination(
773792
assert table.expires is not None
774793

775794

795+
def test_to_gbq_w_clustering_existing_table(
796+
scalars_df_default_index,
797+
dataset_id,
798+
bigquery_client,
799+
):
800+
destination_table = f"{dataset_id}.test_to_gbq_w_clustering_existing_table"
801+
scalars_df_default_index.to_gbq(destination_table)
802+
803+
table = bigquery_client.get_table(destination_table)
804+
assert table.clustering_fields is None
805+
assert table.expires is None
806+
807+
with pytest.raises(ValueError, match="Table clustering fields cannot be changed"):
808+
clustering_columns = ["int64_col"]
809+
scalars_df_default_index.to_gbq(
810+
destination_table,
811+
if_exists="replace",
812+
clustering_columns=clustering_columns,
813+
)
814+
815+
776816
def test_to_gbq_w_invalid_destination_table(scalars_df_index):
777817
with pytest.raises(ValueError):
778818
scalars_df_index.to_gbq("table_id")

tests/unit/test_dataframe_io.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,8 @@ def test_dataframe_to_pandas(mock_df, api_name, kwargs):
4949
mock_df.to_pandas.assert_called_once_with(
5050
allow_large_results=kwargs["allow_large_results"]
5151
)
52+
53+
54+
def test_to_gbq_if_exists_invalid(mock_df):
55+
with pytest.raises(ValueError, match="Got invalid value 'invalid' for if_exists."):
56+
mock_df.to_gbq("a.b.c", if_exists="invalid")

0 commit comments

Comments
 (0)