diff --git a/bigframes/pandas/io/api.py b/bigframes/pandas/io/api.py index 16548dd4ad..c09251de3b 100644 --- a/bigframes/pandas/io/api.py +++ b/bigframes/pandas/io/api.py @@ -15,6 +15,7 @@ from __future__ import annotations import inspect +import threading import typing from typing import ( Any, @@ -465,6 +466,8 @@ def from_glob_path( from_glob_path.__doc__ = inspect.getdoc(bigframes.session.Session.from_glob_path) +_default_location_lock = threading.Lock() + def _set_default_session_location_if_possible(query): # Set the location as per the query if this is the first query the user is @@ -475,31 +478,34 @@ def _set_default_session_location_if_possible(query): # If query is a table name, then it would be the location of the table. # If query is a SQL with a table, then it would be table's location. # If query is a SQL with no table, then it would be the BQ default location. - if ( - config.options.bigquery._session_started - or config.options.bigquery.location - or config.options.bigquery.use_regional_endpoints - ): - return - - clients_provider = bigframes.session.clients.ClientsProvider( - project=config.options.bigquery.project, - location=config.options.bigquery.location, - use_regional_endpoints=config.options.bigquery.use_regional_endpoints, - credentials=config.options.bigquery.credentials, - application_name=config.options.bigquery.application_name, - bq_kms_key_name=config.options.bigquery.kms_key_name, - client_endpoints_override=config.options.bigquery.client_endpoints_override, - ) - - bqclient = clients_provider.bqclient - - if bigframes.session._io.bigquery.is_query(query): - # Intentionally run outside of the session so that we can detect the - # location before creating the session. Since it's a dry_run, labels - # aren't necessary. - job = bqclient.query(query, bigquery.QueryJobConfig(dry_run=True)) - config.options.bigquery.location = job.location - else: - table = bqclient.get_table(query) - config.options.bigquery.location = table.location + global _default_location_lock + + with _default_location_lock: + if ( + config.options.bigquery._session_started + or config.options.bigquery.location + or config.options.bigquery.use_regional_endpoints + ): + return + + clients_provider = bigframes.session.clients.ClientsProvider( + project=config.options.bigquery.project, + location=config.options.bigquery.location, + use_regional_endpoints=config.options.bigquery.use_regional_endpoints, + credentials=config.options.bigquery.credentials, + application_name=config.options.bigquery.application_name, + bq_kms_key_name=config.options.bigquery.kms_key_name, + client_endpoints_override=config.options.bigquery.client_endpoints_override, + ) + + bqclient = clients_provider.bqclient + + if bigframes.session._io.bigquery.is_query(query): + # Intentionally run outside of the session so that we can detect the + # location before creating the session. Since it's a dry_run, labels + # aren't necessary. + job = bqclient.query(query, bigquery.QueryJobConfig(dry_run=True)) + config.options.bigquery.location = job.location + else: + table = bqclient.get_table(query) + config.options.bigquery.location = table.location diff --git a/bigframes/session/clients.py b/bigframes/session/clients.py index 5ef974d565..a8e1ab71f1 100644 --- a/bigframes/session/clients.py +++ b/bigframes/session/clients.py @@ -15,12 +15,12 @@ """Clients manages the connection to Google APIs.""" import os +import threading import typing from typing import Optional import google.api_core.client_info import google.api_core.client_options -import google.api_core.exceptions import google.api_core.gapic_v1.client_info import google.auth.credentials import google.cloud.bigquery as bigquery @@ -84,6 +84,9 @@ def __init__( if credentials is None: credentials, credentials_project = _get_default_credentials_with_project() + # Ensure an access token is available. + credentials.refresh(google.auth.transport.requests.Request()) + # Prefer the project in this order: # 1. Project explicitly specified by the user # 2. Project set in the environment @@ -127,19 +130,30 @@ def __init__( self._client_endpoints_override = client_endpoints_override # cloud clients initialized for lazy load + self._bqclient_lock = threading.Lock() self._bqclient = None + + self._bqconnectionclient_lock = threading.Lock() self._bqconnectionclient: Optional[ google.cloud.bigquery_connection_v1.ConnectionServiceClient ] = None + + self._bqstoragereadclient_lock = threading.Lock() self._bqstoragereadclient: Optional[ google.cloud.bigquery_storage_v1.BigQueryReadClient ] = None + + self._bqstoragewriteclient_lock = threading.Lock() self._bqstoragewriteclient: Optional[ google.cloud.bigquery_storage_v1.BigQueryWriteClient ] = None + + self._cloudfunctionsclient_lock = threading.Lock() self._cloudfunctionsclient: Optional[ google.cloud.functions_v2.FunctionServiceClient ] = None + + self._resourcemanagerclient_lock = threading.Lock() self._resourcemanagerclient: Optional[ google.cloud.resourcemanager_v3.ProjectsClient ] = None @@ -166,6 +180,7 @@ def _create_bigquery_client(self): project=self._project, location=self._location, ) + if self._bq_kms_key_name: # Note: Key configuration only applies automatically to load and query jobs, not copy jobs. encryption_config = bigquery.EncryptionConfiguration( @@ -186,114 +201,126 @@ def _create_bigquery_client(self): @property def bqclient(self): - if not self._bqclient: - self._bqclient = self._create_bigquery_client() + with self._bqclient_lock: + if not self._bqclient: + self._bqclient = self._create_bigquery_client() return self._bqclient @property def bqconnectionclient(self): - if not self._bqconnectionclient: - bqconnection_options = None - if "bqconnectionclient" in self._client_endpoints_override: - bqconnection_options = google.api_core.client_options.ClientOptions( - api_endpoint=self._client_endpoints_override["bqconnectionclient"] - ) + with self._bqconnectionclient_lock: + if not self._bqconnectionclient: + bqconnection_options = None + if "bqconnectionclient" in self._client_endpoints_override: + bqconnection_options = google.api_core.client_options.ClientOptions( + api_endpoint=self._client_endpoints_override[ + "bqconnectionclient" + ] + ) - bqconnection_info = google.api_core.gapic_v1.client_info.ClientInfo( - user_agent=self._application_name - ) - self._bqconnectionclient = ( - google.cloud.bigquery_connection_v1.ConnectionServiceClient( - client_info=bqconnection_info, - client_options=bqconnection_options, - credentials=self._credentials, + bqconnection_info = google.api_core.gapic_v1.client_info.ClientInfo( + user_agent=self._application_name + ) + self._bqconnectionclient = ( + google.cloud.bigquery_connection_v1.ConnectionServiceClient( + client_info=bqconnection_info, + client_options=bqconnection_options, + credentials=self._credentials, + ) ) - ) return self._bqconnectionclient @property def bqstoragereadclient(self): - if not self._bqstoragereadclient: - bqstorage_options = None - if "bqstoragereadclient" in self._client_endpoints_override: - bqstorage_options = google.api_core.client_options.ClientOptions( - api_endpoint=self._client_endpoints_override["bqstoragereadclient"] - ) - elif self._use_regional_endpoints: - bqstorage_options = google.api_core.client_options.ClientOptions( - api_endpoint=_BIGQUERYSTORAGE_REGIONAL_ENDPOINT.format( - location=self._location + with self._bqstoragereadclient_lock: + if not self._bqstoragereadclient: + bqstorage_options = None + if "bqstoragereadclient" in self._client_endpoints_override: + bqstorage_options = google.api_core.client_options.ClientOptions( + api_endpoint=self._client_endpoints_override[ + "bqstoragereadclient" + ] + ) + elif self._use_regional_endpoints: + bqstorage_options = google.api_core.client_options.ClientOptions( + api_endpoint=_BIGQUERYSTORAGE_REGIONAL_ENDPOINT.format( + location=self._location + ) ) - ) - bqstorage_info = google.api_core.gapic_v1.client_info.ClientInfo( - user_agent=self._application_name - ) - self._bqstoragereadclient = ( - google.cloud.bigquery_storage_v1.BigQueryReadClient( - client_info=bqstorage_info, - client_options=bqstorage_options, - credentials=self._credentials, + bqstorage_info = google.api_core.gapic_v1.client_info.ClientInfo( + user_agent=self._application_name + ) + self._bqstoragereadclient = ( + google.cloud.bigquery_storage_v1.BigQueryReadClient( + client_info=bqstorage_info, + client_options=bqstorage_options, + credentials=self._credentials, + ) ) - ) return self._bqstoragereadclient @property def bqstoragewriteclient(self): - if not self._bqstoragewriteclient: - bqstorage_options = None - if "bqstoragewriteclient" in self._client_endpoints_override: - bqstorage_options = google.api_core.client_options.ClientOptions( - api_endpoint=self._client_endpoints_override["bqstoragewriteclient"] - ) - elif self._use_regional_endpoints: - bqstorage_options = google.api_core.client_options.ClientOptions( - api_endpoint=_BIGQUERYSTORAGE_REGIONAL_ENDPOINT.format( - location=self._location + with self._bqstoragewriteclient_lock: + if not self._bqstoragewriteclient: + bqstorage_options = None + if "bqstoragewriteclient" in self._client_endpoints_override: + bqstorage_options = google.api_core.client_options.ClientOptions( + api_endpoint=self._client_endpoints_override[ + "bqstoragewriteclient" + ] + ) + elif self._use_regional_endpoints: + bqstorage_options = google.api_core.client_options.ClientOptions( + api_endpoint=_BIGQUERYSTORAGE_REGIONAL_ENDPOINT.format( + location=self._location + ) ) - ) - bqstorage_info = google.api_core.gapic_v1.client_info.ClientInfo( - user_agent=self._application_name - ) - self._bqstoragewriteclient = ( - google.cloud.bigquery_storage_v1.BigQueryWriteClient( - client_info=bqstorage_info, - client_options=bqstorage_options, - credentials=self._credentials, + bqstorage_info = google.api_core.gapic_v1.client_info.ClientInfo( + user_agent=self._application_name + ) + self._bqstoragewriteclient = ( + google.cloud.bigquery_storage_v1.BigQueryWriteClient( + client_info=bqstorage_info, + client_options=bqstorage_options, + credentials=self._credentials, + ) ) - ) return self._bqstoragewriteclient @property def cloudfunctionsclient(self): - if not self._cloudfunctionsclient: - functions_info = google.api_core.gapic_v1.client_info.ClientInfo( - user_agent=self._application_name - ) - self._cloudfunctionsclient = ( - google.cloud.functions_v2.FunctionServiceClient( - client_info=functions_info, - credentials=self._credentials, + with self._cloudfunctionsclient_lock: + if not self._cloudfunctionsclient: + functions_info = google.api_core.gapic_v1.client_info.ClientInfo( + user_agent=self._application_name + ) + self._cloudfunctionsclient = ( + google.cloud.functions_v2.FunctionServiceClient( + client_info=functions_info, + credentials=self._credentials, + ) ) - ) return self._cloudfunctionsclient @property def resourcemanagerclient(self): - if not self._resourcemanagerclient: - resourcemanager_info = google.api_core.gapic_v1.client_info.ClientInfo( - user_agent=self._application_name - ) - self._resourcemanagerclient = ( - google.cloud.resourcemanager_v3.ProjectsClient( - credentials=self._credentials, client_info=resourcemanager_info + with self._resourcemanagerclient_lock: + if not self._resourcemanagerclient: + resourcemanager_info = google.api_core.gapic_v1.client_info.ClientInfo( + user_agent=self._application_name + ) + self._resourcemanagerclient = ( + google.cloud.resourcemanager_v3.ProjectsClient( + credentials=self._credentials, client_info=resourcemanager_info + ) ) - ) return self._resourcemanagerclient