diff --git a/.gitignore b/.gitignore index e776c10e..18dae08c 100644 --- a/.gitignore +++ b/.gitignore @@ -57,3 +57,5 @@ docs/_build/ # The test configuration file(s) could potentially contain credentials tests/config.json + +.env \ No newline at end of file diff --git a/README.md b/README.md index eb9e6248..e193f257 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,8 @@ # Microsoft Authentication Library (MSAL) for Python - | `dev` branch | Reference Docs | # of Downloads per different platforms | # of Downloads per recent MSAL versions | |---------------|---------------|----------------------------------------|-----------------------------------------| - [![Build status](https://api.travis-ci.org/AzureAD/microsoft-authentication-library-for-python.svg?branch=dev)](https://travis-ci.org/AzureAD/microsoft-authentication-library-for-python) | [![Documentation Status](https://readthedocs.org/projects/msal-python/badge/?version=latest)](https://msal-python.readthedocs.io/en/latest/?badge=latest) | [![Downloads](https://pepy.tech/badge/msal)](https://pypistats.org/packages/msal) | [![Download monthly](https://pepy.tech/badge/msal/month)](https://pepy.tech/project/msal) + [![Build status](https://github.com/AzureAD/microsoft-authentication-library-for-python/actions/workflows/python-package.yml/badge.svg?branch=dev)](https://github.com/AzureAD/microsoft-authentication-library-for-python/actions) | [![Documentation Status](https://readthedocs.org/projects/msal-python/badge/?version=latest)](https://msal-python.readthedocs.io/en/latest/?badge=latest) | [![Downloads](https://pepy.tech/badge/msal)](https://pypistats.org/packages/msal) | [![Download monthly](https://pepy.tech/badge/msal/month)](https://pepy.tech/project/msal) The Microsoft Authentication Library for Python enables applications to integrate with the [Microsoft identity platform](https://aka.ms/aaddevv2). It allows you to sign in users or apps with Microsoft identities ([Azure AD](https://azure.microsoft.com/services/active-directory/), [Microsoft Accounts](https://account.microsoft.com) and [Azure AD B2C](https://azure.microsoft.com/services/active-directory-b2c/) accounts) and obtain tokens to call Microsoft APIs such as [Microsoft Graph](https://graph.microsoft.io/) or your own APIs registered with the Microsoft identity platform. It is built using industry standard OAuth2 and OpenID Connect protocols diff --git a/msal/application.py b/msal/application.py index 35bd91a1..c7a3471f 100644 --- a/msal/application.py +++ b/msal/application.py @@ -9,6 +9,7 @@ import sys import warnings from threading import Lock +import os import requests @@ -20,10 +21,11 @@ from .token_cache import TokenCache import msal.telemetry from .region import _detect_region +from .throttled_http_client import ThrottledHttpClient # The __init__.py will import this. Not the other way around. -__version__ = "1.13.0" +__version__ = "1.14.0" logger = logging.getLogger(__name__) @@ -69,6 +71,46 @@ def _clean_up(result): return result +def _preferred_browser(): + """Register Edge and return a name suitable for subsequent webbrowser.get(...) + when appropriate. Otherwise return None. + """ + # On Linux, only Edge will provide device-based Conditional Access support + if sys.platform != "linux": # On other platforms, we have no browser preference + return None + browser_path = "/usr/bin/microsoft-edge" # Use a full path owned by sys admin + user_has_no_preference = "BROWSER" not in os.environ + user_wont_mind_edge = "microsoft-edge" in os.environ.get("BROWSER", "") # Note: + # BROWSER could contain "microsoft-edge" or "/path/to/microsoft-edge". + # Python documentation (https://docs.python.org/3/library/webbrowser.html) + # does not document the name being implicitly register, + # so there is no public API to know whether the ENV VAR browser would work. + # Therefore, we would not bother examine the env var browser's type. + # We would just register our own Edge instance. + if (user_has_no_preference or user_wont_mind_edge) and os.path.exists(browser_path): + try: + import webbrowser # Lazy import. Some distro may not have this. + browser_name = "msal-edge" # Avoid popular name "microsoft-edge" + # otherwise `BROWSER="microsoft-edge"; webbrowser.get("microsoft-edge")` + # would return a GenericBrowser instance which won't work. + try: + registration_available = isinstance( + webbrowser.get(browser_name), webbrowser.BackgroundBrowser) + except webbrowser.Error: + registration_available = False + if not registration_available: + logger.debug("Register %s with %s", browser_name, browser_path) + # By registering our own browser instance with our own name, + # rather than populating a process-wide BROWSER enn var, + # this approach does not have side effect on non-MSAL code path. + webbrowser.register( # Even double-register happens to work fine + browser_name, None, webbrowser.BackgroundBrowser(browser_path)) + return browser_name + except ImportError: + pass # We may still proceed + return None + + class ClientApplication(object): ACQUIRE_TOKEN_SILENT_ID = "84" @@ -295,6 +337,10 @@ def __init__( a = requests.adapters.HTTPAdapter(max_retries=1) self.http_client.mount("http://", a) self.http_client.mount("https://", a) + self.http_client = ThrottledHttpClient( + self.http_client, + {} # Hard code an in-memory cache, for now + ) self.app_name = app_name self.app_version = app_version @@ -371,7 +417,7 @@ def _get_regional_authority(self, central_authority): self._region_configured if is_region_specified else self._region_detected) if region_to_use: logger.info('Region to be used: {}'.format(repr(region_to_use))) - regional_host = ("{}.login.microsoft.com".format(region_to_use) + regional_host = ("{}.r.login.microsoftonline.com".format(region_to_use) if central_authority.instance in ( # The list came from https://github.com/AzureAD/microsoft-authentication-library-for-python/pull/358/files#r629400328 "login.microsoftonline.com", @@ -392,6 +438,7 @@ def _build_client(self, client_credential, authority, skip_regional_client=False "x-client-sku": "MSAL.Python", "x-client-ver": __version__, "x-client-os": sys.platform, "x-client-cpu": "x64" if sys.maxsize > 2 ** 32 else "x86", + "x-ms-lib-capability": "retry-after, h429", } if self.app_name: default_headers['x-app-name'] = self.app_name @@ -1393,6 +1440,7 @@ def acquire_token_interactive( }, data=dict(kwargs.pop("data", {}), claims=claims), headers=telemetry_context.generate_headers(), + browser_name=_preferred_browser(), **kwargs)) telemetry_context.update_telemetry(response) return response diff --git a/msal/individual_cache.py b/msal/individual_cache.py new file mode 100644 index 00000000..4c6fa00e --- /dev/null +++ b/msal/individual_cache.py @@ -0,0 +1,286 @@ +from functools import wraps +import time +try: + from collections.abc import MutableMapping # Python 3.3+ +except ImportError: + from collections import MutableMapping # Python 2.7+ +import heapq +from threading import Lock + + +class _ExpiringMapping(MutableMapping): + _INDEX = "_index_" + + def __init__(self, mapping=None, capacity=None, expires_in=None, lock=None, + *args, **kwargs): + """Items in this mapping can have individual shelf life, + just like food items in your refrigerator have their different shelf life + determined by each food, not by the refrigerator. + + Expired items will be automatically evicted. + The clean-up will be done at each time when adding a new item, + or when looping or counting the entire mapping. + (This is better than being done indecisively by a background thread, + which might not always happen before your accessing the mapping.) + + This implementation uses no dependency other than Python standard library. + + :param MutableMapping mapping: + A dict-like key-value mapping, which needs to support __setitem__(), + __getitem__(), __delitem__(), get(), pop(). + + The default mapping is an in-memory dict. + + You could potentially supply a file-based dict-like object, too. + This implementation deliberately avoid mapping.__iter__(), + which could be slow on a file-based mapping. + + :param int capacity: + How many items this mapping will hold. + When you attempt to add new item into a full mapping, + it will automatically delete the item that is expiring soonest. + + The default value is None, which means there is no capacity limit. + + :param int expires_in: + How many seconds an item would expire and be purged from this mapping. + Also known as time-to-live (TTL). + You can also use :func:`~set()` to provide per-item expires_in value. + + :param Lock lock: + A locking mechanism with context manager interface. + If no lock is provided, a threading.Lock will be used. + But you may want to supply a different lock, + if your customized mapping is being shared differently. + """ + super(_ExpiringMapping, self).__init__(*args, **kwargs) + self._mapping = mapping if mapping is not None else {} + self._capacity = capacity + self._expires_in = expires_in + self._lock = Lock() if lock is None else lock + + def _validate_key(self, key): + if key == self._INDEX: + raise ValueError("key {} is a reserved keyword in {}".format( + key, self.__class__.__name__)) + + def set(self, key, value, expires_in): + # This method's name was chosen so that it matches its cousin __setitem__(), + # and it also complements the counterpart get(). + # The downside is such a name shadows the built-in type set in this file, + # but you can overcome that by defining a global alias for set. + """It sets the key-value pair into this mapping, with its per-item expires_in. + + It will take O(logN) time, because it will run some maintenance. + This worse-than-constant time is acceptable, because in a cache scenario, + __setitem__() would only be called during a cache miss, + which would already incur an expensive target function call anyway. + + By the way, most other methods of this mapping still have O(1) constant time. + """ + with self._lock: + self._set(key, value, expires_in) + + def _set(self, key, value, expires_in): + # This internal implementation powers both set() and __setitem__(), + # so that they don't depend on each other. + self._validate_key(key) + sequence, timestamps = self._mapping.get(self._INDEX, ([], {})) + self._maintenance(sequence, timestamps) # O(logN) + now = int(time.time()) + expires_at = now + expires_in + entry = [expires_at, now, key] + is_new_item = key not in timestamps + is_beyond_capacity = self._capacity and len(timestamps) >= self._capacity + if is_new_item and is_beyond_capacity: + self._drop_indexed_entry(timestamps, heapq.heappushpop(sequence, entry)) + else: # Simply add new entry. The old one would become a harmless orphan. + heapq.heappush(sequence, entry) + timestamps[key] = [expires_at, now] # It overwrites existing key, if any + self._mapping[key] = value + self._mapping[self._INDEX] = sequence, timestamps + + def _maintenance(self, sequence, timestamps): # O(logN) + """It will modify input sequence and timestamps in-place""" + now = int(time.time()) + while sequence: # Clean up expired items + expires_at, created_at, key = sequence[0] + if created_at <= now < expires_at: # Then all remaining items are fresh + break + self._drop_indexed_entry(timestamps, sequence[0]) # It could error out + heapq.heappop(sequence) # Only pop it after a successful _drop_indexed_entry() + while self._capacity is not None and len(timestamps) > self._capacity: + self._drop_indexed_entry(timestamps, sequence[0]) # It could error out + heapq.heappop(sequence) # Only pop it after a successful _drop_indexed_entry() + + def _drop_indexed_entry(self, timestamps, entry): + """For an entry came from index, drop it from timestamps and self._mapping""" + expires_at, created_at, key = entry + if [expires_at, created_at] == timestamps.get(key): # So it is not an orphan + self._mapping.pop(key, None) # It could raise exception + timestamps.pop(key, None) # This would probably always succeed + + def __setitem__(self, key, value): + """Implements the __setitem__(). + + Same characteristic as :func:`~set()`, + but use class-wide expires_in which was specified by :func:`~__init__()`. + """ + if self._expires_in is None: + raise ValueError("Need a numeric value for expires_in during __init__()") + with self._lock: + self._set(key, value, self._expires_in) + + def __getitem__(self, key): # O(1) + """If the item you requested already expires, KeyError will be raised.""" + self._validate_key(key) + with self._lock: + # Skip self._maintenance(), because it would need O(logN) time + sequence, timestamps = self._mapping.get(self._INDEX, ([], {})) + expires_at, created_at = timestamps[key] # Would raise KeyError accordingly + now = int(time.time()) + if not created_at <= now < expires_at: + self._mapping.pop(key, None) + timestamps.pop(key, None) + self._mapping[self._INDEX] = sequence, timestamps + raise KeyError("{} {}".format( + key, + "expired" if now >= expires_at else "created in the future?", + )) + return self._mapping[key] # O(1) + + def __delitem__(self, key): # O(1) + """If the item you requested already expires, KeyError will be raised.""" + self._validate_key(key) + with self._lock: + # Skip self._maintenance(), because it would need O(logN) time + self._mapping.pop(key, None) # O(1) + sequence, timestamps = self._mapping.get(self._INDEX, ([], {})) + del timestamps[key] # O(1) + self._mapping[self._INDEX] = sequence, timestamps + + def __len__(self): # O(logN) + """Drop all expired items and return the remaining length""" + with self._lock: + sequence, timestamps = self._mapping.get(self._INDEX, ([], {})) + self._maintenance(sequence, timestamps) # O(logN) + self._mapping[self._INDEX] = sequence, timestamps + return len(timestamps) # Faster than iter(self._mapping) when it is on disk + + def __iter__(self): + """Drop all expired items and return an iterator of the remaining items""" + with self._lock: + sequence, timestamps = self._mapping.get(self._INDEX, ([], {})) + self._maintenance(sequence, timestamps) # O(logN) + self._mapping[self._INDEX] = sequence, timestamps + return iter(timestamps) # Faster than iter(self._mapping) when it is on disk + + +class _IndividualCache(object): + # The code structure below can decorate both function and method. + # It is inspired by https://stackoverflow.com/a/9417088 + # We may potentially switch to build upon + # https://github.com/micheles/decorator/blob/master/docs/documentation.md#statement-of-the-problem + def __init__(self, mapping=None, key_maker=None, expires_in=None): + """Constructs a cache decorator that allows item-by-item control on + how to cache the return value of the decorated function. + + :param MutableMapping mapping: + The cached items will be stored inside. + You'd want to use a ExpiringMapping + if you plan to utilize the ``expires_in`` behavior. + + If nothing is provided, an in-memory dict will be used, + but it will provide no expiry functionality. + + .. note:: + + When using this class as a decorator, + your mapping needs to be available at "compile" time, + so it would typically be a global-, module- or class-level mapping:: + + module_mapping = {} + + @IndividualCache(mapping=module_mapping, ...) + def foo(): + ... + + If you want to use a mapping available only at run-time, + you have to manually decorate your function at run-time, too:: + + def foo(): + ... + + def bar(runtime_mapping): + foo = IndividualCache(mapping=runtime_mapping...)(foo) + + :param callable key_maker: + A callable which should have signature as + ``lambda function, args, kwargs: "return a string as key"``. + + If key_maker happens to return ``None``, the cache will be bypassed, + the underlying function will be invoked directly, + and the invoke result will not be cached either. + + :param callable expires_in: + The default value is ``None``, + which means the content being cached has no per-item expiry, + and will subject to the underlying mapping's global expiry time. + + It can be an integer indicating + how many seconds the result will be cached. + In particular, if the value is 0, + it means the result expires after zero second (i.e. immediately), + therefore the result will *not* be cached. + (Mind the difference between ``expires_in=0`` and ``expires_in=None``.) + + Or it can be a callable with the signature as + ``lambda function=function, args=args, kwargs=kwargs, result=result: 123`` + to calculate the expiry on the fly. + Its return value will be interpreted in the same way as above. + """ + self._mapping = mapping if mapping is not None else {} + self._key_maker = key_maker or (lambda function, args, kwargs: ( + function, # This default implementation uses function as part of key, + # so that the cache is partitioned by function. + # However, you could have many functions to use same namespace, + # so different decorators could share same cache. + args, + tuple(kwargs.items()), # raw kwargs is not hashable + )) + self._expires_in = expires_in + + def __call__(self, function): + + @wraps(function) + def wrapper(*args, **kwargs): + key = self._key_maker(function, args, kwargs) + if key is None: # Then bypass the cache + return function(*args, **kwargs) + + now = int(time.time()) + try: + return self._mapping[key] + except KeyError: + # We choose to NOT call function(...) in this block, otherwise + # potential exception from function(...) would become a confusing + # "During handling of the above exception, another exception occurred" + pass + value = function(*args, **kwargs) + + expires_in = self._expires_in( + function=function, + args=args, + kwargs=kwargs, + result=value, + ) if callable(self._expires_in) else self._expires_in + if expires_in == 0: + return value + if expires_in is None: + self._mapping[key] = value + else: + self._mapping.set(key, value, expires_in) + return value + + return wrapper + diff --git a/msal/oauth2cli/authcode.py b/msal/oauth2cli/authcode.py index 25c337c4..24e3f642 100644 --- a/msal/oauth2cli/authcode.py +++ b/msal/oauth2cli/authcode.py @@ -45,9 +45,14 @@ def is_wsl(): return platform_name == 'linux' and 'microsoft' in release -def _browse(auth_uri): # throws ImportError, possibly webbrowser.Error in future +def _browse(auth_uri, browser_name=None): # throws ImportError, webbrowser.Error + """Browse uri with named browser. Default browser is customizable by $BROWSER""" import webbrowser # Lazy import. Some distro may not have this. - browser_opened = webbrowser.open(auth_uri) # Use default browser. Customizable by $BROWSER + if browser_name: + browser_opened = webbrowser.get(browser_name).open(auth_uri) + else: + # This one can survive BROWSER=nonexist, while get(None).open(...) can not + browser_opened = webbrowser.open(auth_uri) # In WSL which doesn't have www-browser, try launching browser with PowerShell if not browser_opened and is_wsl(): @@ -147,6 +152,7 @@ def get_port(self): def get_auth_response(self, auth_uri=None, timeout=None, state=None, welcome_template=None, success_template=None, error_template=None, auth_uri_callback=None, + browser_name=None, ): """Wait and return the auth response. Raise RuntimeError when timeout. @@ -173,6 +179,12 @@ def get_auth_response(self, auth_uri=None, timeout=None, state=None, A function with the shape of lambda auth_uri: ... When a browser was unable to be launch, this function will be called, so that the app could tell user to manually visit the auth_uri. + :param str browser_name: + If you did + ``webbrowser.register("xyz", None, BackgroundBrowser("/path/to/browser"))`` + beforehand, you can pass in the name "xyz" to use that browser. + The default value ``None`` means using default browser, + which is customizable by env var $BROWSER. :return: The auth response of the first leg of Auth Code flow, typically {"code": "...", "state": "..."} or {"error": "...", ...} @@ -190,7 +202,7 @@ def get_auth_response(self, auth_uri=None, timeout=None, state=None, logger.info("Open a browser on this device to visit: %s" % _uri) browser_opened = False try: - browser_opened = _browse(_uri) + browser_opened = _browse(_uri, browser_name=browser_name) except: # Had to use broad except, because the potential # webbrowser.Error is purposely undefined outside of _browse(). # Absorb and proceed. Because browser could be manually run elsewhere. diff --git a/msal/oauth2cli/oauth2.py b/msal/oauth2cli/oauth2.py index 04a6b70d..305061cf 100644 --- a/msal/oauth2cli/oauth2.py +++ b/msal/oauth2cli/oauth2.py @@ -579,8 +579,10 @@ def obtain_token_by_browser( timeout=None, welcome_template=None, success_template=None, + error_template=None, auth_params=None, auth_uri_callback=None, + browser_name=None, **kwargs): """A native app can use this method to obtain token via a local browser. @@ -613,6 +615,14 @@ def obtain_token_by_browser( These parameters will be sent to authorization_endpoint. :param int timeout: In seconds. None means wait indefinitely. + + :param str browser_name: + If you did + ``webbrowser.register("xyz", None, BackgroundBrowser("/path/to/browser"))`` + beforehand, you can pass in the name "xyz" to use that browser. + The default value ``None`` means using default browser, + which is customizable by env var $BROWSER. + :return: Same as :func:`~obtain_token_by_auth_code_flow()` """ _redirect_uri = urlparse(redirect_uri or "http://127.0.0.1:0") @@ -638,7 +648,9 @@ def obtain_token_by_browser( timeout=timeout, welcome_template=welcome_template, success_template=success_template, + error_template=error_template, auth_uri_callback=auth_uri_callback, + browser_name=browser_name, ) except PermissionError: if 0 < listen_port < 1024: diff --git a/msal/region.py b/msal/region.py index 6ad84c45..dacd49d7 100644 --- a/msal/region.py +++ b/msal/region.py @@ -5,14 +5,9 @@ def _detect_region(http_client=None): - region = _detect_region_of_azure_function() # It is cheap, so we do it always - if http_client and not region: + if http_client: return _detect_region_of_azure_vm(http_client) # It could hang for minutes - return region - - -def _detect_region_of_azure_function(): - return os.environ.get("REGION_NAME") + return None def _detect_region_of_azure_vm(http_client): diff --git a/msal/throttled_http_client.py b/msal/throttled_http_client.py new file mode 100644 index 00000000..d30eda5e --- /dev/null +++ b/msal/throttled_http_client.py @@ -0,0 +1,140 @@ +from threading import Lock +from hashlib import sha256 + +from .individual_cache import _IndividualCache as IndividualCache +from .individual_cache import _ExpiringMapping as ExpiringMapping + + +# https://datatracker.ietf.org/doc/html/rfc8628#section-3.4 +DEVICE_AUTH_GRANT = "urn:ietf:params:oauth:grant-type:device_code" + + +def _hash(raw): + return sha256(repr(raw).encode("utf-8")).hexdigest() + + +def _parse_http_429_5xx_retry_after(result=None, **ignored): + """Return seconds to throttle""" + assert result is not None, """ + The signature defines it with a default value None, + only because the its shape is already decided by the + IndividualCache's.__call__(). + In actual code path, the result parameter here won't be None. + """ + response = result + lowercase_headers = {k.lower(): v for k, v in getattr( + # Historically, MSAL's HttpResponse does not always have headers + response, "headers", {}).items()} + if not (response.status_code == 429 or response.status_code >= 500 + or "retry-after" in lowercase_headers): + return 0 # Quick exit + default = 60 # Recommended at the end of + # https://identitydivision.visualstudio.com/devex/_git/AuthLibrariesApiReview?version=GBdev&path=%2FService%20protection%2FIntial%20set%20of%20protection%20measures.md&_a=preview + retry_after = int(lowercase_headers.get("retry-after", default)) + try: + # AAD's retry_after uses integer format only + # https://stackoverflow.microsoft.com/questions/264931/264932 + delay_seconds = int(retry_after) + except ValueError: + delay_seconds = default + return min(3600, delay_seconds) + + +def _extract_data(kwargs, key, default=None): + data = kwargs.get("data", {}) # data is usually a dict, but occasionally a string + return data.get(key) if isinstance(data, dict) else default + + +class ThrottledHttpClient(object): + def __init__(self, http_client, http_cache): + """Throttle the given http_client by storing and retrieving data from cache. + + This wrapper exists so that our patching post() and get() would prevent + re-patching side effect when/if same http_client being reused. + """ + expiring_mapping = ExpiringMapping( # It will automatically clean up + mapping=http_cache if http_cache is not None else {}, + capacity=1024, # To prevent cache blowing up especially for CCA + lock=Lock(), # TODO: This should ideally also allow customization + ) + + _post = http_client.post # We'll patch _post, and keep original post() intact + + _post = IndividualCache( + # Internal specs requires throttling on at least token endpoint, + # here we have a generic patch for POST on all endpoints. + mapping=expiring_mapping, + key_maker=lambda func, args, kwargs: + "POST {} client_id={} scope={} hash={} 429/5xx/Retry-After".format( + args[0], # It is the url, typically containing authority and tenant + _extract_data(kwargs, "client_id"), # Per internal specs + _extract_data(kwargs, "scope"), # Per internal specs + _hash( + # The followings are all approximations of the "account" concept + # to support per-account throttling. + # TODO: We may want to disable it for confidential client, though + _extract_data(kwargs, "refresh_token", # "account" during refresh + _extract_data(kwargs, "code", # "account" of auth code grant + _extract_data(kwargs, "username")))), # "account" of ROPC + ), + expires_in=_parse_http_429_5xx_retry_after, + )(_post) + + _post = IndividualCache( # It covers the "UI required cache" + mapping=expiring_mapping, + key_maker=lambda func, args, kwargs: "POST {} hash={} 400".format( + args[0], # It is the url, typically containing authority and tenant + _hash( + # Here we use literally all parameters, even those short-lived + # parameters containing timestamps (WS-Trust or POP assertion), + # because they will automatically be cleaned up by ExpiringMapping. + # + # Furthermore, there is no need to implement + # "interactive requests would reset the cache", + # because acquire_token_silent()'s would be automatically unblocked + # due to token cache layer operates on top of http cache layer. + # + # And, acquire_token_silent(..., force_refresh=True) will NOT + # bypass http cache, because there is no real gain from that. + # We won't bother implement it, nor do we want to encourage + # acquire_token_silent(..., force_refresh=True) pattern. + str(kwargs.get("params")) + str(kwargs.get("data"))), + ), + expires_in=lambda result=None, data=None, **ignored: + 60 + if result.status_code == 400 + # Here we choose to cache exact HTTP 400 errors only (rather than 4xx) + # because they are the ones defined in OAuth2 + # (https://datatracker.ietf.org/doc/html/rfc6749#section-5.2) + # Other 4xx errors might have different requirements e.g. + # "407 Proxy auth required" would need a key including http headers. + and not( # Exclude Device Flow cause its retry is expected and regulated + isinstance(data, dict) and data.get("grant_type") == DEVICE_AUTH_GRANT + ) + and "retry-after" not in set( # Leave it to the Retry-After decorator + h.lower() for h in getattr(result, "headers", {}).keys()) + else 0, + )(_post) + + self.post = _post + + self.get = IndividualCache( # Typically those discovery GETs + mapping=expiring_mapping, + key_maker=lambda func, args, kwargs: "GET {} hash={} 2xx".format( + args[0], # It is the url, sometimes containing inline params + _hash(kwargs.get("params", "")), + ), + expires_in=lambda result=None, **ignored: + 3600*24 if 200 <= result.status_code < 300 else 0, + )(http_client.get) + + self._http_client = http_client + + # The following 2 methods have been defined dynamically by __init__() + #def post(self, *args, **kwargs): pass + #def get(self, *args, **kwargs): pass + + def close(self): + """MSAL won't need this. But we allow throttled_http_client.close() anyway""" + return self._http_client.close() + diff --git a/msal/token_cache.py b/msal/token_cache.py index 5b31b299..2ed819d7 100644 --- a/msal/token_cache.py +++ b/msal/token_cache.py @@ -122,6 +122,19 @@ def wipe(dictionary, sensitive_fields): # Masks sensitive info default=str, # A workaround when assertion is in bytes in Python 3 )) + def __parse_account(self, response, id_token_claims): + """Return client_info and home_account_id""" + if "client_info" in response: # It happens when client_info and profile are in request + client_info = json.loads(decode_part(response["client_info"])) + if "uid" in client_info and "utid" in client_info: + return client_info, "{uid}.{utid}".format(**client_info) + # https://github.com/AzureAD/microsoft-authentication-library-for-python/issues/387 + if id_token_claims: # This would be an end user on ADFS-direct scenario + sub = id_token_claims["sub"] # "sub" always exists, per OIDC specs + return {"uid": sub}, sub + # client_credentials flow will reach this code path + return {}, None + def __add(self, event, now=None): # event typically contains: client_id, scope, token_endpoint, # response, params, data, grant_type @@ -138,14 +151,7 @@ def __add(self, event, now=None): id_token_claims = ( decode_id_token(id_token, client_id=event["client_id"]) if id_token else {}) - client_info = {} - home_account_id = None # It would remain None in client_credentials flow - if "client_info" in response: # We asked for it, and AAD will provide it - client_info = json.loads(decode_part(response["client_info"])) - home_account_id = "{uid}.{utid}".format(**client_info) - elif id_token_claims: # This would be an end user on ADFS-direct scenario - client_info["uid"] = id_token_claims.get("sub") - home_account_id = id_token_claims.get("sub") + client_info, home_account_id = self.__parse_account(response, id_token_claims) target = ' '.join(event.get("scope") or []) # Per schema, we don't sort it diff --git a/requirements.txt b/requirements.txt index 9c558e35..d078afb9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ . +python-dotenv diff --git a/tests/test_application.py b/tests/test_application.py index ea98b16f..5a92c8d4 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -5,7 +5,7 @@ import msal from msal.application import _merge_claims_challenge_and_capabilities from tests import unittest -from tests.test_token_cache import TokenCacheTestCase +from tests.test_token_cache import build_id_token, build_response from tests.http_client import MinimalHttpClient, MinimalResponse from msal.telemetry import CLIENT_CURRENT_TELEMETRY, CLIENT_LAST_TELEMETRY @@ -66,7 +66,7 @@ def setUp(self): "client_id": self.client_id, "scope": self.scopes, "token_endpoint": "{}/oauth2/v2.0/token".format(self.authority_url), - "response": TokenCacheTestCase.build_response( + "response": build_response( access_token="an expired AT to trigger refresh", expires_in=-99, uid=self.uid, utid=self.utid, refresh_token=self.rt), }) # The add(...) helper populates correct home_account_id for future searching @@ -125,9 +125,9 @@ def setUp(self): "client_id": self.preexisting_family_app_id, "scope": self.scopes, "token_endpoint": "{}/oauth2/v2.0/token".format(self.authority_url), - "response": TokenCacheTestCase.build_response( + "response": build_response( access_token="Siblings won't share AT. test_remove_account() will.", - id_token=TokenCacheTestCase.build_id_token(aud=self.preexisting_family_app_id), + id_token=build_id_token(aud=self.preexisting_family_app_id), uid=self.uid, utid=self.utid, refresh_token=self.frt, foci="1"), }) # The add(...) helper populates correct home_account_id for future searching @@ -153,8 +153,7 @@ def test_known_orphan_app_will_skip_frt_and_only_use_its_own_rt(self): "client_id": app.client_id, "scope": self.scopes, "token_endpoint": "{}/oauth2/v2.0/token".format(self.authority_url), - "response": TokenCacheTestCase.build_response( - uid=self.uid, utid=self.utid, refresh_token=rt), + "response": build_response(uid=self.uid, utid=self.utid, refresh_token=rt), }) logger.debug("%s.cache = %s", self.id(), self.cache.serialize()) def tester(url, data=None, **kwargs): @@ -168,7 +167,7 @@ def tester(url, data=None, **kwargs): self.assertEqual( self.frt, data.get("refresh_token"), "Should attempt the FRT") return MinimalResponse( - status_code=200, text=json.dumps(TokenCacheTestCase.build_response( + status_code=200, text=json.dumps(build_response( uid=self.uid, utid=self.utid, foci="1", access_token="at"))) app = ClientApplication( "unknown_family_app", authority=self.authority_url, token_cache=self.cache) @@ -246,7 +245,7 @@ def setUp(self): "scope": self.scopes, "token_endpoint": "https://{}/common/oauth2/v2.0/token".format( self.environment_in_cache), - "response": TokenCacheTestCase.build_response( + "response": build_response( uid=uid, utid=utid, access_token=self.access_token, refresh_token="some refresh token"), }) # The add(...) helper populates correct home_account_id for future searching @@ -342,7 +341,7 @@ def populate_cache(self, access_token="at", expires_in=86400, refresh_in=43200): "client_id": self.client_id, "scope": self.scopes, "token_endpoint": "{}/oauth2/v2.0/token".format(self.authority_url), - "response": TokenCacheTestCase.build_response( + "response": build_response( access_token=access_token, expires_in=expires_in, refresh_in=refresh_in, uid=self.uid, utid=self.utid, refresh_token=self.rt), @@ -424,7 +423,7 @@ def populate_cache(self, cache, access_token="at"): "client_id": self.client_id, "scope": self.scopes, "token_endpoint": "{}/oauth2/v2.0/token".format(self.authority_url), - "response": TokenCacheTestCase.build_response( + "response": build_response( access_token=access_token, uid=self.uid, utid=self.utid, refresh_token=self.rt), }) @@ -571,9 +570,9 @@ def test_get_accounts(self): "scope": scopes, "token_endpoint": "https://{}/{}/oauth2/v2.0/token".format(environment, tenant), - "response": TokenCacheTestCase.build_response( + "response": build_response( uid=uid, utid=utid, access_token="at", refresh_token="rt", - id_token=TokenCacheTestCase.build_id_token( + id_token=build_id_token( aud=client_id, sub="oid_in_" + tenant, preferred_username=username, diff --git a/tests/test_e2e.py b/tests/test_e2e.py index e5c0c129..20afaa0a 100644 --- a/tests/test_e2e.py +++ b/tests/test_e2e.py @@ -1,3 +1,16 @@ +"""If the following ENV VAR are available, many end-to-end test cases would run. +LAB_APP_CLIENT_SECRET=... +LAB_OBO_CLIENT_SECRET=... +LAB_APP_CLIENT_ID=... +LAB_OBO_PUBLIC_CLIENT_ID=... +LAB_OBO_CONFIDENTIAL_CLIENT_ID=... +""" +try: + from dotenv import load_dotenv # Use this only in local dev machine + load_dotenv() # take environment variables from .env. +except: + pass + import logging import os import json @@ -124,10 +137,15 @@ def assertCacheWorksForApp(self, result_from_wire, scope): def _test_username_password(self, authority=None, client_id=None, username=None, password=None, scope=None, client_secret=None, # Since MSAL 1.11, confidential client has ROPC too + azure_region=None, + http_client=None, **ignored): assert authority and client_id and username and password and scope self.app = msal.ClientApplication( - client_id, authority=authority, http_client=MinimalHttpClient(), + client_id, authority=authority, + http_client=http_client or MinimalHttpClient(), + azure_region=azure_region, # Regional endpoint does not support ROPC. + # Here we just use it to test a regional app won't break ROPC. client_credential=client_secret) result = self.app.acquire_token_by_username_password( username, password, scopes=scope) @@ -528,11 +546,16 @@ def _test_acquire_token_by_auth_code_flow( error_description=result.get("error_description"))) self.assertCacheWorksForUser(result, scope, username=None) - def _test_acquire_token_obo(self, config_pca, config_cca): + def _test_acquire_token_obo(self, config_pca, config_cca, + azure_region=None, # Regional endpoint does not really support OBO. + # Here we just test regional apps won't adversely break OBO + http_client=None, + ): # 1. An app obtains a token representing a user, for our mid-tier service pca = msal.PublicClientApplication( config_pca["client_id"], authority=config_pca["authority"], - http_client=MinimalHttpClient()) + azure_region=azure_region, + http_client=http_client or MinimalHttpClient()) pca_result = pca.acquire_token_by_username_password( config_pca["username"], config_pca["password"], @@ -547,7 +570,8 @@ def _test_acquire_token_obo(self, config_pca, config_cca): config_cca["client_id"], client_credential=config_cca["client_secret"], authority=config_cca["authority"], - http_client=MinimalHttpClient(), + azure_region=azure_region, + http_client=http_client or MinimalHttpClient(), # token_cache= ..., # Default token cache is all-tokens-store-in-memory. # That's fine if OBO app uses short-lived msal instance per session. # Otherwise, the OBO app need to implement a one-cache-per-user setup. @@ -690,14 +714,17 @@ def test_acquire_token_obo(self): self._test_acquire_token_obo(config_pca, config_cca) def test_acquire_token_by_client_secret(self): - # This is copied from ArlingtonCloudTestCase's same test case - try: - config = self.get_lab_user(usertype="cloud", publicClient="no") - except requests.exceptions.HTTPError: - self.skipTest("The lab does not provide confidential app for testing") - else: - config["client_secret"] = self.get_lab_user_secret("TBD") # TODO - self._test_acquire_token_by_client_secret(**config) + # Vastly different than ArlingtonCloudTestCase.test_acquire_token_by_client_secret() + _app = self.get_lab_app_object( + publicClient="no", signinAudience="AzureAdMyOrg") + self._test_acquire_token_by_client_secret( + client_id=_app["appId"], + client_secret=self.get_lab_user_secret( + _app["clientSecret"].split("/")[-1]), + authority="{}{}.onmicrosoft.com".format( + _app["authority"], _app["labName"].lower().rstrip(".com")), + scope=["https://graph.microsoft.com/.default"], + ) @unittest.skipUnless( os.getenv("LAB_OBO_CLIENT_SECRET"), @@ -762,6 +789,7 @@ def test_b2c_acquire_token_by_ropc(self): class WorldWideRegionalEndpointTestCase(LabBasedTestCase): region = "westus" + timeout = 2 # Short timeout makes this test case responsive on non-VM def test_acquire_token_for_client_should_hit_regional_endpoint(self): """This is the only grant supported by regional endpoint, for now""" @@ -782,7 +810,7 @@ def test_acquire_token_for_client_should_hit_regional_endpoint(self): status_code=400, text='{"error": "mock"}')) as mocked_method: self.app.acquire_token_for_client(scopes) mocked_method.assert_called_with( - 'https://westus.login.microsoft.com/{}/oauth2/v2.0/token'.format( + 'https://westus.r.login.microsoftonline.com/{}/oauth2/v2.0/token'.format( self.app.authority.tenant), params=ANY, data=ANY, headers=ANY) result = self.app.acquire_token_for_client( @@ -792,15 +820,6 @@ def test_acquire_token_for_client_should_hit_regional_endpoint(self): self.assertIn('access_token', result) self.assertCacheWorksForApp(result, scopes) - -class RegionalEndpointViaEnvVarTestCase(WorldWideRegionalEndpointTestCase): - - def setUp(self): - os.environ["REGION_NAME"] = "eastus" - - def tearDown(self): - del os.environ["REGION_NAME"] - @unittest.skipUnless( os.getenv("LAB_OBO_CLIENT_SECRET"), "Need LAB_OBO_CLIENT_SECRET from https://aka.ms/GetLabSecret?Secret=TodoListServiceV2-OBO") @@ -826,7 +845,11 @@ def test_cca_obo_should_bypass_regional_endpoint_therefore_still_work(self): config_pca["password"] = self.get_lab_user_secret(config_pca["lab_name"]) config_pca["scope"] = ["api://%s/read" % config_cca["client_id"]] - self._test_acquire_token_obo(config_pca, config_cca) + self._test_acquire_token_obo( + config_pca, config_cca, + azure_region=self.region, + http_client=MinimalHttpClient(timeout=self.timeout), + ) @unittest.skipUnless( os.getenv("LAB_OBO_CLIENT_SECRET"), @@ -843,7 +866,10 @@ def test_cca_ropc_should_bypass_regional_endpoint_therefore_still_work(self): config["client_id"] = os.getenv("LAB_OBO_CONFIDENTIAL_CLIENT_ID") config["scope"] = ["https://graph.microsoft.com/.default"] config["client_secret"] = os.getenv("LAB_OBO_CLIENT_SECRET") - self._test_username_password(**config) + self._test_username_password( + azure_region=self.region, + http_client=MinimalHttpClient(timeout=self.timeout), + **config) class ArlingtonCloudTestCase(LabBasedTestCase): diff --git a/tests/test_individual_cache.py b/tests/test_individual_cache.py new file mode 100644 index 00000000..38bd572d --- /dev/null +++ b/tests/test_individual_cache.py @@ -0,0 +1,93 @@ +from time import sleep +from random import random +import unittest +from msal.individual_cache import _ExpiringMapping as ExpiringMapping +from msal.individual_cache import _IndividualCache as IndividualCache + + +class TestExpiringMapping(unittest.TestCase): + def setUp(self): + self.mapping = {} + self.m = ExpiringMapping(mapping=self.mapping, capacity=2, expires_in=1) + + def test_should_disallow_accessing_reserved_keyword(self): + with self.assertRaises(ValueError): + self.m.get(ExpiringMapping._INDEX) + + def test_setitem(self): + self.assertEqual(0, len(self.m)) + self.m["thing one"] = "one" + self.assertIn(ExpiringMapping._INDEX, self.mapping, "Index created") + self.assertEqual(1, len(self.m), "It contains one item (excluding index)") + self.assertEqual("one", self.m["thing one"]) + self.assertEqual(["thing one"], list(self.m)) + + def test_set(self): + self.assertEqual(0, len(self.m)) + self.m.set("thing two", "two", 2) + self.assertIn(ExpiringMapping._INDEX, self.mapping, "Index created") + self.assertEqual(1, len(self.m), "It contains one item (excluding index)") + self.assertEqual("two", self.m["thing two"]) + self.assertEqual(["thing two"], list(self.m)) + + def test_len_should_purge(self): + self.m["thing one"] = "one" + sleep(1) + self.assertEqual(0, len(self.m)) + + def test_iter_should_purge(self): + self.m["thing one"] = "one" + sleep(1) + self.assertEqual([], list(self.m)) + + def test_get_should_purge(self): + self.m["thing one"] = "one" + sleep(1) + with self.assertRaises(KeyError): + self.m["thing one"] + + def test_various_expiring_time(self): + self.assertEqual(0, len(self.m)) + self.m["thing one"] = "one" + self.m.set("thing two", "two", 2) + self.assertEqual(2, len(self.m), "It contains 2 items") + sleep(1) + self.assertEqual(["thing two"], list(self.m), "One expires, another remains") + + def test_old_item_can_be_updated_with_new_expiry_time(self): + self.assertEqual(0, len(self.m)) + self.m["thing"] = "one" + self.m.set("thing", "two", 2) + self.assertEqual(1, len(self.m), "It contains 1 item") + self.assertEqual("two", self.m["thing"], 'Already been updated to "two"') + sleep(1) + self.assertEqual("two", self.m["thing"], "Not yet expires") + sleep(1) + self.assertEqual(0, len(self.m)) + + def test_oversized_input_should_purge_most_aging_item(self): + self.assertEqual(0, len(self.m)) + self.m["thing one"] = "one" + self.m.set("thing two", "two", 2) + self.assertEqual(2, len(self.m), "It contains 2 items") + self.m["thing three"] = "three" + self.assertEqual(2, len(self.m), "It contains 2 items") + self.assertNotIn("thing one", self.m) + + +class TestIndividualCache(unittest.TestCase): + mapping = {} + + @IndividualCache(mapping=mapping) + def foo(self, a, b, c=None, d=None): + return random() # So that we'd know whether a new response is received + + def test_memorize_a_function_call(self): + self.assertNotEqual(self.foo(1, 1), self.foo(2, 2)) + self.assertEqual( + self.foo(1, 2, c=3, d=4), + self.foo(1, 2, c=3, d=4), + "Subsequent run should obtain same result from cache") + # Note: In Python 3.7+, dict is ordered, so the following is typically True: + #self.assertNotEqual(self.foo(a=1, b=2), self.foo(b=2, a=1)) + diff --git a/tests/test_throttled_http_client.py b/tests/test_throttled_http_client.py new file mode 100644 index 00000000..75408330 --- /dev/null +++ b/tests/test_throttled_http_client.py @@ -0,0 +1,179 @@ +# Test cases for https://identitydivision.visualstudio.com/devex/_git/AuthLibrariesApiReview?version=GBdev&path=%2FService%20protection%2FIntial%20set%20of%20protection%20measures.md&_a=preview&anchor=common-test-cases +from time import sleep +from random import random +import logging +from msal.throttled_http_client import ThrottledHttpClient +from tests import unittest +from tests.http_client import MinimalResponse + + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.DEBUG) + + +class DummyHttpResponse(MinimalResponse): + def __init__(self, headers=None, **kwargs): + self.headers = {} if headers is None else headers + super(DummyHttpResponse, self).__init__(**kwargs) + + +class DummyHttpClient(object): + def __init__(self, status_code=None, response_headers=None): + self._status_code = status_code + self._response_headers = response_headers + + def _build_dummy_response(self): + return DummyHttpResponse( + status_code=self._status_code, + headers=self._response_headers, + text=random(), # So that we'd know whether a new response is received + ) + + def post(self, url, params=None, data=None, headers=None, **kwargs): + return self._build_dummy_response() + + def get(self, url, params=None, headers=None, **kwargs): + return self._build_dummy_response() + + def close(self): + raise CloseMethodCalled("Not used by MSAL, but our customers may use it") + + +class CloseMethodCalled(Exception): + pass + + +class TestHttpDecoration(unittest.TestCase): + + def test_throttled_http_client_should_not_alter_original_http_client(self): + http_cache = {} + original_http_client = DummyHttpClient() + original_get = original_http_client.get + original_post = original_http_client.post + throttled_http_client = ThrottledHttpClient(original_http_client, http_cache) + goal = """The implementation should wrap original http_client + and keep it intact, instead of monkey-patching it""" + self.assertNotEqual(throttled_http_client, original_http_client, goal) + self.assertEqual(original_post, original_http_client.post) + self.assertEqual(original_get, original_http_client.get) + + def _test_RetryAfter_N_seconds_should_keep_entry_for_N_seconds( + self, http_client, retry_after): + http_cache = {} + http_client = ThrottledHttpClient(http_client, http_cache) + resp1 = http_client.post("https://example.com") # We implemented POST only + resp2 = http_client.post("https://example.com") # We implemented POST only + logger.debug(http_cache) + self.assertEqual(resp1.text, resp2.text, "Should return a cached response") + sleep(retry_after + 1) + resp3 = http_client.post("https://example.com") # We implemented POST only + self.assertNotEqual(resp1.text, resp3.text, "Should return a new response") + + def test_429_with_RetryAfter_N_seconds_should_keep_entry_for_N_seconds(self): + retry_after = 1 + self._test_RetryAfter_N_seconds_should_keep_entry_for_N_seconds( + DummyHttpClient( + status_code=429, response_headers={"Retry-After": retry_after}), + retry_after) + + def test_5xx_with_RetryAfter_N_seconds_should_keep_entry_for_N_seconds(self): + retry_after = 1 + self._test_RetryAfter_N_seconds_should_keep_entry_for_N_seconds( + DummyHttpClient( + status_code=503, response_headers={"Retry-After": retry_after}), + retry_after) + + def test_400_with_RetryAfter_N_seconds_should_keep_entry_for_N_seconds(self): + """Retry-After is supposed to only shown in http 429/5xx, + but we choose to support Retry-After for arbitrary http response.""" + retry_after = 1 + self._test_RetryAfter_N_seconds_should_keep_entry_for_N_seconds( + DummyHttpClient( + status_code=400, response_headers={"Retry-After": retry_after}), + retry_after) + + def test_one_RetryAfter_request_should_block_a_similar_request(self): + http_cache = {} + http_client = DummyHttpClient( + status_code=429, response_headers={"Retry-After": 2}) + http_client = ThrottledHttpClient(http_client, http_cache) + resp1 = http_client.post("https://example.com", data={ + "scope": "one", "claims": "bar", "grant_type": "authorization_code"}) + resp2 = http_client.post("https://example.com", data={ + "scope": "one", "claims": "foo", "grant_type": "password"}) + logger.debug(http_cache) + self.assertEqual(resp1.text, resp2.text, "Should return a cached response") + + def test_one_RetryAfter_request_should_not_block_a_different_request(self): + http_cache = {} + http_client = DummyHttpClient( + status_code=429, response_headers={"Retry-After": 2}) + http_client = ThrottledHttpClient(http_client, http_cache) + resp1 = http_client.post("https://example.com", data={"scope": "one"}) + resp2 = http_client.post("https://example.com", data={"scope": "two"}) + logger.debug(http_cache) + self.assertNotEqual(resp1.text, resp2.text, "Should return a new response") + + def test_one_invalid_grant_should_block_a_similar_request(self): + http_cache = {} + http_client = DummyHttpClient( + status_code=400) # It covers invalid_grant and interaction_required + http_client = ThrottledHttpClient(http_client, http_cache) + resp1 = http_client.post("https://example.com", data={"claims": "foo"}) + logger.debug(http_cache) + resp1_again = http_client.post("https://example.com", data={"claims": "foo"}) + self.assertEqual(resp1.text, resp1_again.text, "Should return a cached response") + resp2 = http_client.post("https://example.com", data={"claims": "bar"}) + self.assertNotEqual(resp1.text, resp2.text, "Should return a new response") + resp2_again = http_client.post("https://example.com", data={"claims": "bar"}) + self.assertEqual(resp2.text, resp2_again.text, "Should return a cached response") + + def test_one_foci_app_recovering_from_invalid_grant_should_also_unblock_another(self): + """ + Need not test multiple FOCI app's acquire_token_silent() here. By design, + one FOCI app's successful populating token cache would result in another + FOCI app's acquire_token_silent() to hit a token without invoking http request. + """ + + def test_forcefresh_behavior(self): + """ + The implementation let token cache and http cache operate in different + layers. They do not couple with each other. + Therefore, acquire_token_silent(..., force_refresh=True) + would bypass the token cache yet technically still hit the http cache. + + But that is OK, cause the customer need no force_refresh in the first place. + After a successful AT/RT acquisition, AT/RT will be in the token cache, + and a normal acquire_token_silent(...) without force_refresh would just work. + This was discussed in https://identitydivision.visualstudio.com/DevEx/_git/AuthLibrariesApiReview/pullrequest/3618?_a=files + """ + + def test_http_get_200_should_be_cached(self): + http_cache = {} + http_client = DummyHttpClient( + status_code=200) # It covers UserRealm discovery and OIDC discovery + http_client = ThrottledHttpClient(http_client, http_cache) + resp1 = http_client.get("https://example.com?foo=bar") + resp2 = http_client.get("https://example.com?foo=bar") + logger.debug(http_cache) + self.assertEqual(resp1.text, resp2.text, "Should return a cached response") + + def test_device_flow_retry_should_not_be_cached(self): + DEVICE_AUTH_GRANT = "urn:ietf:params:oauth:grant-type:device_code" + http_cache = {} + http_client = DummyHttpClient(status_code=400) + http_client = ThrottledHttpClient(http_client, http_cache) + resp1 = http_client.get( + "https://example.com", data={"grant_type": DEVICE_AUTH_GRANT}) + resp2 = http_client.get( + "https://example.com", data={"grant_type": DEVICE_AUTH_GRANT}) + logger.debug(http_cache) + self.assertNotEqual(resp1.text, resp2.text, "Should return a new response") + + def test_throttled_http_client_should_provide_close(self): + http_cache = {} + http_client = DummyHttpClient(status_code=200) + http_client = ThrottledHttpClient(http_client, http_cache) + with self.assertRaises(CloseMethodCalled): + http_client.close() + diff --git a/tests/test_token_cache.py b/tests/test_token_cache.py index 3cce0c82..2fe486c2 100644 --- a/tests/test_token_cache.py +++ b/tests/test_token_cache.py @@ -11,52 +11,56 @@ logging.basicConfig(level=logging.DEBUG) -class TokenCacheTestCase(unittest.TestCase): +# NOTE: These helpers were once implemented as static methods in TokenCacheTestCase. +# That would cause other test files' "from ... import TokenCacheTestCase" +# to re-run all test cases in this file. +# Now we avoid that, by defining these helpers in module level. +def build_id_token( + iss="issuer", sub="subject", aud="my_client_id", exp=None, iat=None, + **claims): # AAD issues "preferred_username", ADFS issues "upn" + return "header.%s.signature" % base64.b64encode(json.dumps(dict({ + "iss": iss, + "sub": sub, + "aud": aud, + "exp": exp or (time.time() + 100), + "iat": iat or time.time(), + }, **claims)).encode()).decode('utf-8') + - @staticmethod - def build_id_token( - iss="issuer", sub="subject", aud="my_client_id", exp=None, iat=None, - **claims): # AAD issues "preferred_username", ADFS issues "upn" - return "header.%s.signature" % base64.b64encode(json.dumps(dict({ - "iss": iss, - "sub": sub, - "aud": aud, - "exp": exp or (time.time() + 100), - "iat": iat or time.time(), - }, **claims)).encode()).decode('utf-8') +def build_response( # simulate a response from AAD + uid=None, utid=None, # If present, they will form client_info + access_token=None, expires_in=3600, token_type="some type", + **kwargs # Pass-through: refresh_token, foci, id_token, error, refresh_in, ... + ): + response = {} + if uid and utid: # Mimic the AAD behavior for "client_info=1" request + response["client_info"] = base64.b64encode(json.dumps({ + "uid": uid, "utid": utid, + }).encode()).decode('utf-8') + if access_token: + response.update({ + "access_token": access_token, + "expires_in": expires_in, + "token_type": token_type, + }) + response.update(kwargs) # Pass-through key-value pairs as top-level fields + return response - @staticmethod - def build_response( # simulate a response from AAD - uid=None, utid=None, # If present, they will form client_info - access_token=None, expires_in=3600, token_type="some type", - **kwargs # Pass-through: refresh_token, foci, id_token, error, refresh_in, ... - ): - response = {} - if uid and utid: # Mimic the AAD behavior for "client_info=1" request - response["client_info"] = base64.b64encode(json.dumps({ - "uid": uid, "utid": utid, - }).encode()).decode('utf-8') - if access_token: - response.update({ - "access_token": access_token, - "expires_in": expires_in, - "token_type": token_type, - }) - response.update(kwargs) # Pass-through key-value pairs as top-level fields - return response + +class TokenCacheTestCase(unittest.TestCase): def setUp(self): self.cache = TokenCache() def testAddByAad(self): client_id = "my_client_id" - id_token = self.build_id_token( + id_token = build_id_token( oid="object1234", preferred_username="John Doe", aud=client_id) self.cache.add({ "client_id": client_id, "scope": ["s2", "s1", "s3"], # Not in particular order "token_endpoint": "https://login.example.com/contoso/v2/token", - "response": self.build_response( + "response": build_response( uid="uid", utid="utid", # client_info expires_in=3600, access_token="an access token", id_token=id_token, refresh_token="a refresh token"), @@ -125,12 +129,12 @@ def testAddByAad(self): def testAddByAdfs(self): client_id = "my_client_id" - id_token = self.build_id_token(aud=client_id, upn="JaneDoe@example.com") + id_token = build_id_token(aud=client_id, upn="JaneDoe@example.com") self.cache.add({ "client_id": client_id, "scope": ["s2", "s1", "s3"], # Not in particular order "token_endpoint": "https://fs.msidlab8.com/adfs/oauth2/token", - "response": self.build_response( + "response": build_response( uid=None, utid=None, # ADFS will provide no client_info expires_in=3600, access_token="an access token", id_token=id_token, refresh_token="a refresh token"), @@ -204,7 +208,7 @@ def test_key_id_is_also_recorded(self): "client_id": "my_client_id", "scope": ["s2", "s1", "s3"], # Not in particular order "token_endpoint": "https://login.example.com/contoso/v2/token", - "response": self.build_response( + "response": build_response( uid="uid", utid="utid", # client_info expires_in=3600, access_token="an access token", refresh_token="a refresh token"), @@ -219,7 +223,7 @@ def test_refresh_in_should_be_recorded_as_refresh_on(self): # Sounds weird. Yep "client_id": "my_client_id", "scope": ["s2", "s1", "s3"], # Not in particular order "token_endpoint": "https://login.example.com/contoso/v2/token", - "response": self.build_response( + "response": build_response( uid="uid", utid="utid", # client_info expires_in=3600, refresh_in=1800, access_token="an access token", ), #refresh_token="a refresh token"),