From d8ce12f0ce850f4495229509abf5d8b1d00c51f3 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Wed, 14 May 2025 23:45:22 +0000 Subject: [PATCH 1/2] fix: avoid table modification for to_gbq --- bigframes/core/compile/sqlglot/sqlglot_ir.py | 57 +++++++++ bigframes/session/bq_caching_executor.py | 65 ++++++++-- tests/system/small/test_dataframe_io.py | 122 ++++++++++++------- tests/unit/test_dataframe_io.py | 5 + 4 files changed, 198 insertions(+), 51 deletions(-) diff --git a/bigframes/core/compile/sqlglot/sqlglot_ir.py b/bigframes/core/compile/sqlglot/sqlglot_ir.py index b23349bcbc..935ad393f8 100644 --- a/bigframes/core/compile/sqlglot/sqlglot_ir.py +++ b/bigframes/core/compile/sqlglot/sqlglot_ir.py @@ -17,6 +17,7 @@ import dataclasses import typing +from google.cloud import bigquery import pyarrow as pa import sqlglot as sg import sqlglot.dialects.bigquery @@ -104,6 +105,24 @@ def from_pyarrow( ) return cls(expr=sg.select(sge.Star()).from_(expr), uid_gen=uid_gen) + @classmethod + def from_query_string( + cls, + query_string: str, + ) -> SQLGlotIR: + """Builds SQLGlot expression from a query string""" + uid_gen: guid.SequentialUIDGenerator = guid.SequentialUIDGenerator() + cte_name = sge.to_identifier( + next(uid_gen.get_uid_stream("bfcte_")), quoted=cls.quoted + ) + cte = sge.CTE( + this=query_string, + alias=cte_name, + ) + select_expr = sge.Select().select(sge.Star()).from_(sge.Table(this=cte_name)) + select_expr.set("with", sge.With(expressions=[cte])) + return cls(expr=select_expr, uid_gen=uid_gen) + def select( self, selected_cols: tuple[tuple[str, sge.Expression], ...], @@ -133,6 +152,36 @@ def project( select_expr = self.expr.select(*projected_cols_expr, append=True) return SQLGlotIR(expr=select_expr) + def insert( + self, + destination: bigquery.TableReference, + ) -> str: + return sge.insert(self.expr.subquery(), _table(destination)).sql( + dialect=self.dialect, pretty=self.pretty + ) + + def replace( + self, + destination: bigquery.TableReference, + ) -> str: + # Workaround for SQLGlot breaking change: + # https://github.com/tobymao/sqlglot/pull/4495 + whens_expr = [ + sge.When(matched=False, source=True, then=sge.Delete()), + sge.When(matched=False, then=sge.Insert(this=sge.Var(this="ROW"))), + ] + whens_str = "\n".join( + when_expr.sql(dialect=self.dialect, pretty=self.pretty) + for when_expr in whens_expr + ) + + merge_str = sge.Merge( + this=_table(destination), + using=self.expr.subquery(), + on=_literal(False, dtypes.BOOL_DTYPE), + ).sql(dialect=self.dialect, pretty=self.pretty) + return f"{merge_str}\n{whens_str}" + def _encapsulate_as_cte( self, ) -> sge.Select: @@ -190,3 +239,11 @@ def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression: def _cast(arg: typing.Any, to: str) -> sge.Cast: return sge.Cast(this=arg, to=to) + + +def _table(table: bigquery.TableReference) -> sge.Table: + return sge.Table( + this=sg.to_identifier(table.table_id, quoted=True), + db=sg.to_identifier(table.dataset_id, quoted=True), + catalog=sg.to_identifier(table.project, quoted=True), + ) diff --git a/bigframes/session/bq_caching_executor.py b/bigframes/session/bq_caching_executor.py index 6614abfed2..b1512231c7 100644 --- a/bigframes/session/bq_caching_executor.py +++ b/bigframes/session/bq_caching_executor.py @@ -29,9 +29,11 @@ import bigframes.core from bigframes.core import compile, rewrite +import bigframes.core.compile.sqlglot.sqlglot_ir as sqlglot_ir import bigframes.core.guid import bigframes.core.nodes as nodes import bigframes.core.ordering as order +import bigframes.core.schema as schemata import bigframes.core.tree_properties as tree_properties import bigframes.dtypes import bigframes.exceptions as bfe @@ -206,17 +208,45 @@ def export_gbq( if bigframes.options.compute.enable_multi_query_execution: self._simplify_with_caching(array_value) - dispositions = { - "fail": bigquery.WriteDisposition.WRITE_EMPTY, - "replace": bigquery.WriteDisposition.WRITE_TRUNCATE, - "append": bigquery.WriteDisposition.WRITE_APPEND, - } + table_exists = True + try: + table = self.bqclient.get_table(destination) + if if_exists == "fail": + raise ValueError(f"Table already exists: {destination.__str__()}") + except google.api_core.exceptions.NotFound: + table_exists = False + + if len(cluster_cols) != 0: + if table_exists and table.clustering_fields != cluster_cols: + raise ValueError( + "Table clustering fields cannot be changed after the table has " + f"been created. Existing clustering fields: {table.clustering_fields}" + ) + sql = self.to_sql(array_value, ordered=False) - job_config = bigquery.QueryJobConfig( - write_disposition=dispositions[if_exists], - destination=destination, - clustering_fields=cluster_cols if cluster_cols else None, - ) + if table_exists and _if_schama_match(table.schema, array_value.schema): + # b/409086472: Uses DML for table appends and replacements to avoid + # BigQuery `RATE_LIMIT_EXCEEDED` errors, as per quota limits: + # https://cloud.google.com/bigquery/quotas#standard_tables + job_config = bigquery.QueryJobConfig() + ir = sqlglot_ir.SQLGlotIR.from_query_string(sql) + if if_exists == "append": + sql = ir.insert(destination) + else: # for "replace" + assert if_exists == "replace" + sql = ir.replace(destination) + else: + dispositions = { + "fail": bigquery.WriteDisposition.WRITE_EMPTY, + "replace": bigquery.WriteDisposition.WRITE_TRUNCATE, + "append": bigquery.WriteDisposition.WRITE_APPEND, + } + job_config = bigquery.QueryJobConfig( + write_disposition=dispositions[if_exists], + destination=destination, + clustering_fields=cluster_cols if cluster_cols else None, + ) + # TODO(swast): plumb through the api_name of the user-facing api that # caused this query. _, query_job = self._run_execute_query( @@ -572,6 +602,21 @@ def _execute_plan( ) +def _if_schama_match( + table_schema: Tuple[bigquery.SchemaField, ...], schema: schemata.ArraySchema +) -> bool: + if len(table_schema) != len(schema.items): + return False + for field in table_schema: + if field.name not in schema.names: + return False + if bigframes.dtypes.convert_schema_field(field)[1] != schema.get_type( + field.name + ): + return False + return True + + def _sanitize( schema: Tuple[bigquery.SchemaField, ...] ) -> Tuple[bigquery.SchemaField, ...]: diff --git a/tests/system/small/test_dataframe_io.py b/tests/system/small/test_dataframe_io.py index d24b592b0d..857bec67c0 100644 --- a/tests/system/small/test_dataframe_io.py +++ b/tests/system/small/test_dataframe_io.py @@ -458,7 +458,7 @@ def test_to_csv_tabs( [True, False], ) @pytest.mark.skipif(pandas_gbq is None, reason="required by pd.read_gbq") -def test_to_gbq_index(scalars_dfs, dataset_id, index): +def test_to_gbq_w_index(scalars_dfs, dataset_id, index): """Test the `to_gbq` API with the `index` parameter.""" scalars_df, scalars_pandas_df = scalars_dfs destination_table = f"{dataset_id}.test_index_df_to_gbq_{index}" @@ -485,48 +485,67 @@ def test_to_gbq_index(scalars_dfs, dataset_id, index): pd.testing.assert_frame_equal(df_out, expected, check_index_type=False) -@pytest.mark.parametrize( - ("if_exists", "expected_index"), - [ - pytest.param("replace", 1), - pytest.param("append", 2), - pytest.param( - "fail", - 0, - marks=pytest.mark.xfail( - raises=google.api_core.exceptions.Conflict, - ), - ), - pytest.param( - "unknown", - 0, - marks=pytest.mark.xfail( - raises=ValueError, - ), - ), - ], -) -@pytest.mark.skipif(pandas_gbq is None, reason="required by pd.read_gbq") -def test_to_gbq_if_exists( - scalars_df_default_index, - scalars_pandas_df_default_index, - dataset_id, - if_exists, - expected_index, -): - """Test the `to_gbq` API with the `if_exists` parameter.""" - destination_table = f"{dataset_id}.test_to_gbq_if_exists_{if_exists}" +def test_to_gbq_if_exists_is_fail(scalars_dfs, dataset_id): + scalars_df, scalars_pandas_df = scalars_dfs + destination_table = f"{dataset_id}.test_to_gbq_if_exists_is_fails" + scalars_df.to_gbq(destination_table) - scalars_df_default_index.to_gbq(destination_table) - scalars_df_default_index.to_gbq(destination_table, if_exists=if_exists) + gcs_df = pd.read_gbq(destination_table, index_col="rowindex") + assert len(gcs_df) == len(scalars_pandas_df) + pd.testing.assert_index_equal(gcs_df.columns, scalars_pandas_df.columns) - gcs_df = pd.read_gbq(destination_table) - assert len(gcs_df.index) == expected_index * len( - scalars_pandas_df_default_index.index - ) - pd.testing.assert_index_equal( - gcs_df.columns, scalars_pandas_df_default_index.columns - ) + # Test default value is "fails" + with pytest.raises(ValueError, match="Table already exists"): + scalars_df.to_gbq(destination_table) + + with pytest.raises(ValueError, match="Table already exists"): + scalars_df.to_gbq(destination_table, if_exists="fail") + + +def test_to_gbq_if_exists_is_replace(scalars_dfs, dataset_id): + scalars_df, scalars_pandas_df = scalars_dfs + destination_table = f"{dataset_id}.test_to_gbq_if_exists_is_replace" + scalars_df.to_gbq(destination_table) + + gcs_df = pd.read_gbq(destination_table, index_col="rowindex") + assert len(gcs_df) == len(scalars_pandas_df) + pd.testing.assert_index_equal(gcs_df.columns, scalars_pandas_df.columns) + + # When replacing a table with same schema + scalars_df.to_gbq(destination_table, if_exists="replace") + gcs_df = pd.read_gbq(destination_table, index_col="rowindex") + assert len(gcs_df) == len(scalars_pandas_df) + pd.testing.assert_index_equal(gcs_df.columns, scalars_pandas_df.columns) + + # When replacing a table with different schema + partitial_scalars_df = scalars_df.drop(columns=["string_col"]) + partitial_scalars_df.to_gbq(destination_table, if_exists="replace") + gcs_df = pd.read_gbq(destination_table, index_col="rowindex") + assert len(gcs_df) == len(partitial_scalars_df) + pd.testing.assert_index_equal(gcs_df.columns, partitial_scalars_df.columns) + + +def test_to_gbq_if_exists_is_append(scalars_dfs, dataset_id): + scalars_df, scalars_pandas_df = scalars_dfs + destination_table = f"{dataset_id}.test_to_gbq_if_exists_is_append" + scalars_df.to_gbq(destination_table) + + gcs_df = pd.read_gbq(destination_table, index_col="rowindex") + assert len(gcs_df) == len(scalars_pandas_df) + pd.testing.assert_index_equal(gcs_df.columns, scalars_pandas_df.columns) + + # When appending to a table with same schema + scalars_df.to_gbq(destination_table, if_exists="append") + gcs_df = pd.read_gbq(destination_table, index_col="rowindex") + assert len(gcs_df) == 2 * len(scalars_pandas_df) + pd.testing.assert_index_equal(gcs_df.columns, scalars_pandas_df.columns) + + # When appending to a table with different schema + partitial_scalars_df = scalars_df.drop(columns=["string_col"]) + partitial_scalars_df.to_gbq(destination_table, if_exists="append") + gcs_df = pd.read_gbq(destination_table, index_col="rowindex") + assert len(gcs_df) == 3 * len(partitial_scalars_df) + pd.testing.assert_index_equal(gcs_df.columns, scalars_df.columns) def test_to_gbq_w_duplicate_column_names( @@ -773,6 +792,27 @@ def test_to_gbq_w_clustering_no_destination( assert table.expires is not None +def test_to_gbq_w_clustering_existing_table( + scalars_df_default_index, + dataset_id, + bigquery_client, +): + destination_table = f"{dataset_id}.test_to_gbq_w_clustering_existing_table" + scalars_df_default_index.to_gbq(destination_table) + + table = bigquery_client.get_table(destination_table) + assert table.clustering_fields is None + assert table.expires is None + + with pytest.raises(ValueError, match="Table clustering fields cannot be changed"): + clustering_columns = ["int64_col"] + scalars_df_default_index.to_gbq( + destination_table, + if_exists="replace", + clustering_columns=clustering_columns, + ) + + def test_to_gbq_w_invalid_destination_table(scalars_df_index): with pytest.raises(ValueError): scalars_df_index.to_gbq("table_id") diff --git a/tests/unit/test_dataframe_io.py b/tests/unit/test_dataframe_io.py index 7845a71134..f2c0241396 100644 --- a/tests/unit/test_dataframe_io.py +++ b/tests/unit/test_dataframe_io.py @@ -49,3 +49,8 @@ def test_dataframe_to_pandas(mock_df, api_name, kwargs): mock_df.to_pandas.assert_called_once_with( allow_large_results=kwargs["allow_large_results"] ) + + +def test_to_gbq_if_exists_invalid(mock_df): + with pytest.raises(ValueError, match="Got invalid value 'invalid' for if_exists."): + mock_df.to_gbq("a.b.c", if_exists="invalid") From c1f59b2b42e5acd25392a33f828e205b81eafa2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Swe=C3=B1a=20=28Swast=29?= Date: Fri, 16 May 2025 11:08:36 -0500 Subject: [PATCH 2/2] Apply suggestions from code review --- bigframes/session/bq_caching_executor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bigframes/session/bq_caching_executor.py b/bigframes/session/bq_caching_executor.py index b1512231c7..118838c059 100644 --- a/bigframes/session/bq_caching_executor.py +++ b/bigframes/session/bq_caching_executor.py @@ -224,7 +224,7 @@ def export_gbq( ) sql = self.to_sql(array_value, ordered=False) - if table_exists and _if_schama_match(table.schema, array_value.schema): + if table_exists and _if_schema_match(table.schema, array_value.schema): # b/409086472: Uses DML for table appends and replacements to avoid # BigQuery `RATE_LIMIT_EXCEEDED` errors, as per quota limits: # https://cloud.google.com/bigquery/quotas#standard_tables @@ -602,7 +602,7 @@ def _execute_plan( ) -def _if_schama_match( +def _if_schema_match( table_schema: Tuple[bigquery.SchemaField, ...], schema: schemata.ArraySchema ) -> bool: if len(table_schema) != len(schema.items):