Skip to content

Commit 8ff0de5

Browse files
feat: experimental service account iam endpoint flow for id token (#1258)
* feat: experimental service account iam endpoint flow for id token * update * update * update test * address comment
1 parent 71b02aa commit 8ff0de5

File tree

4 files changed

+183
-6
lines changed

4 files changed

+183
-6
lines changed

google/oauth2/_client.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@
4040
_JSON_CONTENT_TYPE = "application/json"
4141
_JWT_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:jwt-bearer"
4242
_REFRESH_GRANT_TYPE = "refresh_token"
43+
_IAM_IDTOKEN_ENDPOINT = (
44+
"https://iamcredentials.googleapis.com/v1/"
45+
+ "projects/-/serviceAccounts/{}:generateIdToken"
46+
)
4347

4448

4549
def _handle_error_response(response_data, retryable_error):
@@ -313,6 +317,44 @@ def jwt_grant(request, token_uri, assertion, can_retry=True):
313317
return access_token, expiry, response_data
314318

315319

320+
def call_iam_generate_id_token_endpoint(request, signer_email, audience, access_token):
321+
"""Call iam.generateIdToken endpoint to get ID token.
322+
323+
Args:
324+
request (google.auth.transport.Request): A callable used to make
325+
HTTP requests.
326+
signer_email (str): The signer email used to form the IAM
327+
generateIdToken endpoint.
328+
audience (str): The audience for the ID token.
329+
access_token (str): The access token used to call the IAM endpoint.
330+
331+
Returns:
332+
Tuple[str, datetime]: The ID token and expiration.
333+
"""
334+
body = {"audience": audience, "includeEmail": "true"}
335+
336+
response_data = _token_endpoint_request(
337+
request,
338+
_IAM_IDTOKEN_ENDPOINT.format(signer_email),
339+
body,
340+
access_token=access_token,
341+
use_json=True,
342+
)
343+
344+
try:
345+
id_token = response_data["token"]
346+
except KeyError as caught_exc:
347+
new_exc = exceptions.RefreshError(
348+
"No ID token in response.", response_data, retryable=False
349+
)
350+
six.raise_from(new_exc, caught_exc)
351+
352+
payload = jwt.decode(id_token, verify=False)
353+
expiry = datetime.datetime.utcfromtimestamp(payload["exp"])
354+
355+
return id_token, expiry
356+
357+
316358
def id_token_jwt_grant(request, token_uri, assertion, can_retry=True):
317359
"""Implements the JWT Profile for OAuth 2.0 Authorization Grants, but
318360
requests an OpenID Connect ID Token instead of an access token.

google/oauth2/service_account.py

Lines changed: 68 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,7 @@ def __init__(
554554
self._token_uri = token_uri
555555
self._target_audience = target_audience
556556
self._quota_project_id = quota_project_id
557+
self._use_iam_endpoint = False
557558

558559
if additional_claims is not None:
559560
self._additional_claims = additional_claims
@@ -639,6 +640,31 @@ def with_target_audience(self, target_audience):
639640
quota_project_id=self.quota_project_id,
640641
)
641642

643+
def _with_use_iam_endpoint(self, use_iam_endpoint):
644+
"""Create a copy of these credentials with the use_iam_endpoint value.
645+
646+
Args:
647+
use_iam_endpoint (bool): If True, IAM generateIdToken endpoint will
648+
be used instead of the token_uri. Note that
649+
iam.serviceAccountTokenCreator role is required to use the IAM
650+
endpoint. The default value is False. This feature is currently
651+
experimental and subject to change without notice.
652+
653+
Returns:
654+
google.auth.service_account.IDTokenCredentials: A new credentials
655+
instance.
656+
"""
657+
cred = self.__class__(
658+
self._signer,
659+
service_account_email=self._service_account_email,
660+
token_uri=self._token_uri,
661+
target_audience=self._target_audience,
662+
additional_claims=self._additional_claims.copy(),
663+
quota_project_id=self.quota_project_id,
664+
)
665+
cred._use_iam_endpoint = use_iam_endpoint
666+
return cred
667+
642668
@_helpers.copy_docstring(credentials.CredentialsWithQuotaProject)
643669
def with_quota_project(self, quota_project_id):
644670
return self.__class__(
@@ -692,14 +718,50 @@ def _make_authorization_grant_assertion(self):
692718

693719
return token
694720

721+
def _refresh_with_iam_endpoint(self, request):
722+
"""Use IAM generateIdToken endpoint to obtain an ID token.
723+
724+
It works as follows:
725+
726+
1. First we create a self signed jwt with
727+
https://www.googleapis.com/auth/iam being the scope.
728+
729+
2. Next we use the self signed jwt as the access token, and make a POST
730+
request to IAM generateIdToken endpoint. The request body is:
731+
{
732+
"audience": self._target_audience,
733+
"includeEmail": "true"
734+
}
735+
TODO: add "set_azp_to_email": "true" once it's ready from server side.
736+
https://github.com/googleapis/google-auth-library-python/issues/1263
737+
738+
If the request is succesfully, it will return {"token":"the ID token"},
739+
and we can extract the ID token and compute its expiry.
740+
"""
741+
jwt_credentials = jwt.Credentials.from_signing_credentials(
742+
self,
743+
None,
744+
additional_claims={"scope": "https://www.googleapis.com/auth/iam"},
745+
)
746+
jwt_credentials.refresh(request)
747+
self.token, self.expiry = _client.call_iam_generate_id_token_endpoint(
748+
request,
749+
self.signer_email,
750+
self._target_audience,
751+
jwt_credentials.token.decode(),
752+
)
753+
695754
@_helpers.copy_docstring(credentials.Credentials)
696755
def refresh(self, request):
697-
assertion = self._make_authorization_grant_assertion()
698-
access_token, expiry, _ = _client.id_token_jwt_grant(
699-
request, self._token_uri, assertion
700-
)
701-
self.token = access_token
702-
self.expiry = expiry
756+
if self._use_iam_endpoint:
757+
self._refresh_with_iam_endpoint(request)
758+
else:
759+
assertion = self._make_authorization_grant_assertion()
760+
access_token, expiry, _ = _client.id_token_jwt_grant(
761+
request, self._token_uri, assertion
762+
)
763+
self.token = access_token
764+
self.expiry = expiry
703765

704766
@property
705767
def service_account_email(self):

tests/oauth2/test__client.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,50 @@ def test_jwt_grant_no_access_token():
305305
assert not excinfo.value.retryable
306306

307307

308+
def test_call_iam_generate_id_token_endpoint():
309+
now = _helpers.utcnow()
310+
id_token_expiry = _helpers.datetime_to_secs(now)
311+
id_token = jwt.encode(SIGNER, {"exp": id_token_expiry}).decode("utf-8")
312+
request = make_request({"token": id_token})
313+
314+
token, expiry = _client.call_iam_generate_id_token_endpoint(
315+
request, "fake_email", "fake_audience", "fake_access_token"
316+
)
317+
318+
assert (
319+
request.call_args[1]["url"]
320+
== "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/fake_email:generateIdToken"
321+
)
322+
assert request.call_args[1]["headers"]["Content-Type"] == "application/json"
323+
assert (
324+
request.call_args[1]["headers"]["Authorization"] == "Bearer fake_access_token"
325+
)
326+
response_body = json.loads(request.call_args[1]["body"])
327+
assert response_body["audience"] == "fake_audience"
328+
assert response_body["includeEmail"] == "true"
329+
330+
# Check result
331+
assert token == id_token
332+
# JWT does not store microseconds
333+
now = now.replace(microsecond=0)
334+
assert expiry == now
335+
336+
337+
def test_call_iam_generate_id_token_endpoint_no_id_token():
338+
request = make_request(
339+
{
340+
# No access token.
341+
"error": "no token"
342+
}
343+
)
344+
345+
with pytest.raises(exceptions.RefreshError) as excinfo:
346+
_client.call_iam_generate_id_token_endpoint(
347+
request, "fake_email", "fake_audience", "fake_access_token"
348+
)
349+
assert excinfo.match("No ID token in response")
350+
351+
308352
def test_id_token_jwt_grant():
309353
now = _helpers.utcnow()
310354
id_token_expiry = _helpers.datetime_to_secs(now)

tests/oauth2/test_service_account.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,7 @@ def test_from_service_account_info(self):
428428
assert credentials.service_account_email == SERVICE_ACCOUNT_INFO["client_email"]
429429
assert credentials._token_uri == SERVICE_ACCOUNT_INFO["token_uri"]
430430
assert credentials._target_audience == self.TARGET_AUDIENCE
431+
assert not credentials._use_iam_endpoint
431432

432433
def test_from_service_account_file(self):
433434
info = SERVICE_ACCOUNT_INFO.copy()
@@ -440,6 +441,7 @@ def test_from_service_account_file(self):
440441
assert credentials._signer.key_id == info["private_key_id"]
441442
assert credentials._token_uri == info["token_uri"]
442443
assert credentials._target_audience == self.TARGET_AUDIENCE
444+
assert not credentials._use_iam_endpoint
443445

444446
def test_default_state(self):
445447
credentials = self.make_credentials()
@@ -466,6 +468,11 @@ def test_with_target_audience(self):
466468
new_credentials = credentials.with_target_audience("https://new.example.com")
467469
assert new_credentials._target_audience == "https://new.example.com"
468470

471+
def test__with_use_iam_endpoint(self):
472+
credentials = self.make_credentials()
473+
new_credentials = credentials._with_use_iam_endpoint(True)
474+
assert new_credentials._use_iam_endpoint
475+
469476
def test_with_quota_project(self):
470477
credentials = self.make_credentials()
471478
new_credentials = credentials.with_quota_project("project-foo")
@@ -517,6 +524,28 @@ def test_refresh_success(self, id_token_jwt_grant):
517524
# expired)
518525
assert credentials.valid
519526

527+
@mock.patch(
528+
"google.oauth2._client.call_iam_generate_id_token_endpoint", autospec=True
529+
)
530+
def test_refresh_iam_flow(self, call_iam_generate_id_token_endpoint):
531+
credentials = self.make_credentials()
532+
credentials._use_iam_endpoint = True
533+
token = "id_token"
534+
call_iam_generate_id_token_endpoint.return_value = (
535+
token,
536+
_helpers.utcnow() + datetime.timedelta(seconds=500),
537+
)
538+
request = mock.Mock()
539+
credentials.refresh(request)
540+
req, signer_email, target_audience, access_token = call_iam_generate_id_token_endpoint.call_args[
541+
0
542+
]
543+
assert req == request
544+
assert signer_email == "[email protected]"
545+
assert target_audience == "https://example.com"
546+
decoded_access_token = jwt.decode(access_token, verify=False)
547+
assert decoded_access_token["scope"] == "https://www.googleapis.com/auth/iam"
548+
520549
@mock.patch("google.oauth2._client.id_token_jwt_grant", autospec=True)
521550
def test_before_request_refreshes(self, id_token_jwt_grant):
522551
credentials = self.make_credentials()

0 commit comments

Comments
 (0)