Skip to content

Commit 545cdca

Browse files
chelsea-lintswast
andauthored
fix: reduce bigquery table modification via DML for to_gbq (#1737)
To avoid exceeding BigQuery's 1500 daily table modification limit, to_gbq now prioritizes INSERT or MERGE DMLs. This method is used when the target table exists and shares the same schema, supporting both data replacement and appending. If schema discrepancies are found, to_gbq will default back to its original table modification process. Fixes internal issue 409086472 Co-authored-by: Tim Sweña (Swast) <[email protected]>
1 parent d937be0 commit 545cdca

File tree

4 files changed

+198
-51
lines changed

4 files changed

+198
-51
lines changed

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import dataclasses
1818
import typing
1919

20+
from google.cloud import bigquery
2021
import pyarrow as pa
2122
import sqlglot as sg
2223
import sqlglot.dialects.bigquery
@@ -104,6 +105,24 @@ def from_pyarrow(
104105
)
105106
return cls(expr=sg.select(sge.Star()).from_(expr), uid_gen=uid_gen)
106107

108+
@classmethod
109+
def from_query_string(
110+
cls,
111+
query_string: str,
112+
) -> SQLGlotIR:
113+
"""Builds SQLGlot expression from a query string"""
114+
uid_gen: guid.SequentialUIDGenerator = guid.SequentialUIDGenerator()
115+
cte_name = sge.to_identifier(
116+
next(uid_gen.get_uid_stream("bfcte_")), quoted=cls.quoted
117+
)
118+
cte = sge.CTE(
119+
this=query_string,
120+
alias=cte_name,
121+
)
122+
select_expr = sge.Select().select(sge.Star()).from_(sge.Table(this=cte_name))
123+
select_expr.set("with", sge.With(expressions=[cte]))
124+
return cls(expr=select_expr, uid_gen=uid_gen)
125+
107126
def select(
108127
self,
109128
selected_cols: tuple[tuple[str, sge.Expression], ...],
@@ -133,6 +152,36 @@ def project(
133152
select_expr = self.expr.select(*projected_cols_expr, append=True)
134153
return SQLGlotIR(expr=select_expr)
135154

155+
def insert(
156+
self,
157+
destination: bigquery.TableReference,
158+
) -> str:
159+
return sge.insert(self.expr.subquery(), _table(destination)).sql(
160+
dialect=self.dialect, pretty=self.pretty
161+
)
162+
163+
def replace(
164+
self,
165+
destination: bigquery.TableReference,
166+
) -> str:
167+
# Workaround for SQLGlot breaking change:
168+
# https://github.com/tobymao/sqlglot/pull/4495
169+
whens_expr = [
170+
sge.When(matched=False, source=True, then=sge.Delete()),
171+
sge.When(matched=False, then=sge.Insert(this=sge.Var(this="ROW"))),
172+
]
173+
whens_str = "\n".join(
174+
when_expr.sql(dialect=self.dialect, pretty=self.pretty)
175+
for when_expr in whens_expr
176+
)
177+
178+
merge_str = sge.Merge(
179+
this=_table(destination),
180+
using=self.expr.subquery(),
181+
on=_literal(False, dtypes.BOOL_DTYPE),
182+
).sql(dialect=self.dialect, pretty=self.pretty)
183+
return f"{merge_str}\n{whens_str}"
184+
136185
def _encapsulate_as_cte(
137186
self,
138187
) -> sge.Select:
@@ -190,3 +239,11 @@ def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression:
190239

191240
def _cast(arg: typing.Any, to: str) -> sge.Cast:
192241
return sge.Cast(this=arg, to=to)
242+
243+
244+
def _table(table: bigquery.TableReference) -> sge.Table:
245+
return sge.Table(
246+
this=sg.to_identifier(table.table_id, quoted=True),
247+
db=sg.to_identifier(table.dataset_id, quoted=True),
248+
catalog=sg.to_identifier(table.project, quoted=True),
249+
)

bigframes/session/bq_caching_executor.py

Lines changed: 55 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,11 @@
2929

3030
import bigframes.core
3131
from bigframes.core import compile, rewrite
32+
import bigframes.core.compile.sqlglot.sqlglot_ir as sqlglot_ir
3233
import bigframes.core.guid
3334
import bigframes.core.nodes as nodes
3435
import bigframes.core.ordering as order
36+
import bigframes.core.schema as schemata
3537
import bigframes.core.tree_properties as tree_properties
3638
import bigframes.dtypes
3739
import bigframes.exceptions as bfe
@@ -206,17 +208,45 @@ def export_gbq(
206208
if bigframes.options.compute.enable_multi_query_execution:
207209
self._simplify_with_caching(array_value)
208210

209-
dispositions = {
210-
"fail": bigquery.WriteDisposition.WRITE_EMPTY,
211-
"replace": bigquery.WriteDisposition.WRITE_TRUNCATE,
212-
"append": bigquery.WriteDisposition.WRITE_APPEND,
213-
}
211+
table_exists = True
212+
try:
213+
table = self.bqclient.get_table(destination)
214+
if if_exists == "fail":
215+
raise ValueError(f"Table already exists: {destination.__str__()}")
216+
except google.api_core.exceptions.NotFound:
217+
table_exists = False
218+
219+
if len(cluster_cols) != 0:
220+
if table_exists and table.clustering_fields != cluster_cols:
221+
raise ValueError(
222+
"Table clustering fields cannot be changed after the table has "
223+
f"been created. Existing clustering fields: {table.clustering_fields}"
224+
)
225+
214226
sql = self.to_sql(array_value, ordered=False)
215-
job_config = bigquery.QueryJobConfig(
216-
write_disposition=dispositions[if_exists],
217-
destination=destination,
218-
clustering_fields=cluster_cols if cluster_cols else None,
219-
)
227+
if table_exists and _if_schema_match(table.schema, array_value.schema):
228+
# b/409086472: Uses DML for table appends and replacements to avoid
229+
# BigQuery `RATE_LIMIT_EXCEEDED` errors, as per quota limits:
230+
# https://cloud.google.com/bigquery/quotas#standard_tables
231+
job_config = bigquery.QueryJobConfig()
232+
ir = sqlglot_ir.SQLGlotIR.from_query_string(sql)
233+
if if_exists == "append":
234+
sql = ir.insert(destination)
235+
else: # for "replace"
236+
assert if_exists == "replace"
237+
sql = ir.replace(destination)
238+
else:
239+
dispositions = {
240+
"fail": bigquery.WriteDisposition.WRITE_EMPTY,
241+
"replace": bigquery.WriteDisposition.WRITE_TRUNCATE,
242+
"append": bigquery.WriteDisposition.WRITE_APPEND,
243+
}
244+
job_config = bigquery.QueryJobConfig(
245+
write_disposition=dispositions[if_exists],
246+
destination=destination,
247+
clustering_fields=cluster_cols if cluster_cols else None,
248+
)
249+
220250
# TODO(swast): plumb through the api_name of the user-facing api that
221251
# caused this query.
222252
_, query_job = self._run_execute_query(
@@ -572,6 +602,21 @@ def _execute_plan(
572602
)
573603

574604

605+
def _if_schema_match(
606+
table_schema: Tuple[bigquery.SchemaField, ...], schema: schemata.ArraySchema
607+
) -> bool:
608+
if len(table_schema) != len(schema.items):
609+
return False
610+
for field in table_schema:
611+
if field.name not in schema.names:
612+
return False
613+
if bigframes.dtypes.convert_schema_field(field)[1] != schema.get_type(
614+
field.name
615+
):
616+
return False
617+
return True
618+
619+
575620
def _sanitize(
576621
schema: Tuple[bigquery.SchemaField, ...]
577622
) -> Tuple[bigquery.SchemaField, ...]:

tests/system/small/test_dataframe_io.py

Lines changed: 81 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,7 @@ def test_to_csv_tabs(
458458
[True, False],
459459
)
460460
@pytest.mark.skipif(pandas_gbq is None, reason="required by pd.read_gbq")
461-
def test_to_gbq_index(scalars_dfs, dataset_id, index):
461+
def test_to_gbq_w_index(scalars_dfs, dataset_id, index):
462462
"""Test the `to_gbq` API with the `index` parameter."""
463463
scalars_df, scalars_pandas_df = scalars_dfs
464464
destination_table = f"{dataset_id}.test_index_df_to_gbq_{index}"
@@ -485,48 +485,67 @@ def test_to_gbq_index(scalars_dfs, dataset_id, index):
485485
pd.testing.assert_frame_equal(df_out, expected, check_index_type=False)
486486

487487

488-
@pytest.mark.parametrize(
489-
("if_exists", "expected_index"),
490-
[
491-
pytest.param("replace", 1),
492-
pytest.param("append", 2),
493-
pytest.param(
494-
"fail",
495-
0,
496-
marks=pytest.mark.xfail(
497-
raises=google.api_core.exceptions.Conflict,
498-
),
499-
),
500-
pytest.param(
501-
"unknown",
502-
0,
503-
marks=pytest.mark.xfail(
504-
raises=ValueError,
505-
),
506-
),
507-
],
508-
)
509-
@pytest.mark.skipif(pandas_gbq is None, reason="required by pd.read_gbq")
510-
def test_to_gbq_if_exists(
511-
scalars_df_default_index,
512-
scalars_pandas_df_default_index,
513-
dataset_id,
514-
if_exists,
515-
expected_index,
516-
):
517-
"""Test the `to_gbq` API with the `if_exists` parameter."""
518-
destination_table = f"{dataset_id}.test_to_gbq_if_exists_{if_exists}"
488+
def test_to_gbq_if_exists_is_fail(scalars_dfs, dataset_id):
489+
scalars_df, scalars_pandas_df = scalars_dfs
490+
destination_table = f"{dataset_id}.test_to_gbq_if_exists_is_fails"
491+
scalars_df.to_gbq(destination_table)
519492

520-
scalars_df_default_index.to_gbq(destination_table)
521-
scalars_df_default_index.to_gbq(destination_table, if_exists=if_exists)
493+
gcs_df = pd.read_gbq(destination_table, index_col="rowindex")
494+
assert len(gcs_df) == len(scalars_pandas_df)
495+
pd.testing.assert_index_equal(gcs_df.columns, scalars_pandas_df.columns)
522496

523-
gcs_df = pd.read_gbq(destination_table)
524-
assert len(gcs_df.index) == expected_index * len(
525-
scalars_pandas_df_default_index.index
526-
)
527-
pd.testing.assert_index_equal(
528-
gcs_df.columns, scalars_pandas_df_default_index.columns
529-
)
497+
# Test default value is "fails"
498+
with pytest.raises(ValueError, match="Table already exists"):
499+
scalars_df.to_gbq(destination_table)
500+
501+
with pytest.raises(ValueError, match="Table already exists"):
502+
scalars_df.to_gbq(destination_table, if_exists="fail")
503+
504+
505+
def test_to_gbq_if_exists_is_replace(scalars_dfs, dataset_id):
506+
scalars_df, scalars_pandas_df = scalars_dfs
507+
destination_table = f"{dataset_id}.test_to_gbq_if_exists_is_replace"
508+
scalars_df.to_gbq(destination_table)
509+
510+
gcs_df = pd.read_gbq(destination_table, index_col="rowindex")
511+
assert len(gcs_df) == len(scalars_pandas_df)
512+
pd.testing.assert_index_equal(gcs_df.columns, scalars_pandas_df.columns)
513+
514+
# When replacing a table with same schema
515+
scalars_df.to_gbq(destination_table, if_exists="replace")
516+
gcs_df = pd.read_gbq(destination_table, index_col="rowindex")
517+
assert len(gcs_df) == len(scalars_pandas_df)
518+
pd.testing.assert_index_equal(gcs_df.columns, scalars_pandas_df.columns)
519+
520+
# When replacing a table with different schema
521+
partitial_scalars_df = scalars_df.drop(columns=["string_col"])
522+
partitial_scalars_df.to_gbq(destination_table, if_exists="replace")
523+
gcs_df = pd.read_gbq(destination_table, index_col="rowindex")
524+
assert len(gcs_df) == len(partitial_scalars_df)
525+
pd.testing.assert_index_equal(gcs_df.columns, partitial_scalars_df.columns)
526+
527+
528+
def test_to_gbq_if_exists_is_append(scalars_dfs, dataset_id):
529+
scalars_df, scalars_pandas_df = scalars_dfs
530+
destination_table = f"{dataset_id}.test_to_gbq_if_exists_is_append"
531+
scalars_df.to_gbq(destination_table)
532+
533+
gcs_df = pd.read_gbq(destination_table, index_col="rowindex")
534+
assert len(gcs_df) == len(scalars_pandas_df)
535+
pd.testing.assert_index_equal(gcs_df.columns, scalars_pandas_df.columns)
536+
537+
# When appending to a table with same schema
538+
scalars_df.to_gbq(destination_table, if_exists="append")
539+
gcs_df = pd.read_gbq(destination_table, index_col="rowindex")
540+
assert len(gcs_df) == 2 * len(scalars_pandas_df)
541+
pd.testing.assert_index_equal(gcs_df.columns, scalars_pandas_df.columns)
542+
543+
# When appending to a table with different schema
544+
partitial_scalars_df = scalars_df.drop(columns=["string_col"])
545+
partitial_scalars_df.to_gbq(destination_table, if_exists="append")
546+
gcs_df = pd.read_gbq(destination_table, index_col="rowindex")
547+
assert len(gcs_df) == 3 * len(partitial_scalars_df)
548+
pd.testing.assert_index_equal(gcs_df.columns, scalars_df.columns)
530549

531550

532551
def test_to_gbq_w_duplicate_column_names(
@@ -773,6 +792,27 @@ def test_to_gbq_w_clustering_no_destination(
773792
assert table.expires is not None
774793

775794

795+
def test_to_gbq_w_clustering_existing_table(
796+
scalars_df_default_index,
797+
dataset_id,
798+
bigquery_client,
799+
):
800+
destination_table = f"{dataset_id}.test_to_gbq_w_clustering_existing_table"
801+
scalars_df_default_index.to_gbq(destination_table)
802+
803+
table = bigquery_client.get_table(destination_table)
804+
assert table.clustering_fields is None
805+
assert table.expires is None
806+
807+
with pytest.raises(ValueError, match="Table clustering fields cannot be changed"):
808+
clustering_columns = ["int64_col"]
809+
scalars_df_default_index.to_gbq(
810+
destination_table,
811+
if_exists="replace",
812+
clustering_columns=clustering_columns,
813+
)
814+
815+
776816
def test_to_gbq_w_invalid_destination_table(scalars_df_index):
777817
with pytest.raises(ValueError):
778818
scalars_df_index.to_gbq("table_id")

tests/unit/test_dataframe_io.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,8 @@ def test_dataframe_to_pandas(mock_df, api_name, kwargs):
4949
mock_df.to_pandas.assert_called_once_with(
5050
allow_large_results=kwargs["allow_large_results"]
5151
)
52+
53+
54+
def test_to_gbq_if_exists_invalid(mock_df):
55+
with pytest.raises(ValueError, match="Got invalid value 'invalid' for if_exists."):
56+
mock_df.to_gbq("a.b.c", if_exists="invalid")

0 commit comments

Comments
 (0)