diff --git a/bigframes/core/blocks.py b/bigframes/core/blocks.py index a587b8af31..6426b7b22b 100644 --- a/bigframes/core/blocks.py +++ b/bigframes/core/blocks.py @@ -22,7 +22,6 @@ from __future__ import annotations import ast -import copy import dataclasses import datetime import functools @@ -30,17 +29,7 @@ import random import textwrap import typing -from typing import ( - Any, - Iterable, - List, - Literal, - Mapping, - Optional, - Sequence, - Tuple, - Union, -) +from typing import Iterable, List, Literal, Mapping, Optional, Sequence, Tuple, Union import warnings import bigframes_vendored.constants as constants @@ -69,6 +58,7 @@ import bigframes.exceptions as bfe import bigframes.operations as ops import bigframes.operations.aggregations as agg_ops +from bigframes.session import dry_runs from bigframes.session import executor as executors # Type constraint for wherever column labels are used @@ -822,59 +812,18 @@ def _compute_dry_run( if sampling.enable_downsampling: raise NotImplementedError("Dry run with sampling is not supported") - index: List[Any] = [] - values: List[Any] = [] - - index.append("columnCount") - values.append(len(self.value_columns)) - index.append("columnDtypes") - values.append( - { - col: self.expr.get_column_type(self.resolve_label_exact_or_error(col)) - for col in self.column_labels - } - ) - - index.append("indexLevel") - values.append(self.index.nlevels) - index.append("indexDtypes") - values.append(self.index.dtypes) - expr = self._apply_value_keys_to_expr(value_keys=value_keys) query_job = self.session._executor.dry_run(expr, ordered) - job_api_repr = copy.deepcopy(query_job._properties) - - job_ref = job_api_repr["jobReference"] - for key, val in job_ref.items(): - index.append(key) - values.append(val) - - index.append("jobType") - values.append(job_api_repr["configuration"]["jobType"]) - - query_config = job_api_repr["configuration"]["query"] - for key in ("destinationTable", "useLegacySql"): - index.append(key) - values.append(query_config.get(key)) - - query_stats = job_api_repr["statistics"]["query"] - for key in ( - "referencedTables", - "totalBytesProcessed", - "cacheHit", - "statementType", - ): - index.append(key) - values.append(query_stats.get(key)) - index.append("creationTime") - values.append( - pd.Timestamp( - job_api_repr["statistics"]["creationTime"], unit="ms", tz="UTC" - ) - ) + column_dtypes = { + col: self.expr.get_column_type(self.resolve_label_exact_or_error(col)) + for col in self.column_labels + } - return pd.Series(values, index=index), query_job + dry_run_stats = dry_runs.get_query_stats_with_dtypes( + query_job, column_dtypes, self.index.dtypes + ) + return dry_run_stats, query_job def _apply_value_keys_to_expr(self, value_keys: Optional[Iterable[str]] = None): expr = self._expr diff --git a/bigframes/pandas/io/api.py b/bigframes/pandas/io/api.py index a119ff67b0..ecf8a59bb7 100644 --- a/bigframes/pandas/io/api.py +++ b/bigframes/pandas/io/api.py @@ -25,6 +25,7 @@ Literal, MutableSequence, Optional, + overload, Sequence, Tuple, Union, @@ -155,6 +156,38 @@ def read_json( read_json.__doc__ = inspect.getdoc(bigframes.session.Session.read_json) +@overload +def read_gbq( # type: ignore[overload-overlap] + query_or_table: str, + *, + index_col: Iterable[str] | str | bigframes.enums.DefaultIndexKind = ..., + columns: Iterable[str] = ..., + configuration: Optional[Dict] = ..., + max_results: Optional[int] = ..., + filters: vendored_pandas_gbq.FiltersType = ..., + use_cache: Optional[bool] = ..., + col_order: Iterable[str] = ..., + dry_run: Literal[False] = ..., +) -> bigframes.dataframe.DataFrame: + ... + + +@overload +def read_gbq( + query_or_table: str, + *, + index_col: Iterable[str] | str | bigframes.enums.DefaultIndexKind = ..., + columns: Iterable[str] = ..., + configuration: Optional[Dict] = ..., + max_results: Optional[int] = ..., + filters: vendored_pandas_gbq.FiltersType = ..., + use_cache: Optional[bool] = ..., + col_order: Iterable[str] = ..., + dry_run: Literal[True] = ..., +) -> pandas.Series: + ... + + def read_gbq( query_or_table: str, *, @@ -165,7 +198,8 @@ def read_gbq( filters: vendored_pandas_gbq.FiltersType = (), use_cache: Optional[bool] = None, col_order: Iterable[str] = (), -) -> bigframes.dataframe.DataFrame: + dry_run: bool = False, +) -> bigframes.dataframe.DataFrame | pandas.Series: _set_default_session_location_if_possible(query_or_table) return global_session.with_default_session( bigframes.session.Session.read_gbq, @@ -177,6 +211,7 @@ def read_gbq( filters=filters, use_cache=use_cache, col_order=col_order, + dry_run=dry_run, ) @@ -208,6 +243,38 @@ def read_gbq_object_table( ) +@overload +def read_gbq_query( # type: ignore[overload-overlap] + query: str, + *, + index_col: Iterable[str] | str | bigframes.enums.DefaultIndexKind = ..., + columns: Iterable[str] = ..., + configuration: Optional[Dict] = ..., + max_results: Optional[int] = ..., + use_cache: Optional[bool] = ..., + col_order: Iterable[str] = ..., + filters: vendored_pandas_gbq.FiltersType = ..., + dry_run: Literal[False] = ..., +) -> bigframes.dataframe.DataFrame: + ... + + +@overload +def read_gbq_query( + query: str, + *, + index_col: Iterable[str] | str | bigframes.enums.DefaultIndexKind = ..., + columns: Iterable[str] = ..., + configuration: Optional[Dict] = ..., + max_results: Optional[int] = ..., + use_cache: Optional[bool] = ..., + col_order: Iterable[str] = ..., + filters: vendored_pandas_gbq.FiltersType = ..., + dry_run: Literal[True] = ..., +) -> pandas.Series: + ... + + def read_gbq_query( query: str, *, @@ -218,7 +285,8 @@ def read_gbq_query( use_cache: Optional[bool] = None, col_order: Iterable[str] = (), filters: vendored_pandas_gbq.FiltersType = (), -) -> bigframes.dataframe.DataFrame: + dry_run: bool = False, +) -> bigframes.dataframe.DataFrame | pandas.Series: _set_default_session_location_if_possible(query) return global_session.with_default_session( bigframes.session.Session.read_gbq_query, @@ -230,12 +298,43 @@ def read_gbq_query( use_cache=use_cache, col_order=col_order, filters=filters, + dry_run=dry_run, ) read_gbq_query.__doc__ = inspect.getdoc(bigframes.session.Session.read_gbq_query) +@overload +def read_gbq_table( # type: ignore[overload-overlap] + query: str, + *, + index_col: Iterable[str] | str | bigframes.enums.DefaultIndexKind = ..., + columns: Iterable[str] = ..., + max_results: Optional[int] = ..., + filters: vendored_pandas_gbq.FiltersType = ..., + use_cache: bool = ..., + col_order: Iterable[str] = ..., + dry_run: Literal[False] = ..., +) -> bigframes.dataframe.DataFrame: + ... + + +@overload +def read_gbq_table( + query: str, + *, + index_col: Iterable[str] | str | bigframes.enums.DefaultIndexKind = ..., + columns: Iterable[str] = ..., + max_results: Optional[int] = ..., + filters: vendored_pandas_gbq.FiltersType = ..., + use_cache: bool = ..., + col_order: Iterable[str] = ..., + dry_run: Literal[True] = ..., +) -> pandas.Series: + ... + + def read_gbq_table( query: str, *, @@ -245,7 +344,8 @@ def read_gbq_table( filters: vendored_pandas_gbq.FiltersType = (), use_cache: bool = True, col_order: Iterable[str] = (), -) -> bigframes.dataframe.DataFrame: + dry_run: bool = False, +) -> bigframes.dataframe.DataFrame | pandas.Series: _set_default_session_location_if_possible(query) return global_session.with_default_session( bigframes.session.Session.read_gbq_table, @@ -256,6 +356,7 @@ def read_gbq_table( filters=filters, use_cache=use_cache, col_order=col_order, + dry_run=dry_run, ) diff --git a/bigframes/session/__init__.py b/bigframes/session/__init__.py index 81866e0d32..998e6e57bc 100644 --- a/bigframes/session/__init__.py +++ b/bigframes/session/__init__.py @@ -31,6 +31,7 @@ Literal, MutableSequence, Optional, + overload, Sequence, Tuple, Union, @@ -382,6 +383,38 @@ def close(self): self.bqclient, self.cloudfunctionsclient, self.session_id ) + @overload + def read_gbq( # type: ignore[overload-overlap] + self, + query_or_table: str, + *, + index_col: Iterable[str] | str | bigframes.enums.DefaultIndexKind = ..., + columns: Iterable[str] = ..., + configuration: Optional[Dict] = ..., + max_results: Optional[int] = ..., + filters: third_party_pandas_gbq.FiltersType = ..., + use_cache: Optional[bool] = ..., + col_order: Iterable[str] = ..., + dry_run: Literal[False] = ..., + ) -> dataframe.DataFrame: + ... + + @overload + def read_gbq( + self, + query_or_table: str, + *, + index_col: Iterable[str] | str | bigframes.enums.DefaultIndexKind = ..., + columns: Iterable[str] = ..., + configuration: Optional[Dict] = ..., + max_results: Optional[int] = ..., + filters: third_party_pandas_gbq.FiltersType = ..., + use_cache: Optional[bool] = ..., + col_order: Iterable[str] = ..., + dry_run: Literal[True] = ..., + ) -> pandas.Series: + ... + def read_gbq( self, query_or_table: str, @@ -393,8 +426,9 @@ def read_gbq( filters: third_party_pandas_gbq.FiltersType = (), use_cache: Optional[bool] = None, col_order: Iterable[str] = (), + dry_run: bool = False # Add a verify index argument that fails if the index is not unique. - ) -> dataframe.DataFrame: + ) -> dataframe.DataFrame | pandas.Series: # TODO(b/281571214): Generate prompt to show the progress of read_gbq. if columns and col_order: raise ValueError( @@ -404,7 +438,7 @@ def read_gbq( columns = col_order if bf_io_bigquery.is_query(query_or_table): - return self._loader.read_gbq_query( + return self._loader.read_gbq_query( # type: ignore # for dry_run overload query_or_table, index_col=index_col, columns=columns, @@ -413,6 +447,7 @@ def read_gbq( api_name="read_gbq", use_cache=use_cache, filters=filters, + dry_run=dry_run, ) else: if configuration is not None: @@ -422,7 +457,7 @@ def read_gbq( "'configuration' or use a query." ) - return self._loader.read_gbq_table( + return self._loader.read_gbq_table( # type: ignore # for dry_run overload query_or_table, index_col=index_col, columns=columns, @@ -430,6 +465,7 @@ def read_gbq( api_name="read_gbq", use_cache=use_cache if use_cache is not None else True, filters=filters, + dry_run=dry_run, ) def _register_object( @@ -440,6 +476,38 @@ def _register_object( ): self._objects.append(weakref.ref(object)) + @overload + def read_gbq_query( # type: ignore[overload-overlap] + self, + query: str, + *, + index_col: Iterable[str] | str | bigframes.enums.DefaultIndexKind = ..., + columns: Iterable[str] = ..., + configuration: Optional[Dict] = ..., + max_results: Optional[int] = ..., + use_cache: Optional[bool] = ..., + col_order: Iterable[str] = ..., + filters: third_party_pandas_gbq.FiltersType = ..., + dry_run: Literal[False] = ..., + ) -> dataframe.DataFrame: + ... + + @overload + def read_gbq_query( + self, + query: str, + *, + index_col: Iterable[str] | str | bigframes.enums.DefaultIndexKind = ..., + columns: Iterable[str] = ..., + configuration: Optional[Dict] = ..., + max_results: Optional[int] = ..., + use_cache: Optional[bool] = ..., + col_order: Iterable[str] = ..., + filters: third_party_pandas_gbq.FiltersType = ..., + dry_run: Literal[True] = ..., + ) -> pandas.Series: + ... + def read_gbq_query( self, query: str, @@ -451,7 +519,8 @@ def read_gbq_query( use_cache: Optional[bool] = None, col_order: Iterable[str] = (), filters: third_party_pandas_gbq.FiltersType = (), - ) -> dataframe.DataFrame: + dry_run: bool = False, + ) -> dataframe.DataFrame | pandas.Series: """Turn a SQL query into a DataFrame. Note: Because the results are written to a temporary table, ordering by @@ -517,7 +586,7 @@ def read_gbq_query( elif col_order: columns = col_order - return self._loader.read_gbq_query( + return self._loader.read_gbq_query( # type: ignore # for dry_run overload query=query, index_col=index_col, columns=columns, @@ -526,8 +595,39 @@ def read_gbq_query( api_name="read_gbq_query", use_cache=use_cache, filters=filters, + dry_run=dry_run, ) + @overload + def read_gbq_table( # type: ignore[overload-overlap] + self, + query: str, + *, + index_col: Iterable[str] | str | bigframes.enums.DefaultIndexKind = ..., + columns: Iterable[str] = ..., + max_results: Optional[int] = ..., + filters: third_party_pandas_gbq.FiltersType = ..., + use_cache: bool = ..., + col_order: Iterable[str] = ..., + dry_run: Literal[False] = ..., + ) -> dataframe.DataFrame: + ... + + @overload + def read_gbq_table( + self, + query: str, + *, + index_col: Iterable[str] | str | bigframes.enums.DefaultIndexKind = ..., + columns: Iterable[str] = ..., + max_results: Optional[int] = ..., + filters: third_party_pandas_gbq.FiltersType = ..., + use_cache: bool = ..., + col_order: Iterable[str] = ..., + dry_run: Literal[True] = ..., + ) -> pandas.Series: + ... + def read_gbq_table( self, query: str, @@ -538,7 +638,8 @@ def read_gbq_table( filters: third_party_pandas_gbq.FiltersType = (), use_cache: bool = True, col_order: Iterable[str] = (), - ) -> dataframe.DataFrame: + dry_run: bool = False, + ) -> dataframe.DataFrame | pandas.Series: """Turn a BigQuery table into a DataFrame. **Examples:** @@ -569,7 +670,7 @@ def read_gbq_table( elif col_order: columns = col_order - return self._loader.read_gbq_table( + return self._loader.read_gbq_table( # type: ignore # for dry_run overload table_id=query, index_col=index_col, columns=columns, @@ -577,6 +678,7 @@ def read_gbq_table( api_name="read_gbq_table", use_cache=use_cache, filters=filters, + dry_run=dry_run, ) def read_gbq_table_streaming( diff --git a/bigframes/session/dry_runs.py b/bigframes/session/dry_runs.py new file mode 100644 index 0000000000..4d5b41345e --- /dev/null +++ b/bigframes/session/dry_runs.py @@ -0,0 +1,134 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import copy +from typing import Any, Dict, List, Sequence + +from google.cloud import bigquery +import pandas + +from bigframes import dtypes + + +def get_table_stats(table: bigquery.Table) -> pandas.Series: + values: List[Any] = [] + index: List[Any] = [] + + # Indicate that no query is executed. + index.append("isQuery") + values.append(False) + + # Populate column and index types + col_dtypes = dtypes.bf_type_from_type_kind(table.schema) + index.append("columnCount") + values.append(len(col_dtypes)) + index.append("columnDtypes") + values.append(col_dtypes) + + for key in ("numBytes", "numRows", "location", "type"): + index.append(key) + values.append(table._properties[key]) + + index.append("creationTime") + values.append(table.created) + + index.append("lastModifiedTime") + values.append(table.modified) + + return pandas.Series(values, index=index) + + +def get_query_stats_with_inferred_dtypes( + query_job: bigquery.QueryJob, + value_cols: Sequence[str], + index_cols: Sequence[str], +) -> pandas.Series: + if query_job.schema is None: + # If the schema is not available, don't bother inferring dtypes. + return get_query_stats(query_job) + + col_dtypes = dtypes.bf_type_from_type_kind(query_job.schema) + + if value_cols: + value_col_dtypes = { + col: col_dtypes[col] for col in value_cols if col in col_dtypes + } + else: + # Use every column that is not mentioned as an index column + value_col_dtypes = { + col: dtype + for col, dtype in col_dtypes.items() + if col not in set(index_cols) + } + + index_dtypes = [col_dtypes[col] for col in index_cols] + + return get_query_stats_with_dtypes(query_job, value_col_dtypes, index_dtypes) + + +def get_query_stats_with_dtypes( + query_job: bigquery.QueryJob, + column_dtypes: Dict[str, dtypes.Dtype], + index_dtypes: Sequence[dtypes.Dtype], +) -> pandas.Series: + index = ["columnCount", "columnDtypes", "indexLevel", "indexDtypes"] + values = [len(column_dtypes), column_dtypes, len(index_dtypes), index_dtypes] + + s = pandas.Series(values, index=index) + + return pandas.concat([s, get_query_stats(query_job)]) + + +def get_query_stats( + query_job: bigquery.QueryJob, +) -> pandas.Series: + """Returns important stats from the query job as a Pandas Series.""" + + index = [] + values = [] + + job_api_repr = copy.deepcopy(query_job._properties) + + job_ref = job_api_repr["jobReference"] + for key, val in job_ref.items(): + index.append(key) + values.append(val) + + index.append("jobType") + values.append(job_api_repr["configuration"]["jobType"]) + + query_config = job_api_repr["configuration"]["query"] + for key in ("destinationTable", "useLegacySql"): + index.append(key) + values.append(query_config.get(key)) + + query_stats = job_api_repr["statistics"]["query"] + for key in ( + "referencedTables", + "totalBytesProcessed", + "cacheHit", + "statementType", + ): + index.append(key) + values.append(query_stats.get(key)) + + index.append("creationTime") + values.append( + pandas.Timestamp( + job_api_repr["statistics"]["creationTime"], unit="ms", tz="UTC" + ) + ) + + return pandas.Series(values, index=index) diff --git a/bigframes/session/loader.py b/bigframes/session/loader.py index 2d554737ee..f748f0fd76 100644 --- a/bigframes/session/loader.py +++ b/bigframes/session/loader.py @@ -30,6 +30,7 @@ List, Literal, Optional, + overload, Sequence, Tuple, ) @@ -49,6 +50,7 @@ import bigframes.core.schema as schemata import bigframes.dtypes import bigframes.formatting_helpers as formatting_helpers +from bigframes.session import dry_runs import bigframes.session._io.bigquery as bf_io_bigquery import bigframes.session._io.bigquery.read_gbq_table as bf_read_gbq_table import bigframes.session.metrics @@ -353,6 +355,48 @@ def _start_generic_job(self, job: formatting_helpers.GenericJob): else: job.result() + @overload + def read_gbq_table( # type: ignore[overload-overlap] + self, + table_id: str, + *, + index_col: Iterable[str] + | str + | Iterable[int] + | int + | bigframes.enums.DefaultIndexKind = ..., + columns: Iterable[str] = ..., + names: Optional[Iterable[str]] = ..., + max_results: Optional[int] = ..., + api_name: str = ..., + use_cache: bool = ..., + filters: third_party_pandas_gbq.FiltersType = ..., + enable_snapshot: bool = ..., + dry_run: Literal[False] = ..., + ) -> dataframe.DataFrame: + ... + + @overload + def read_gbq_table( + self, + table_id: str, + *, + index_col: Iterable[str] + | str + | Iterable[int] + | int + | bigframes.enums.DefaultIndexKind = ..., + columns: Iterable[str] = ..., + names: Optional[Iterable[str]] = ..., + max_results: Optional[int] = ..., + api_name: str = ..., + use_cache: bool = ..., + filters: third_party_pandas_gbq.FiltersType = ..., + enable_snapshot: bool = ..., + dry_run: Literal[True] = ..., + ) -> pandas.Series: + ... + def read_gbq_table( self, table_id: str, @@ -369,7 +413,8 @@ def read_gbq_table( use_cache: bool = True, filters: third_party_pandas_gbq.FiltersType = (), enable_snapshot: bool = True, - ) -> dataframe.DataFrame: + dry_run: bool = False, + ) -> dataframe.DataFrame | pandas.Series: import bigframes._tools.strings import bigframes.dataframe as dataframe @@ -495,14 +540,18 @@ def read_gbq_table( time_travel_timestamp=None, ) - return self.read_gbq_query( + return self.read_gbq_query( # type: ignore # for dry_run overload query, index_col=index_cols, columns=columns, api_name=api_name, use_cache=use_cache, + dry_run=dry_run, ) + if dry_run: + return dry_runs.get_table_stats(table) + # ----------------------------------------- # Validate table access and features # ----------------------------------------- @@ -653,6 +702,38 @@ def load_file( table_id = f"{table.project}.{table.dataset_id}.{table.table_id}" return table_id + @overload + def read_gbq_query( # type: ignore[overload-overlap] + self, + query: str, + *, + index_col: Iterable[str] | str | bigframes.enums.DefaultIndexKind = ..., + columns: Iterable[str] = ..., + configuration: Optional[Dict] = ..., + max_results: Optional[int] = ..., + api_name: str = ..., + use_cache: Optional[bool] = ..., + filters: third_party_pandas_gbq.FiltersType = ..., + dry_run: Literal[False] = ..., + ) -> dataframe.DataFrame: + ... + + @overload + def read_gbq_query( + self, + query: str, + *, + index_col: Iterable[str] | str | bigframes.enums.DefaultIndexKind = ..., + columns: Iterable[str] = ..., + configuration: Optional[Dict] = ..., + max_results: Optional[int] = ..., + api_name: str = ..., + use_cache: Optional[bool] = ..., + filters: third_party_pandas_gbq.FiltersType = ..., + dry_run: Literal[True] = ..., + ) -> pandas.Series: + ... + def read_gbq_query( self, query: str, @@ -664,7 +745,8 @@ def read_gbq_query( api_name: str = "read_gbq_query", use_cache: Optional[bool] = None, filters: third_party_pandas_gbq.FiltersType = (), - ) -> dataframe.DataFrame: + dry_run: bool = False, + ) -> dataframe.DataFrame | pandas.Series: import bigframes.dataframe as dataframe configuration = _transform_read_gbq_configuration(configuration) @@ -710,6 +792,17 @@ def read_gbq_query( time_travel_timestamp=None, ) + if dry_run: + job_config = typing.cast( + bigquery.QueryJobConfig, + bigquery.QueryJobConfig.from_api_repr(configuration), + ) + job_config.dry_run = True + query_job = self._bqclient.query(query, job_config=job_config) + return dry_runs.get_query_stats_with_inferred_dtypes( + query_job, list(columns), index_cols + ) + # No cluster candidates as user query might not be clusterable (eg because of ORDER BY clause) destination, query_job = self._query_to_destination( query, diff --git a/tests/system/small/test_session.py b/tests/system/small/test_session.py index ced01c940f..ad01a95509 100644 --- a/tests/system/small/test_session.py +++ b/tests/system/small/test_session.py @@ -1831,3 +1831,100 @@ def test_read_gbq_duplicate_columns_xfail( index_col=index_col, columns=columns, ) + + +def test_read_gbq_with_table_ref_dry_run(scalars_table_id, session): + result = session.read_gbq(scalars_table_id, dry_run=True) + + assert isinstance(result, pd.Series) + _assert_table_dry_run_stats_are_valid(result) + + +def test_read_gbq_with_query_dry_run(scalars_table_id, session): + query = f"SELECT * FROM {scalars_table_id} LIMIT 10;" + result = session.read_gbq(query, dry_run=True) + + assert isinstance(result, pd.Series) + _assert_query_dry_run_stats_are_valid(result) + + +def test_read_gbq_dry_run_with_column_and_index(scalars_table_id, session): + query = f"SELECT * FROM {scalars_table_id} LIMIT 10;" + result = session.read_gbq( + query, dry_run=True, columns=["int64_col", "float64_col"], index_col="int64_too" + ) + + assert isinstance(result, pd.Series) + _assert_query_dry_run_stats_are_valid(result) + assert result["columnCount"] == 2 + assert result["columnDtypes"] == { + "int64_col": pd.Int64Dtype(), + "float64_col": pd.Float64Dtype(), + } + assert result["indexLevel"] == 1 + assert result["indexDtypes"] == [pd.Int64Dtype()] + + +def test_read_gbq_table_dry_run(scalars_table_id, session): + result = session.read_gbq_table(scalars_table_id, dry_run=True) + + assert isinstance(result, pd.Series) + _assert_table_dry_run_stats_are_valid(result) + + +def test_read_gbq_table_dry_run_with_max_results(scalars_table_id, session): + result = session.read_gbq_table(scalars_table_id, dry_run=True, max_results=100) + + assert isinstance(result, pd.Series) + _assert_query_dry_run_stats_are_valid(result) + + +def test_read_gbq_query_dry_run(scalars_table_id, session): + query = f"SELECT * FROM {scalars_table_id} LIMIT 10;" + result = session.read_gbq_query(query, dry_run=True) + + assert isinstance(result, pd.Series) + _assert_query_dry_run_stats_are_valid(result) + + +def _assert_query_dry_run_stats_are_valid(result: pd.Series): + expected_index = pd.Index( + [ + "columnCount", + "columnDtypes", + "indexLevel", + "indexDtypes", + "projectId", + "location", + "jobType", + "destinationTable", + "useLegacySql", + "referencedTables", + "totalBytesProcessed", + "cacheHit", + "statementType", + "creationTime", + ] + ) + + pd.testing.assert_index_equal(result.index, expected_index) + assert result["columnCount"] + result["indexLevel"] > 0 + + +def _assert_table_dry_run_stats_are_valid(result: pd.Series): + expected_index = pd.Index( + [ + "isQuery", + "columnCount", + "columnDtypes", + "numBytes", + "numRows", + "location", + "type", + "creationTime", + "lastModifiedTime", + ] + ) + + pd.testing.assert_index_equal(result.index, expected_index) + assert result["columnCount"] == len(result["columnDtypes"])