From d52459563defadffdb5fa25060d7c98593abaa87 Mon Sep 17 00:00:00 2001 From: Ray Luo Date: Sun, 28 Jan 2024 15:30:22 -0800 Subject: [PATCH 1/2] Tolerate ID token time errors --- oauth2cli/__init__.py | 2 +- oauth2cli/oidc.py | 80 ++++++++++++++++++++++++++++++++++++------- tests/test_oidc.py | 21 ++++++++++++ 3 files changed, 89 insertions(+), 14 deletions(-) create mode 100644 tests/test_oidc.py diff --git a/oauth2cli/__init__.py b/oauth2cli/__init__.py index 60bf2595..d9978726 100644 --- a/oauth2cli/__init__.py +++ b/oauth2cli/__init__.py @@ -1,6 +1,6 @@ __version__ = "0.4.0" -from .oidc import Client +from .oidc import Client, IdTokenError from .assertion import JwtAssertionCreator from .assertion import JwtSigner # Obsolete. For backward compatibility. from .authcode import AuthCodeReceiver diff --git a/oauth2cli/oidc.py b/oauth2cli/oidc.py index d4d3a927..01ee7894 100644 --- a/oauth2cli/oidc.py +++ b/oauth2cli/oidc.py @@ -5,9 +5,13 @@ import string import warnings import hashlib +import logging from . import oauth2 + +logger = logging.getLogger(__name__) + def decode_part(raw, encoding="utf-8"): """Decode a part of the JWT. @@ -32,6 +36,45 @@ def decode_part(raw, encoding="utf-8"): base64decode = decode_part # Obsolete. For backward compatibility only. +def _epoch_to_local(epoch): + return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(epoch)) + +class IdTokenError(RuntimeError): # We waised RuntimeError before, so keep it + """In unlikely event of an ID token is malformed, this exception will be raised.""" + def __init__(self, reason, now, claims): + super(IdTokenError, self).__init__( + "%s Current epoch = %s. The id_token was approximately: %s" % ( + reason, _epoch_to_local(now), json.dumps(dict( + claims, + iat=_epoch_to_local(claims["iat"]) if claims.get("iat") else None, + exp=_epoch_to_local(claims["exp"]) if claims.get("exp") else None, + ), indent=2))) + +class _IdTokenTimeError(IdTokenError): # This is not intended to be raised and caught + _SUGGESTION = "Make sure your computer's time and time zone are both correct." + def __init__(self, reason, now, claims): + super(_IdTokenTimeError, self).__init__(reason+ " " + self._SUGGESTION, now, claims) + def log(self): + # Influenced by JWT specs https://tools.ietf.org/html/rfc7519#section-4.1.5 + # and OIDC specs https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation + # We used to raise this error, but now we just log it as warning, because: + # 1. If it is caused by incorrect local machine time, + # then the token(s) are still correct and probably functioning, + # so, there is no point to error out. + # 2. If it is caused by incorrect IdP time, then it is IdP's fault, + # There is not much a client can do, so, we might as well return the token(s) + # and let downstream components to decide what to do. + logger.warning(str(self)) + +class IdTokenIssuerError(IdTokenError): + pass + +class IdTokenAudienceError(IdTokenError): + pass + +class IdTokenNonceError(IdTokenError): + pass + def decode_id_token(id_token, client_id=None, issuer=None, nonce=None, now=None): """Decodes and validates an id_token and returns its claims as a dictionary. @@ -41,41 +84,52 @@ def decode_id_token(id_token, client_id=None, issuer=None, nonce=None, now=None) `maybe more `_ """ decoded = json.loads(decode_part(id_token.split('.')[1])) - err = None # https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation + # Based on https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation _now = int(now or time.time()) skew = 120 # 2 minutes - TIME_SUGGESTION = "Make sure your computer's time and time zone are both correct." + if _now + skew < decoded.get("nbf", _now - 1): # nbf is optional per JWT specs # This is not an ID token validation, but a JWT validation # https://tools.ietf.org/html/rfc7519#section-4.1.5 - err = "0. The ID token is not yet valid. " + TIME_SUGGESTION + _IdTokenTimeError("0. The ID token is not yet valid.", _now, decoded).log() + if issuer and issuer != decoded["iss"]: # https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderConfigurationResponse - err = ('2. The Issuer Identifier for the OpenID Provider, "%s", ' + raise IdTokenIssuerError( + '2. The Issuer Identifier for the OpenID Provider, "%s", ' "(which is typically obtained during Discovery), " - "MUST exactly match the value of the iss (issuer) Claim.") % issuer + "MUST exactly match the value of the iss (issuer) Claim." % issuer, + _now, + decoded) + if client_id: valid_aud = client_id in decoded["aud"] if isinstance( decoded["aud"], list) else client_id == decoded["aud"] if not valid_aud: - err = ( + raise IdTokenAudienceError( "3. The aud (audience) claim must contain this client's client_id " '"%s", case-sensitively. Was your client_id in wrong casing?' # Some IdP accepts wrong casing request but issues right casing IDT - ) % client_id + % client_id, + _now, + decoded) + # Per specs: # 6. If the ID Token is received via direct communication between # the Client and the Token Endpoint (which it is during _obtain_token()), # the TLS server validation MAY be used to validate the issuer # in place of checking the token signature. + if _now - skew > decoded["exp"]: - err = "9. The ID token already expires. " + TIME_SUGGESTION + _IdTokenTimeError("9. The ID token already expires.", _now, decoded).log() + if nonce and nonce != decoded.get("nonce"): - err = ("11. Nonce must be the same value " - "as the one that was sent in the Authentication Request.") - if err: - raise RuntimeError("%s Current epoch = %s. The id_token was: %s" % ( - err, _now, json.dumps(decoded, indent=2))) + raise IdTokenNonceError( + "11. Nonce must be the same value " + "as the one that was sent in the Authentication Request.", + _now, + decoded) + return decoded diff --git a/tests/test_oidc.py b/tests/test_oidc.py new file mode 100644 index 00000000..d6a929bc --- /dev/null +++ b/tests/test_oidc.py @@ -0,0 +1,21 @@ +from tests import unittest + +import oauth2cli + + +class TestIdToken(unittest.TestCase): + EXPIRED_ID_TOKEN = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJpc3N1ZXIiLCJpYXQiOjE3MDY1NzA3MzIsImV4cCI6MTY3NDk0ODMzMiwiYXVkIjoiZm9vIiwic3ViIjoic3ViamVjdCJ9.wyWNFxnE35SMP6FpxnWZmWQAy4KD0No_Q1rUy5bNnLs" + + def test_id_token_should_tolerate_time_error(self): + self.assertEqual(oauth2cli.oidc.decode_id_token(self.EXPIRED_ID_TOKEN), { + "iss": "issuer", + "iat": 1706570732, + "exp": 1674948332, # 2023-1-28 + "aud": "foo", + "sub": "subject", + }, "id_token is decoded correctly, without raising exception") + + def test_id_token_should_error_out_on_client_id_error(self): + with self.assertRaises(oauth2cli.IdTokenError): + oauth2cli.oidc.decode_id_token(self.EXPIRED_ID_TOKEN, client_id="not foo") + From 386ea2e02a533373ab2d557da6d5aa55a748d7d3 Mon Sep 17 00:00:00 2001 From: Ray Luo Date: Thu, 25 Jan 2024 23:48:09 -0800 Subject: [PATCH 2/2] Tolerate ID token time errors --- docs/index.rst | 9 +++++++++ msal/__init__.py | 3 +-- tests/test_oidc.py | 5 +++-- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/docs/index.rst b/docs/index.rst index 2129e106..11dd9b05 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -148,3 +148,12 @@ New in MSAL Python 1.26 .. automethod:: __init__ + +Exceptions +---------- +These are exceptions that MSAL Python may raise. +You should not need to create them directly. +You may want to catch them to provide a better error message to your end users. + +.. autoclass:: msal.IdTokenError + diff --git a/msal/__init__.py b/msal/__init__.py index 09b7a504..5c5292fa 100644 --- a/msal/__init__.py +++ b/msal/__init__.py @@ -31,7 +31,6 @@ ConfidentialClientApplication, PublicClientApplication, ) -from .oauth2cli.oidc import Prompt +from .oauth2cli.oidc import Prompt, IdTokenError from .token_cache import TokenCache, SerializableTokenCache from .auth_scheme import PopAuthScheme - diff --git a/tests/test_oidc.py b/tests/test_oidc.py index d6a929bc..297dfeb5 100644 --- a/tests/test_oidc.py +++ b/tests/test_oidc.py @@ -1,6 +1,7 @@ from tests import unittest -import oauth2cli +import msal +from msal import oauth2cli class TestIdToken(unittest.TestCase): @@ -16,6 +17,6 @@ def test_id_token_should_tolerate_time_error(self): }, "id_token is decoded correctly, without raising exception") def test_id_token_should_error_out_on_client_id_error(self): - with self.assertRaises(oauth2cli.IdTokenError): + with self.assertRaises(msal.IdTokenError): oauth2cli.oidc.decode_id_token(self.EXPIRED_ID_TOKEN, client_id="not foo")