Skip to content

Commit 91d3c65

Browse files
BigTailWolflsirac
andauthored
feat: adding universe domain support for downscroped credentials (#1463)
* feat: adding universe domain support for downscroped credentials * fix lint * address comments * Update tests/test_downscoped.py Co-authored-by: Leo <[email protected]> --------- Co-authored-by: Leo <[email protected]>
1 parent 969808f commit 91d3c65

File tree

3 files changed

+98
-8
lines changed

3 files changed

+98
-8
lines changed

packages/google-auth/google/auth/credentials.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from google.auth import metrics
2525
from google.auth._refresh_worker import RefreshThreadManager
2626

27+
DEFAULT_UNIVERSE_DOMAIN = "googleapis.com"
28+
2729

2830
class Credentials(metaclass=abc.ABCMeta):
2931
"""Base class for all credentials.
@@ -57,7 +59,7 @@ def __init__(self):
5759
"""Optional[dict]: Cache of a trust boundary response which has a list
5860
of allowed regions and an encoded string representation of credentials
5961
trust boundary."""
60-
self._universe_domain = "googleapis.com"
62+
self._universe_domain = DEFAULT_UNIVERSE_DOMAIN
6163
"""Optional[str]: The universe domain value, default is googleapis.com
6264
"""
6365

packages/google-auth/google/auth/downscoped.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
# The token exchange requested_token_type. This is always an access_token.
6464
_STS_REQUESTED_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:access_token"
6565
# The STS token URL used to exchanged a short lived access token for a downscoped one.
66-
_STS_TOKEN_URL = "https://sts.googleapis.com/v1/token"
66+
_STS_TOKEN_URL_PATTERN = "https://sts.{}/v1/token"
6767
# The subject token type to use when exchanging a short lived access token for a
6868
# downscoped token.
6969
_STS_SUBJECT_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:access_token"
@@ -437,7 +437,11 @@ class Credentials(credentials.CredentialsWithQuotaProject):
437437
"""
438438

439439
def __init__(
440-
self, source_credentials, credential_access_boundary, quota_project_id=None
440+
self,
441+
source_credentials,
442+
credential_access_boundary,
443+
quota_project_id=None,
444+
universe_domain=credentials.DEFAULT_UNIVERSE_DOMAIN,
441445
):
442446
"""Instantiates a downscoped credentials object using the provided source
443447
credentials and credential access boundary rules.
@@ -456,6 +460,7 @@ def __init__(
456460
the upper bound of the permissions that are available on that resource and an
457461
optional condition to further restrict permissions.
458462
quota_project_id (Optional[str]): The optional quota project ID.
463+
universe_domain (Optional[str]): The universe domain value, default is googleapis.com
459464
Raises:
460465
google.auth.exceptions.RefreshError: If the source credentials
461466
return an error on token refresh.
@@ -467,7 +472,10 @@ def __init__(
467472
self._source_credentials = source_credentials
468473
self._credential_access_boundary = credential_access_boundary
469474
self._quota_project_id = quota_project_id
470-
self._sts_client = sts.Client(_STS_TOKEN_URL)
475+
self._universe_domain = universe_domain or credentials.DEFAULT_UNIVERSE_DOMAIN
476+
self._sts_client = sts.Client(
477+
_STS_TOKEN_URL_PATTERN.format(self.universe_domain)
478+
)
471479

472480
@_helpers.copy_docstring(credentials.Credentials)
473481
def refresh(self, request):

packages/google-auth/tests/test_downscoped.py

Lines changed: 84 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from google.auth import downscoped
2626
from google.auth import exceptions
2727
from google.auth import transport
28+
from google.auth.credentials import DEFAULT_UNIVERSE_DOMAIN
2829
from google.auth.credentials import TokenState
2930

3031

@@ -447,7 +448,11 @@ def test_to_json(self):
447448

448449
class TestCredentials(object):
449450
@staticmethod
450-
def make_credentials(source_credentials=SourceCredentials(), quota_project_id=None):
451+
def make_credentials(
452+
source_credentials=SourceCredentials(),
453+
quota_project_id=None,
454+
universe_domain=None,
455+
):
451456
availability_condition = make_availability_condition(
452457
EXPRESSION, TITLE, DESCRIPTION
453458
)
@@ -458,7 +463,10 @@ def make_credentials(source_credentials=SourceCredentials(), quota_project_id=No
458463
credential_access_boundary = make_credential_access_boundary(rules)
459464

460465
return downscoped.Credentials(
461-
source_credentials, credential_access_boundary, quota_project_id
466+
source_credentials,
467+
credential_access_boundary,
468+
quota_project_id,
469+
universe_domain,
462470
)
463471

464472
@staticmethod
@@ -473,10 +481,12 @@ def make_mock_request(data, status=http_client.OK):
473481
return request
474482

475483
@staticmethod
476-
def assert_request_kwargs(request_kwargs, headers, request_data):
484+
def assert_request_kwargs(
485+
request_kwargs, headers, request_data, token_endpoint=TOKEN_EXCHANGE_ENDPOINT
486+
):
477487
"""Asserts the request was called with the expected parameters.
478488
"""
479-
assert request_kwargs["url"] == TOKEN_EXCHANGE_ENDPOINT
489+
assert request_kwargs["url"] == token_endpoint
480490
assert request_kwargs["method"] == "POST"
481491
assert request_kwargs["headers"] == headers
482492
assert request_kwargs["body"] is not None
@@ -496,6 +506,33 @@ def test_default_state(self):
496506
assert not credentials.expired
497507
# No quota project ID set.
498508
assert not credentials.quota_project_id
509+
assert credentials.universe_domain == DEFAULT_UNIVERSE_DOMAIN
510+
511+
def test_default_state_with_explicit_none_value(self):
512+
credentials = self.make_credentials(universe_domain=None)
513+
514+
# No token acquired yet.
515+
assert not credentials.token
516+
assert not credentials.valid
517+
# Expiration hasn't been set yet.
518+
assert not credentials.expiry
519+
assert not credentials.expired
520+
# No quota project ID set.
521+
assert not credentials.quota_project_id
522+
assert credentials.universe_domain == DEFAULT_UNIVERSE_DOMAIN
523+
524+
def test_create_with_customized_universe_domain(self):
525+
test_universe_domain = "foo.com"
526+
credentials = self.make_credentials(universe_domain=test_universe_domain)
527+
# No token acquired yet.
528+
assert not credentials.token
529+
assert not credentials.valid
530+
# Expiration hasn't been set yet.
531+
assert not credentials.expiry
532+
assert not credentials.expired
533+
# No quota project ID set.
534+
assert not credentials.quota_project_id
535+
assert credentials.universe_domain == test_universe_domain
499536

500537
def test_with_quota_project(self):
501538
credentials = self.make_credentials()
@@ -506,6 +543,49 @@ def test_with_quota_project(self):
506543

507544
assert quota_project_creds.quota_project_id == "project-foo"
508545

546+
@mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min)
547+
def test_refresh_on_custom_universe(self, unused_utcnow):
548+
test_universe_domain = "foo.com"
549+
response = SUCCESS_RESPONSE.copy()
550+
# Test custom expiration to confirm expiry is set correctly.
551+
response["expires_in"] = 2800
552+
expected_expiry = datetime.datetime.min + datetime.timedelta(
553+
seconds=response["expires_in"]
554+
)
555+
headers = {"Content-Type": "application/x-www-form-urlencoded"}
556+
request_data = {
557+
"grant_type": GRANT_TYPE,
558+
"subject_token": "ACCESS_TOKEN_1",
559+
"subject_token_type": SUBJECT_TOKEN_TYPE,
560+
"requested_token_type": REQUESTED_TOKEN_TYPE,
561+
"options": urllib.parse.quote(json.dumps(CREDENTIAL_ACCESS_BOUNDARY_JSON)),
562+
}
563+
request = self.make_mock_request(status=http_client.OK, data=response)
564+
source_credentials = SourceCredentials()
565+
credentials = self.make_credentials(
566+
source_credentials=source_credentials, universe_domain=test_universe_domain
567+
)
568+
token_exchange_endpoint = downscoped._STS_TOKEN_URL_PATTERN.format(
569+
test_universe_domain
570+
)
571+
572+
# Spy on calls to source credentials refresh to confirm the expected request
573+
# instance is used.
574+
with mock.patch.object(
575+
source_credentials, "refresh", wraps=source_credentials.refresh
576+
) as wrapped_souce_cred_refresh:
577+
credentials.refresh(request)
578+
579+
self.assert_request_kwargs(
580+
request.call_args[1], headers, request_data, token_exchange_endpoint
581+
)
582+
assert credentials.valid
583+
assert credentials.expiry == expected_expiry
584+
assert not credentials.expired
585+
assert credentials.token == response["access_token"]
586+
# Confirm source credentials called with the same request instance.
587+
wrapped_souce_cred_refresh.assert_called_with(request)
588+
509589
@mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min)
510590
def test_refresh(self, unused_utcnow):
511591
response = SUCCESS_RESPONSE.copy()

0 commit comments

Comments
 (0)