diff --git a/bigframes/core/blocks.py b/bigframes/core/blocks.py index ccb2ffe401..a3a2ac36f5 100644 --- a/bigframes/core/blocks.py +++ b/bigframes/core/blocks.py @@ -153,6 +153,7 @@ def __init__( self._stats_cache[" ".join(self.index_columns)] = {} self._transpose_cache: Optional[Block] = transpose_cache + self._view_ref: Optional[bigquery.TableReference] = None @classmethod def from_local( @@ -2487,6 +2488,17 @@ def to_sql_query( idx_labels, ) + def to_view(self, include_index: bool) -> bigquery.TableReference: + """ + Creates a temporary BigQuery VIEW with the SQL corresponding to this block. + """ + if self._view_ref is not None: + return self._view_ref + + sql, _, _ = self.to_sql_query(include_index=include_index) + self._view_ref = self.session._create_temp_view(sql) + return self._view_ref + def cached(self, *, force: bool = False, session_aware: bool = False) -> None: """Write the block to a session table.""" # use a heuristic for whether something needs to be cached diff --git a/bigframes/core/pyformat.py b/bigframes/core/pyformat.py index 98f175d300..59ccdf1f5f 100644 --- a/bigframes/core/pyformat.py +++ b/bigframes/core/pyformat.py @@ -37,9 +37,13 @@ def _table_to_sql(table: _BQ_TABLE_TYPES) -> str: return f"`{table.project}`.`{table.dataset_id}`.`{table.table_id}`" -def _field_to_template_value(name: str, value: Any) -> str: +def _field_to_template_value( + name: str, + value: Any, +) -> str: """Convert value to something embeddable in a SQL string.""" import bigframes.core.sql # Avoid circular imports + import bigframes.dataframe # Avoid circular imports _validate_type(name, value) @@ -47,20 +51,27 @@ def _field_to_template_value(name: str, value: Any) -> str: if isinstance(value, table_types): return _table_to_sql(value) - # TODO(tswast): convert DataFrame objects to gbq tables or a literals subquery. + # TODO(tswast): convert pandas DataFrame objects to gbq tables or a literals subquery. + if isinstance(value, bigframes.dataframe.DataFrame): + return _table_to_sql(value._to_view()) + return bigframes.core.sql.simple_literal(value) def _validate_type(name: str, value: Any): """Raises TypeError if value is unsupported.""" import bigframes.core.sql # Avoid circular imports + import bigframes.dataframe # Avoid circular imports if value is None: return # None can't be used in isinstance, but is a valid literal. - supported_types = typing.get_args(_BQ_TABLE_TYPES) + typing.get_args( - bigframes.core.sql.SIMPLE_LITERAL_TYPES + supported_types = ( + typing.get_args(_BQ_TABLE_TYPES) + + typing.get_args(bigframes.core.sql.SIMPLE_LITERAL_TYPES) + + (bigframes.dataframe.DataFrame,) ) + if not isinstance(value, supported_types): raise TypeError( f"{name} has unsupported type: {type(value)}. " @@ -80,8 +91,6 @@ def pyformat( sql_template: str, *, pyformat_args: dict, - # TODO: add dry_run parameter to avoid expensive API calls in conversion - # TODO: and session to upload data / convert to table if necessary ) -> str: """Unsafe Python-style string formatting of SQL string. diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index 8ed749138c..a98733b48a 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -394,6 +394,19 @@ def astype( return self._apply_unary_op(ops.AsTypeOp(dtype, safe_cast)) + def _should_sql_have_index(self) -> bool: + """Should the SQL we pass to BQML and other I/O include the index?""" + + return self._has_index and ( + self.index.name is not None or len(self.index.names) > 1 + ) + + def _to_view(self) -> bigquery.TableReference: + """Compiles this DataFrame's expression tree to SQL and saves it to a + (temporary) view. + """ + return self._block.to_view(include_index=self._should_sql_have_index()) + def _to_sql_query( self, include_index: bool, enable_cache: bool = True ) -> Tuple[str, list[str], list[blocks.Label]]: @@ -420,9 +433,7 @@ def sql(self) -> str: string representing the compiled SQL. """ try: - include_index = self._has_index and ( - self.index.name is not None or len(self.index.names) > 1 - ) + include_index = self._should_sql_have_index() sql, _, _ = self._to_sql_query(include_index=include_index) return sql except AttributeError as e: diff --git a/bigframes/session/__init__.py b/bigframes/session/__init__.py index 81359ebb36..7630e71eaa 100644 --- a/bigframes/session/__init__.py +++ b/bigframes/session/__init__.py @@ -527,7 +527,6 @@ def _read_gbq_colab( query = bigframes.core.pyformat.pyformat( query, pyformat_args=pyformat_args, - # TODO: add dry_run parameter to avoid API calls for data in pyformat_args ) return self._loader.read_gbq_query( @@ -1938,6 +1937,10 @@ def _create_object_table(self, path: str, connection: str) -> str: return table + def _create_temp_view(self, sql: str) -> bigquery.TableReference: + """Create a random id Object Table from the input path and connection.""" + return self._anon_dataset_manager.create_temp_view(sql) + def from_glob_path( self, path: str, *, connection: Optional[str] = None, name: Optional[str] = None ) -> dataframe.DataFrame: diff --git a/bigframes/session/_io/bigquery/__init__.py b/bigframes/session/_io/bigquery/__init__.py index c08bb8d0dc..267111afe0 100644 --- a/bigframes/session/_io/bigquery/__init__.py +++ b/bigframes/session/_io/bigquery/__init__.py @@ -139,6 +139,28 @@ def create_temp_table( return f"{table_ref.project}.{table_ref.dataset_id}.{table_ref.table_id}" +def create_temp_view( + bqclient: bigquery.Client, + table_ref: bigquery.TableReference, + *, + expiration: datetime.datetime, + sql: str, +) -> str: + """Create an empty table with an expiration in the desired session. + + The table will be deleted when the session is closed or the expiration + is reached. + """ + destination = bigquery.Table(table_ref) + destination.expires = expiration + destination.view_query = sql + + # Ok if already exists, since this will only happen from retries internal to this method + # as the requested table id has a random UUID4 component. + bqclient.create_table(destination, exists_ok=True) + return f"{table_ref.project}.{table_ref.dataset_id}.{table_ref.table_id}" + + def set_table_expiration( bqclient: bigquery.Client, table_ref: bigquery.TableReference, diff --git a/bigframes/session/anonymous_dataset.py b/bigframes/session/anonymous_dataset.py index c8980e159b..bc785f693f 100644 --- a/bigframes/session/anonymous_dataset.py +++ b/bigframes/session/anonymous_dataset.py @@ -53,6 +53,12 @@ def __init__( def location(self): return self._location + def _default_expiration(self): + """When should the table expire automatically?""" + return ( + datetime.datetime.now(datetime.timezone.utc) + constants.DEFAULT_EXPIRATION + ) + def create_temp_table( self, schema: Sequence[bigquery.SchemaField], cluster_cols: Sequence[str] = [] ) -> bigquery.TableReference: @@ -60,9 +66,7 @@ def create_temp_table( Allocates and and creates a table in the anonymous dataset. The table will be cleaned up by clean_up_tables. """ - expiration = ( - datetime.datetime.now(datetime.timezone.utc) + constants.DEFAULT_EXPIRATION - ) + expiration = self._default_expiration() table = bf_io_bigquery.create_temp_table( self.bqclient, self.allocate_temp_table(), @@ -73,6 +77,20 @@ def create_temp_table( ) return bigquery.TableReference.from_string(table) + def create_temp_view(self, sql: str) -> bigquery.TableReference: + """ + Allocates and and creates a view in the anonymous dataset. + The view will be cleaned up by clean_up_tables. + """ + expiration = self._default_expiration() + table = bf_io_bigquery.create_temp_view( + self.bqclient, + self.allocate_temp_table(), + expiration=expiration, + sql=sql, + ) + return bigquery.TableReference.from_string(table) + def allocate_temp_table(self) -> bigquery.TableReference: """ Allocates a unique table id, but does not create the table. diff --git a/tests/system/small/session/test_read_gbq_colab.py b/tests/system/small/session/test_read_gbq_colab.py index 00ce0c722b..946faffab2 100644 --- a/tests/system/small/session/test_read_gbq_colab.py +++ b/tests/system/small/session/test_read_gbq_colab.py @@ -73,3 +73,31 @@ def test_read_gbq_colab_includes_formatted_scalars(session): } ), ) + + +def test_read_gbq_colab_includes_formatted_bigframes_dataframe( + session, scalars_df_index, scalars_pandas_df_index +): + pyformat_args = { + # Apply some operations to make sure the columns aren't renamed. + "some_dataframe": scalars_df_index[scalars_df_index["int64_col"] > 0].assign( + int64_col=scalars_df_index["int64_too"] + ), + # This is not a supported type, but ignored if not referenced. + "some_object": object(), + } + df = session._read_gbq_colab( + """ + SELECT int64_col, rowindex + FROM {some_dataframe} + ORDER BY rowindex ASC + """, + pyformat_args=pyformat_args, + ) + result = df.to_pandas() + expected = ( + scalars_pandas_df_index[scalars_pandas_df_index["int64_col"] > 0] + .assign(int64_col=scalars_pandas_df_index["int64_too"]) + .reset_index(drop=False)[["int64_col", "rowindex"]] + ) + pandas.testing.assert_frame_equal(result, expected) diff --git a/tests/unit/session/test_read_gbq_colab.py b/tests/unit/session/test_read_gbq_colab.py index 9afdba9eb3..cffc6b3af7 100644 --- a/tests/unit/session/test_read_gbq_colab.py +++ b/tests/unit/session/test_read_gbq_colab.py @@ -14,6 +14,10 @@ """Unit tests for read_gbq_colab helper functions.""" +import textwrap + +from google.cloud import bigquery + from bigframes.testing import mocks @@ -32,29 +36,39 @@ def test_read_gbq_colab_includes_label(): assert "session-read_gbq_colab" in label_values -def test_read_gbq_colab_includes_formatted_values_in_dry_run(): +def test_read_gbq_colab_includes_formatted_values_in_dry_run(monkeypatch): session = mocks.create_bigquery_session() + bf_df = mocks.create_dataframe(monkeypatch, session=session) + bf_df._to_view = lambda: bigquery.TableReference.from_string("my-project.my_dataset.some_view") # type: ignore pyformat_args = { "some_integer": 123, "some_string": "This could be dangerous, but we escape it", + "bf_df": bf_df, # This is not a supported type, but ignored if not referenced. "some_object": object(), } + _ = session._read_gbq_colab( - """ - SELECT {some_integer} as some_integer, - {some_string} as some_string, - '{{escaped}}' as escaped - """, + textwrap.dedent( + """ + SELECT {some_integer} as some_integer, + {some_string} as some_string, + '{{escaped}}' as escaped + FROM {bf_df} + """ + ), pyformat_args=pyformat_args, dry_run=True, ) - expected = """ + expected = textwrap.dedent( + """ SELECT 123 as some_integer, 'This could be dangerous, but we escape it' as some_string, '{escaped}' as escaped + FROM `my-project`.`my_dataset`.`some_view` """ + ) queries = session._queries # type: ignore configs = session._job_configs # type: ignore