Skip to content
This repository was archived by the owner on Dec 31, 2023. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from google.api_core import gapic_v1 # type: ignore
from google.api_core import retry as retries # type: ignore
from google.auth import credentials as ga_credentials # type: ignore
from google.oauth2 import service_account # type: ignore

from google.cloud.kms_v1.types import resources
from google.cloud.kms_v1.types import service
Expand Down Expand Up @@ -68,6 +69,7 @@ def __init__(
scopes: Optional[Sequence[str]] = None,
quota_project_id: Optional[str] = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
always_use_jwt_access: Optional[bool] = False,
**kwargs,
) -> None:
"""Instantiate the transport.
Expand All @@ -91,6 +93,8 @@ def __init__(
API requests. If ``None``, then default info will be used.
Generally, you only need to set this if you're developing
your own client library.
always_use_jwt_access (Optional[bool]): Whether self signed JWT
should be used for service account credentials.
"""
# Save the hostname. Default to port 443 (HTTPS) if none is specified.
if ":" not in host:
Expand Down Expand Up @@ -119,6 +123,14 @@ def __init__(
**scopes_kwargs, quota_project_id=quota_project_id
)

# If the credentials is service account credentials, then always try to use self signed JWT.
if (
always_use_jwt_access
and isinstance(credentials, service_account.Credentials)
and hasattr(service_account.Credentials, "with_always_use_jwt_access")
):
credentials = credentials.with_always_use_jwt_access(True)

# Save the credentials.
self._credentials = credentials

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def __init__(
scopes=scopes,
quota_project_id=quota_project_id,
client_info=client_info,
always_use_jwt_access=True,
)

if not self._grpc_channel:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def __init__(
scopes=scopes,
quota_project_id=quota_project_id,
client_info=client_info,
always_use_jwt_access=True,
)

if not self._grpc_channel:
Expand Down
11 changes: 11 additions & 0 deletions tests/unit/gapic/kms_v1/test_key_management_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,17 @@ def test__get_default_mtls_endpoint():
)


@pytest.mark.parametrize(
"client_class", [KeyManagementServiceClient, KeyManagementServiceAsyncClient,]
)
def test_key_management_service_client_service_account_always_use_jwt(client_class):
creds = service_account.Credentials(None, None, None)
if hasattr(service_account.Credentials, "with_always_use_jwt_access"):
assert not creds._always_use_jwt_access
client = client_class(credentials=creds)
assert client.transport._credentials._always_use_jwt_access


@pytest.mark.parametrize(
"client_class", [KeyManagementServiceClient, KeyManagementServiceAsyncClient,]
)
Expand Down