diff --git a/bigframes/_config/bigquery_options.py b/bigframes/_config/bigquery_options.py index 3a6008eaa8..d591ea85b3 100644 --- a/bigframes/_config/bigquery_options.py +++ b/bigframes/_config/bigquery_options.py @@ -16,10 +16,11 @@ from __future__ import annotations -from typing import Literal, Optional +from typing import Literal, Optional, Sequence, Tuple import warnings import google.auth.credentials +import requests.adapters import bigframes.enums import bigframes.exceptions as bfe @@ -90,6 +91,9 @@ def __init__( allow_large_results: bool = False, ordering_mode: Literal["strict", "partial"] = "strict", client_endpoints_override: Optional[dict] = None, + requests_transport_adapters: Sequence[ + Tuple[str, requests.adapters.BaseAdapter] + ] = (), ): self._credentials = credentials self._project = project @@ -100,6 +104,7 @@ def __init__( self._kms_key_name = kms_key_name self._skip_bq_connection_check = skip_bq_connection_check self._allow_large_results = allow_large_results + self._requests_transport_adapters = requests_transport_adapters self._session_started = False # Determines the ordering strictness for the session. self._ordering_mode = _validate_ordering_mode(ordering_mode) @@ -379,3 +384,43 @@ def client_endpoints_override(self, value: dict): ) self._client_endpoints_override = value + + @property + def requests_transport_adapters( + self, + ) -> Sequence[Tuple[str, requests.adapters.BaseAdapter]]: + """Transport adapters for requests-based REST clients such as the + google-cloud-bigquery package. + + For more details, see the explanation in `requests guide to transport + adapters + `_. + + **Examples:** + + Increase the connection pool size using the requests `HTTPAdapter + `_. + + >>> import bigframes.pandas as bpd + >>> bpd.options.bigquery.requests_transport_adapters = ( + ... ("http://", requests.adapters.HTTPAdapter(pool_maxsize=100)), + ... ("https://", requests.adapters.HTTPAdapter(pool_maxsize=100)), + ... ) # doctest: +SKIP + + Returns: + Sequence[Tuple[str, requests.adapters.BaseAdapter]]: + Prefixes and corresponding transport adapters to `mount + `_ + in requests-based REST clients. + """ + return self._requests_transport_adapters + + @requests_transport_adapters.setter + def requests_transport_adapters( + self, value: Sequence[Tuple[str, requests.adapters.BaseAdapter]] + ) -> None: + if self._session_started and self._requests_transport_adapters != value: + raise ValueError( + SESSION_STARTED_MESSAGE.format(attribute="requests_transport_adapters") + ) + self._requests_transport_adapters = value diff --git a/bigframes/pandas/io/api.py b/bigframes/pandas/io/api.py index c09251de3b..b2ce5f211e 100644 --- a/bigframes/pandas/io/api.py +++ b/bigframes/pandas/io/api.py @@ -496,6 +496,7 @@ def _set_default_session_location_if_possible(query): application_name=config.options.bigquery.application_name, bq_kms_key_name=config.options.bigquery.kms_key_name, client_endpoints_override=config.options.bigquery.client_endpoints_override, + requests_transport_adapters=config.options.bigquery.requests_transport_adapters, ) bqclient = clients_provider.bqclient diff --git a/bigframes/session/__init__.py b/bigframes/session/__init__.py index 46d71a079e..c24dca554a 100644 --- a/bigframes/session/__init__.py +++ b/bigframes/session/__init__.py @@ -172,6 +172,7 @@ def __init__( application_name=context.application_name, bq_kms_key_name=self._bq_kms_key_name, client_endpoints_override=context.client_endpoints_override, + requests_transport_adapters=context.requests_transport_adapters, ) # TODO(shobs): Remove this logic after https://github.com/ibis-project/ibis/issues/8494 diff --git a/bigframes/session/clients.py b/bigframes/session/clients.py index 86312eb9ba..d680b94b8a 100644 --- a/bigframes/session/clients.py +++ b/bigframes/session/clients.py @@ -17,18 +17,20 @@ import os import threading import typing -from typing import Optional +from typing import Optional, Sequence, Tuple import google.api_core.client_info import google.api_core.client_options import google.api_core.gapic_v1.client_info import google.auth.credentials +import google.auth.transport.requests import google.cloud.bigquery as bigquery import google.cloud.bigquery_connection_v1 import google.cloud.bigquery_storage_v1 import google.cloud.functions_v2 import google.cloud.resourcemanager_v3 import pydata_google_auth +import requests import bigframes.constants import bigframes.version @@ -79,6 +81,10 @@ def __init__( application_name: Optional[str] = None, bq_kms_key_name: Optional[str] = None, client_endpoints_override: dict = {}, + *, + requests_transport_adapters: Sequence[ + Tuple[str, requests.adapters.BaseAdapter] + ] = (), ): credentials_project = None if credentials is None: @@ -124,6 +130,7 @@ def __init__( ) self._location = location self._use_regional_endpoints = use_regional_endpoints + self._requests_transport_adapters = requests_transport_adapters self._credentials = credentials self._bq_kms_key_name = bq_kms_key_name @@ -173,12 +180,21 @@ def _create_bigquery_client(self): user_agent=self._application_name ) + requests_session = google.auth.transport.requests.AuthorizedSession( + self._credentials + ) + for prefix, adapter in self._requests_transport_adapters: + requests_session.mount(prefix, adapter) + bq_client = bigquery.Client( client_info=bq_info, client_options=bq_options, - credentials=self._credentials, project=self._project, location=self._location, + # Instead of credentials, use _http so that users can override + # requests options with transport adapters. See internal issue + # b/419106112. + _http=requests_session, ) # If a new enough client library is available, we opt-in to the faster diff --git a/tests/system/small/test_pandas_options.py b/tests/system/small/test_pandas_options.py index d59b6d66b5..55e5036a42 100644 --- a/tests/system/small/test_pandas_options.py +++ b/tests/system/small/test_pandas_options.py @@ -279,16 +279,18 @@ def test_credentials_need_reauthentication( # Call get_global_session() *after* read_gbq so that our location detection # has a chance to work. session = bpd.get_global_session() - assert session.bqclient._credentials.valid + assert session.bqclient._http.credentials.valid with monkeypatch.context() as m: # Simulate expired credentials to trigger the credential refresh flow - m.setattr(session.bqclient._credentials, "expiry", datetime.datetime.utcnow()) - assert not session.bqclient._credentials.valid + m.setattr( + session.bqclient._http.credentials, "expiry", datetime.datetime.utcnow() + ) + assert not session.bqclient._http.credentials.valid # Simulate an exception during the credential refresh flow m.setattr( - session.bqclient._credentials, + session.bqclient._http.credentials, "refresh", mock.Mock(side_effect=google.auth.exceptions.RefreshError()), ) diff --git a/tests/unit/_config/test_bigquery_options.py b/tests/unit/_config/test_bigquery_options.py index b8f3a612d4..686499aa75 100644 --- a/tests/unit/_config/test_bigquery_options.py +++ b/tests/unit/_config/test_bigquery_options.py @@ -38,6 +38,7 @@ ("skip_bq_connection_check", False, True), ("client_endpoints_override", {}, {"bqclient": "endpoint_address"}), ("ordering_mode", "strict", "partial"), + ("requests_transport_adapters", object(), object()), ], ) def test_setter_raises_if_session_started(attribute, original_value, new_value): diff --git a/tests/unit/session/test_clients.py b/tests/unit/session/test_clients.py index 6b0d8583a5..5304c99466 100644 --- a/tests/unit/session/test_clients.py +++ b/tests/unit/session/test_clients.py @@ -15,25 +15,22 @@ import os import pathlib import tempfile -from typing import Optional +from typing import cast, Optional import unittest.mock as mock -import google.api_core.client_info -import google.api_core.client_options -import google.api_core.exceptions -import google.api_core.gapic_v1.client_info import google.auth.credentials import google.cloud.bigquery import google.cloud.bigquery_connection_v1 import google.cloud.bigquery_storage_v1 import google.cloud.functions_v2 import google.cloud.resourcemanager_v3 +import requests.adapters import bigframes.session.clients as clients import bigframes.version -def create_clients_provider(application_name: Optional[str] = None): +def create_clients_provider(application_name: Optional[str] = None, **kwargs): credentials = mock.create_autospec(google.auth.credentials.Credentials) return clients.ClientsProvider( project="test-project", @@ -42,6 +39,7 @@ def create_clients_provider(application_name: Optional[str] = None): credentials=credentials, application_name=application_name, bq_kms_key_name="projects/my-project/locations/us/keyRings/myKeyRing/cryptoKeys/myKey", + **kwargs, ) @@ -136,6 +134,24 @@ def assert_clients_wo_user_agent( ) +def test_requests_transport_adapters_pool_maxsize(monkeypatch): + monkeypatch_client_constructors(monkeypatch) + requests_transport_adapters = ( + ("http://", requests.adapters.HTTPAdapter(pool_maxsize=123)), + ("https://", requests.adapters.HTTPAdapter(pool_maxsize=123)), + ) # doctest: +SKIP + provider = create_clients_provider( + requests_transport_adapters=requests_transport_adapters + ) + + _, kwargs = cast(mock.Mock, provider.bqclient).call_args + requests_session = kwargs.get("_http") + adapter: requests.adapters.HTTPAdapter = requests_session.get_adapter( + "https://bigquery.googleapis.com/" + ) + assert adapter._pool_maxsize == 123 # type: ignore + + def test_user_agent_default(monkeypatch): monkeypatch_client_constructors(monkeypatch) provider = create_clients_provider(application_name=None)