diff --git a/google/cloud/__init__.py b/google/cloud/__init__.py index 8fcc60e2b..8e60d8439 100644 --- a/google/cloud/__init__.py +++ b/google/cloud/__init__.py @@ -21,4 +21,4 @@ except ImportError: import pkgutil - __path__ = pkgutil.extend_path(__path__, __name__) + __path__ = pkgutil.extend_path(__path__, __name__) # type: ignore diff --git a/google/cloud/bigquery/_helpers.py b/google/cloud/bigquery/_helpers.py index e95d38545..e2ca7fa07 100644 --- a/google/cloud/bigquery/_helpers.py +++ b/google/cloud/bigquery/_helpers.py @@ -22,7 +22,7 @@ from typing import Any, Optional, Union from dateutil import relativedelta -from google.cloud._helpers import UTC +from google.cloud._helpers import UTC # type: ignore from google.cloud._helpers import _date_from_iso8601_date from google.cloud._helpers import _datetime_from_microseconds from google.cloud._helpers import _RFC3339_MICROS @@ -126,7 +126,7 @@ def __init__(self): def installed_version(self) -> packaging.version.Version: """Return the parsed version of pyarrow.""" if self._installed_version is None: - import pyarrow + import pyarrow # type: ignore self._installed_version = packaging.version.parse( # Use 0.0.0, since it is earlier than any released version. diff --git a/google/cloud/bigquery/_http.py b/google/cloud/bigquery/_http.py index 81e7922e6..f7207f32e 100644 --- a/google/cloud/bigquery/_http.py +++ b/google/cloud/bigquery/_http.py @@ -17,7 +17,7 @@ import os import pkg_resources -from google.cloud import _http # pytype: disable=import-error +from google.cloud import _http # type: ignore # pytype: disable=import-error from google.cloud.bigquery import __version__ diff --git a/google/cloud/bigquery/_pandas_helpers.py b/google/cloud/bigquery/_pandas_helpers.py index 0cb851469..de6356c2a 100644 --- a/google/cloud/bigquery/_pandas_helpers.py +++ b/google/cloud/bigquery/_pandas_helpers.py @@ -21,7 +21,7 @@ import warnings try: - import pandas + import pandas # type: ignore except ImportError: # pragma: NO COVER pandas = None else: @@ -29,7 +29,7 @@ try: # _BaseGeometry is used to detect shapely objevys in `bq_to_arrow_array` - from shapely.geometry.base import BaseGeometry as _BaseGeometry + from shapely.geometry.base import BaseGeometry as _BaseGeometry # type: ignore except ImportError: # pragma: NO COVER # No shapely, use NoneType for _BaseGeometry as a placeholder. _BaseGeometry = type(None) @@ -43,7 +43,7 @@ def _to_wkb(): # - Avoid extra work done by `shapely.wkb.dumps` that we don't need. # - Caches the WKBWriter (and write method lookup :) ) # - Avoids adding WKBWriter, lgeos, and notnull to the module namespace. - from shapely.geos import WKBWriter, lgeos + from shapely.geos import WKBWriter, lgeos # type: ignore write = WKBWriter(lgeos).write notnull = pandas.notnull @@ -574,7 +574,7 @@ def dataframe_to_parquet( """ pyarrow = _helpers.PYARROW_VERSIONS.try_import(raise_if_error=True) - import pyarrow.parquet + import pyarrow.parquet # type: ignore kwargs = ( {"use_compliant_nested_type": parquet_use_compliant_nested_type} diff --git a/google/cloud/bigquery/_tqdm_helpers.py b/google/cloud/bigquery/_tqdm_helpers.py index 99e720e2b..632f70f87 100644 --- a/google/cloud/bigquery/_tqdm_helpers.py +++ b/google/cloud/bigquery/_tqdm_helpers.py @@ -21,7 +21,7 @@ import warnings try: - import tqdm + import tqdm # type: ignore except ImportError: # pragma: NO COVER tqdm = None diff --git a/google/cloud/bigquery/client.py b/google/cloud/bigquery/client.py index 4bdd43e8f..3e641e195 100644 --- a/google/cloud/bigquery/client.py +++ b/google/cloud/bigquery/client.py @@ -28,12 +28,23 @@ import math import os import tempfile -from typing import Any, BinaryIO, Dict, Iterable, Optional, Sequence, Tuple, Union +import typing +from typing import ( + Any, + BinaryIO, + Dict, + Iterable, + List, + Optional, + Sequence, + Tuple, + Union, +) import uuid import warnings from google import resumable_media # type: ignore -from google.resumable_media.requests import MultipartUpload +from google.resumable_media.requests import MultipartUpload # type: ignore from google.resumable_media.requests import ResumableUpload import google.api_core.client_options @@ -41,16 +52,16 @@ from google.api_core.iam import Policy from google.api_core import page_iterator from google.api_core import retry as retries -import google.cloud._helpers +import google.cloud._helpers # type: ignore from google.cloud import exceptions # pytype: disable=import-error -from google.cloud.client import ClientWithProject # pytype: disable=import-error +from google.cloud.client import ClientWithProject # type: ignore # pytype: disable=import-error try: from google.cloud.bigquery_storage_v1.services.big_query_read.client import ( DEFAULT_CLIENT_INFO as DEFAULT_BQSTORAGE_CLIENT_INFO, ) except ImportError: - DEFAULT_BQSTORAGE_CLIENT_INFO = None + DEFAULT_BQSTORAGE_CLIENT_INFO = None # type: ignore from google.cloud.bigquery._helpers import _del_sub_prop from google.cloud.bigquery._helpers import _get_sub_prop @@ -100,6 +111,11 @@ pyarrow = _helpers.PYARROW_VERSIONS.try_import() +TimeoutType = Union[float, None] + +if typing.TYPE_CHECKING: # pragma: NO COVER + # os.PathLike is only subscriptable in Python 3.9+, thus shielding with a condition. + PathType = Union[str, bytes, os.PathLike[str], os.PathLike[bytes]] _DEFAULT_CHUNKSIZE = 100 * 1024 * 1024 # 100 MB _MAX_MULTIPART_SIZE = 5 * 1024 * 1024 @@ -248,7 +264,7 @@ def get_service_account_email( self, project: str = None, retry: retries.Retry = DEFAULT_RETRY, - timeout: float = DEFAULT_TIMEOUT, + timeout: TimeoutType = DEFAULT_TIMEOUT, ) -> str: """Get the email address of the project's BigQuery service account @@ -295,7 +311,7 @@ def list_projects( max_results: int = None, page_token: str = None, retry: retries.Retry = DEFAULT_RETRY, - timeout: float = DEFAULT_TIMEOUT, + timeout: TimeoutType = DEFAULT_TIMEOUT, page_size: int = None, ) -> page_iterator.Iterator: """List projects for the project associated with this client. @@ -361,7 +377,7 @@ def list_datasets( max_results: int = None, page_token: str = None, retry: retries.Retry = DEFAULT_RETRY, - timeout: float = DEFAULT_TIMEOUT, + timeout: TimeoutType = DEFAULT_TIMEOUT, page_size: int = None, ) -> page_iterator.Iterator: """List datasets for the project associated with this client. @@ -400,7 +416,7 @@ def list_datasets( Iterator of :class:`~google.cloud.bigquery.dataset.DatasetListItem`. associated with the project. """ - extra_params = {} + extra_params: Dict[str, Any] = {} if project is None: project = self.project if include_all: @@ -526,12 +542,12 @@ def _ensure_bqstorage_client( bqstorage_client = bigquery_storage.BigQueryReadClient( credentials=self._credentials, client_options=client_options, - client_info=client_info, + client_info=client_info, # type: ignore # (None is also accepted) ) return bqstorage_client - def _dataset_from_arg(self, dataset): + def _dataset_from_arg(self, dataset) -> Union[Dataset, DatasetReference]: if isinstance(dataset, str): dataset = DatasetReference.from_string( dataset, default_project=self.project @@ -552,7 +568,7 @@ def create_dataset( dataset: Union[str, Dataset, DatasetReference, DatasetListItem], exists_ok: bool = False, retry: retries.Retry = DEFAULT_RETRY, - timeout: float = DEFAULT_TIMEOUT, + timeout: TimeoutType = DEFAULT_TIMEOUT, ) -> Dataset: """API call: create the dataset via a POST request. @@ -627,7 +643,7 @@ def create_routine( routine: Routine, exists_ok: bool = False, retry: retries.Retry = DEFAULT_RETRY, - timeout: float = DEFAULT_TIMEOUT, + timeout: TimeoutType = DEFAULT_TIMEOUT, ) -> Routine: """[Beta] Create a routine via a POST request. @@ -682,7 +698,7 @@ def create_table( table: Union[str, Table, TableReference, TableListItem], exists_ok: bool = False, retry: retries.Retry = DEFAULT_RETRY, - timeout: float = DEFAULT_TIMEOUT, + timeout: TimeoutType = DEFAULT_TIMEOUT, ) -> Table: """API call: create a table via a PUT request @@ -765,7 +781,7 @@ def get_dataset( self, dataset_ref: Union[DatasetReference, str], retry: retries.Retry = DEFAULT_RETRY, - timeout: float = DEFAULT_TIMEOUT, + timeout: TimeoutType = DEFAULT_TIMEOUT, ) -> Dataset: """Fetch the dataset referenced by ``dataset_ref`` @@ -809,7 +825,7 @@ def get_iam_policy( table: Union[Table, TableReference, TableListItem, str], requested_policy_version: int = 1, retry: retries.Retry = DEFAULT_RETRY, - timeout: float = DEFAULT_TIMEOUT, + timeout: TimeoutType = DEFAULT_TIMEOUT, ) -> Policy: table = _table_arg_to_table_ref(table, default_project=self.project) @@ -838,7 +854,7 @@ def set_iam_policy( policy: Policy, updateMask: str = None, retry: retries.Retry = DEFAULT_RETRY, - timeout: float = DEFAULT_TIMEOUT, + timeout: TimeoutType = DEFAULT_TIMEOUT, ) -> Policy: table = _table_arg_to_table_ref(table, default_project=self.project) @@ -870,7 +886,7 @@ def test_iam_permissions( table: Union[Table, TableReference, TableListItem, str], permissions: Sequence[str], retry: retries.Retry = DEFAULT_RETRY, - timeout: float = DEFAULT_TIMEOUT, + timeout: TimeoutType = DEFAULT_TIMEOUT, ) -> Dict[str, Any]: table = _table_arg_to_table_ref(table, default_project=self.project) @@ -894,7 +910,7 @@ def get_model( self, model_ref: Union[ModelReference, str], retry: retries.Retry = DEFAULT_RETRY, - timeout: float = DEFAULT_TIMEOUT, + timeout: TimeoutType = DEFAULT_TIMEOUT, ) -> Model: """[Beta] Fetch the model referenced by ``model_ref``. @@ -937,7 +953,7 @@ def get_routine( self, routine_ref: Union[Routine, RoutineReference, str], retry: retries.Retry = DEFAULT_RETRY, - timeout: float = DEFAULT_TIMEOUT, + timeout: TimeoutType = DEFAULT_TIMEOUT, ) -> Routine: """[Beta] Get the routine referenced by ``routine_ref``. @@ -981,7 +997,7 @@ def get_table( self, table: Union[Table, TableReference, TableListItem, str], retry: retries.Retry = DEFAULT_RETRY, - timeout: float = DEFAULT_TIMEOUT, + timeout: TimeoutType = DEFAULT_TIMEOUT, ) -> Table: """Fetch the table referenced by ``table``. @@ -1024,7 +1040,7 @@ def update_dataset( dataset: Dataset, fields: Sequence[str], retry: retries.Retry = DEFAULT_RETRY, - timeout: float = DEFAULT_TIMEOUT, + timeout: TimeoutType = DEFAULT_TIMEOUT, ) -> Dataset: """Change some fields of a dataset. @@ -1071,7 +1087,7 @@ def update_dataset( """ partial = dataset._build_resource(fields) if dataset.etag is not None: - headers = {"If-Match": dataset.etag} + headers: Optional[Dict[str, str]] = {"If-Match": dataset.etag} else: headers = None path = dataset.path @@ -1094,7 +1110,7 @@ def update_model( model: Model, fields: Sequence[str], retry: retries.Retry = DEFAULT_RETRY, - timeout: float = DEFAULT_TIMEOUT, + timeout: TimeoutType = DEFAULT_TIMEOUT, ) -> Model: """[Beta] Change some fields of a model. @@ -1135,7 +1151,7 @@ def update_model( """ partial = model._build_resource(fields) if model.etag: - headers = {"If-Match": model.etag} + headers: Optional[Dict[str, str]] = {"If-Match": model.etag} else: headers = None path = model.path @@ -1158,7 +1174,7 @@ def update_routine( routine: Routine, fields: Sequence[str], retry: retries.Retry = DEFAULT_RETRY, - timeout: float = DEFAULT_TIMEOUT, + timeout: TimeoutType = DEFAULT_TIMEOUT, ) -> Routine: """[Beta] Change some fields of a routine. @@ -1205,7 +1221,7 @@ def update_routine( """ partial = routine._build_resource(fields) if routine.etag: - headers = {"If-Match": routine.etag} + headers: Optional[Dict[str, str]] = {"If-Match": routine.etag} else: headers = None @@ -1232,7 +1248,7 @@ def update_table( table: Table, fields: Sequence[str], retry: retries.Retry = DEFAULT_RETRY, - timeout: float = DEFAULT_TIMEOUT, + timeout: TimeoutType = DEFAULT_TIMEOUT, ) -> Table: """Change some fields of a table. @@ -1273,7 +1289,7 @@ def update_table( """ partial = table._build_resource(fields) if table.etag is not None: - headers = {"If-Match": table.etag} + headers: Optional[Dict[str, str]] = {"If-Match": table.etag} else: headers = None @@ -1298,7 +1314,7 @@ def list_models( max_results: int = None, page_token: str = None, retry: retries.Retry = DEFAULT_RETRY, - timeout: float = DEFAULT_TIMEOUT, + timeout: TimeoutType = DEFAULT_TIMEOUT, page_size: int = None, ) -> page_iterator.Iterator: """[Beta] List models in the dataset. @@ -1366,7 +1382,7 @@ def api_request(*args, **kwargs): max_results=max_results, page_size=page_size, ) - result.dataset = dataset + result.dataset = dataset # type: ignore return result def list_routines( @@ -1375,7 +1391,7 @@ def list_routines( max_results: int = None, page_token: str = None, retry: retries.Retry = DEFAULT_RETRY, - timeout: float = DEFAULT_TIMEOUT, + timeout: TimeoutType = DEFAULT_TIMEOUT, page_size: int = None, ) -> page_iterator.Iterator: """[Beta] List routines in the dataset. @@ -1443,7 +1459,7 @@ def api_request(*args, **kwargs): max_results=max_results, page_size=page_size, ) - result.dataset = dataset + result.dataset = dataset # type: ignore return result def list_tables( @@ -1452,7 +1468,7 @@ def list_tables( max_results: int = None, page_token: str = None, retry: retries.Retry = DEFAULT_RETRY, - timeout: float = DEFAULT_TIMEOUT, + timeout: TimeoutType = DEFAULT_TIMEOUT, page_size: int = None, ) -> page_iterator.Iterator: """List tables in the dataset. @@ -1519,7 +1535,7 @@ def api_request(*args, **kwargs): max_results=max_results, page_size=page_size, ) - result.dataset = dataset + result.dataset = dataset # type: ignore return result def delete_dataset( @@ -1527,7 +1543,7 @@ def delete_dataset( dataset: Union[Dataset, DatasetReference, DatasetListItem, str], delete_contents: bool = False, retry: retries.Retry = DEFAULT_RETRY, - timeout: float = DEFAULT_TIMEOUT, + timeout: TimeoutType = DEFAULT_TIMEOUT, not_found_ok: bool = False, ) -> None: """Delete a dataset. @@ -1586,7 +1602,7 @@ def delete_model( self, model: Union[Model, ModelReference, str], retry: retries.Retry = DEFAULT_RETRY, - timeout: float = DEFAULT_TIMEOUT, + timeout: TimeoutType = DEFAULT_TIMEOUT, not_found_ok: bool = False, ) -> None: """[Beta] Delete a model @@ -1640,7 +1656,7 @@ def delete_job_metadata( project: Optional[str] = None, location: Optional[str] = None, retry: retries.Retry = DEFAULT_RETRY, - timeout: float = DEFAULT_TIMEOUT, + timeout: TimeoutType = DEFAULT_TIMEOUT, not_found_ok: bool = False, ): """[Beta] Delete job metadata from job history. @@ -1703,7 +1719,7 @@ def delete_routine( self, routine: Union[Routine, RoutineReference, str], retry: retries.Retry = DEFAULT_RETRY, - timeout: float = DEFAULT_TIMEOUT, + timeout: TimeoutType = DEFAULT_TIMEOUT, not_found_ok: bool = False, ) -> None: """[Beta] Delete a routine. @@ -1757,7 +1773,7 @@ def delete_table( self, table: Union[Table, TableReference, TableListItem, str], retry: retries.Retry = DEFAULT_RETRY, - timeout: float = DEFAULT_TIMEOUT, + timeout: TimeoutType = DEFAULT_TIMEOUT, not_found_ok: bool = False, ) -> None: """Delete a table @@ -1811,7 +1827,7 @@ def _get_query_results( project: str = None, timeout_ms: int = None, location: str = None, - timeout: float = DEFAULT_TIMEOUT, + timeout: TimeoutType = DEFAULT_TIMEOUT, ) -> _QueryResults: """Get the query results object for a query job. @@ -1836,7 +1852,7 @@ def _get_query_results( A new ``_QueryResults`` instance. """ - extra_params = {"maxResults": 0} + extra_params: Dict[str, Any] = {"maxResults": 0} if timeout is not None: timeout = max(timeout, _MIN_GET_QUERY_RESULTS_TIMEOUT) @@ -1870,20 +1886,18 @@ def _get_query_results( ) return _QueryResults.from_api_repr(resource) - def job_from_resource(self, resource: dict) -> job.UnknownJob: + def job_from_resource( + self, resource: dict + ) -> Union[ + job.CopyJob, job.ExtractJob, job.LoadJob, job.QueryJob, job.UnknownJob, + ]: """Detect correct job type from resource and instantiate. Args: resource (Dict): one job resource from API response Returns: - Union[ \ - google.cloud.bigquery.job.LoadJob, \ - google.cloud.bigquery.job.CopyJob, \ - google.cloud.bigquery.job.ExtractJob, \ - google.cloud.bigquery.job.QueryJob \ - ]: - The job instance, constructed via the resource. + The job instance, constructed via the resource. """ config = resource.get("configuration", {}) if "load" in config: @@ -1900,7 +1914,7 @@ def create_job( self, job_config: dict, retry: retries.Retry = DEFAULT_RETRY, - timeout: float = DEFAULT_TIMEOUT, + timeout: TimeoutType = DEFAULT_TIMEOUT, ) -> Union[job.LoadJob, job.CopyJob, job.ExtractJob, job.QueryJob]: """Create a new job. Args: @@ -1933,7 +1947,7 @@ def create_job( return self.load_table_from_uri( source_uris, destination, - job_config=load_job_config, + job_config=typing.cast(LoadJobConfig, load_job_config), retry=retry, timeout=timeout, ) @@ -1953,7 +1967,7 @@ def create_job( return self.copy_table( sources, destination, - job_config=copy_job_config, + job_config=typing.cast(CopyJobConfig, copy_job_config), retry=retry, timeout=timeout, ) @@ -1973,7 +1987,7 @@ def create_job( return self.extract_table( source, destination_uris, - job_config=extract_job_config, + job_config=typing.cast(ExtractJobConfig, extract_job_config), retry=retry, timeout=timeout, source_type=source_type, @@ -1986,32 +2000,30 @@ def create_job( ) query = _get_sub_prop(copy_config, ["query", "query"]) return self.query( - query, job_config=query_job_config, retry=retry, timeout=timeout + query, + job_config=typing.cast(QueryJobConfig, query_job_config), + retry=retry, + timeout=timeout, ) else: raise TypeError("Invalid job configuration received.") def get_job( self, - job_id: str, + job_id: Union[str, job.LoadJob, job.CopyJob, job.ExtractJob, job.QueryJob], project: str = None, location: str = None, retry: retries.Retry = DEFAULT_RETRY, - timeout: float = DEFAULT_TIMEOUT, - ) -> Union[job.LoadJob, job.CopyJob, job.ExtractJob, job.QueryJob]: + timeout: TimeoutType = DEFAULT_TIMEOUT, + ) -> Union[job.LoadJob, job.CopyJob, job.ExtractJob, job.QueryJob, job.UnknownJob]: """Fetch a job for the project associated with this client. See https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs/get Args: - job_id (Union[ \ - str, \ - google.cloud.bigquery.job.LoadJob, \ - google.cloud.bigquery.job.CopyJob, \ - google.cloud.bigquery.job.ExtractJob, \ - google.cloud.bigquery.job.QueryJob \ - ]): Job identifier. + job_id: + Job identifier. Keyword Arguments: project (Optional[str]): @@ -2026,13 +2038,7 @@ def get_job( before using ``retry``. Returns: - Union[ \ - google.cloud.bigquery.job.LoadJob, \ - google.cloud.bigquery.job.CopyJob, \ - google.cloud.bigquery.job.ExtractJob, \ - google.cloud.bigquery.job.QueryJob \ - ]: - Job instance, based on the resource returned by the API. + Job instance, based on the resource returned by the API. """ extra_params = {"projection": "full"} @@ -2071,7 +2077,7 @@ def cancel_job( project: str = None, location: str = None, retry: retries.Retry = DEFAULT_RETRY, - timeout: float = DEFAULT_TIMEOUT, + timeout: TimeoutType = DEFAULT_TIMEOUT, ) -> Union[job.LoadJob, job.CopyJob, job.ExtractJob, job.QueryJob]: """Attempt to cancel a job from a job ID. @@ -2137,7 +2143,11 @@ def cancel_job( timeout=timeout, ) - return self.job_from_resource(resource["job"]) + job_instance = self.job_from_resource(resource["job"]) # never an UnknownJob + + return typing.cast( + Union[job.LoadJob, job.CopyJob, job.ExtractJob, job.QueryJob], job_instance, + ) def list_jobs( self, @@ -2148,7 +2158,7 @@ def list_jobs( all_users: bool = None, state_filter: str = None, retry: retries.Retry = DEFAULT_RETRY, - timeout: float = DEFAULT_TIMEOUT, + timeout: TimeoutType = DEFAULT_TIMEOUT, min_creation_time: datetime.datetime = None, max_creation_time: datetime.datetime = None, page_size: int = None, @@ -2263,9 +2273,9 @@ def load_table_from_uri( project: str = None, job_config: LoadJobConfig = None, retry: retries.Retry = DEFAULT_RETRY, - timeout: float = DEFAULT_TIMEOUT, + timeout: TimeoutType = DEFAULT_TIMEOUT, ) -> job.LoadJob: - """Starts a job for loading data into a table from CloudStorage. + """Starts a job for loading data into a table from Cloud Storage. See https://cloud.google.com/bigquery/docs/reference/rest/v2/Job#jobconfigurationload @@ -2348,7 +2358,7 @@ def load_table_from_file( location: str = None, project: str = None, job_config: LoadJobConfig = None, - timeout: float = DEFAULT_TIMEOUT, + timeout: TimeoutType = DEFAULT_TIMEOUT, ) -> job.LoadJob: """Upload the contents of this table from a file-like object. @@ -2439,7 +2449,7 @@ def load_table_from_file( except resumable_media.InvalidResponse as exc: raise exceptions.from_http_response(exc.response) - return self.job_from_resource(response.json()) + return typing.cast(LoadJob, self.job_from_resource(response.json())) def load_table_from_dataframe( self, @@ -2452,7 +2462,7 @@ def load_table_from_dataframe( project: str = None, job_config: LoadJobConfig = None, parquet_compression: str = "snappy", - timeout: float = DEFAULT_TIMEOUT, + timeout: TimeoutType = DEFAULT_TIMEOUT, ) -> job.LoadJob: """Upload the contents of a table from a pandas DataFrame. @@ -2592,7 +2602,7 @@ def load_table_from_dataframe( try: table = self.get_table(destination) except core_exceptions.NotFound: - table = None + pass else: columns_and_indexes = frozenset( name @@ -2707,7 +2717,7 @@ def load_table_from_json( location: str = None, project: str = None, job_config: LoadJobConfig = None, - timeout: float = DEFAULT_TIMEOUT, + timeout: TimeoutType = DEFAULT_TIMEOUT, ) -> job.LoadJob: """Upload the contents of a table from a JSON string or dict. @@ -2995,7 +3005,7 @@ def copy_table( project: str = None, job_config: CopyJobConfig = None, retry: retries.Retry = DEFAULT_RETRY, - timeout: float = DEFAULT_TIMEOUT, + timeout: TimeoutType = DEFAULT_TIMEOUT, ) -> job.CopyJob: """Copy one or more tables to another table. @@ -3101,7 +3111,7 @@ def extract_table( project: str = None, job_config: ExtractJobConfig = None, retry: retries.Retry = DEFAULT_RETRY, - timeout: float = DEFAULT_TIMEOUT, + timeout: TimeoutType = DEFAULT_TIMEOUT, source_type: str = "Table", ) -> job.ExtractJob: """Start a job to extract a table into Cloud Storage files. @@ -3200,7 +3210,7 @@ def query( location: str = None, project: str = None, retry: retries.Retry = DEFAULT_RETRY, - timeout: float = DEFAULT_TIMEOUT, + timeout: TimeoutType = DEFAULT_TIMEOUT, job_retry: retries.Retry = DEFAULT_JOB_RETRY, ) -> job.QueryJob: """Run a SQL query. @@ -3357,7 +3367,7 @@ def insert_rows( table: Union[Table, TableReference, str], rows: Union[Iterable[Tuple], Iterable[Dict]], selected_fields: Sequence[SchemaField] = None, - **kwargs: dict, + **kwargs, ) -> Sequence[dict]: """Insert rows into a table via the streaming API. @@ -3482,7 +3492,7 @@ def insert_rows_json( ignore_unknown_values: bool = None, template_suffix: str = None, retry: retries.Retry = DEFAULT_RETRY, - timeout: float = DEFAULT_TIMEOUT, + timeout: TimeoutType = DEFAULT_TIMEOUT, ) -> Sequence[dict]: """Insert rows into a table without applying local type conversions. @@ -3550,8 +3560,8 @@ def insert_rows_json( # insert_rows_json doesn't need the table schema. It's not doing any # type conversions. table = _table_arg_to_table_ref(table, default_project=self.project) - rows_info = [] - data = {"rows": rows_info} + rows_info: List[Any] = [] + data: Dict[str, Any] = {"rows": rows_info} if row_ids is None: warnings.warn( @@ -3569,7 +3579,7 @@ def insert_rows_json( raise TypeError(msg) for i, row in enumerate(json_rows): - info = {"json": row} + info: Dict[str, Any] = {"json": row} if row_ids is AutoRowIDs.GENERATE_UUID: info["insertId"] = str(uuid.uuid4()) @@ -3618,7 +3628,7 @@ def list_partitions( self, table: Union[Table, TableReference, TableListItem, str], retry: retries.Retry = DEFAULT_RETRY, - timeout: float = DEFAULT_TIMEOUT, + timeout: TimeoutType = DEFAULT_TIMEOUT, ) -> Sequence[str]: """List the partitions in a table. @@ -3669,7 +3679,7 @@ def list_rows( start_index: int = None, page_size: int = None, retry: retries.Retry = DEFAULT_RETRY, - timeout: float = DEFAULT_TIMEOUT, + timeout: TimeoutType = DEFAULT_TIMEOUT, ) -> RowIterator: """List the rows of the table. @@ -3745,7 +3755,7 @@ def list_rows( table = self.get_table(table.reference, retry=retry, timeout=timeout) schema = table.schema - params = {} + params: Dict[str, Any] = {} if selected_fields is not None: params["selectedFields"] = ",".join(field.name for field in selected_fields) if start_index is not None: @@ -3781,7 +3791,7 @@ def _list_rows_from_query_results( start_index: int = None, page_size: int = None, retry: retries.Retry = DEFAULT_RETRY, - timeout: float = DEFAULT_TIMEOUT, + timeout: TimeoutType = DEFAULT_TIMEOUT, ) -> RowIterator: """List the rows of a completed query. See @@ -3826,7 +3836,7 @@ def _list_rows_from_query_results( Iterator of row data :class:`~google.cloud.bigquery.table.Row`-s. """ - params = { + params: Dict[str, Any] = { "fields": _LIST_ROWS_FROM_QUERY_RESULTS_FIELDS, "location": location, } @@ -3867,7 +3877,7 @@ def _schema_to_json_file_object(self, schema_list, file_obj): """ json.dump(schema_list, file_obj, indent=2, sort_keys=True) - def schema_from_json(self, file_or_path: Union[str, BinaryIO]): + def schema_from_json(self, file_or_path: "PathType"): """Takes a file object or file path that contains json that describes a table schema. @@ -3881,7 +3891,7 @@ def schema_from_json(self, file_or_path: Union[str, BinaryIO]): return self._schema_from_json_file_object(file_obj) def schema_to_json( - self, schema_list: Sequence[SchemaField], destination: Union[str, BinaryIO] + self, schema_list: Sequence[SchemaField], destination: "PathType" ): """Takes a list of schema field objects. @@ -4023,13 +4033,12 @@ def _extract_job_reference(job, project=None, location=None): return (project, location, job_id) -def _make_job_id(job_id, prefix=None): +def _make_job_id(job_id: Optional[str], prefix: Optional[str] = None) -> str: """Construct an ID for a new job. Args: - job_id (Optional[str]): the user-provided job ID. - - prefix (Optional[str]): the user-provided prefix for a job ID. + job_id: the user-provided job ID. + prefix: the user-provided prefix for a job ID. Returns: str: A job ID diff --git a/google/cloud/bigquery/dataset.py b/google/cloud/bigquery/dataset.py index 21e56f305..ff015d605 100644 --- a/google/cloud/bigquery/dataset.py +++ b/google/cloud/bigquery/dataset.py @@ -18,7 +18,7 @@ import copy -import google.cloud._helpers +import google.cloud._helpers # type: ignore from google.cloud.bigquery import _helpers from google.cloud.bigquery.model import ModelReference diff --git a/google/cloud/bigquery/dbapi/_helpers.py b/google/cloud/bigquery/dbapi/_helpers.py index 72e711bcf..e5c7ef7ec 100644 --- a/google/cloud/bigquery/dbapi/_helpers.py +++ b/google/cloud/bigquery/dbapi/_helpers.py @@ -161,7 +161,7 @@ def _parse_struct_fields( yield m.group(1, 2) -SCALAR, ARRAY, STRUCT = "sar" +SCALAR, ARRAY, STRUCT = ("s", "a", "r") def _parse_type( @@ -226,19 +226,19 @@ def complex_query_parameter_type(name: typing.Optional[str], type_: str, base: s type_type, sub_type = _parse_type(type_, name, base) if type_type == SCALAR: - type_ = sub_type + result_type = sub_type elif type_type == ARRAY: - type_ = query.ArrayQueryParameterType(sub_type, name=name) + result_type = query.ArrayQueryParameterType(sub_type, name=name) elif type_type == STRUCT: fields = [ complex_query_parameter_type(field_name, field_type, base) for field_name, field_type in sub_type ] - type_ = query.StructQueryParameterType(*fields, name=name) + result_type = query.StructQueryParameterType(*fields, name=name) else: # pragma: NO COVER raise AssertionError("Bad type_type", type_type) # Can't happen :) - return type_ + return result_type def complex_query_parameter( @@ -256,6 +256,12 @@ def complex_query_parameter( struct>> """ + param: typing.Union[ + query.ScalarQueryParameter, + query.ArrayQueryParameter, + query.StructQueryParameter, + ] + base = base or type_ type_type, sub_type = _parse_type(type_, name, base) diff --git a/google/cloud/bigquery/dbapi/cursor.py b/google/cloud/bigquery/dbapi/cursor.py index b1239ff57..03f3b72ca 100644 --- a/google/cloud/bigquery/dbapi/cursor.py +++ b/google/cloud/bigquery/dbapi/cursor.py @@ -31,7 +31,7 @@ from google.cloud.bigquery import job from google.cloud.bigquery.dbapi import _helpers from google.cloud.bigquery.dbapi import exceptions -import google.cloud.exceptions +import google.cloud.exceptions # type: ignore _LOGGER = logging.getLogger(__name__) diff --git a/google/cloud/bigquery/external_config.py b/google/cloud/bigquery/external_config.py index 5f284c639..e6f6a97c3 100644 --- a/google/cloud/bigquery/external_config.py +++ b/google/cloud/bigquery/external_config.py @@ -556,6 +556,10 @@ def from_api_repr(cls, resource: dict) -> "GoogleSheetsOptions": ParquetOptions, ) +OptionsType = Union[ + AvroOptions, BigtableOptions, CSVOptions, GoogleSheetsOptions, ParquetOptions, +] + class HivePartitioningOptions(object): """[Beta] Options that configure hive partitioning. @@ -664,13 +668,15 @@ def source_format(self): return self._properties["sourceFormat"] @property - def options(self) -> Optional[Union[_OPTION_CLASSES]]: + def options(self) -> Optional[OptionsType]: """Source-specific options.""" for optcls in _OPTION_CLASSES: - if self.source_format == optcls._SOURCE_FORMAT: - options = optcls() - self._properties.setdefault(optcls._RESOURCE_NAME, {}) - options._properties = self._properties[optcls._RESOURCE_NAME] + # The code below is too much magic for mypy to handle. + if self.source_format == optcls._SOURCE_FORMAT: # type: ignore + options: OptionsType = optcls() # type: ignore + options._properties = self._properties.setdefault( + optcls._RESOURCE_NAME, {} # type: ignore + ) return options # No matching source format found. @@ -799,6 +805,13 @@ def schema(self): prop = self._properties.get("schema", {}) return [SchemaField.from_api_repr(field) for field in prop.get("fields", [])] + @schema.setter + def schema(self, value): + prop = value + if value is not None: + prop = {"fields": [field.to_api_repr() for field in value]} + self._properties["schema"] = prop + @property def connection_id(self): """Optional[str]: [Experimental] ID of a BigQuery Connection API @@ -816,13 +829,6 @@ def connection_id(self): def connection_id(self, value): self._properties["connectionId"] = value - @schema.setter - def schema(self, value): - prop = value - if value is not None: - prop = {"fields": [field.to_api_repr() for field in value]} - self._properties["schema"] = prop - @property def avro_options(self) -> Optional[AvroOptions]: """Additional properties to set if ``sourceFormat`` is set to AVRO. diff --git a/google/cloud/bigquery/job/base.py b/google/cloud/bigquery/job/base.py index 88d6bec14..97acab5d2 100644 --- a/google/cloud/bigquery/job/base.py +++ b/google/cloud/bigquery/job/base.py @@ -696,7 +696,7 @@ def done( self.reload(retry=retry, timeout=timeout) return self.state == _DONE_STATE - def result( + def result( # type: ignore # (signature complaint) self, retry: "retries.Retry" = DEFAULT_RETRY, timeout: float = None ) -> "_AsyncJob": """Start the job and wait for it to complete and get the result. @@ -921,7 +921,7 @@ def from_api_repr(cls, resource: dict) -> "_JobConfig": # cls is one of the job config subclasses that provides the job_type argument to # this base class on instantiation, thus missing-parameter warning is a false # positive here. - job_config = cls() # pytype: disable=missing-parameter + job_config = cls() # type: ignore # pytype: disable=missing-parameter job_config._properties = resource return job_config diff --git a/google/cloud/bigquery/job/query.py b/google/cloud/bigquery/job/query.py index 942c85fc3..36e388238 100644 --- a/google/cloud/bigquery/job/query.py +++ b/google/cloud/bigquery/job/query.py @@ -56,9 +56,9 @@ if typing.TYPE_CHECKING: # pragma: NO COVER # Assumption: type checks are only used by library developers and CI environments # that have all optional dependencies installed, thus no conditional imports. - import pandas - import geopandas - import pyarrow + import pandas # type: ignore + import geopandas # type: ignore + import pyarrow # type: ignore from google.api_core import retry as retries from google.cloud import bigquery_storage from google.cloud.bigquery.client import Client @@ -144,7 +144,7 @@ def from_api_repr(cls, stats: Dict[str, str]) -> "DmlStats": args = ( int(stats.get(api_field, default_val)) - for api_field, default_val in zip(api_fields, cls.__new__.__defaults__) + for api_field, default_val in zip(api_fields, cls.__new__.__defaults__) # type: ignore ) return cls(*args) @@ -161,7 +161,7 @@ def __init__( statement_byte_budget: Optional[int] = None, key_result_statement: Optional[KeyResultStatementKind] = None, ): - self._properties = {} + self._properties: Dict[str, Any] = {} self.statement_timeout_ms = statement_timeout_ms self.statement_byte_budget = statement_byte_budget self.key_result_statement = key_result_statement @@ -193,9 +193,8 @@ def statement_timeout_ms(self) -> Union[int, None]: @statement_timeout_ms.setter def statement_timeout_ms(self, value: Union[int, None]): - if value is not None: - value = str(value) - self._properties["statementTimeoutMs"] = value + new_value = None if value is None else str(value) + self._properties["statementTimeoutMs"] = new_value @property def statement_byte_budget(self) -> Union[int, None]: @@ -207,9 +206,8 @@ def statement_byte_budget(self) -> Union[int, None]: @statement_byte_budget.setter def statement_byte_budget(self, value: Union[int, None]): - if value is not None: - value = str(value) - self._properties["statementByteBudget"] = value + new_value = None if value is None else str(value) + self._properties["statementByteBudget"] = new_value @property def key_result_statement(self) -> Union[KeyResultStatementKind, None]: @@ -666,9 +664,8 @@ def script_options(self) -> ScriptOptions: @script_options.setter def script_options(self, value: Union[ScriptOptions, None]): - if value is not None: - value = value.to_api_repr() - self._set_sub_prop("scriptOptions", value) + new_value = None if value is None else value.to_api_repr() + self._set_sub_prop("scriptOptions", new_value) def to_api_repr(self) -> dict: """Build an API representation of the query job config. @@ -1330,7 +1327,7 @@ def _done_or_raise(self, retry=DEFAULT_RETRY, timeout=None): except exceptions.GoogleAPIError as exc: self.set_exception(exc) - def result( + def result( # type: ignore # (complaints about the overloaded signature) self, page_size: int = None, max_results: int = None, @@ -1400,7 +1397,7 @@ def result( retry_do_query = getattr(self, "_retry_do_query", None) if retry_do_query is not None: if job_retry is DEFAULT_JOB_RETRY: - job_retry = self._job_retry + job_retry = self._job_retry # type: ignore else: if job_retry is not None and job_retry is not DEFAULT_JOB_RETRY: raise TypeError( @@ -1451,7 +1448,7 @@ def do_get_result(): except exceptions.GoogleAPICallError as exc: exc.message += self._format_for_exception(self.query, self.job_id) - exc.query_job = self + exc.query_job = self # type: ignore raise except requests.exceptions.Timeout as exc: raise concurrent.futures.TimeoutError from exc diff --git a/google/cloud/bigquery/magics/line_arg_parser/lexer.py b/google/cloud/bigquery/magics/line_arg_parser/lexer.py index cd809c389..71b287d01 100644 --- a/google/cloud/bigquery/magics/line_arg_parser/lexer.py +++ b/google/cloud/bigquery/magics/line_arg_parser/lexer.py @@ -98,7 +98,7 @@ def _generate_next_value_(name, start, count, last_values): return name -TokenType = AutoStrEnum( # pytype: disable=wrong-arg-types +TokenType = AutoStrEnum( # type: ignore # pytype: disable=wrong-arg-types "TokenType", [ (name, enum.auto()) diff --git a/google/cloud/bigquery/magics/magics.py b/google/cloud/bigquery/magics/magics.py index ec0430518..1d8d8ed30 100644 --- a/google/cloud/bigquery/magics/magics.py +++ b/google/cloud/bigquery/magics/magics.py @@ -90,16 +90,16 @@ from concurrent import futures try: - import IPython - from IPython import display - from IPython.core import magic_arguments + import IPython # type: ignore + from IPython import display # type: ignore + from IPython.core import magic_arguments # type: ignore except ImportError: # pragma: NO COVER raise ImportError("This module can only be loaded in IPython.") from google.api_core import client_info from google.api_core import client_options from google.api_core.exceptions import NotFound -import google.auth +import google.auth # type: ignore from google.cloud import bigquery import google.cloud.bigquery.dataset from google.cloud.bigquery.dbapi import _helpers diff --git a/google/cloud/bigquery/model.py b/google/cloud/bigquery/model.py index 2d3f6660f..cdb411e08 100644 --- a/google/cloud/bigquery/model.py +++ b/google/cloud/bigquery/model.py @@ -20,8 +20,8 @@ from google.protobuf import json_format -import google.cloud._helpers -from google.api_core import datetime_helpers +import google.cloud._helpers # type: ignore +from google.api_core import datetime_helpers # type: ignore from google.cloud.bigquery import _helpers from google.cloud.bigquery_v2 import types from google.cloud.bigquery.encryption_configuration import EncryptionConfiguration diff --git a/google/cloud/bigquery/opentelemetry_tracing.py b/google/cloud/bigquery/opentelemetry_tracing.py index b1a1027d2..748f2136d 100644 --- a/google/cloud/bigquery/opentelemetry_tracing.py +++ b/google/cloud/bigquery/opentelemetry_tracing.py @@ -14,7 +14,7 @@ import logging from contextlib import contextmanager -from google.api_core.exceptions import GoogleAPICallError +from google.api_core.exceptions import GoogleAPICallError # type: ignore logger = logging.getLogger(__name__) try: diff --git a/google/cloud/bigquery/query.py b/google/cloud/bigquery/query.py index 708f5f47b..637be62be 100644 --- a/google/cloud/bigquery/query.py +++ b/google/cloud/bigquery/query.py @@ -367,14 +367,14 @@ class _AbstractQueryParameter(object): """ @classmethod - def from_api_repr(cls, resource: dict) -> "ScalarQueryParameter": + def from_api_repr(cls, resource: dict) -> "_AbstractQueryParameter": """Factory: construct parameter from JSON resource. Args: resource (Dict): JSON mapping of parameter Returns: - google.cloud.bigquery.query.ScalarQueryParameter + A new instance of _AbstractQueryParameter subclass. """ raise NotImplementedError @@ -471,7 +471,7 @@ def to_api_repr(self) -> dict: converter = _SCALAR_VALUE_TO_JSON_PARAM.get(self.type_) if converter is not None: value = converter(value) - resource = { + resource: Dict[str, Any] = { "parameterType": {"type": self.type_}, "parameterValue": {"value": value}, } @@ -734,7 +734,7 @@ def from_api_repr(cls, resource: dict) -> "StructQueryParameter": struct_values = resource["parameterValue"]["structValues"] for key, value in struct_values.items(): type_ = types[key] - converted = None + converted: Optional[Union[ArrayQueryParameter, StructQueryParameter]] = None if type_ == "STRUCT": struct_resource = { "name": key, diff --git a/google/cloud/bigquery/retry.py b/google/cloud/bigquery/retry.py index 8a86973cd..254b26608 100644 --- a/google/cloud/bigquery/retry.py +++ b/google/cloud/bigquery/retry.py @@ -14,7 +14,7 @@ from google.api_core import exceptions from google.api_core import retry -from google.auth import exceptions as auth_exceptions +from google.auth import exceptions as auth_exceptions # type: ignore import requests.exceptions diff --git a/google/cloud/bigquery/routine/routine.py b/google/cloud/bigquery/routine/routine.py index a776212c3..a66434300 100644 --- a/google/cloud/bigquery/routine/routine.py +++ b/google/cloud/bigquery/routine/routine.py @@ -18,7 +18,7 @@ from google.protobuf import json_format -import google.cloud._helpers +import google.cloud._helpers # type: ignore from google.cloud.bigquery import _helpers import google.cloud.bigquery_v2.types from google.cloud.bigquery_v2.types import StandardSqlTableType diff --git a/google/cloud/bigquery/schema.py b/google/cloud/bigquery/schema.py index 225942234..2af61b672 100644 --- a/google/cloud/bigquery/schema.py +++ b/google/cloud/bigquery/schema.py @@ -16,7 +16,7 @@ import collections import enum -from typing import Iterable, Union +from typing import Any, Dict, Iterable, Union from google.cloud.bigquery_v2 import types @@ -106,7 +106,7 @@ def __init__( scale: Union[int, _DefaultSentinel] = _DEFAULT_VALUE, max_length: Union[int, _DefaultSentinel] = _DEFAULT_VALUE, ): - self._properties = { + self._properties: Dict[str, Any] = { "name": name, "type": field_type, } diff --git a/google/cloud/bigquery/table.py b/google/cloud/bigquery/table.py index 608218fdc..60c8593c7 100644 --- a/google/cloud/bigquery/table.py +++ b/google/cloud/bigquery/table.py @@ -21,37 +21,37 @@ import functools import operator import typing -from typing import Any, Dict, Iterable, Iterator, Optional, Tuple +from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union import warnings try: - import pandas + import pandas # type: ignore except ImportError: # pragma: NO COVER pandas = None try: - import geopandas + import geopandas # type: ignore except ImportError: geopandas = None else: _COORDINATE_REFERENCE_SYSTEM = "EPSG:4326" try: - import shapely.geos + import shapely.geos # type: ignore except ImportError: shapely = None else: _read_wkt = shapely.geos.WKTReader(shapely.geos.lgeos).read try: - import pyarrow + import pyarrow # type: ignore except ImportError: # pragma: NO COVER pyarrow = None import google.api_core.exceptions from google.api_core.page_iterator import HTTPIterator -import google.cloud._helpers +import google.cloud._helpers # type: ignore from google.cloud.bigquery import _helpers from google.cloud.bigquery import _pandas_helpers from google.cloud.bigquery.exceptions import LegacyBigQueryStorageError @@ -130,7 +130,7 @@ def _view_use_legacy_sql_getter(table): class _TableBase: """Base class for Table-related classes with common functionality.""" - _PROPERTY_TO_API_FIELD = { + _PROPERTY_TO_API_FIELD: Dict[str, Union[str, List[str]]] = { "dataset_id": ["tableReference", "datasetId"], "project": ["tableReference", "projectId"], "table_id": ["tableReference", "tableId"], @@ -807,7 +807,7 @@ def view_query(self): view_use_legacy_sql = property(_view_use_legacy_sql_getter) - @view_use_legacy_sql.setter + @view_use_legacy_sql.setter # type: ignore # (redefinition from above) def view_use_legacy_sql(self, value): if not isinstance(value, bool): raise ValueError("Pass a boolean") @@ -1746,7 +1746,7 @@ def to_arrow( progress_bar.close() finally: if owns_bqstorage_client: - bqstorage_client._transport.grpc_channel.close() + bqstorage_client._transport.grpc_channel.close() # type: ignore if record_batches and bqstorage_client is not None: return pyarrow.Table.from_batches(record_batches) @@ -1763,7 +1763,7 @@ def to_dataframe_iterable( self, bqstorage_client: "bigquery_storage.BigQueryReadClient" = None, dtypes: Dict[str, Any] = None, - max_queue_size: int = _pandas_helpers._MAX_QUEUE_SIZE_DEFAULT, + max_queue_size: int = _pandas_helpers._MAX_QUEUE_SIZE_DEFAULT, # type: ignore ) -> "pandas.DataFrame": """Create an iterable of pandas DataFrames, to process the table as a stream. @@ -2307,8 +2307,6 @@ def __repr__(self): key_vals = ["{}={}".format(key, val) for key, val in self._key()] return "PartitionRange({})".format(", ".join(key_vals)) - __hash__ = None - class RangePartitioning(object): """Range-based partitioning configuration for a table. @@ -2387,8 +2385,6 @@ def __repr__(self): key_vals = ["{}={}".format(key, repr(val)) for key, val in self._key()] return "RangePartitioning({})".format(", ".join(key_vals)) - __hash__ = None - class TimePartitioningType(object): """Specifies the type of time partitioning to perform.""" @@ -2657,7 +2653,7 @@ def _rows_page_start(iterator, page, response): # pylint: enable=unused-argument -def _table_arg_to_table_ref(value, default_project=None): +def _table_arg_to_table_ref(value, default_project=None) -> TableReference: """Helper to convert a string or Table to TableReference. This function keeps TableReference and other kinds of objects unchanged. @@ -2669,7 +2665,7 @@ def _table_arg_to_table_ref(value, default_project=None): return value -def _table_arg_to_table(value, default_project=None): +def _table_arg_to_table(value, default_project=None) -> Table: """Helper to convert a string or TableReference to a Table. This function keeps Table and other kinds of objects unchanged. diff --git a/noxfile.py b/noxfile.py index 1879a5cd8..505911861 100644 --- a/noxfile.py +++ b/noxfile.py @@ -22,6 +22,7 @@ import nox +MYPY_VERSION = "mypy==0.910" PYTYPE_VERSION = "pytype==2021.4.9" BLACK_VERSION = "black==19.10b0" BLACK_PATHS = ("docs", "google", "samples", "tests", "noxfile.py", "setup.py") @@ -41,6 +42,7 @@ "lint", "lint_setup_py", "blacken", + "mypy", "pytype", "docs", ] @@ -113,9 +115,24 @@ def unit_noextras(session): default(session, install_extras=False) +@nox.session(python=DEFAULT_PYTHON_VERSION) +def mypy(session): + """Run type checks with mypy.""" + session.install("-e", ".[all]") + session.install("ipython") + session.install(MYPY_VERSION) + + # Just install the dependencies' type info directly, since "mypy --install-types" + # might require an additional pass. + session.install( + "types-protobuf", "types-python-dateutil", "types-requests", "types-setuptools", + ) + session.run("mypy", "google/cloud") + + @nox.session(python=DEFAULT_PYTHON_VERSION) def pytype(session): - """Run type checks.""" + """Run type checks with pytype.""" # An indirect dependecy attrs==21.1.0 breaks the check, and installing a less # recent version avoids the error until a possibly better fix is found. # https://github.com/googleapis/python-bigquery/issues/655 diff --git a/tests/unit/test_opentelemetry_tracing.py b/tests/unit/test_opentelemetry_tracing.py index cc1ca7903..3021a3dbf 100644 --- a/tests/unit/test_opentelemetry_tracing.py +++ b/tests/unit/test_opentelemetry_tracing.py @@ -51,9 +51,16 @@ def setup(): memory_exporter = InMemorySpanExporter() span_processor = SimpleSpanProcessor(memory_exporter) tracer_provider.add_span_processor(span_processor) - trace.set_tracer_provider(tracer_provider) + + # OpenTelemetry API >= 0.12b0 does not allow overriding the tracer once + # initialized, thus directly override (and then restore) the internal global var. + orig_trace_provider = trace._TRACER_PROVIDER + trace._TRACER_PROVIDER = tracer_provider + yield memory_exporter + trace._TRACER_PROVIDER = orig_trace_provider + @pytest.mark.skipif(opentelemetry is None, reason="Require `opentelemetry`") def test_opentelemetry_not_installed(setup, monkeypatch):