diff --git a/msal/application.py b/msal/application.py index 4f68fc20..d6fb131a 100644 --- a/msal/application.py +++ b/msal/application.py @@ -21,6 +21,7 @@ from .token_cache import TokenCache import msal.telemetry from .region import _detect_region +from .throttled_http_client import ThrottledHttpClient # The __init__.py will import this. Not the other way around. @@ -336,6 +337,10 @@ def __init__( a = requests.adapters.HTTPAdapter(max_retries=1) self.http_client.mount("http://", a) self.http_client.mount("https://", a) + self.http_client = ThrottledHttpClient( + self.http_client, + {} # Hard code an in-memory cache, for now + ) self.app_name = app_name self.app_version = app_version @@ -433,6 +438,7 @@ def _build_client(self, client_credential, authority, skip_regional_client=False "x-client-sku": "MSAL.Python", "x-client-ver": __version__, "x-client-os": sys.platform, "x-client-cpu": "x64" if sys.maxsize > 2 ** 32 else "x86", + "x-ms-lib-capability": "retry-after, h429", } if self.app_name: default_headers['x-app-name'] = self.app_name diff --git a/msal/individual_cache.py b/msal/individual_cache.py new file mode 100644 index 00000000..4c6fa00e --- /dev/null +++ b/msal/individual_cache.py @@ -0,0 +1,286 @@ +from functools import wraps +import time +try: + from collections.abc import MutableMapping # Python 3.3+ +except ImportError: + from collections import MutableMapping # Python 2.7+ +import heapq +from threading import Lock + + +class _ExpiringMapping(MutableMapping): + _INDEX = "_index_" + + def __init__(self, mapping=None, capacity=None, expires_in=None, lock=None, + *args, **kwargs): + """Items in this mapping can have individual shelf life, + just like food items in your refrigerator have their different shelf life + determined by each food, not by the refrigerator. + + Expired items will be automatically evicted. + The clean-up will be done at each time when adding a new item, + or when looping or counting the entire mapping. + (This is better than being done indecisively by a background thread, + which might not always happen before your accessing the mapping.) + + This implementation uses no dependency other than Python standard library. + + :param MutableMapping mapping: + A dict-like key-value mapping, which needs to support __setitem__(), + __getitem__(), __delitem__(), get(), pop(). + + The default mapping is an in-memory dict. + + You could potentially supply a file-based dict-like object, too. + This implementation deliberately avoid mapping.__iter__(), + which could be slow on a file-based mapping. + + :param int capacity: + How many items this mapping will hold. + When you attempt to add new item into a full mapping, + it will automatically delete the item that is expiring soonest. + + The default value is None, which means there is no capacity limit. + + :param int expires_in: + How many seconds an item would expire and be purged from this mapping. + Also known as time-to-live (TTL). + You can also use :func:`~set()` to provide per-item expires_in value. + + :param Lock lock: + A locking mechanism with context manager interface. + If no lock is provided, a threading.Lock will be used. + But you may want to supply a different lock, + if your customized mapping is being shared differently. + """ + super(_ExpiringMapping, self).__init__(*args, **kwargs) + self._mapping = mapping if mapping is not None else {} + self._capacity = capacity + self._expires_in = expires_in + self._lock = Lock() if lock is None else lock + + def _validate_key(self, key): + if key == self._INDEX: + raise ValueError("key {} is a reserved keyword in {}".format( + key, self.__class__.__name__)) + + def set(self, key, value, expires_in): + # This method's name was chosen so that it matches its cousin __setitem__(), + # and it also complements the counterpart get(). + # The downside is such a name shadows the built-in type set in this file, + # but you can overcome that by defining a global alias for set. + """It sets the key-value pair into this mapping, with its per-item expires_in. + + It will take O(logN) time, because it will run some maintenance. + This worse-than-constant time is acceptable, because in a cache scenario, + __setitem__() would only be called during a cache miss, + which would already incur an expensive target function call anyway. + + By the way, most other methods of this mapping still have O(1) constant time. + """ + with self._lock: + self._set(key, value, expires_in) + + def _set(self, key, value, expires_in): + # This internal implementation powers both set() and __setitem__(), + # so that they don't depend on each other. + self._validate_key(key) + sequence, timestamps = self._mapping.get(self._INDEX, ([], {})) + self._maintenance(sequence, timestamps) # O(logN) + now = int(time.time()) + expires_at = now + expires_in + entry = [expires_at, now, key] + is_new_item = key not in timestamps + is_beyond_capacity = self._capacity and len(timestamps) >= self._capacity + if is_new_item and is_beyond_capacity: + self._drop_indexed_entry(timestamps, heapq.heappushpop(sequence, entry)) + else: # Simply add new entry. The old one would become a harmless orphan. + heapq.heappush(sequence, entry) + timestamps[key] = [expires_at, now] # It overwrites existing key, if any + self._mapping[key] = value + self._mapping[self._INDEX] = sequence, timestamps + + def _maintenance(self, sequence, timestamps): # O(logN) + """It will modify input sequence and timestamps in-place""" + now = int(time.time()) + while sequence: # Clean up expired items + expires_at, created_at, key = sequence[0] + if created_at <= now < expires_at: # Then all remaining items are fresh + break + self._drop_indexed_entry(timestamps, sequence[0]) # It could error out + heapq.heappop(sequence) # Only pop it after a successful _drop_indexed_entry() + while self._capacity is not None and len(timestamps) > self._capacity: + self._drop_indexed_entry(timestamps, sequence[0]) # It could error out + heapq.heappop(sequence) # Only pop it after a successful _drop_indexed_entry() + + def _drop_indexed_entry(self, timestamps, entry): + """For an entry came from index, drop it from timestamps and self._mapping""" + expires_at, created_at, key = entry + if [expires_at, created_at] == timestamps.get(key): # So it is not an orphan + self._mapping.pop(key, None) # It could raise exception + timestamps.pop(key, None) # This would probably always succeed + + def __setitem__(self, key, value): + """Implements the __setitem__(). + + Same characteristic as :func:`~set()`, + but use class-wide expires_in which was specified by :func:`~__init__()`. + """ + if self._expires_in is None: + raise ValueError("Need a numeric value for expires_in during __init__()") + with self._lock: + self._set(key, value, self._expires_in) + + def __getitem__(self, key): # O(1) + """If the item you requested already expires, KeyError will be raised.""" + self._validate_key(key) + with self._lock: + # Skip self._maintenance(), because it would need O(logN) time + sequence, timestamps = self._mapping.get(self._INDEX, ([], {})) + expires_at, created_at = timestamps[key] # Would raise KeyError accordingly + now = int(time.time()) + if not created_at <= now < expires_at: + self._mapping.pop(key, None) + timestamps.pop(key, None) + self._mapping[self._INDEX] = sequence, timestamps + raise KeyError("{} {}".format( + key, + "expired" if now >= expires_at else "created in the future?", + )) + return self._mapping[key] # O(1) + + def __delitem__(self, key): # O(1) + """If the item you requested already expires, KeyError will be raised.""" + self._validate_key(key) + with self._lock: + # Skip self._maintenance(), because it would need O(logN) time + self._mapping.pop(key, None) # O(1) + sequence, timestamps = self._mapping.get(self._INDEX, ([], {})) + del timestamps[key] # O(1) + self._mapping[self._INDEX] = sequence, timestamps + + def __len__(self): # O(logN) + """Drop all expired items and return the remaining length""" + with self._lock: + sequence, timestamps = self._mapping.get(self._INDEX, ([], {})) + self._maintenance(sequence, timestamps) # O(logN) + self._mapping[self._INDEX] = sequence, timestamps + return len(timestamps) # Faster than iter(self._mapping) when it is on disk + + def __iter__(self): + """Drop all expired items and return an iterator of the remaining items""" + with self._lock: + sequence, timestamps = self._mapping.get(self._INDEX, ([], {})) + self._maintenance(sequence, timestamps) # O(logN) + self._mapping[self._INDEX] = sequence, timestamps + return iter(timestamps) # Faster than iter(self._mapping) when it is on disk + + +class _IndividualCache(object): + # The code structure below can decorate both function and method. + # It is inspired by https://stackoverflow.com/a/9417088 + # We may potentially switch to build upon + # https://github.com/micheles/decorator/blob/master/docs/documentation.md#statement-of-the-problem + def __init__(self, mapping=None, key_maker=None, expires_in=None): + """Constructs a cache decorator that allows item-by-item control on + how to cache the return value of the decorated function. + + :param MutableMapping mapping: + The cached items will be stored inside. + You'd want to use a ExpiringMapping + if you plan to utilize the ``expires_in`` behavior. + + If nothing is provided, an in-memory dict will be used, + but it will provide no expiry functionality. + + .. note:: + + When using this class as a decorator, + your mapping needs to be available at "compile" time, + so it would typically be a global-, module- or class-level mapping:: + + module_mapping = {} + + @IndividualCache(mapping=module_mapping, ...) + def foo(): + ... + + If you want to use a mapping available only at run-time, + you have to manually decorate your function at run-time, too:: + + def foo(): + ... + + def bar(runtime_mapping): + foo = IndividualCache(mapping=runtime_mapping...)(foo) + + :param callable key_maker: + A callable which should have signature as + ``lambda function, args, kwargs: "return a string as key"``. + + If key_maker happens to return ``None``, the cache will be bypassed, + the underlying function will be invoked directly, + and the invoke result will not be cached either. + + :param callable expires_in: + The default value is ``None``, + which means the content being cached has no per-item expiry, + and will subject to the underlying mapping's global expiry time. + + It can be an integer indicating + how many seconds the result will be cached. + In particular, if the value is 0, + it means the result expires after zero second (i.e. immediately), + therefore the result will *not* be cached. + (Mind the difference between ``expires_in=0`` and ``expires_in=None``.) + + Or it can be a callable with the signature as + ``lambda function=function, args=args, kwargs=kwargs, result=result: 123`` + to calculate the expiry on the fly. + Its return value will be interpreted in the same way as above. + """ + self._mapping = mapping if mapping is not None else {} + self._key_maker = key_maker or (lambda function, args, kwargs: ( + function, # This default implementation uses function as part of key, + # so that the cache is partitioned by function. + # However, you could have many functions to use same namespace, + # so different decorators could share same cache. + args, + tuple(kwargs.items()), # raw kwargs is not hashable + )) + self._expires_in = expires_in + + def __call__(self, function): + + @wraps(function) + def wrapper(*args, **kwargs): + key = self._key_maker(function, args, kwargs) + if key is None: # Then bypass the cache + return function(*args, **kwargs) + + now = int(time.time()) + try: + return self._mapping[key] + except KeyError: + # We choose to NOT call function(...) in this block, otherwise + # potential exception from function(...) would become a confusing + # "During handling of the above exception, another exception occurred" + pass + value = function(*args, **kwargs) + + expires_in = self._expires_in( + function=function, + args=args, + kwargs=kwargs, + result=value, + ) if callable(self._expires_in) else self._expires_in + if expires_in == 0: + return value + if expires_in is None: + self._mapping[key] = value + else: + self._mapping.set(key, value, expires_in) + return value + + return wrapper + diff --git a/msal/throttled_http_client.py b/msal/throttled_http_client.py new file mode 100644 index 00000000..24bf5137 --- /dev/null +++ b/msal/throttled_http_client.py @@ -0,0 +1,134 @@ +from threading import Lock +from hashlib import sha256 + +from .individual_cache import _IndividualCache as IndividualCache +from .individual_cache import _ExpiringMapping as ExpiringMapping + + +# https://datatracker.ietf.org/doc/html/rfc8628#section-3.4 +DEVICE_AUTH_GRANT = "urn:ietf:params:oauth:grant-type:device_code" + + +def _hash(raw): + return sha256(repr(raw).encode("utf-8")).hexdigest() + + +def _parse_http_429_5xx_retry_after(result=None, **ignored): + """Return seconds to throttle""" + assert result is not None, """ + The signature defines it with a default value None, + only because the its shape is already decided by the + IndividualCache's.__call__(). + In actual code path, the result parameter here won't be None. + """ + response = result + lowercase_headers = {k.lower(): v for k, v in getattr( + # Historically, MSAL's HttpResponse does not always have headers + response, "headers", {}).items()} + if not (response.status_code == 429 or response.status_code >= 500 + or "retry-after" in lowercase_headers): + return 0 # Quick exit + default = 60 # Recommended at the end of + # https://identitydivision.visualstudio.com/devex/_git/AuthLibrariesApiReview?version=GBdev&path=%2FService%20protection%2FIntial%20set%20of%20protection%20measures.md&_a=preview + retry_after = int(lowercase_headers.get("retry-after", default)) + try: + # AAD's retry_after uses integer format only + # https://stackoverflow.microsoft.com/questions/264931/264932 + delay_seconds = int(retry_after) + except ValueError: + delay_seconds = default + return min(3600, delay_seconds) + + +def _extract_data(kwargs, key, default=None): + data = kwargs.get("data", {}) # data is usually a dict, but occasionally a string + return data.get(key) if isinstance(data, dict) else default + + +class ThrottledHttpClient(object): + def __init__(self, http_client, http_cache): + """Throttle the given http_client by storing and retrieving data from cache. + + This wrapper exists so that our patching post() and get() would prevent + re-patching side effect when/if same http_client being reused. + """ + expiring_mapping = ExpiringMapping( # It will automatically clean up + mapping=http_cache if http_cache is not None else {}, + capacity=1024, # To prevent cache blowing up especially for CCA + lock=Lock(), # TODO: This should ideally also allow customization + ) + + _post = http_client.post # We'll patch _post, and keep original post() intact + + _post = IndividualCache( + # Internal specs requires throttling on at least token endpoint, + # here we have a generic patch for POST on all endpoints. + mapping=expiring_mapping, + key_maker=lambda func, args, kwargs: + "POST {} client_id={} scope={} hash={} 429/5xx/Retry-After".format( + args[0], # It is the url, typically containing authority and tenant + _extract_data(kwargs, "client_id"), # Per internal specs + _extract_data(kwargs, "scope"), # Per internal specs + _hash( + # The followings are all approximations of the "account" concept + # to support per-account throttling. + # TODO: We may want to disable it for confidential client, though + _extract_data(kwargs, "refresh_token", # "account" during refresh + _extract_data(kwargs, "code", # "account" of auth code grant + _extract_data(kwargs, "username")))), # "account" of ROPC + ), + expires_in=_parse_http_429_5xx_retry_after, + )(_post) + + _post = IndividualCache( # It covers the "UI required cache" + mapping=expiring_mapping, + key_maker=lambda func, args, kwargs: "POST {} hash={} 400".format( + args[0], # It is the url, typically containing authority and tenant + _hash( + # Here we use literally all parameters, even those short-lived + # parameters containing timestamps (WS-Trust or POP assertion), + # because they will automatically be cleaned up by ExpiringMapping. + # + # Furthermore, there is no need to implement + # "interactive requests would reset the cache", + # because acquire_token_silent()'s would be automatically unblocked + # due to token cache layer operates on top of http cache layer. + # + # And, acquire_token_silent(..., force_refresh=True) will NOT + # bypass http cache, because there is no real gain from that. + # We won't bother implement it, nor do we want to encourage + # acquire_token_silent(..., force_refresh=True) pattern. + str(kwargs.get("params")) + str(kwargs.get("data"))), + ), + expires_in=lambda result=None, data=None, **ignored: + 60 + if result.status_code == 400 + # Here we choose to cache exact HTTP 400 errors only (rather than 4xx) + # because they are the ones defined in OAuth2 + # (https://datatracker.ietf.org/doc/html/rfc6749#section-5.2) + # Other 4xx errors might have different requirements e.g. + # "407 Proxy auth required" would need a key including http headers. + and not( # Exclude Device Flow cause its retry is expected and regulated + isinstance(data, dict) and data.get("grant_type") == DEVICE_AUTH_GRANT + ) + and "retry-after" not in set( # Leave it to the Retry-After decorator + h.lower() for h in getattr(result, "headers", {}).keys()) + else 0, + )(_post) + + self.post = _post + + self.get = IndividualCache( # Typically those discovery GETs + mapping=expiring_mapping, + key_maker=lambda func, args, kwargs: "GET {} hash={} 2xx".format( + args[0], # It is the url, sometimes containing inline params + _hash(kwargs.get("params", "")), + ), + expires_in=lambda result=None, **ignored: + 3600*24 if 200 <= result.status_code < 300 else 0, + )(http_client.get) + + # The following 2 methods have been defined dynamically by __init__() + #def post(self, *args, **kwargs): pass + #def get(self, *args, **kwargs): pass + diff --git a/tests/test_individual_cache.py b/tests/test_individual_cache.py new file mode 100644 index 00000000..38bd572d --- /dev/null +++ b/tests/test_individual_cache.py @@ -0,0 +1,93 @@ +from time import sleep +from random import random +import unittest +from msal.individual_cache import _ExpiringMapping as ExpiringMapping +from msal.individual_cache import _IndividualCache as IndividualCache + + +class TestExpiringMapping(unittest.TestCase): + def setUp(self): + self.mapping = {} + self.m = ExpiringMapping(mapping=self.mapping, capacity=2, expires_in=1) + + def test_should_disallow_accessing_reserved_keyword(self): + with self.assertRaises(ValueError): + self.m.get(ExpiringMapping._INDEX) + + def test_setitem(self): + self.assertEqual(0, len(self.m)) + self.m["thing one"] = "one" + self.assertIn(ExpiringMapping._INDEX, self.mapping, "Index created") + self.assertEqual(1, len(self.m), "It contains one item (excluding index)") + self.assertEqual("one", self.m["thing one"]) + self.assertEqual(["thing one"], list(self.m)) + + def test_set(self): + self.assertEqual(0, len(self.m)) + self.m.set("thing two", "two", 2) + self.assertIn(ExpiringMapping._INDEX, self.mapping, "Index created") + self.assertEqual(1, len(self.m), "It contains one item (excluding index)") + self.assertEqual("two", self.m["thing two"]) + self.assertEqual(["thing two"], list(self.m)) + + def test_len_should_purge(self): + self.m["thing one"] = "one" + sleep(1) + self.assertEqual(0, len(self.m)) + + def test_iter_should_purge(self): + self.m["thing one"] = "one" + sleep(1) + self.assertEqual([], list(self.m)) + + def test_get_should_purge(self): + self.m["thing one"] = "one" + sleep(1) + with self.assertRaises(KeyError): + self.m["thing one"] + + def test_various_expiring_time(self): + self.assertEqual(0, len(self.m)) + self.m["thing one"] = "one" + self.m.set("thing two", "two", 2) + self.assertEqual(2, len(self.m), "It contains 2 items") + sleep(1) + self.assertEqual(["thing two"], list(self.m), "One expires, another remains") + + def test_old_item_can_be_updated_with_new_expiry_time(self): + self.assertEqual(0, len(self.m)) + self.m["thing"] = "one" + self.m.set("thing", "two", 2) + self.assertEqual(1, len(self.m), "It contains 1 item") + self.assertEqual("two", self.m["thing"], 'Already been updated to "two"') + sleep(1) + self.assertEqual("two", self.m["thing"], "Not yet expires") + sleep(1) + self.assertEqual(0, len(self.m)) + + def test_oversized_input_should_purge_most_aging_item(self): + self.assertEqual(0, len(self.m)) + self.m["thing one"] = "one" + self.m.set("thing two", "two", 2) + self.assertEqual(2, len(self.m), "It contains 2 items") + self.m["thing three"] = "three" + self.assertEqual(2, len(self.m), "It contains 2 items") + self.assertNotIn("thing one", self.m) + + +class TestIndividualCache(unittest.TestCase): + mapping = {} + + @IndividualCache(mapping=mapping) + def foo(self, a, b, c=None, d=None): + return random() # So that we'd know whether a new response is received + + def test_memorize_a_function_call(self): + self.assertNotEqual(self.foo(1, 1), self.foo(2, 2)) + self.assertEqual( + self.foo(1, 2, c=3, d=4), + self.foo(1, 2, c=3, d=4), + "Subsequent run should obtain same result from cache") + # Note: In Python 3.7+, dict is ordered, so the following is typically True: + #self.assertNotEqual(self.foo(a=1, b=2), self.foo(b=2, a=1)) + diff --git a/tests/test_throttled_http_client.py b/tests/test_throttled_http_client.py new file mode 100644 index 00000000..9a65efc1 --- /dev/null +++ b/tests/test_throttled_http_client.py @@ -0,0 +1,165 @@ +# Test cases for https://identitydivision.visualstudio.com/devex/_git/AuthLibrariesApiReview?version=GBdev&path=%2FService%20protection%2FIntial%20set%20of%20protection%20measures.md&_a=preview&anchor=common-test-cases +from time import sleep +from random import random +import logging +from msal.throttled_http_client import ThrottledHttpClient +from tests import unittest +from tests.http_client import MinimalResponse + + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.DEBUG) + + +class DummyHttpResponse(MinimalResponse): + def __init__(self, headers=None, **kwargs): + self.headers = {} if headers is None else headers + super(DummyHttpResponse, self).__init__(**kwargs) + + +class DummyHttpClient(object): + def __init__(self, status_code=None, response_headers=None): + self._status_code = status_code + self._response_headers = response_headers + + def _build_dummy_response(self): + return DummyHttpResponse( + status_code=self._status_code, + headers=self._response_headers, + text=random(), # So that we'd know whether a new response is received + ) + + def post(self, url, params=None, data=None, headers=None, **kwargs): + return self._build_dummy_response() + + def get(self, url, params=None, headers=None, **kwargs): + return self._build_dummy_response() + + +class TestHttpDecoration(unittest.TestCase): + + def test_throttled_http_client_should_not_alter_original_http_client(self): + http_cache = {} + original_http_client = DummyHttpClient() + original_get = original_http_client.get + original_post = original_http_client.post + throttled_http_client = ThrottledHttpClient(original_http_client, http_cache) + goal = """The implementation should wrap original http_client + and keep it intact, instead of monkey-patching it""" + self.assertNotEqual(throttled_http_client, original_http_client, goal) + self.assertEqual(original_post, original_http_client.post) + self.assertEqual(original_get, original_http_client.get) + + def _test_RetryAfter_N_seconds_should_keep_entry_for_N_seconds( + self, http_client, retry_after): + http_cache = {} + http_client = ThrottledHttpClient(http_client, http_cache) + resp1 = http_client.post("https://example.com") # We implemented POST only + resp2 = http_client.post("https://example.com") # We implemented POST only + logger.debug(http_cache) + self.assertEqual(resp1.text, resp2.text, "Should return a cached response") + sleep(retry_after + 1) + resp3 = http_client.post("https://example.com") # We implemented POST only + self.assertNotEqual(resp1.text, resp3.text, "Should return a new response") + + def test_429_with_RetryAfter_N_seconds_should_keep_entry_for_N_seconds(self): + retry_after = 1 + self._test_RetryAfter_N_seconds_should_keep_entry_for_N_seconds( + DummyHttpClient( + status_code=429, response_headers={"Retry-After": retry_after}), + retry_after) + + def test_5xx_with_RetryAfter_N_seconds_should_keep_entry_for_N_seconds(self): + retry_after = 1 + self._test_RetryAfter_N_seconds_should_keep_entry_for_N_seconds( + DummyHttpClient( + status_code=503, response_headers={"Retry-After": retry_after}), + retry_after) + + def test_400_with_RetryAfter_N_seconds_should_keep_entry_for_N_seconds(self): + """Retry-After is supposed to only shown in http 429/5xx, + but we choose to support Retry-After for arbitrary http response.""" + retry_after = 1 + self._test_RetryAfter_N_seconds_should_keep_entry_for_N_seconds( + DummyHttpClient( + status_code=400, response_headers={"Retry-After": retry_after}), + retry_after) + + def test_one_RetryAfter_request_should_block_a_similar_request(self): + http_cache = {} + http_client = DummyHttpClient( + status_code=429, response_headers={"Retry-After": 2}) + http_client = ThrottledHttpClient(http_client, http_cache) + resp1 = http_client.post("https://example.com", data={ + "scope": "one", "claims": "bar", "grant_type": "authorization_code"}) + resp2 = http_client.post("https://example.com", data={ + "scope": "one", "claims": "foo", "grant_type": "password"}) + logger.debug(http_cache) + self.assertEqual(resp1.text, resp2.text, "Should return a cached response") + + def test_one_RetryAfter_request_should_not_block_a_different_request(self): + http_cache = {} + http_client = DummyHttpClient( + status_code=429, response_headers={"Retry-After": 2}) + http_client = ThrottledHttpClient(http_client, http_cache) + resp1 = http_client.post("https://example.com", data={"scope": "one"}) + resp2 = http_client.post("https://example.com", data={"scope": "two"}) + logger.debug(http_cache) + self.assertNotEqual(resp1.text, resp2.text, "Should return a new response") + + def test_one_invalid_grant_should_block_a_similar_request(self): + http_cache = {} + http_client = DummyHttpClient( + status_code=400) # It covers invalid_grant and interaction_required + http_client = ThrottledHttpClient(http_client, http_cache) + resp1 = http_client.post("https://example.com", data={"claims": "foo"}) + logger.debug(http_cache) + resp1_again = http_client.post("https://example.com", data={"claims": "foo"}) + self.assertEqual(resp1.text, resp1_again.text, "Should return a cached response") + resp2 = http_client.post("https://example.com", data={"claims": "bar"}) + self.assertNotEqual(resp1.text, resp2.text, "Should return a new response") + resp2_again = http_client.post("https://example.com", data={"claims": "bar"}) + self.assertEqual(resp2.text, resp2_again.text, "Should return a cached response") + + def test_one_foci_app_recovering_from_invalid_grant_should_also_unblock_another(self): + """ + Need not test multiple FOCI app's acquire_token_silent() here. By design, + one FOCI app's successful populating token cache would result in another + FOCI app's acquire_token_silent() to hit a token without invoking http request. + """ + + def test_forcefresh_behavior(self): + """ + The implementation let token cache and http cache operate in different + layers. They do not couple with each other. + Therefore, acquire_token_silent(..., force_refresh=True) + would bypass the token cache yet technically still hit the http cache. + + But that is OK, cause the customer need no force_refresh in the first place. + After a successful AT/RT acquisition, AT/RT will be in the token cache, + and a normal acquire_token_silent(...) without force_refresh would just work. + This was discussed in https://identitydivision.visualstudio.com/DevEx/_git/AuthLibrariesApiReview/pullrequest/3618?_a=files + """ + + def test_http_get_200_should_be_cached(self): + http_cache = {} + http_client = DummyHttpClient( + status_code=200) # It covers UserRealm discovery and OIDC discovery + http_client = ThrottledHttpClient(http_client, http_cache) + resp1 = http_client.get("https://example.com?foo=bar") + resp2 = http_client.get("https://example.com?foo=bar") + logger.debug(http_cache) + self.assertEqual(resp1.text, resp2.text, "Should return a cached response") + + def test_device_flow_retry_should_not_be_cached(self): + DEVICE_AUTH_GRANT = "urn:ietf:params:oauth:grant-type:device_code" + http_cache = {} + http_client = DummyHttpClient(status_code=400) + http_client = ThrottledHttpClient(http_client, http_cache) + resp1 = http_client.get( + "https://example.com", data={"grant_type": DEVICE_AUTH_GRANT}) + resp2 = http_client.get( + "https://example.com", data={"grant_type": DEVICE_AUTH_GRANT}) + logger.debug(http_cache) + self.assertNotEqual(resp1.text, resp2.text, "Should return a new response") +