From b520b3ea9474a4b380d70816ad3131daa567d79a Mon Sep 17 00:00:00 2001 From: Ray Luo Date: Wed, 4 Aug 2021 11:00:16 -0700 Subject: [PATCH] Implementing CCS Routing info X-AnchorMailbox's value is case-insensitive Both auth code flow and interactive flow switch to client_info Add upn:username for ROPC per recent discussion --- msal/application.py | 43 ++++++++++++++++++++++---- tests/test_ccs.py | 73 +++++++++++++++++++++++++++++++++++++++++++++ tests/test_e2e.py | 2 +- 3 files changed, 112 insertions(+), 6 deletions(-) create mode 100644 tests/test_ccs.py diff --git a/msal/application.py b/msal/application.py index c7a3471f..d7c4c147 100644 --- a/msal/application.py +++ b/msal/application.py @@ -14,6 +14,7 @@ import requests from .oauth2cli import Client, JwtAssertionCreator +from .oauth2cli.oidc import decode_part from .authority import Authority from .mex import send_request as mex_send_request from .wstrust_request import send_request as wst_send_request @@ -111,6 +112,34 @@ def _preferred_browser(): return None +class _ClientWithCcsRoutingInfo(Client): + + def initiate_auth_code_flow(self, **kwargs): + return super(_ClientWithCcsRoutingInfo, self).initiate_auth_code_flow( + client_info=1, # To be used as CSS Routing info + **kwargs) + + def obtain_token_by_auth_code_flow( + self, auth_code_flow, auth_response, **kwargs): + # Note: the obtain_token_by_browser() is also covered by this + assert isinstance(auth_code_flow, dict) and isinstance(auth_response, dict) + headers = kwargs.pop("headers", {}) + client_info = json.loads( + decode_part(auth_response["client_info"]) + ) if auth_response.get("client_info") else {} + if "uid" in client_info and "utid" in client_info: + # Note: The value of X-AnchorMailbox is also case-insensitive + headers["X-AnchorMailbox"] = "Oid:{uid}@{utid}".format(**client_info) + return super(_ClientWithCcsRoutingInfo, self).obtain_token_by_auth_code_flow( + auth_code_flow, auth_response, headers=headers, **kwargs) + + def obtain_token_by_username_password(self, username, password, **kwargs): + headers = kwargs.pop("headers", {}) + headers["X-AnchorMailbox"] = "upn:{}".format(username) + return super(_ClientWithCcsRoutingInfo, self).obtain_token_by_username_password( + username, password, headers=headers, **kwargs) + + class ClientApplication(object): ACQUIRE_TOKEN_SILENT_ID = "84" @@ -481,7 +510,7 @@ def _build_client(self, client_credential, authority, skip_regional_client=False authority.device_authorization_endpoint or urljoin(authority.token_endpoint, "devicecode"), } - central_client = Client( + central_client = _ClientWithCcsRoutingInfo( central_configuration, self.client_id, http_client=self.http_client, @@ -506,7 +535,7 @@ def _build_client(self, client_credential, authority, skip_regional_client=False regional_authority.device_authorization_endpoint or urljoin(regional_authority.token_endpoint, "devicecode"), } - regional_client = Client( + regional_client = _ClientWithCcsRoutingInfo( regional_configuration, self.client_id, http_client=self.http_client, @@ -577,7 +606,7 @@ def initiate_auth_code_flow( 3. and then relay this dict and subsequent auth response to :func:`~acquire_token_by_auth_code_flow()`. """ - client = Client( + client = _ClientWithCcsRoutingInfo( {"authorization_endpoint": self.authority.authorization_endpoint}, self.client_id, http_client=self.http_client) @@ -654,7 +683,7 @@ def get_authorization_request_url( self.http_client ) if authority else self.authority - client = Client( + client = _ClientWithCcsRoutingInfo( {"authorization_endpoint": the_authority.authorization_endpoint}, self.client_id, http_client=self.http_client) @@ -1178,6 +1207,10 @@ def _acquire_token_silent_by_finding_specific_refresh_token( key=lambda e: int(e.get("last_modification_time", "0")), reverse=True): logger.debug("Cache attempts an RT") + headers = telemetry_context.generate_headers() + if "home_account_id" in query: # Then use it as CCS Routing info + headers["X-AnchorMailbox"] = "Oid:{}".format( # case-insensitive value + query["home_account_id"].replace(".", "@")) response = client.obtain_token_by_refresh_token( entry, rt_getter=lambda token_item: token_item["secret"], on_removing_rt=lambda rt_item: None, # Disable RT removal, @@ -1189,7 +1222,7 @@ def _acquire_token_silent_by_finding_specific_refresh_token( skip_account_creation=True, # To honor a concurrent remove_account() )), scope=scopes, - headers=telemetry_context.generate_headers(), + headers=headers, data=dict( kwargs.pop("data", {}), claims=_merge_claims_challenge_and_capabilities( diff --git a/tests/test_ccs.py b/tests/test_ccs.py new file mode 100644 index 00000000..8b801773 --- /dev/null +++ b/tests/test_ccs.py @@ -0,0 +1,73 @@ +import unittest +try: + from unittest.mock import patch, ANY +except: + from mock import patch, ANY + +from tests.http_client import MinimalResponse +from tests.test_token_cache import build_response + +import msal + + +class TestCcsRoutingInfoTestCase(unittest.TestCase): + + def test_acquire_token_by_auth_code_flow(self): + app = msal.ClientApplication("client_id") + state = "foo" + flow = app.initiate_auth_code_flow( + ["some", "scope"], login_hint="johndoe@contoso.com", state=state) + with patch.object(app.http_client, "post", return_value=MinimalResponse( + status_code=400, text='{"error": "mock"}')) as mocked_method: + app.acquire_token_by_auth_code_flow(flow, { + "state": state, + "code": "bar", + "client_info": # MSAL asks for client_info, so it would be available + "eyJ1aWQiOiJhYTkwNTk0OS1hMmI4LTRlMGEtOGFlYS1iMzJlNTNjY2RiNDEiLCJ1dGlkIjoiNzJmOTg4YmYtODZmMS00MWFmLTkxYWItMmQ3Y2QwMTFkYjQ3In0", + }) + self.assertEqual( + "Oid:aa905949-a2b8-4e0a-8aea-b32e53ccdb41@72f988bf-86f1-41af-91ab-2d7cd011db47", + mocked_method.call_args[1].get("headers", {}).get('X-AnchorMailbox'), + "CSS routing info should be derived from client_info") + + # I've manually tested acquire_token_interactive. No need to automate it, + # because it and acquire_token_by_auth_code_flow() share same code path. + + def test_acquire_token_silent(self): + uid = "foo" + utid = "bar" + client_id = "my_client_id" + scopes = ["some", "scope"] + authority_url = "https://login.microsoftonline.com/common" + token_cache = msal.TokenCache() + token_cache.add({ # Pre-populate the cache + "client_id": client_id, + "scope": scopes, + "token_endpoint": "{}/oauth2/v2.0/token".format(authority_url), + "response": build_response( + access_token="an expired AT to trigger refresh", expires_in=-99, + uid=uid, utid=utid, refresh_token="this is a RT"), + }) # The add(...) helper populates correct home_account_id for future searching + app = msal.ClientApplication( + client_id, authority=authority_url, token_cache=token_cache) + with patch.object(app.http_client, "post", return_value=MinimalResponse( + status_code=400, text='{"error": "mock"}')) as mocked_method: + account = {"home_account_id": "{}.{}".format(uid, utid)} + app.acquire_token_silent(["scope"], account) + self.assertEqual( + "Oid:{}@{}".format( # Server accepts case-insensitive value + uid, utid), # It would look like "Oid:foo@bar" + mocked_method.call_args[1].get("headers", {}).get('X-AnchorMailbox'), + "CSS routing info should be derived from home_account_id") + + def test_acquire_token_by_username_password(self): + app = msal.ClientApplication("client_id") + username = "johndoe@contoso.com" + with patch.object(app.http_client, "post", return_value=MinimalResponse( + status_code=400, text='{"error": "mock"}')) as mocked_method: + app.acquire_token_by_username_password(username, "password", ["scope"]) + self.assertEqual( + "upn:" + username, + mocked_method.call_args[1].get("headers", {}).get('X-AnchorMailbox'), + "CSS routing info should be derived from client_info") + diff --git a/tests/test_e2e.py b/tests/test_e2e.py index 20afaa0a..2defecd6 100644 --- a/tests/test_e2e.py +++ b/tests/test_e2e.py @@ -516,8 +516,8 @@ def _test_acquire_token_by_auth_code_flow( client_id, authority=authority, http_client=MinimalHttpClient()) with AuthCodeReceiver(port=port) as receiver: flow = self.app.initiate_auth_code_flow( + scope, redirect_uri="http://localhost:%d" % receiver.get_port(), - scopes=scope, ) auth_response = receiver.get_auth_response( auth_uri=flow["auth_uri"], state=flow["state"], timeout=60,