From cbe98f327af3147a903c7ccfea440b2d3a5af983 Mon Sep 17 00:00:00 2001 From: Abhidnya Patil Date: Tue, 25 Feb 2020 05:12:02 -0800 Subject: [PATCH 01/33] Initial commit --- msal/application.py | 18 ++++-- msal/authority.py | 76 ++++++++++++++++-------- msal/mex.py | 11 +++- msal/oauth2cli/http/__init__.py | 2 + msal/oauth2cli/http/http_client.py | 86 ++++++++++++++++++++++++++++ msal/oauth2cli/http/http_response.py | 16 ++++++ msal/oauth2cli/oauth2.py | 35 +++++++---- msal/wstrust_request.py | 13 ++++- 8 files changed, 209 insertions(+), 48 deletions(-) create mode 100644 msal/oauth2cli/http/__init__.py create mode 100644 msal/oauth2cli/http/http_client.py create mode 100644 msal/oauth2cli/http/http_response.py diff --git a/msal/application.py b/msal/application.py index 01f25461..cbb34fc3 100644 --- a/msal/application.py +++ b/msal/application.py @@ -16,6 +16,7 @@ from .wstrust_request import send_request as wst_send_request from .wstrust_response import * from .token_cache import TokenCache +from msal.oauth2cli.http import DefaultHttpClient # The __init__.py will import this. Not the other way around. @@ -92,7 +93,7 @@ def __init__( self, client_id, client_credential=None, authority=None, validate_authority=True, token_cache=None, - verify=True, proxies=None, timeout=None, + verify=True, proxies=None, http_client=None, timeout=None, client_claims=None, app_name=None, app_version=None): """Create an instance of application. @@ -163,12 +164,14 @@ def __init__( self.client_claims = client_claims self.verify = verify self.proxies = proxies + from msal.oauth2cli.http import DefaultHttpClient + self.http_client = http_client if http_client else DefaultHttpClient(self.verify, self.proxies) self.timeout = timeout self.app_name = app_name self.app_version = app_version self.authority = Authority( authority or "https://login.microsoftonline.com/common/", - validate_authority, verify=verify, proxies=proxies, timeout=timeout) + validate_authority, verify=verify, proxies=proxies, timeout=timeout, http_client = self.http_client) # Here the self.authority is not the same type as authority in input self.token_cache = token_cache or TokenCache() self.client = self._build_client(client_credential, self.authority) @@ -218,6 +221,7 @@ def _build_client(self, client_credential, authority): on_obtaining_tokens=self.token_cache.add, on_removing_rt=self.token_cache.remove_rt, on_updating_rt=self.token_cache.update_rt, + http_client=self.http_client, verify=self.verify, proxies=self.proxies, timeout=self.timeout) def get_authorization_request_url( @@ -367,13 +371,17 @@ def _find_msal_accounts(self, environment): def _get_authority_aliases(self, instance): if not self.authority_groups: - resp = requests.get( - "https://login.microsoftonline.com/common/discovery/instance?api-version=1.1&authorization_endpoint=https://login.microsoftonline.com/common/oauth2/authorize", + resp = self.http_client.request("GET", "https://login.microsoftonline.com/common/discovery/instance?api-version=1.1&authorization_endpoint=https://login.microsoftonline.com/common/oauth2/authorize", headers={'Accept': 'application/json'}, verify=self.verify, proxies=self.proxies, timeout=self.timeout) + + # resp = requests.get( + # "https://login.microsoftonline.com/common/discovery/instance?api-version=1.1&authorization_endpoint=https://login.microsoftonline.com/common/oauth2/authorize", + # headers={'Accept': 'application/json'}, + # verify=self.verify, proxies=self.proxies, timeout=self.timeout) resp.raise_for_status() self.authority_groups = [ - set(group['aliases']) for group in resp.json()['metadata']] + set(group['aliases']) for group in resp.content.json()['metadata']] for group in self.authority_groups: if instance in group: return [alias for alias in group if alias != instance] diff --git a/msal/authority.py b/msal/authority.py index d8221eca..4bf5741f 100644 --- a/msal/authority.py +++ b/msal/authority.py @@ -34,7 +34,7 @@ class Authority(object): _domains_without_user_realm_discovery = set([]) def __init__(self, authority_url, validate_authority=True, - verify=True, proxies=None, timeout=None, + verify=True, proxies=None, timeout=None, http_client=None ): """Creates an authority instance, and also validates it. @@ -47,6 +47,7 @@ def __init__(self, authority_url, validate_authority=True, self.verify = verify self.proxies = proxies self.timeout = timeout + self.http_client = http_client authority, self.instance, tenant = canonicalize(authority_url) parts = authority.path.split('/') is_b2c = any(self.instance.endswith("." + d) for d in WELL_KNOWN_B2C_HOSTS) or ( @@ -87,18 +88,56 @@ def user_realm_discovery(self, username, correlation_id=None, response=None): # "federation_protocol", "cloud_audience_urn", # "federation_metadata_url", "federation_active_auth_url", etc. if self.instance not in self.__class__._domains_without_user_realm_discovery: - resp = response or requests.get( - "https://{netloc}/common/userrealm/{username}?api-version=1.0".format( - netloc=self.instance, username=username), - headers={'Accept':'application/json', - 'client-request-id': correlation_id}, - verify=self.verify, proxies=self.proxies, timeout=self.timeout) + resp = response or self.http_client.request("GET", + "https://{netloc}/common/userrealm/{username}?api-version=1.0".format( + netloc=self.instance, username=username), headers={'Accept':'application/json', + 'client-request-id': correlation_id}, timeout= self.timeout) if resp.status_code != 404: resp.raise_for_status() - return resp.json() + return resp.content.json() + # resp = response or requests.get( + # "https://{netloc}/common/userrealm/{username}?api-version=1.0".format( + # netloc=self.instance, username=username), + # headers={'Accept':'application/json', + # 'client-request-id': correlation_id}, + # verify=self.verify, proxies=self.proxies, timeout=self.timeout) + # if resp.status_code != 404: + # resp.raise_for_status() + # return resp.json() self.__class__._domains_without_user_realm_discovery.add(self.instance) return {} # This can guide the caller to fall back normal ROPC flow + def instance_discovery(self, url, **kwargs): + resp = self.http_client.request("GET", 'https://{}/common/discovery/instance'.format( + WORLD_WIDE # Historically using WORLD_WIDE. Could use self.instance too + # See https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/4.0.0/src/Microsoft.Identity.Client/Instance/AadInstanceDiscovery.cs#L101-L103 + # and https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/4.0.0/src/Microsoft.Identity.Client/Instance/AadAuthority.cs#L19-L33 + ), params={'authorization_endpoint': url, 'api-version': '1.0'}, + **kwargs) + + return resp.content.json() + # return requests.get( # Note: This URL seemingly returns V1 endpoint only + # 'https://{}/common/discovery/instance'.format( + # WORLD_WIDE # Historically using WORLD_WIDE. Could use self.instance too + # # See https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/4.0.0/src/Microsoft.Identity.Client/Instance/AadInstanceDiscovery.cs#L101-L103 + # # and https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/4.0.0/src/Microsoft.Identity.Client/Instance/AadAuthority.cs#L19-L33 + # ), + # params={'authorization_endpoint': url, 'api-version': '1.0'}, + # **kwargs).json() + + def tenant_discovery(self, tenant_discovery_endpoint, **kwargs): + # Returns Openid Configuration + resp = self.http_client.request("GET", tenant_discovery_endpoint, + **kwargs) + payload = resp.content.json() + + # resp = requests.get(tenant_discovery_endpoint, **kwargs) + # payload = resp.json() + if 'authorization_endpoint' in payload and 'token_endpoint' in payload: + return payload + raise MsalServiceError(status_code=resp.status_code, **payload) + + def canonicalize(authority_url): # Returns (url_parsed_result, hostname_in_lowercase, tenant) @@ -113,21 +152,8 @@ def canonicalize(authority_url): % authority_url) return authority, authority.hostname, parts[1] -def instance_discovery(url, **kwargs): - return requests.get( # Note: This URL seemingly returns V1 endpoint only - 'https://{}/common/discovery/instance'.format( - WORLD_WIDE # Historically using WORLD_WIDE. Could use self.instance too - # See https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/4.0.0/src/Microsoft.Identity.Client/Instance/AadInstanceDiscovery.cs#L101-L103 - # and https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/4.0.0/src/Microsoft.Identity.Client/Instance/AadAuthority.cs#L19-L33 - ), - params={'authorization_endpoint': url, 'api-version': '1.0'}, - **kwargs).json() - -def tenant_discovery(tenant_discovery_endpoint, **kwargs): - # Returns Openid Configuration - resp = requests.get(tenant_discovery_endpoint, **kwargs) - payload = resp.json() - if 'authorization_endpoint' in payload and 'token_endpoint' in payload: - return payload - raise MsalServiceError(status_code=resp.status_code, **payload) + + + + diff --git a/msal/mex.py b/msal/mex.py index caf5e3ed..c10f5617 100644 --- a/msal/mex.py +++ b/msal/mex.py @@ -34,6 +34,7 @@ except ImportError: from xml.etree import ElementTree as ET +from .oauth2cli.http import DefaultHttpClient import requests @@ -42,9 +43,13 @@ def _xpath_of_root(route_to_leaf): return '/'.join(route_to_leaf + ['..'] * (len(route_to_leaf)-1)) def send_request(mex_endpoint, **kwargs): - mex_document = requests.get( - mex_endpoint, headers={'Content-Type': 'application/soap+xml'}, - **kwargs).text + http_client = DefaultHttpClient() + resp = http_client.request("GET", mex_endpoint, headers={'Content-Type': 'application/soap+xml'}, + **kwargs) + mex_document = resp.content.text + # mex_document = requests.get( + # mex_endpoint, headers={'Content-Type': 'application/soap+xml'}, + # **kwargs).text return Mex(mex_document).get_wstrust_username_password_endpoint() diff --git a/msal/oauth2cli/http/__init__.py b/msal/oauth2cli/http/__init__.py new file mode 100644 index 00000000..8624042a --- /dev/null +++ b/msal/oauth2cli/http/__init__.py @@ -0,0 +1,2 @@ +from .http_client import DefaultHttpClient +from .http_response import Response \ No newline at end of file diff --git a/msal/oauth2cli/http/http_client.py b/msal/oauth2cli/http/http_client.py new file mode 100644 index 00000000..fc0f52d4 --- /dev/null +++ b/msal/oauth2cli/http/http_client.py @@ -0,0 +1,86 @@ +from requests import Request, Session +from .http_response import Response + +import requests + +class HttpClient(object): + """ + An abstract class representing an HTTP client. + """ + def request(self, method, url, params=None, data=None, headers=None, auth=None, + timeout=None, allow_redirects=False): + """ + Make an HTTP request. + """ + +class DefaultHttpClient(HttpClient): + """ + General purpose HTTP Client for interacting with the Twilio API + """ + def __init__(self, default_headers=None, verify=None, proxy=None): + """ + Constructor for the TwilioHttpClient + + :param bool pool_connections + :param request_hooks + :param int timeout: Timeout for the requests. + Timeout should never be zero (0) or less. + :param logger + :param dict proxy: Http proxy for the requests session + """ + self.session = Session() + self.session.headers.update(default_headers or {}) + self.session.verify = verify + self.session.proxy = proxy + + + def request(self, method, url, params=None, data=None, headers=None, auth=None, timeout=None, + allow_redirects=False, **kwargs): + """ + Make an HTTP Request with parameters provided. + + :param str method: The HTTP method to use + :param str url: The URL to request + :param dict params: Query parameters to append to the URL + :param dict data: Parameters to go in the body of the HTTP request + :param dict headers: HTTP Headers to send with the request + :param tuple auth: Basic Auth arguments + :param float timeout: Socket/Read timeout for the request + :param boolean allow_redirects: Whether or not to allow redirects + See the requests documentation for explanation of all these parameters + + :return: An http response + :rtype: A :class:`Response ` object + """ + # if timeout is not None and timeout <= 0: + # raise ValueError(timeout) + # + # kwargs = { + # 'method': method.upper(), + # 'url': url, + # 'params': params, + # 'data': data, + # 'headers': headers, + # 'auth': auth, + # 'hooks': self.request_hooks + # } + # + # if params: + # self.logger.info('{method} Request: {url}?{query}'.format(query=urlencode(params), **kwargs)) + # self.logger.info('PARAMS: {params}'.format(**kwargs)) + # else: + # self.logger.info('{method} Request: {url}'.format(**kwargs)) + # if data: + # self.logger.info('PAYLOAD: {data}'.format(**kwargs)) + # + # self.last_response = None + session = self.session or Session() + if method == "POST": + response = session.post(headers=headers, params=params, data=data, auth=auth, + timeout=timeout, **kwargs) + elif method == "GET": + response = requests.get(url= url, headers=headers, params=params, timeout=timeout) + + self.last_response = Response(int(response.status_code), response.text) + + return self.last_response \ No newline at end of file diff --git a/msal/oauth2cli/http/http_response.py b/msal/oauth2cli/http/http_response.py new file mode 100644 index 00000000..12867416 --- /dev/null +++ b/msal/oauth2cli/http/http_response.py @@ -0,0 +1,16 @@ +class Response(object): + """ + + """ + def __init__(self, status_code, text): + self.content = text + self.cached = False + self.status_code = status_code + self.ok = self.status_code < 400 + + @property + def text(self): + return self.content + + def __repr__(self): + return 'HTTP {} {}'.format(self.status_code, self.content) \ No newline at end of file diff --git a/msal/oauth2cli/oauth2.py b/msal/oauth2cli/oauth2.py index 9a947390..188d3728 100644 --- a/msal/oauth2cli/oauth2.py +++ b/msal/oauth2cli/oauth2.py @@ -12,7 +12,8 @@ import base64 import sys -import requests +from .http import DefaultHttpClient +from .http import Response string_types = (str,) if sys.version_info[0] >= 3 else (basestring, ) @@ -40,6 +41,7 @@ def __init__( client_assertion_type=None, # type: Optional[str] default_headers=None, # type: Optional[dict] default_body=None, # type: Optional[dict] + http_client=None, verify=True, # type: Union[str, True, False, None] proxies=None, # type: Optional[dict] timeout=None, # type: Union[tuple, float, None] @@ -85,10 +87,10 @@ def __init__( if client_assertion_type is not None: self.default_body["client_assertion_type"] = client_assertion_type self.logger = logging.getLogger(__name__) - self.session = s = requests.Session() - s.headers.update(default_headers or {}) - s.verify = verify - s.proxies = proxies or {} + if not http_client: + self.http_client = DefaultHttpClient(verify, proxies or {}) + else: + self.http_client = http_client self.timeout = timeout def _build_auth_request_params(self, response_type, **kwargs): @@ -148,18 +150,22 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749 raise ValueError("token_endpoint not found in configuration") _headers = {'Accept': 'application/json'} _headers.update(headers or {}) - resp = (post or self.session.post)( - self.configuration["token_endpoint"], + resp = (post or self.http_client.request)("POST", self.configuration["token_endpoint"], headers=_headers, params=params, data=_data, auth=auth, timeout=timeout or self.timeout, **kwargs) + # resp = (post or self.session.post)( + # self.configuration["token_endpoint"], + # headers=_headers, params=params, data=_data, auth=auth, + # timeout=timeout or self.timeout, + # **kwargs) if resp.status_code >= 500: resp.raise_for_status() # TODO: Will probably retry here try: # The spec (https://tools.ietf.org/html/rfc6749#section-5.2) says # even an error response will be a valid json structure, # so we simply return it here, without needing to invent an exception. - return resp.json() + return resp.content.json() except ValueError: self.logger.exception( "Token response is not in json format: %s", resp.text) @@ -218,10 +224,15 @@ def initiate_device_flow(self, scope=None, timeout=None, **kwargs): DAE = "device_authorization_endpoint" if not self.configuration.get(DAE): raise ValueError("You need to provide device authorization endpoint") - flow = self.session.post(self.configuration[DAE], - data={"client_id": self.client_id, "scope": self._stringify(scope or [])}, - timeout=timeout or self.timeout, - **kwargs).json() + resp = self.http_client.request("POST", self.configuration[DAE], + data={"client_id": self.client_id, "scope": self._stringify(scope or [])}, + timeout=timeout or self.timeout, + **kwargs) + flow = resp.content.json() + # flow = self.session.post(self.configuration[DAE], + # data={"client_id": self.client_id, "scope": self._stringify(scope or [])}, + # timeout=timeout or self.timeout, + # **kwargs).json() flow["interval"] = int(flow.get("interval", 5)) # Some IdP returns string flow["expires_in"] = int(flow.get("expires_in", 1800)) flow["expires_at"] = time.time() + flow["expires_in"] # We invent this diff --git a/msal/wstrust_request.py b/msal/wstrust_request.py index 84c03848..5b7f95f3 100644 --- a/msal/wstrust_request.py +++ b/msal/wstrust_request.py @@ -33,6 +33,7 @@ from .mex import Mex from .wstrust_response import parse_response +from .oauth2cli.http import DefaultHttpClient logger = logging.getLogger(__name__) @@ -51,15 +52,21 @@ def send_request( "Unsupported soap action: %s" % soap_action) data = _build_rst( username, password, cloud_audience_urn, endpoint_address, soap_action) - resp = requests.post(endpoint_address, data=data, headers={ + http_client = DefaultHttpClient() + resp = http_client.request("POST", endpoint_address, + headers={ 'Content-type':'application/soap+xml; charset=utf-8', 'SOAPAction': soap_action, }, **kwargs) + # resp = requests.post(endpoint_address, data=data, headers={ + # 'Content-type':'application/soap+xml; charset=utf-8', + # 'SOAPAction': soap_action, + # }, **kwargs) if resp.status_code >= 400: - logger.debug("Unsuccessful WsTrust request receives: %s", resp.text) + logger.debug("Unsuccessful WsTrust request receives: %s", resp.content.text) # It turns out ADFS uses 5xx status code even with client-side incorrect password error # resp.raise_for_status() - return parse_response(resp.text) + return parse_response(resp.content.text) def escape_password(password): return (password.replace('&', '&').replace('"', '"') From e64332fca25bd16f39c65700700f134fb09ff7e9 Mon Sep 17 00:00:00 2001 From: Abhidnya Patil Date: Wed, 4 Mar 2020 23:41:44 -0800 Subject: [PATCH 02/33] Iteration 1 --- msal/application.py | 11 ++------ msal/authority.py | 15 ++-------- msal/oauth2cli/http/http_client.py | 41 ++++++---------------------- msal/oauth2cli/http/http_response.py | 3 -- msal/oauth2cli/oauth2.py | 4 +-- 5 files changed, 14 insertions(+), 60 deletions(-) diff --git a/msal/application.py b/msal/application.py index cbb34fc3..3beede0e 100644 --- a/msal/application.py +++ b/msal/application.py @@ -164,14 +164,13 @@ def __init__( self.client_claims = client_claims self.verify = verify self.proxies = proxies - from msal.oauth2cli.http import DefaultHttpClient - self.http_client = http_client if http_client else DefaultHttpClient(self.verify, self.proxies) + self.http_client = http_client or DefaultHttpClient(verify=self.verify, proxy=self.proxies) self.timeout = timeout self.app_name = app_name self.app_version = app_version self.authority = Authority( authority or "https://login.microsoftonline.com/common/", - validate_authority, verify=verify, proxies=proxies, timeout=timeout, http_client = self.http_client) + validate_authority, verify=verify, proxies=proxies, timeout=timeout, http_client=self.http_client) # Here the self.authority is not the same type as authority in input self.token_cache = token_cache or TokenCache() self.client = self._build_client(client_credential, self.authority) @@ -374,12 +373,6 @@ def _get_authority_aliases(self, instance): resp = self.http_client.request("GET", "https://login.microsoftonline.com/common/discovery/instance?api-version=1.1&authorization_endpoint=https://login.microsoftonline.com/common/oauth2/authorize", headers={'Accept': 'application/json'}, verify=self.verify, proxies=self.proxies, timeout=self.timeout) - - # resp = requests.get( - # "https://login.microsoftonline.com/common/discovery/instance?api-version=1.1&authorization_endpoint=https://login.microsoftonline.com/common/oauth2/authorize", - # headers={'Accept': 'application/json'}, - # verify=self.verify, proxies=self.proxies, timeout=self.timeout) - resp.raise_for_status() self.authority_groups = [ set(group['aliases']) for group in resp.content.json()['metadata']] for group in self.authority_groups: diff --git a/msal/authority.py b/msal/authority.py index 4bf5741f..48b16397 100644 --- a/msal/authority.py +++ b/msal/authority.py @@ -4,7 +4,6 @@ from urlparse import urlparse import logging -import requests from .exceptions import MsalServiceError @@ -54,7 +53,7 @@ def __init__(self, authority_url, validate_authority=True, len(parts) == 3 and parts[2].lower().startswith("b2c_")) if (tenant != "adfs" and (not is_b2c) and validate_authority and self.instance not in WELL_KNOWN_AUTHORITY_HOSTS): - payload = instance_discovery( + payload = self.instance_discovery( "https://{}{}/oauth2/v2.0/authorize".format( self.instance, authority.path), verify=verify, proxies=proxies, timeout=timeout) @@ -74,7 +73,7 @@ def __init__(self, authority_url, validate_authority=True, authority.path, # In B2C scenario, it is "/tenant/policy" "" if tenant == "adfs" else "/v2.0" # the AAD v2 endpoint )) - openid_config = tenant_discovery( + openid_config = self.tenant_discovery( tenant_discovery_endpoint, verify=verify, proxies=proxies, timeout=timeout) logger.debug("openid_config = %s", openid_config) @@ -93,17 +92,7 @@ def user_realm_discovery(self, username, correlation_id=None, response=None): netloc=self.instance, username=username), headers={'Accept':'application/json', 'client-request-id': correlation_id}, timeout= self.timeout) if resp.status_code != 404: - resp.raise_for_status() return resp.content.json() - # resp = response or requests.get( - # "https://{netloc}/common/userrealm/{username}?api-version=1.0".format( - # netloc=self.instance, username=username), - # headers={'Accept':'application/json', - # 'client-request-id': correlation_id}, - # verify=self.verify, proxies=self.proxies, timeout=self.timeout) - # if resp.status_code != 404: - # resp.raise_for_status() - # return resp.json() self.__class__._domains_without_user_realm_discovery.add(self.instance) return {} # This can guide the caller to fall back normal ROPC flow diff --git a/msal/oauth2cli/http/http_client.py b/msal/oauth2cli/http/http_client.py index fc0f52d4..accceb3f 100644 --- a/msal/oauth2cli/http/http_client.py +++ b/msal/oauth2cli/http/http_client.py @@ -17,7 +17,7 @@ class DefaultHttpClient(HttpClient): """ General purpose HTTP Client for interacting with the Twilio API """ - def __init__(self, default_headers=None, verify=None, proxy=None): + def __init__(self, default_headers={}, verify=None, proxy=None): """ Constructor for the TwilioHttpClient @@ -33,7 +33,6 @@ def __init__(self, default_headers=None, verify=None, proxy=None): self.session.verify = verify self.session.proxy = proxy - def request(self, method, url, params=None, data=None, headers=None, auth=None, timeout=None, allow_redirects=False, **kwargs): """ @@ -50,37 +49,13 @@ def request(self, method, url, params=None, data=None, headers=None, auth=None, See the requests documentation for explanation of all these parameters :return: An http response - :rtype: A :class:`Response ` object + :rtype: A :class:`Response ` object """ - # if timeout is not None and timeout <= 0: - # raise ValueError(timeout) - # - # kwargs = { - # 'method': method.upper(), - # 'url': url, - # 'params': params, - # 'data': data, - # 'headers': headers, - # 'auth': auth, - # 'hooks': self.request_hooks - # } - # - # if params: - # self.logger.info('{method} Request: {url}?{query}'.format(query=urlencode(params), **kwargs)) - # self.logger.info('PARAMS: {params}'.format(**kwargs)) - # else: - # self.logger.info('{method} Request: {url}'.format(**kwargs)) - # if data: - # self.logger.info('PAYLOAD: {data}'.format(**kwargs)) - # - # self.last_response = None - session = self.session or Session() if method == "POST": - response = session.post(headers=headers, params=params, data=data, auth=auth, - timeout=timeout, **kwargs) + response = self.session.post(url=url, headers=headers, params=params, data=data, auth=auth, + timeout=timeout, **kwargs) elif method == "GET": - response = requests.get(url= url, headers=headers, params=params, timeout=timeout) - - self.last_response = Response(int(response.status_code), response.text) - - return self.last_response \ No newline at end of file + response = self.session.get(url=url, headers=headers, params=params, timeout=timeout, data=data, auth=auth) + response.raise_for_status() + response = Response(int(response.status_code), response) + return response diff --git a/msal/oauth2cli/http/http_response.py b/msal/oauth2cli/http/http_response.py index 12867416..41770475 100644 --- a/msal/oauth2cli/http/http_response.py +++ b/msal/oauth2cli/http/http_response.py @@ -11,6 +11,3 @@ def __init__(self, status_code, text): @property def text(self): return self.content - - def __repr__(self): - return 'HTTP {} {}'.format(self.status_code, self.content) \ No newline at end of file diff --git a/msal/oauth2cli/oauth2.py b/msal/oauth2cli/oauth2.py index 188d3728..e3f6bfa5 100644 --- a/msal/oauth2cli/oauth2.py +++ b/msal/oauth2cli/oauth2.py @@ -88,7 +88,7 @@ def __init__( self.default_body["client_assertion_type"] = client_assertion_type self.logger = logging.getLogger(__name__) if not http_client: - self.http_client = DefaultHttpClient(verify, proxies or {}) + self.http_client = DefaultHttpClient(verify=verify, proxy=proxies or {}, default_headers= {}) else: self.http_client = http_client self.timeout = timeout @@ -150,7 +150,7 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749 raise ValueError("token_endpoint not found in configuration") _headers = {'Accept': 'application/json'} _headers.update(headers or {}) - resp = (post or self.http_client.request)("POST", self.configuration["token_endpoint"], + resp = self.http_client.request("POST", self.configuration["token_endpoint"], headers=_headers, params=params, data=_data, auth=auth, timeout=timeout or self.timeout, **kwargs) From aed0e8d9dd53e24903fcdc4ecca08592984771ca Mon Sep 17 00:00:00 2001 From: Abhidnya Patil Date: Thu, 5 Mar 2020 17:28:41 -0800 Subject: [PATCH 03/33] Iteration 2 rectifying tests --- msal/application.py | 4 +- msal/authority.py | 9 ++-- msal/oauth2cli/http/http_client.py | 17 ++++++- msal/oauth2cli/http/http_request.py | 76 +++++++++++++++++++++++++++++ msal/oauth2cli/oauth2.py | 19 +------- tests/test_application.py | 42 ++++++++-------- 6 files changed, 122 insertions(+), 45 deletions(-) create mode 100644 msal/oauth2cli/http/http_request.py diff --git a/msal/application.py b/msal/application.py index 24641e42..ff0d41c0 100644 --- a/msal/application.py +++ b/msal/application.py @@ -374,7 +374,7 @@ def _get_authority_aliases(self, instance): headers={'Accept': 'application/json'}, verify=self.verify, proxies=self.proxies, timeout=self.timeout) self.authority_groups = [ - set(group['aliases']) for group in resp.content.json()['metadata']] + set(group['aliases']) for group in resp.content['metadata']] for group in self.authority_groups: if instance in group: return [alias for alias in group if alias != instance] @@ -506,7 +506,7 @@ def acquire_token_silent_with_error( the_authority = Authority( "https://" + alias + "/" + self.authority.tenant, validate_authority=False, - verify=self.verify, proxies=self.proxies, timeout=self.timeout) + verify=self.verify, proxies=self.proxies, timeout=self.timeout, http_client=self.http_client) result = self._acquire_token_silent_from_cache_and_possibly_refresh_it( scopes, account, the_authority, force_refresh=force_refresh, correlation_id=correlation_id, diff --git a/msal/authority.py b/msal/authority.py index 48b16397..eb7de754 100644 --- a/msal/authority.py +++ b/msal/authority.py @@ -1,10 +1,11 @@ +from msal.oauth2cli.http import DefaultHttpClient + try: from urllib.parse import urlparse except ImportError: # Fall back to Python 2 from urlparse import urlparse import logging - from .exceptions import MsalServiceError @@ -46,7 +47,7 @@ def __init__(self, authority_url, validate_authority=True, self.verify = verify self.proxies = proxies self.timeout = timeout - self.http_client = http_client + self.http_client = http_client or DefaultHttpClient(verify=self.verify, proxy=self.proxies) authority, self.instance, tenant = canonicalize(authority_url) parts = authority.path.split('/') is_b2c = any(self.instance.endswith("." + d) for d in WELL_KNOWN_B2C_HOSTS) or ( @@ -92,7 +93,7 @@ def user_realm_discovery(self, username, correlation_id=None, response=None): netloc=self.instance, username=username), headers={'Accept':'application/json', 'client-request-id': correlation_id}, timeout= self.timeout) if resp.status_code != 404: - return resp.content.json() + return resp.content self.__class__._domains_without_user_realm_discovery.add(self.instance) return {} # This can guide the caller to fall back normal ROPC flow @@ -118,7 +119,7 @@ def tenant_discovery(self, tenant_discovery_endpoint, **kwargs): # Returns Openid Configuration resp = self.http_client.request("GET", tenant_discovery_endpoint, **kwargs) - payload = resp.content.json() + payload = resp.content # resp = requests.get(tenant_discovery_endpoint, **kwargs) # payload = resp.json() diff --git a/msal/oauth2cli/http/http_client.py b/msal/oauth2cli/http/http_client.py index accceb3f..bff90625 100644 --- a/msal/oauth2cli/http/http_client.py +++ b/msal/oauth2cli/http/http_client.py @@ -51,11 +51,24 @@ def request(self, method, url, params=None, data=None, headers=None, auth=None, :return: An http response :rtype: A :class:`Response ` object """ + content = None if method == "POST": response = self.session.post(url=url, headers=headers, params=params, data=data, auth=auth, timeout=timeout, **kwargs) + if response.status_code >=500: + response.raise_for_status() + try: + # The spec (https://tools.ietf.org/html/rfc6749#section-5.2) says + # even an error response will be a valid json structure, + # so we simply return it here, without needing to invent an exception. + content = response.json() + except ValueError: + self.logger.exception( + "Token response is not in json format: %s", response.text) + raise elif method == "GET": response = self.session.get(url=url, headers=headers, params=params, timeout=timeout, data=data, auth=auth) - response.raise_for_status() - response = Response(int(response.status_code), response) + response.raise_for_status() + content = response.json() + response = Response(int(response.status_code), content) return response diff --git a/msal/oauth2cli/http/http_request.py b/msal/oauth2cli/http/http_request.py new file mode 100644 index 00000000..600de4ce --- /dev/null +++ b/msal/oauth2cli/http/http_request.py @@ -0,0 +1,76 @@ + + +class Request(object): + """ + An HTTP request. + """ + ANY = '*' + + def __init__(self, + method=ANY, + url=ANY, + auth=ANY, + params=ANY, + data=ANY, + headers=ANY, + **kwargs): + self.method = method.upper() + self.url = url + self.auth = auth + self.params = params + self.data = data + self.headers = headers + + @classmethod + def attribute_equal(cls, lhs, rhs): + if lhs == cls.ANY or rhs == cls.ANY: + # ANY matches everything + return True + + lhs = lhs or None + rhs = rhs or None + + return lhs == rhs + + def __eq__(self, other): + if not isinstance(other, Request): + return False + + return self.attribute_equal(self.method, other.method) and \ + self.attribute_equal(self.url, other.url) and \ + self.attribute_equal(self.auth, other.auth) and \ + self.attribute_equal(self.params, other.params) and \ + self.attribute_equal(self.data, other.data) and \ + self.attribute_equal(self.headers, other.headers) + + def __str__(self): + auth = '' + if self.auth and self.auth != self.ANY: + auth = '{} '.format(self.auth) + + params = '' + if self.params and self.params != self.ANY: + params = '?{}'.format(urlencode(self.params, doseq=True)) + + data = '' + if self.data and self.data != self.ANY: + if self.method == 'GET': + data = '\n -G' + data += '\n{}'.format('\n'.join(' -d "{}={}"'.format(k, v) for k, v in self.data.items())) + + headers = '' + if self.headers and self.headers != self.ANY: + headers = '\n{}'.format('\n'.join(' -H "{}: {}"'.format(k, v) + for k, v in self.headers.items())) + + return '{auth}{method} {url}{params}{data}{headers}'.format( + auth=auth, + method=self.method, + url=self.url, + params=params, + data=data, + headers=headers, + ) + + def __repr__(self): + return str(self) \ No newline at end of file diff --git a/msal/oauth2cli/oauth2.py b/msal/oauth2cli/oauth2.py index e3f6bfa5..f54a1cc7 100644 --- a/msal/oauth2cli/oauth2.py +++ b/msal/oauth2cli/oauth2.py @@ -150,26 +150,11 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749 raise ValueError("token_endpoint not found in configuration") _headers = {'Accept': 'application/json'} _headers.update(headers or {}) - resp = self.http_client.request("POST", self.configuration["token_endpoint"], + resp = (post or self.http_client.request)("POST", self.configuration["token_endpoint"], headers=_headers, params=params, data=_data, auth=auth, timeout=timeout or self.timeout, **kwargs) - # resp = (post or self.session.post)( - # self.configuration["token_endpoint"], - # headers=_headers, params=params, data=_data, auth=auth, - # timeout=timeout or self.timeout, - # **kwargs) - if resp.status_code >= 500: - resp.raise_for_status() # TODO: Will probably retry here - try: - # The spec (https://tools.ietf.org/html/rfc6749#section-5.2) says - # even an error response will be a valid json structure, - # so we simply return it here, without needing to invent an exception. - return resp.content.json() - except ValueError: - self.logger.exception( - "Token response is not in json format: %s", resp.text) - raise + return resp.content def obtain_token_by_refresh_token(self, refresh_token, scope=None, **kwargs): # type: (str, Union[str, list, set, tuple]) -> dict diff --git a/tests/test_application.py b/tests/test_application.py index 4d7c2881..452acda8 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -4,6 +4,8 @@ import json import logging +from msal.oauth2cli.http import Response + try: from unittest.mock import * # Python 3 except: @@ -49,8 +51,9 @@ def test_extract_multiple_tag_enclosed_certs(self): class TestClientApplicationAcquireTokenSilentErrorBehaviors(unittest.TestCase): def setUp(self): + self.http_client = DefaultHttpClient() self.authority_url = "https://login.microsoftonline.com/common" - self.authority = msal.authority.Authority(self.authority_url) + self.authority = msal.authority.Authority(self.authority_url, http_client= self.http_client) self.scopes = ["s1", "s2"] self.uid = "my_uid" self.utid = "my_utid" @@ -67,7 +70,7 @@ def setUp(self): uid=self.uid, utid=self.utid, refresh_token=self.rt), }) # The add(...) helper populates correct home_account_id for future searching self.app = ClientApplication( - self.client_id, authority=self.authority_url, token_cache=self.cache) + self.client_id, authority=self.authority_url, token_cache=self.cache, http_client=self.http_client) def test_cache_empty_will_be_returned_as_None(self): self.assertEqual( @@ -77,30 +80,30 @@ def test_cache_empty_will_be_returned_as_None(self): def test_acquire_token_silent_will_suppress_error(self): error_response = {"error": "invalid_grant", "suberror": "xyz"} - def tester(url, **kwargs): - return Mock(status_code=400, json=Mock(return_value=error_response)) + def tester(method, url, **kwargs): + return Response(400, error_response) self.assertEqual(None, self.app.acquire_token_silent( self.scopes, self.account, post=tester)) def test_acquire_token_silent_with_error_will_return_error(self): error_response = {"error": "invalid_grant", "error_description": "xyz"} - def tester(url, **kwargs): - return Mock(status_code=400, json=Mock(return_value=error_response)) + def tester(method, url, **kwargs): + return Response(400, error_response) self.assertEqual(error_response, self.app.acquire_token_silent_with_error( self.scopes, self.account, post=tester)) def test_atswe_will_map_some_suberror_to_classification_as_is(self): error_response = {"error": "invalid_grant", "suberror": "basic_action"} - def tester(url, **kwargs): - return Mock(status_code=400, json=Mock(return_value=error_response)) + def tester(method, url, **kwargs): + return Response(400, error_response) result = self.app.acquire_token_silent_with_error( self.scopes, self.account, post=tester) self.assertEqual("basic_action", result.get("classification")) def test_atswe_will_map_some_suberror_to_classification_to_empty_string(self): error_response = {"error": "invalid_grant", "suberror": "client_mismatch"} - def tester(url, **kwargs): - return Mock(status_code=400, json=Mock(return_value=error_response)) + def tester(method, url, **kwargs): + return Response(400, error_response) result = self.app.acquire_token_silent_with_error( self.scopes, self.account, post=tester) self.assertEqual("", result.get("classification")) @@ -131,11 +134,11 @@ def test_unknown_orphan_app_will_attempt_frt_and_not_remove_it(self): app = ClientApplication( "unknown_orphan", authority=self.authority_url, token_cache=self.cache) logger.debug("%s.cache = %s", self.id(), self.cache.serialize()) - def tester(url, data=None, **kwargs): + def tester(method, url, data=None, **kwargs): self.assertEqual(self.frt, data.get("refresh_token"), "Should attempt the FRT") - return Mock(status_code=400, json=Mock(return_value={ + return Response(400, { "error": "invalid_grant", - "error_description": "Was issued to another client"})) + "error_description": "Was issued to another client"}) app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( self.authority, self.scopes, self.account, post=tester) self.assertNotEqual([], app.token_cache.find( @@ -154,20 +157,19 @@ def test_known_orphan_app_will_skip_frt_and_only_use_its_own_rt(self): uid=self.uid, utid=self.utid, refresh_token=rt), }) logger.debug("%s.cache = %s", self.id(), self.cache.serialize()) - def tester(url, data=None, **kwargs): + def tester(method, url, data=None, **kwargs): self.assertEqual(rt, data.get("refresh_token"), "Should attempt the RT") - return Mock(status_code=200, json=Mock(return_value={})) + return Response(200, {}) app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( self.authority, self.scopes, self.account, post=tester) def test_unknown_family_app_will_attempt_frt_and_join_family(self): - def tester(url, data=None, **kwargs): + def tester(method, url, data=None, **kwargs): self.assertEqual( self.frt, data.get("refresh_token"), "Should attempt the FRT") - return Mock( - status_code=200, - json=Mock(return_value=TokenCacheTestCase.build_response( - uid=self.uid, utid=self.utid, foci="1", access_token="at"))) + return Response( + 200, TokenCacheTestCase.build_response( + uid=self.uid, utid=self.utid, foci="1", access_token="at")) app = ClientApplication( "unknown_family_app", authority=self.authority_url, token_cache=self.cache) at = app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( From 30f7d99bf06c9c040a0e85933d9e21de5d060f59 Mon Sep 17 00:00:00 2001 From: Abhidnya Patil Date: Thu, 5 Mar 2020 18:42:09 -0800 Subject: [PATCH 04/33] Iteration 3 modifying some more failing tests --- msal/application.py | 2 +- msal/authority.py | 6 +++--- msal/oauth2cli/http/http_client.py | 1 - 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/msal/application.py b/msal/application.py index ff0d41c0..6bece5e8 100644 --- a/msal/application.py +++ b/msal/application.py @@ -735,7 +735,7 @@ def acquire_token_by_username_password( CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header( self.ACQUIRE_TOKEN_BY_USERNAME_PASSWORD_ID), } - if not self.authority.is_adfs: + if not self.authority.is_adfs and not self.authority.is_b2c: user_realm_result = self.authority.user_realm_discovery( username, correlation_id=headers[CLIENT_REQUEST_ID]) if user_realm_result.get("account_type") == "Federated": diff --git a/msal/authority.py b/msal/authority.py index eb7de754..694a1a05 100644 --- a/msal/authority.py +++ b/msal/authority.py @@ -50,9 +50,9 @@ def __init__(self, authority_url, validate_authority=True, self.http_client = http_client or DefaultHttpClient(verify=self.verify, proxy=self.proxies) authority, self.instance, tenant = canonicalize(authority_url) parts = authority.path.split('/') - is_b2c = any(self.instance.endswith("." + d) for d in WELL_KNOWN_B2C_HOSTS) or ( + self.is_b2c = any(self.instance.endswith("." + d) for d in WELL_KNOWN_B2C_HOSTS) or ( len(parts) == 3 and parts[2].lower().startswith("b2c_")) - if (tenant != "adfs" and (not is_b2c) and validate_authority + if (tenant != "adfs" and (not self.is_b2c) and validate_authority and self.instance not in WELL_KNOWN_AUTHORITY_HOSTS): payload = self.instance_discovery( "https://{}{}/oauth2/v2.0/authorize".format( @@ -105,7 +105,7 @@ def instance_discovery(self, url, **kwargs): ), params={'authorization_endpoint': url, 'api-version': '1.0'}, **kwargs) - return resp.content.json() + return resp.content # return requests.get( # Note: This URL seemingly returns V1 endpoint only # 'https://{}/common/discovery/instance'.format( # WORLD_WIDE # Historically using WORLD_WIDE. Could use self.instance too diff --git a/msal/oauth2cli/http/http_client.py b/msal/oauth2cli/http/http_client.py index bff90625..c935cf09 100644 --- a/msal/oauth2cli/http/http_client.py +++ b/msal/oauth2cli/http/http_client.py @@ -68,7 +68,6 @@ def request(self, method, url, params=None, data=None, headers=None, auth=None, raise elif method == "GET": response = self.session.get(url=url, headers=headers, params=params, timeout=timeout, data=data, auth=auth) - response.raise_for_status() content = response.json() response = Response(int(response.status_code), content) return response From 34e8e16f7e997c50335331b7fe1fcbc6a4e6d8f1 Mon Sep 17 00:00:00 2001 From: Abhidnya Patil Date: Fri, 6 Mar 2020 17:12:28 -0800 Subject: [PATCH 05/33] Iteration 4 --- msal/application.py | 14 +++-- msal/authority.py | 31 +++------- msal/{oauth2cli => }/http/__init__.py | 2 +- msal/{oauth2cli => }/http/http_client.py | 76 +++++++++++------------- msal/http/http_response.py | 9 +++ msal/mex.py | 5 +- msal/oauth2cli/http/http_request.py | 76 ------------------------ msal/oauth2cli/http/http_response.py | 13 ---- msal/oauth2cli/oauth2.py | 23 ++++--- msal/wstrust_request.py | 12 +--- tests/test_application.py | 26 ++++---- 11 files changed, 87 insertions(+), 200 deletions(-) rename msal/{oauth2cli => }/http/__init__.py (54%) rename msal/{oauth2cli => }/http/http_client.py (53%) create mode 100644 msal/http/http_response.py delete mode 100644 msal/oauth2cli/http/http_request.py delete mode 100644 msal/oauth2cli/http/http_response.py diff --git a/msal/application.py b/msal/application.py index 6bece5e8..c6d25ef1 100644 --- a/msal/application.py +++ b/msal/application.py @@ -1,3 +1,4 @@ +import json import time try: # Python 2 from urlparse import urljoin @@ -8,15 +9,13 @@ import warnings import uuid -import requests - from .oauth2cli import Client, JwtAssertionCreator from .authority import Authority from .mex import send_request as mex_send_request from .wstrust_request import send_request as wst_send_request from .wstrust_response import * from .token_cache import TokenCache -from msal.oauth2cli.http import DefaultHttpClient +from msal.http import DefaultHttpClient # The __init__.py will import this. Not the other way around. @@ -164,7 +163,7 @@ def __init__( self.client_claims = client_claims self.verify = verify self.proxies = proxies - self.http_client = http_client or DefaultHttpClient(verify=self.verify, proxy=self.proxies) + self.http_client = http_client or DefaultHttpClient(verify=self.verify, proxies=self.proxies) self.timeout = timeout self.app_name = app_name self.app_version = app_version @@ -268,7 +267,7 @@ def get_authorization_request_url( # Multi-tenant app can use new authority on demand the_authority = Authority( authority, - verify=self.verify, proxies=self.proxies, timeout=self.timeout, + verify=self.verify, proxies=self.proxies, timeout=self.timeout, http_client=self.http_client ) if authority else self.authority client = Client( @@ -373,8 +372,11 @@ def _get_authority_aliases(self, instance): resp = self.http_client.request("GET", "https://login.microsoftonline.com/common/discovery/instance?api-version=1.1&authorization_endpoint=https://login.microsoftonline.com/common/oauth2/authorize", headers={'Accept': 'application/json'}, verify=self.verify, proxies=self.proxies, timeout=self.timeout) + if resp.status_code >= 400: + raise + resp = json.loads(resp.content) self.authority_groups = [ - set(group['aliases']) for group in resp.content['metadata']] + set(group['aliases']) for group in resp['metadata']] for group in self.authority_groups: if instance in group: return [alias for alias in group if alias != instance] diff --git a/msal/authority.py b/msal/authority.py index 694a1a05..843f8992 100644 --- a/msal/authority.py +++ b/msal/authority.py @@ -1,4 +1,6 @@ -from msal.oauth2cli.http import DefaultHttpClient +import json + +from msal.http import DefaultHttpClient try: from urllib.parse import urlparse @@ -25,6 +27,7 @@ "b2clogin.de", ] + class Authority(object): """This class represents an (already-validated) authority. @@ -47,7 +50,7 @@ def __init__(self, authority_url, validate_authority=True, self.verify = verify self.proxies = proxies self.timeout = timeout - self.http_client = http_client or DefaultHttpClient(verify=self.verify, proxy=self.proxies) + self.http_client = http_client or DefaultHttpClient(verify=self.verify, proxies=self.proxies) authority, self.instance, tenant = canonicalize(authority_url) parts = authority.path.split('/') self.is_b2c = any(self.instance.endswith("." + d) for d in WELL_KNOWN_B2C_HOSTS) or ( @@ -87,15 +90,11 @@ def user_realm_discovery(self, username, correlation_id=None, response=None): # It will typically return a dict containing "ver", "account_type", # "federation_protocol", "cloud_audience_urn", # "federation_metadata_url", "federation_active_auth_url", etc. - if self.instance not in self.__class__._domains_without_user_realm_discovery: - resp = response or self.http_client.request("GET", + resp = response or self.http_client.request("GET", "https://{netloc}/common/userrealm/{username}?api-version=1.0".format( netloc=self.instance, username=username), headers={'Accept':'application/json', 'client-request-id': correlation_id}, timeout= self.timeout) - if resp.status_code != 404: - return resp.content - self.__class__._domains_without_user_realm_discovery.add(self.instance) - return {} # This can guide the caller to fall back normal ROPC flow + return json.loads(resp.content) def instance_discovery(self, url, **kwargs): resp = self.http_client.request("GET", 'https://{}/common/discovery/instance'.format( @@ -104,25 +103,13 @@ def instance_discovery(self, url, **kwargs): # and https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/4.0.0/src/Microsoft.Identity.Client/Instance/AadAuthority.cs#L19-L33 ), params={'authorization_endpoint': url, 'api-version': '1.0'}, **kwargs) - - return resp.content - # return requests.get( # Note: This URL seemingly returns V1 endpoint only - # 'https://{}/common/discovery/instance'.format( - # WORLD_WIDE # Historically using WORLD_WIDE. Could use self.instance too - # # See https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/4.0.0/src/Microsoft.Identity.Client/Instance/AadInstanceDiscovery.cs#L101-L103 - # # and https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/4.0.0/src/Microsoft.Identity.Client/Instance/AadAuthority.cs#L19-L33 - # ), - # params={'authorization_endpoint': url, 'api-version': '1.0'}, - # **kwargs).json() + return json.loads(resp.content) def tenant_discovery(self, tenant_discovery_endpoint, **kwargs): # Returns Openid Configuration resp = self.http_client.request("GET", tenant_discovery_endpoint, **kwargs) - payload = resp.content - - # resp = requests.get(tenant_discovery_endpoint, **kwargs) - # payload = resp.json() + payload = json.loads(resp.content) if 'authorization_endpoint' in payload and 'token_endpoint' in payload: return payload raise MsalServiceError(status_code=resp.status_code, **payload) diff --git a/msal/oauth2cli/http/__init__.py b/msal/http/__init__.py similarity index 54% rename from msal/oauth2cli/http/__init__.py rename to msal/http/__init__.py index 8624042a..d97cc32a 100644 --- a/msal/oauth2cli/http/__init__.py +++ b/msal/http/__init__.py @@ -1,2 +1,2 @@ from .http_client import DefaultHttpClient -from .http_response import Response \ No newline at end of file +from .http_response import Response diff --git a/msal/oauth2cli/http/http_client.py b/msal/http/http_client.py similarity index 53% rename from msal/oauth2cli/http/http_client.py rename to msal/http/http_client.py index c935cf09..b1303ace 100644 --- a/msal/oauth2cli/http/http_client.py +++ b/msal/http/http_client.py @@ -1,7 +1,10 @@ -from requests import Request, Session +import logging + +from requests import Session + from .http_response import Response +logger = logging.getLogger(__name__) -import requests class HttpClient(object): """ @@ -10,64 +13,53 @@ class HttpClient(object): def request(self, method, url, params=None, data=None, headers=None, auth=None, timeout=None, allow_redirects=False): """ - Make an HTTP request. + Makes an HTTP Request with parameters provided. + + :param str method: The HTTP method to use + :param str url: The URL to request + :param dict params: Query parameters to append to the URL + :param dict data: Parameters to go in the body of the HTTP request + :param dict headers: HTTP Headers to send with the request + :param tuple auth: Basic Auth arguments + :param float timeout: Socket/Read timeout for the request + :param boolean allow_redirects: Whether or not to allow redirects + See the requests documentation for explanation of all these parameters + + :return: An http response + :rtype: A :class:`Response ` object """ + class DefaultHttpClient(HttpClient): """ - General purpose HTTP Client for interacting with the Twilio API + Default HTTP Client """ - def __init__(self, default_headers={}, verify=None, proxy=None): + def __init__(self, verify=True, proxies=None): """ - Constructor for the TwilioHttpClient + Constructor for the DefaultHttpClient - :param bool pool_connections - :param request_hooks - :param int timeout: Timeout for the requests. - Timeout should never be zero (0) or less. - :param logger - :param dict proxy: Http proxy for the requests session + :param verify: (optional) + It will be passed to the + `verify parameter in the underlying requests library + `_ + :param proxies: (optional) + It will be passed to the + `proxies parameter in the underlying requests library + `_ """ self.session = Session() - self.session.headers.update(default_headers or {}) self.session.verify = verify - self.session.proxy = proxy + self.session.proxies = proxies def request(self, method, url, params=None, data=None, headers=None, auth=None, timeout=None, allow_redirects=False, **kwargs): - """ - Make an HTTP Request with parameters provided. - - :param str method: The HTTP method to use - :param str url: The URL to request - :param dict params: Query parameters to append to the URL - :param dict data: Parameters to go in the body of the HTTP request - :param dict headers: HTTP Headers to send with the request - :param tuple auth: Basic Auth arguments - :param float timeout: Socket/Read timeout for the request - :param boolean allow_redirects: Whether or not to allow redirects - See the requests documentation for explanation of all these parameters - :return: An http response - :rtype: A :class:`Response ` object - """ - content = None if method == "POST": response = self.session.post(url=url, headers=headers, params=params, data=data, auth=auth, timeout=timeout, **kwargs) - if response.status_code >=500: - response.raise_for_status() - try: - # The spec (https://tools.ietf.org/html/rfc6749#section-5.2) says - # even an error response will be a valid json structure, - # so we simply return it here, without needing to invent an exception. - content = response.json() - except ValueError: - self.logger.exception( - "Token response is not in json format: %s", response.text) - raise elif method == "GET": response = self.session.get(url=url, headers=headers, params=params, timeout=timeout, data=data, auth=auth) - content = response.json() + + content = response.text response = Response(int(response.status_code), content) return response diff --git a/msal/http/http_response.py b/msal/http/http_response.py new file mode 100644 index 00000000..3351eb38 --- /dev/null +++ b/msal/http/http_response.py @@ -0,0 +1,9 @@ +class Response(object): + + def __init__(self, status_code, content): + """HTTP Response object + :param int status_code: Status code from HTTP response + :param str text: HTTP response in string format + """ + self.status_code = status_code + self.content = content diff --git a/msal/mex.py b/msal/mex.py index c10f5617..5b6cf08c 100644 --- a/msal/mex.py +++ b/msal/mex.py @@ -34,8 +34,7 @@ except ImportError: from xml.etree import ElementTree as ET -from .oauth2cli.http import DefaultHttpClient -import requests +from msal.http import DefaultHttpClient def _xpath_of_root(route_to_leaf): @@ -46,7 +45,7 @@ def send_request(mex_endpoint, **kwargs): http_client = DefaultHttpClient() resp = http_client.request("GET", mex_endpoint, headers={'Content-Type': 'application/soap+xml'}, **kwargs) - mex_document = resp.content.text + mex_document = resp.content # mex_document = requests.get( # mex_endpoint, headers={'Content-Type': 'application/soap+xml'}, # **kwargs).text diff --git a/msal/oauth2cli/http/http_request.py b/msal/oauth2cli/http/http_request.py deleted file mode 100644 index 600de4ce..00000000 --- a/msal/oauth2cli/http/http_request.py +++ /dev/null @@ -1,76 +0,0 @@ - - -class Request(object): - """ - An HTTP request. - """ - ANY = '*' - - def __init__(self, - method=ANY, - url=ANY, - auth=ANY, - params=ANY, - data=ANY, - headers=ANY, - **kwargs): - self.method = method.upper() - self.url = url - self.auth = auth - self.params = params - self.data = data - self.headers = headers - - @classmethod - def attribute_equal(cls, lhs, rhs): - if lhs == cls.ANY or rhs == cls.ANY: - # ANY matches everything - return True - - lhs = lhs or None - rhs = rhs or None - - return lhs == rhs - - def __eq__(self, other): - if not isinstance(other, Request): - return False - - return self.attribute_equal(self.method, other.method) and \ - self.attribute_equal(self.url, other.url) and \ - self.attribute_equal(self.auth, other.auth) and \ - self.attribute_equal(self.params, other.params) and \ - self.attribute_equal(self.data, other.data) and \ - self.attribute_equal(self.headers, other.headers) - - def __str__(self): - auth = '' - if self.auth and self.auth != self.ANY: - auth = '{} '.format(self.auth) - - params = '' - if self.params and self.params != self.ANY: - params = '?{}'.format(urlencode(self.params, doseq=True)) - - data = '' - if self.data and self.data != self.ANY: - if self.method == 'GET': - data = '\n -G' - data += '\n{}'.format('\n'.join(' -d "{}={}"'.format(k, v) for k, v in self.data.items())) - - headers = '' - if self.headers and self.headers != self.ANY: - headers = '\n{}'.format('\n'.join(' -H "{}: {}"'.format(k, v) - for k, v in self.headers.items())) - - return '{auth}{method} {url}{params}{data}{headers}'.format( - auth=auth, - method=self.method, - url=self.url, - params=params, - data=data, - headers=headers, - ) - - def __repr__(self): - return str(self) \ No newline at end of file diff --git a/msal/oauth2cli/http/http_response.py b/msal/oauth2cli/http/http_response.py deleted file mode 100644 index 41770475..00000000 --- a/msal/oauth2cli/http/http_response.py +++ /dev/null @@ -1,13 +0,0 @@ -class Response(object): - """ - - """ - def __init__(self, status_code, text): - self.content = text - self.cached = False - self.status_code = status_code - self.ok = self.status_code < 400 - - @property - def text(self): - return self.content diff --git a/msal/oauth2cli/oauth2.py b/msal/oauth2cli/oauth2.py index f54a1cc7..faa9950c 100644 --- a/msal/oauth2cli/oauth2.py +++ b/msal/oauth2cli/oauth2.py @@ -1,5 +1,6 @@ """This OAuth2 client implementation aims to be spec-compliant, and generic.""" # OAuth2 spec https://tools.ietf.org/html/rfc6749 +import json try: from urllib.parse import urlencode, parse_qs @@ -12,9 +13,7 @@ import base64 import sys -from .http import DefaultHttpClient -from .http import Response - +from msal.http import DefaultHttpClient string_types = (str,) if sys.version_info[0] >= 3 else (basestring, ) @@ -83,14 +82,12 @@ def __init__( self.client_id = client_id self.client_secret = client_secret self.client_assertion = client_assertion + self.default_headers = default_headers or {} self.default_body = default_body or {} if client_assertion_type is not None: self.default_body["client_assertion_type"] = client_assertion_type self.logger = logging.getLogger(__name__) - if not http_client: - self.http_client = DefaultHttpClient(verify=verify, proxy=proxies or {}, default_headers= {}) - else: - self.http_client = http_client + self.http_client = http_client if http_client else DefaultHttpClient(verify=verify, proxies=proxies) self.timeout = timeout def _build_auth_request_params(self, response_type, **kwargs): @@ -149,12 +146,16 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749 if "token_endpoint" not in self.configuration: raise ValueError("token_endpoint not found in configuration") _headers = {'Accept': 'application/json'} + _headers.update(self.default_headers) _headers.update(headers or {}) resp = (post or self.http_client.request)("POST", self.configuration["token_endpoint"], headers=_headers, params=params, data=_data, auth=auth, timeout=timeout or self.timeout, **kwargs) - return resp.content + if resp.status_code >= 500: + raise Exception + resp = json.loads(resp.content) + return resp def obtain_token_by_refresh_token(self, refresh_token, scope=None, **kwargs): # type: (str, Union[str, list, set, tuple]) -> dict @@ -213,11 +214,7 @@ def initiate_device_flow(self, scope=None, timeout=None, **kwargs): data={"client_id": self.client_id, "scope": self._stringify(scope or [])}, timeout=timeout or self.timeout, **kwargs) - flow = resp.content.json() - # flow = self.session.post(self.configuration[DAE], - # data={"client_id": self.client_id, "scope": self._stringify(scope or [])}, - # timeout=timeout or self.timeout, - # **kwargs).json() + flow = json.loads(resp.content) flow["interval"] = int(flow.get("interval", 5)) # Some IdP returns string flow["expires_in"] = int(flow.get("expires_in", 1800)) flow["expires_at"] = time.time() + flow["expires_in"] # We invent this diff --git a/msal/wstrust_request.py b/msal/wstrust_request.py index 5b7f95f3..53d160ce 100644 --- a/msal/wstrust_request.py +++ b/msal/wstrust_request.py @@ -29,11 +29,9 @@ from datetime import datetime, timedelta import logging -import requests - from .mex import Mex from .wstrust_response import parse_response -from .oauth2cli.http import DefaultHttpClient +from msal.http import DefaultHttpClient logger = logging.getLogger(__name__) @@ -53,20 +51,16 @@ def send_request( data = _build_rst( username, password, cloud_audience_urn, endpoint_address, soap_action) http_client = DefaultHttpClient() - resp = http_client.request("POST", endpoint_address, + resp = http_client.request("POST", endpoint_address, data=data, headers={ 'Content-type':'application/soap+xml; charset=utf-8', 'SOAPAction': soap_action, }, **kwargs) - # resp = requests.post(endpoint_address, data=data, headers={ - # 'Content-type':'application/soap+xml; charset=utf-8', - # 'SOAPAction': soap_action, - # }, **kwargs) if resp.status_code >= 400: logger.debug("Unsuccessful WsTrust request receives: %s", resp.content.text) # It turns out ADFS uses 5xx status code even with client-side incorrect password error # resp.raise_for_status() - return parse_response(resp.content.text) + return parse_response(resp.content) def escape_password(password): return (password.replace('&', '&').replace('"', '"') diff --git a/tests/test_application.py b/tests/test_application.py index 452acda8..d7af2bcd 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -1,10 +1,7 @@ # Note: Since Aug 2019 we move all e2e tests into test_e2e.py, # so this test_application file contains only unit tests without dependency. -import os -import json -import logging -from msal.oauth2cli.http import Response +from msal.http import Response try: from unittest.mock import * # Python 3 @@ -79,21 +76,21 @@ def test_cache_empty_will_be_returned_as_None(self): None, self.app.acquire_token_silent_with_error(['cache_miss'], self.account)) def test_acquire_token_silent_will_suppress_error(self): - error_response = {"error": "invalid_grant", "suberror": "xyz"} + error_response = '{"error": "invalid_grant", "suberror": "xyz"}' def tester(method, url, **kwargs): return Response(400, error_response) self.assertEqual(None, self.app.acquire_token_silent( self.scopes, self.account, post=tester)) def test_acquire_token_silent_with_error_will_return_error(self): - error_response = {"error": "invalid_grant", "error_description": "xyz"} + error_response = '{"error": "invalid_grant", "error_description": "xyz"}' def tester(method, url, **kwargs): return Response(400, error_response) - self.assertEqual(error_response, self.app.acquire_token_silent_with_error( + self.assertEqual(json.loads(error_response), self.app.acquire_token_silent_with_error( self.scopes, self.account, post=tester)) def test_atswe_will_map_some_suberror_to_classification_as_is(self): - error_response = {"error": "invalid_grant", "suberror": "basic_action"} + error_response = '{"error": "invalid_grant", "suberror": "basic_action"}' def tester(method, url, **kwargs): return Response(400, error_response) result = self.app.acquire_token_silent_with_error( @@ -101,7 +98,7 @@ def tester(method, url, **kwargs): self.assertEqual("basic_action", result.get("classification")) def test_atswe_will_map_some_suberror_to_classification_to_empty_string(self): - error_response = {"error": "invalid_grant", "suberror": "client_mismatch"} + error_response = '{"error": "invalid_grant", "suberror": "client_mismatch"}' def tester(method, url, **kwargs): return Response(400, error_response) result = self.app.acquire_token_silent_with_error( @@ -134,11 +131,10 @@ def test_unknown_orphan_app_will_attempt_frt_and_not_remove_it(self): app = ClientApplication( "unknown_orphan", authority=self.authority_url, token_cache=self.cache) logger.debug("%s.cache = %s", self.id(), self.cache.serialize()) + error_response = '{"error": "invalid_grant","error_description": "Was issued to another client"}' def tester(method, url, data=None, **kwargs): self.assertEqual(self.frt, data.get("refresh_token"), "Should attempt the FRT") - return Response(400, { - "error": "invalid_grant", - "error_description": "Was issued to another client"}) + return Response(400, error_response) app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( self.authority, self.scopes, self.account, post=tester) self.assertNotEqual([], app.token_cache.find( @@ -159,7 +155,7 @@ def test_known_orphan_app_will_skip_frt_and_only_use_its_own_rt(self): logger.debug("%s.cache = %s", self.id(), self.cache.serialize()) def tester(method, url, data=None, **kwargs): self.assertEqual(rt, data.get("refresh_token"), "Should attempt the RT") - return Response(200, {}) + return Response(200, '{}') app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( self.authority, self.scopes, self.account, post=tester) @@ -168,8 +164,8 @@ def tester(method, url, data=None, **kwargs): self.assertEqual( self.frt, data.get("refresh_token"), "Should attempt the FRT") return Response( - 200, TokenCacheTestCase.build_response( - uid=self.uid, utid=self.utid, foci="1", access_token="at")) + 200, json.dumps(TokenCacheTestCase.build_response( + uid=self.uid, utid=self.utid, foci="1", access_token="at"))) app = ClientApplication( "unknown_family_app", authority=self.authority_url, token_cache=self.cache) at = app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( From f60bbb5136a0ae08a0096cb2d37f9ade01774561 Mon Sep 17 00:00:00 2001 From: Abhidnya Patil Date: Fri, 6 Mar 2020 17:16:25 -0800 Subject: [PATCH 06/33] Removing tests whose implementation was removed --- tests/test_authority.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/tests/test_authority.py b/tests/test_authority.py index d1e75ef7..80023fdf 100644 --- a/tests/test_authority.py +++ b/tests/test_authority.py @@ -73,22 +73,3 @@ def test_canonicalize_rejects_tenantless_host_with_trailing_slash(self): canonicalize("https://no.tenant.example.com/") -@unittest.skipIf(os.getenv("TRAVIS_TAG"), "Skip network io during tagged release") -class TestAuthorityInternalHelperUserRealmDiscovery(unittest.TestCase): - def test_memorize(self): - # We use a real authority so the constructor can finish tenant discovery - authority = "https://login.microsoftonline.com/common" - self.assertNotIn(authority, Authority._domains_without_user_realm_discovery) - a = Authority(authority, validate_authority=False) - - # We now pretend this authority supports no User Realm Discovery - class MockResponse(object): - status_code = 404 - a.user_realm_discovery("john.doe@example.com", response=MockResponse()) - self.assertIn( - "login.microsoftonline.com", - Authority._domains_without_user_realm_discovery, - "user_realm_discovery() should memorize domains not supporting URD") - a.user_realm_discovery("john.doe@example.com", - response="This would cause exception if memorization did not work") - From d54fb8ba926431c19d7e172d7b23c2b784868a59 Mon Sep 17 00:00:00 2001 From: Abhidnya Patil Date: Fri, 6 Mar 2020 17:29:40 -0800 Subject: [PATCH 07/33] Iteration 5 --- msal/application.py | 6 +++++- msal/authority.py | 6 ------ msal/mex.py | 3 --- msal/wstrust_request.py | 2 +- 4 files changed, 6 insertions(+), 11 deletions(-) diff --git a/msal/application.py b/msal/application.py index c6d25ef1..9a0721c7 100644 --- a/msal/application.py +++ b/msal/application.py @@ -147,6 +147,9 @@ def __init__( It will be passed to the `proxies parameter in the underlying requests library `_ + :param http_client: (optional) + Your implementation of abstract class HttpClient + Defaults to default http client implementation which uses requests :param timeout: (optional) It will be passed to the `timeout parameter in the underlying requests library @@ -169,7 +172,8 @@ def __init__( self.app_version = app_version self.authority = Authority( authority or "https://login.microsoftonline.com/common/", - validate_authority, verify=verify, proxies=proxies, timeout=timeout, http_client=self.http_client) + validate_authority, verify=verify, proxies=proxies, timeout=timeout, + http_client=self.http_client) # Here the self.authority is not the same type as authority in input self.token_cache = token_cache or TokenCache() self.client = self._build_client(client_credential, self.authority) diff --git a/msal/authority.py b/msal/authority.py index 843f8992..5309601d 100644 --- a/msal/authority.py +++ b/msal/authority.py @@ -128,9 +128,3 @@ def canonicalize(authority_url): "or https://.b2clogin.com/.onmicrosoft.com/policy" % authority_url) return authority, authority.hostname, parts[1] - - - - - - diff --git a/msal/mex.py b/msal/mex.py index 5b6cf08c..39ef01d4 100644 --- a/msal/mex.py +++ b/msal/mex.py @@ -46,9 +46,6 @@ def send_request(mex_endpoint, **kwargs): resp = http_client.request("GET", mex_endpoint, headers={'Content-Type': 'application/soap+xml'}, **kwargs) mex_document = resp.content - # mex_document = requests.get( - # mex_endpoint, headers={'Content-Type': 'application/soap+xml'}, - # **kwargs).text return Mex(mex_document).get_wstrust_username_password_endpoint() diff --git a/msal/wstrust_request.py b/msal/wstrust_request.py index 53d160ce..256ecc8c 100644 --- a/msal/wstrust_request.py +++ b/msal/wstrust_request.py @@ -57,7 +57,7 @@ def send_request( 'SOAPAction': soap_action, }, **kwargs) if resp.status_code >= 400: - logger.debug("Unsuccessful WsTrust request receives: %s", resp.content.text) + logger.debug("Unsuccessful WsTrust request receives: %s", resp.content) # It turns out ADFS uses 5xx status code even with client-side incorrect password error # resp.raise_for_status() return parse_response(resp.content) From be389d5ec21904a11d0f70e85278ca863c2c532f Mon Sep 17 00:00:00 2001 From: Abhidnya Patil Date: Mon, 9 Mar 2020 12:42:09 -0700 Subject: [PATCH 08/33] Replacing generic exception to specific Http error --- msal/application.py | 8 ++++++-- msal/oauth2cli/oauth2.py | 4 +++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/msal/application.py b/msal/application.py index 9a0721c7..75e191b6 100644 --- a/msal/application.py +++ b/msal/application.py @@ -2,8 +2,10 @@ import time try: # Python 2 from urlparse import urljoin + from urllib2 import HTTPError except: # Python 3 from urllib.parse import urljoin + from urllib.error import HTTPError import logging import sys import warnings @@ -376,8 +378,10 @@ def _get_authority_aliases(self, instance): resp = self.http_client.request("GET", "https://login.microsoftonline.com/common/discovery/instance?api-version=1.1&authorization_endpoint=https://login.microsoftonline.com/common/oauth2/authorize", headers={'Accept': 'application/json'}, verify=self.verify, proxies=self.proxies, timeout=self.timeout) - if resp.status_code >= 400: - raise + if resp.status_code >= 500: + raise HttpError("Internal server error %s" % resp.content) + elif resp.status_code >= 400: + raise HttpError("Client error %s" % resp.content) resp = json.loads(resp.content) self.authority_groups = [ set(group['aliases']) for group in resp['metadata']] diff --git a/msal/oauth2cli/oauth2.py b/msal/oauth2cli/oauth2.py index faa9950c..01ea86d7 100644 --- a/msal/oauth2cli/oauth2.py +++ b/msal/oauth2cli/oauth2.py @@ -4,9 +4,11 @@ try: from urllib.parse import urlencode, parse_qs + from urllib.error import HTTPError except ImportError: from urlparse import parse_qs from urllib import urlencode + from urllib2 import HTTPError import logging import warnings import time @@ -153,7 +155,7 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749 timeout=timeout or self.timeout, **kwargs) if resp.status_code >= 500: - raise Exception + raise HttpError("Internal server error %s" % resp.content) resp = json.loads(resp.content) return resp From b3a4e09a3cb845cb0244143228a3bc60d405b3fd Mon Sep 17 00:00:00 2001 From: Abhidnya Patil Date: Thu, 19 Mar 2020 17:46:07 -0700 Subject: [PATCH 09/33] Refactor according to new interface --- msal/application.py | 54 +++++++++------------- msal/authority.py | 34 +++++++------- msal/http/__init__.py | 2 - msal/http/http_client.py | 65 --------------------------- msal/http/http_response.py | 9 ---- msal/mex.py | 6 +-- msal/oauth2cli/__init__.py | 3 +- msal/oauth2cli/default_http_client.py | 46 +++++++++++++++++++ msal/oauth2cli/oauth2.py | 23 ++++------ msal/wstrust_request.py | 12 ++--- tests/test_application.py | 19 ++++---- tests/test_authority.py | 8 ++-- tests/test_client.py | 5 ++- 13 files changed, 116 insertions(+), 170 deletions(-) delete mode 100644 msal/http/__init__.py delete mode 100644 msal/http/http_client.py delete mode 100644 msal/http/http_response.py create mode 100644 msal/oauth2cli/default_http_client.py diff --git a/msal/application.py b/msal/application.py index 75e191b6..29de032a 100644 --- a/msal/application.py +++ b/msal/application.py @@ -2,23 +2,19 @@ import time try: # Python 2 from urlparse import urljoin - from urllib2 import HTTPError except: # Python 3 from urllib.parse import urljoin - from urllib.error import HTTPError import logging import sys import warnings import uuid -from .oauth2cli import Client, JwtAssertionCreator +from .oauth2cli import Client, JwtAssertionCreator, DefaultHttpClient from .authority import Authority from .mex import send_request as mex_send_request from .wstrust_request import send_request as wst_send_request from .wstrust_response import * from .token_cache import TokenCache -from msal.http import DefaultHttpClient - # The __init__.py will import this. Not the other way around. __version__ = "1.1.0" @@ -56,11 +52,11 @@ def decorate_scope( CLIENT_CURRENT_TELEMETRY = 'x-client-current-telemetry' def _get_new_correlation_id(): - return str(uuid.uuid4()) + return str(uuid.uuid4()) def _build_current_telemetry_request_header(public_api_id, force_refresh=False): - return "1|{},{}|".format(public_api_id, "1" if force_refresh else "0") + return "1|{},{}|".format(public_api_id, "1" if force_refresh else "0") def extract_certs(public_cert_content): @@ -141,6 +137,9 @@ def __init__( :param TokenCache cache: Sets the token cache used by this ClientApplication instance. By default, an in-memory cache will be created and used. + :param http_client: (optional) + Your implementation of abstract class HttpClient + Defaults to default http client implementation which uses requests :param verify: (optional) It will be passed to the `verify parameter in the underlying requests library @@ -149,9 +148,6 @@ def __init__( It will be passed to the `proxies parameter in the underlying requests library `_ - :param http_client: (optional) - Your implementation of abstract class HttpClient - Defaults to default http client implementation which uses requests :param timeout: (optional) It will be passed to the `timeout parameter in the underlying requests library @@ -166,16 +162,13 @@ def __init__( self.client_id = client_id self.client_credential = client_credential self.client_claims = client_claims - self.verify = verify - self.proxies = proxies - self.http_client = http_client or DefaultHttpClient(verify=self.verify, proxies=self.proxies) + self.http_client = http_client if http_client else DefaultHttpClient(verify=verify, proxies=proxies) self.timeout = timeout self.app_name = app_name self.app_version = app_version self.authority = Authority( authority or "https://login.microsoftonline.com/common/", - validate_authority, verify=verify, proxies=proxies, timeout=timeout, - http_client=self.http_client) + http_client=self.http_client, validate_authority=validate_authority, timeout=timeout) # Here the self.authority is not the same type as authority in input self.token_cache = token_cache or TokenCache() self.client = self._build_client(client_credential, self.authority) @@ -225,8 +218,7 @@ def _build_client(self, client_credential, authority): on_obtaining_tokens=self.token_cache.add, on_removing_rt=self.token_cache.remove_rt, on_updating_rt=self.token_cache.update_rt, - http_client=self.http_client, - verify=self.verify, proxies=self.proxies, timeout=self.timeout) + http_client=self.http_client, timeout=self.timeout) def get_authorization_request_url( self, @@ -272,8 +264,7 @@ def get_authorization_request_url( # The previous implementation is, it will use self.authority by default. # Multi-tenant app can use new authority on demand the_authority = Authority( - authority, - verify=self.verify, proxies=self.proxies, timeout=self.timeout, http_client=self.http_client + authority, http_client=self.http_client, timeout=self.timeout ) if authority else self.authority client = Client( @@ -375,14 +366,12 @@ def _find_msal_accounts(self, environment): def _get_authority_aliases(self, instance): if not self.authority_groups: - resp = self.http_client.request("GET", "https://login.microsoftonline.com/common/discovery/instance?api-version=1.1&authorization_endpoint=https://login.microsoftonline.com/common/oauth2/authorize", + resp = self.http_client.get("https://login.microsoftonline.com/common/discovery/instance?api-version=1.1&authorization_endpoint=https://login.microsoftonline.com/common/oauth2/authorize", headers={'Accept': 'application/json'}, - verify=self.verify, proxies=self.proxies, timeout=self.timeout) + timeout=self.timeout) if resp.status_code >= 500: - raise HttpError("Internal server error %s" % resp.content) - elif resp.status_code >= 400: - raise HttpError("Client error %s" % resp.content) - resp = json.loads(resp.content) + resp.raise_for_status() + resp = json.loads(resp.text) self.authority_groups = [ set(group['aliases']) for group in resp['metadata']] for group in self.authority_groups: @@ -502,8 +491,8 @@ def acquire_token_silent_with_error( if authority: warnings.warn("We haven't decided how/if this method will accept authority parameter") # the_authority = Authority( - # authority, - # verify=self.verify, proxies=self.proxies, timeout=self.timeout, + # authority, http_client=self.http_client, + # timeout=self.timeout # ) if authority else self.authority result = self._acquire_token_silent_from_cache_and_possibly_refresh_it( scopes, account, self.authority, force_refresh=force_refresh, @@ -515,8 +504,8 @@ def acquire_token_silent_with_error( for alias in self._get_authority_aliases(self.authority.instance): the_authority = Authority( "https://" + alias + "/" + self.authority.tenant, - validate_authority=False, - verify=self.verify, proxies=self.proxies, timeout=self.timeout, http_client=self.http_client) + http_client=self.http_client, validate_authority=False, + timeout=self.timeout) result = self._acquire_token_silent_from_cache_and_possibly_refresh_it( scopes, account, the_authority, force_refresh=force_refresh, correlation_id=correlation_id, @@ -759,13 +748,10 @@ def acquire_token_by_username_password( def _acquire_token_by_username_password_federated( self, user_realm_result, username, password, scopes=None, **kwargs): - verify = kwargs.pop("verify", self.verify) - proxies = kwargs.pop("proxies", self.proxies) wstrust_endpoint = {} if user_realm_result.get("federation_metadata_url"): wstrust_endpoint = mex_send_request( - user_realm_result["federation_metadata_url"], - verify=verify, proxies=proxies) + user_realm_result["federation_metadata_url"], http_client=self.http_client) if wstrust_endpoint is None: raise ValueError("Unable to find wstrust endpoint from MEX. " "This typically happens when attempting MSA accounts. " @@ -777,7 +763,7 @@ def _acquire_token_by_username_password_federated( wstrust_endpoint.get("address", # Fallback to an AAD supplied endpoint user_realm_result.get("federation_active_auth_url")), - wstrust_endpoint.get("action"), verify=verify, proxies=proxies) + wstrust_endpoint.get("action"), http_client=self.http_client) if not ("token" in wstrust_result and "type" in wstrust_result): raise RuntimeError("Unsuccessful RSTR. %s" % wstrust_result) GRANT_TYPE_SAML1_1 = 'urn:ietf:params:oauth:grant-type:saml1_1-bearer' diff --git a/msal/authority.py b/msal/authority.py index 5309601d..6a844c0f 100644 --- a/msal/authority.py +++ b/msal/authority.py @@ -1,6 +1,6 @@ import json -from msal.http import DefaultHttpClient +from .oauth2cli.default_http_client import DefaultHttpClient try: from urllib.parse import urlparse @@ -36,8 +36,8 @@ class Authority(object): """ _domains_without_user_realm_discovery = set([]) - def __init__(self, authority_url, validate_authority=True, - verify=True, proxies=None, timeout=None, http_client=None + def __init__(self, authority_url, http_client, validate_authority=True, + timeout=None ): """Creates an authority instance, and also validates it. @@ -47,10 +47,8 @@ def __init__(self, authority_url, validate_authority=True, This parameter only controls whether an instance discovery will be performed. """ - self.verify = verify - self.proxies = proxies + self.http_client = http_client self.timeout = timeout - self.http_client = http_client or DefaultHttpClient(verify=self.verify, proxies=self.proxies) authority, self.instance, tenant = canonicalize(authority_url) parts = authority.path.split('/') self.is_b2c = any(self.instance.endswith("." + d) for d in WELL_KNOWN_B2C_HOSTS) or ( @@ -60,7 +58,7 @@ def __init__(self, authority_url, validate_authority=True, payload = self.instance_discovery( "https://{}{}/oauth2/v2.0/authorize".format( self.instance, authority.path), - verify=verify, proxies=proxies, timeout=timeout) + timeout=timeout) if payload.get("error") == "invalid_instance": raise ValueError( "invalid_instance: " @@ -79,7 +77,7 @@ def __init__(self, authority_url, validate_authority=True, )) openid_config = self.tenant_discovery( tenant_discovery_endpoint, - verify=verify, proxies=proxies, timeout=timeout) + timeout=timeout) logger.debug("openid_config = %s", openid_config) self.authorization_endpoint = openid_config['authorization_endpoint'] self.token_endpoint = openid_config['token_endpoint'] @@ -90,32 +88,30 @@ def user_realm_discovery(self, username, correlation_id=None, response=None): # It will typically return a dict containing "ver", "account_type", # "federation_protocol", "cloud_audience_urn", # "federation_metadata_url", "federation_active_auth_url", etc. - resp = response or self.http_client.request("GET", - "https://{netloc}/common/userrealm/{username}?api-version=1.0".format( - netloc=self.instance, username=username), headers={'Accept':'application/json', - 'client-request-id': correlation_id}, timeout= self.timeout) - return json.loads(resp.content) + resp = response or self.http_client.get("https://{netloc}/common/userrealm/{username}?api-version=1.0".format( + netloc=self.instance, username=username), + headers={'Accept':'application/json', 'client-request-id': correlation_id}, + timeout=self.timeout) + return json.loads(resp.text) def instance_discovery(self, url, **kwargs): - resp = self.http_client.request("GET", 'https://{}/common/discovery/instance'.format( + resp = self.http_client.get('https://{}/common/discovery/instance'.format( WORLD_WIDE # Historically using WORLD_WIDE. Could use self.instance too # See https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/4.0.0/src/Microsoft.Identity.Client/Instance/AadInstanceDiscovery.cs#L101-L103 # and https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/4.0.0/src/Microsoft.Identity.Client/Instance/AadAuthority.cs#L19-L33 ), params={'authorization_endpoint': url, 'api-version': '1.0'}, **kwargs) - return json.loads(resp.content) + return json.loads(resp.text) def tenant_discovery(self, tenant_discovery_endpoint, **kwargs): # Returns Openid Configuration - resp = self.http_client.request("GET", tenant_discovery_endpoint, - **kwargs) - payload = json.loads(resp.content) + resp = self.http_client.get(tenant_discovery_endpoint, **kwargs) + payload = json.loads(resp.text) if 'authorization_endpoint' in payload and 'token_endpoint' in payload: return payload raise MsalServiceError(status_code=resp.status_code, **payload) - def canonicalize(authority_url): # Returns (url_parsed_result, hostname_in_lowercase, tenant) authority = urlparse(authority_url) diff --git a/msal/http/__init__.py b/msal/http/__init__.py deleted file mode 100644 index d97cc32a..00000000 --- a/msal/http/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .http_client import DefaultHttpClient -from .http_response import Response diff --git a/msal/http/http_client.py b/msal/http/http_client.py deleted file mode 100644 index b1303ace..00000000 --- a/msal/http/http_client.py +++ /dev/null @@ -1,65 +0,0 @@ -import logging - -from requests import Session - -from .http_response import Response -logger = logging.getLogger(__name__) - - -class HttpClient(object): - """ - An abstract class representing an HTTP client. - """ - def request(self, method, url, params=None, data=None, headers=None, auth=None, - timeout=None, allow_redirects=False): - """ - Makes an HTTP Request with parameters provided. - - :param str method: The HTTP method to use - :param str url: The URL to request - :param dict params: Query parameters to append to the URL - :param dict data: Parameters to go in the body of the HTTP request - :param dict headers: HTTP Headers to send with the request - :param tuple auth: Basic Auth arguments - :param float timeout: Socket/Read timeout for the request - :param boolean allow_redirects: Whether or not to allow redirects - See the requests documentation for explanation of all these parameters - - :return: An http response - :rtype: A :class:`Response ` object - """ - - -class DefaultHttpClient(HttpClient): - """ - Default HTTP Client - """ - def __init__(self, verify=True, proxies=None): - """ - Constructor for the DefaultHttpClient - - :param verify: (optional) - It will be passed to the - `verify parameter in the underlying requests library - `_ - :param proxies: (optional) - It will be passed to the - `proxies parameter in the underlying requests library - `_ - """ - self.session = Session() - self.session.verify = verify - self.session.proxies = proxies - - def request(self, method, url, params=None, data=None, headers=None, auth=None, timeout=None, - allow_redirects=False, **kwargs): - - if method == "POST": - response = self.session.post(url=url, headers=headers, params=params, data=data, auth=auth, - timeout=timeout, **kwargs) - elif method == "GET": - response = self.session.get(url=url, headers=headers, params=params, timeout=timeout, data=data, auth=auth) - - content = response.text - response = Response(int(response.status_code), content) - return response diff --git a/msal/http/http_response.py b/msal/http/http_response.py deleted file mode 100644 index 3351eb38..00000000 --- a/msal/http/http_response.py +++ /dev/null @@ -1,9 +0,0 @@ -class Response(object): - - def __init__(self, status_code, content): - """HTTP Response object - :param int status_code: Status code from HTTP response - :param str text: HTTP response in string format - """ - self.status_code = status_code - self.content = content diff --git a/msal/mex.py b/msal/mex.py index 39ef01d4..239d6aa1 100644 --- a/msal/mex.py +++ b/msal/mex.py @@ -34,15 +34,13 @@ except ImportError: from xml.etree import ElementTree as ET -from msal.http import DefaultHttpClient - def _xpath_of_root(route_to_leaf): # Construct an xpath suitable to find a root node which has a specified leaf return '/'.join(route_to_leaf + ['..'] * (len(route_to_leaf)-1)) -def send_request(mex_endpoint, **kwargs): - http_client = DefaultHttpClient() + +def send_request(mex_endpoint, http_client, **kwargs): resp = http_client.request("GET", mex_endpoint, headers={'Content-Type': 'application/soap+xml'}, **kwargs) mex_document = resp.content diff --git a/msal/oauth2cli/__init__.py b/msal/oauth2cli/__init__.py index b8941361..1c8bdd1d 100644 --- a/msal/oauth2cli/__init__.py +++ b/msal/oauth2cli/__init__.py @@ -3,4 +3,5 @@ from .oidc import Client from .assertion import JwtAssertionCreator from .assertion import JwtSigner # Obsolete. For backward compatibility. - +from .http import HttpClient, Response +from .default_http_client import DefaultHttpClient diff --git a/msal/oauth2cli/default_http_client.py b/msal/oauth2cli/default_http_client.py new file mode 100644 index 00000000..63350fd0 --- /dev/null +++ b/msal/oauth2cli/default_http_client.py @@ -0,0 +1,46 @@ +import requests + +from .http import HttpClient, Response + + +class DefaultHttpClient(HttpClient): + """ + Default HTTP Client + """ + def __init__(self, verify=True, proxies=None, timeout=None): + """ + Constructor for the DefaultHttpClient + + verify=True, # type: Union[str, True, False, None] + proxies=None, # type: Optional[dict] + """ + self.session = requests.Session() + if verify: + self.session.verify = verify + if proxies: + self.session.proxies = proxies + if timeout: + self.session.timeout = timeout + + def post(self, url, params=None, data=None, headers=None, **kwargs): + + response = self.session.post(url=url, params=params, headers=headers, data=data, **kwargs) + return Response(response.status_code, response.text) + + def get(self, url, params=None, headers=None, **kwargs): + response = self.session.get(url=url, params=params, headers=headers, **kwargs) + return Response(response.status_code, response.text) + + +class Response(Response): + + def __init__(self, status_code, text): + """HTTP Response object + :param int status_code: Status code from HTTP response + :param str text: HTTP response in string format + """ + self.status_code = status_code + self.text = text + + def raise_for_status(self): + self.text.raise_for_status() \ No newline at end of file diff --git a/msal/oauth2cli/oauth2.py b/msal/oauth2cli/oauth2.py index 01ea86d7..0fe46971 100644 --- a/msal/oauth2cli/oauth2.py +++ b/msal/oauth2cli/oauth2.py @@ -4,18 +4,15 @@ try: from urllib.parse import urlencode, parse_qs - from urllib.error import HTTPError except ImportError: from urlparse import parse_qs from urllib import urlencode - from urllib2 import HTTPError import logging import warnings import time import base64 import sys - -from msal.http import DefaultHttpClient +from .default_http_client import DefaultHttpClient string_types = (str,) if sys.version_info[0] >= 3 else (basestring, ) @@ -37,14 +34,12 @@ def __init__( self, server_configuration, # type: dict client_id, # type: str + http_client, # type: HttpClient client_secret=None, # type: Optional[str] client_assertion=None, # type: Union[bytes, callable, None] client_assertion_type=None, # type: Optional[str] default_headers=None, # type: Optional[dict] default_body=None, # type: Optional[dict] - http_client=None, - verify=True, # type: Union[str, True, False, None] - proxies=None, # type: Optional[dict] timeout=None, # type: Union[tuple, float, None] ): """Initialize a client object to talk all the OAuth2 grants to the server. @@ -89,7 +84,7 @@ def __init__( if client_assertion_type is not None: self.default_body["client_assertion_type"] = client_assertion_type self.logger = logging.getLogger(__name__) - self.http_client = http_client if http_client else DefaultHttpClient(verify=verify, proxies=proxies) + self.http_client = http_client self.timeout = timeout def _build_auth_request_params(self, response_type, **kwargs): @@ -98,7 +93,6 @@ def _build_auth_request_params(self, response_type, **kwargs): # or it can be a space-delimited string as defined in # https://tools.ietf.org/html/rfc6749#section-8.4 response_type = self._stringify(response_type) - params = {'client_id': self.client_id, 'response_type': response_type} params.update(kwargs) # Note: None values will override params params = {k: v for k, v in params.items() if v is not None} # clean up @@ -150,14 +144,13 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749 _headers = {'Accept': 'application/json'} _headers.update(self.default_headers) _headers.update(headers or {}) - resp = (post or self.http_client.request)("POST", self.configuration["token_endpoint"], + resp = (post or self.http_client.post)(self.configuration["token_endpoint"], headers=_headers, params=params, data=_data, auth=auth, timeout=timeout or self.timeout, **kwargs) if resp.status_code >= 500: - raise HttpError("Internal server error %s" % resp.content) - resp = json.loads(resp.content) - return resp + resp.raise_for_status() + return json.loads(resp.text) def obtain_token_by_refresh_token(self, refresh_token, scope=None, **kwargs): # type: (str, Union[str, list, set, tuple]) -> dict @@ -212,11 +205,11 @@ def initiate_device_flow(self, scope=None, timeout=None, **kwargs): DAE = "device_authorization_endpoint" if not self.configuration.get(DAE): raise ValueError("You need to provide device authorization endpoint") - resp = self.http_client.request("POST", self.configuration[DAE], + resp = self.http_client.post(self.configuration[DAE], data={"client_id": self.client_id, "scope": self._stringify(scope or [])}, timeout=timeout or self.timeout, **kwargs) - flow = json.loads(resp.content) + flow = json.loads(resp.text) flow["interval"] = int(flow.get("interval", 5)) # Some IdP returns string flow["expires_in"] = int(flow.get("expires_in", 1800)) flow["expires_at"] = time.time() + flow["expires_in"] # We invent this diff --git a/msal/wstrust_request.py b/msal/wstrust_request.py index 256ecc8c..14695ab5 100644 --- a/msal/wstrust_request.py +++ b/msal/wstrust_request.py @@ -31,13 +31,11 @@ from .mex import Mex from .wstrust_response import parse_response -from msal.http import DefaultHttpClient - logger = logging.getLogger(__name__) def send_request( - username, password, cloud_audience_urn, endpoint_address, soap_action, + username, password, cloud_audience_urn, endpoint_address, soap_action, http_client, **kwargs): if not endpoint_address: raise ValueError("WsTrust endpoint address can not be empty") @@ -50,23 +48,24 @@ def send_request( "Unsupported soap action: %s" % soap_action) data = _build_rst( username, password, cloud_audience_urn, endpoint_address, soap_action) - http_client = DefaultHttpClient() - resp = http_client.request("POST", endpoint_address, data=data, + resp = http_client.post(endpoint_address, data=data, headers={ 'Content-type':'application/soap+xml; charset=utf-8', 'SOAPAction': soap_action, }, **kwargs) if resp.status_code >= 400: - logger.debug("Unsuccessful WsTrust request receives: %s", resp.content) + logger.debug("Unsuccessful WsTrust request receives: %s", resp.text) # It turns out ADFS uses 5xx status code even with client-side incorrect password error # resp.raise_for_status() return parse_response(resp.content) + def escape_password(password): return (password.replace('&', '&').replace('"', '"') .replace("'", ''') # the only one not provided by cgi.escape(s, True) .replace('<', '<').replace('>', '>')) + def wsu_time_format(datetime_obj): # WsTrust (http://docs.oasis-open.org/ws-sx/ws-trust/v1.4/ws-trust.html) # does not seem to define timestamp format, but we see YYYY-mm-ddTHH:MM:SSZ @@ -75,6 +74,7 @@ def wsu_time_format(datetime_obj): # https://docs.python.org/2/library/datetime.html#datetime.datetime.isoformat return datetime_obj.strftime('%Y-%m-%dT%H:%M:%SZ') + def _build_rst(username, password, cloud_audience_urn, endpoint_address, soap_action): now = datetime.utcnow() return """ diff --git a/tests/test_application.py b/tests/test_application.py index d7af2bcd..dd556fd0 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -1,14 +1,13 @@ # Note: Since Aug 2019 we move all e2e tests into test_e2e.py, # so this test_application file contains only unit tests without dependency. -from msal.http import Response - try: from unittest.mock import * # Python 3 except: from mock import * # Need an external mock package from msal.application import * +from msal.oauth2cli.default_http_client import Response import msal from tests import unittest from tests.test_token_cache import TokenCacheTestCase @@ -77,21 +76,21 @@ def test_cache_empty_will_be_returned_as_None(self): def test_acquire_token_silent_will_suppress_error(self): error_response = '{"error": "invalid_grant", "suberror": "xyz"}' - def tester(method, url, **kwargs): + def tester(url, **kwargs): return Response(400, error_response) self.assertEqual(None, self.app.acquire_token_silent( self.scopes, self.account, post=tester)) def test_acquire_token_silent_with_error_will_return_error(self): error_response = '{"error": "invalid_grant", "error_description": "xyz"}' - def tester(method, url, **kwargs): + def tester(url, **kwargs): return Response(400, error_response) self.assertEqual(json.loads(error_response), self.app.acquire_token_silent_with_error( self.scopes, self.account, post=tester)) def test_atswe_will_map_some_suberror_to_classification_as_is(self): error_response = '{"error": "invalid_grant", "suberror": "basic_action"}' - def tester(method, url, **kwargs): + def tester(url, **kwargs): return Response(400, error_response) result = self.app.acquire_token_silent_with_error( self.scopes, self.account, post=tester) @@ -99,7 +98,7 @@ def tester(method, url, **kwargs): def test_atswe_will_map_some_suberror_to_classification_to_empty_string(self): error_response = '{"error": "invalid_grant", "suberror": "client_mismatch"}' - def tester(method, url, **kwargs): + def tester(url, **kwargs): return Response(400, error_response) result = self.app.acquire_token_silent_with_error( self.scopes, self.account, post=tester) @@ -109,7 +108,7 @@ class TestClientApplicationAcquireTokenSilentFociBehaviors(unittest.TestCase): def setUp(self): self.authority_url = "https://login.microsoftonline.com/common" - self.authority = msal.authority.Authority(self.authority_url) + self.authority = msal.authority.Authority(self.authority_url, DefaultHttpClient()) self.scopes = ["s1", "s2"] self.uid = "my_uid" self.utid = "my_utid" @@ -132,7 +131,7 @@ def test_unknown_orphan_app_will_attempt_frt_and_not_remove_it(self): "unknown_orphan", authority=self.authority_url, token_cache=self.cache) logger.debug("%s.cache = %s", self.id(), self.cache.serialize()) error_response = '{"error": "invalid_grant","error_description": "Was issued to another client"}' - def tester(method, url, data=None, **kwargs): + def tester(url, data=None, **kwargs): self.assertEqual(self.frt, data.get("refresh_token"), "Should attempt the FRT") return Response(400, error_response) app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( @@ -153,14 +152,14 @@ def test_known_orphan_app_will_skip_frt_and_only_use_its_own_rt(self): uid=self.uid, utid=self.utid, refresh_token=rt), }) logger.debug("%s.cache = %s", self.id(), self.cache.serialize()) - def tester(method, url, data=None, **kwargs): + def tester(url, data=None, **kwargs): self.assertEqual(rt, data.get("refresh_token"), "Should attempt the RT") return Response(200, '{}') app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( self.authority, self.scopes, self.account, post=tester) def test_unknown_family_app_will_attempt_frt_and_join_family(self): - def tester(method, url, data=None, **kwargs): + def tester(url, data=None, **kwargs): self.assertEqual( self.frt, data.get("refresh_token"), "Should attempt the FRT") return Response( diff --git a/tests/test_authority.py b/tests/test_authority.py index 80023fdf..9767dc34 100644 --- a/tests/test_authority.py +++ b/tests/test_authority.py @@ -1,8 +1,8 @@ import os from msal.authority import * -from msal.exceptions import MsalServiceError from tests import unittest +from msal.oauth2cli.default_http_client import DefaultHttpClient @unittest.skipIf(os.getenv("TRAVIS_TAG"), "Skip network io during tagged release") @@ -11,7 +11,7 @@ class TestAuthority(unittest.TestCase): def test_wellknown_host_and_tenant(self): # Assert all well known authority hosts are using their own "common" tenant for host in WELL_KNOWN_AUTHORITY_HOSTS: - a = Authority('https://{}/common'.format(host)) + a = Authority('https://{}/common'.format(host), DefaultHttpClient()) self.assertEqual( a.authorization_endpoint, 'https://%s/common/oauth2/v2.0/authorize' % host) @@ -24,14 +24,14 @@ def test_lessknown_host_will_return_a_set_of_v1_endpoints(self): # It is probably not a strict API contract. I simply mention it here. less_known = 'login.windows.net' # less.known.host/ v1_token_endpoint = 'https://{}/common/oauth2/token'.format(less_known) - a = Authority('https://{}/common'.format(less_known)) + a = Authority('https://{}/common'.format(less_known), DefaultHttpClient()) self.assertEqual(a.token_endpoint, v1_token_endpoint) self.assertNotIn('v2.0', a.token_endpoint) def test_unknown_host_wont_pass_instance_discovery(self): _assert = getattr(self, "assertRaisesRegex", self.assertRaisesRegexp) # Hack with _assert(ValueError, "invalid_instance"): - Authority('https://example.com/tenant_doesnt_matter_in_this_case') + Authority('https://example.com/tenant_doesnt_matter_in_this_case', DefaultHttpClient()) def test_invalid_host_skipping_validation_can_be_turned_off(self): try: diff --git a/tests/test_client.py b/tests/test_client.py index 87d2ecf6..6faf4969 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -11,6 +11,7 @@ from msal.oauth2cli import Client, JwtSigner from msal.oauth2cli.authcode import obtain_auth_code +from msal.oauth2cli.default_http_client import DefaultHttpClient from tests import unittest, Oauth2TestCase @@ -99,11 +100,13 @@ def setUpClass(cls): issuer=CONFIG["client_id"], ), client_assertion_type=Client.CLIENT_ASSERTION_TYPE_JWT, + http_client=DefaultHttpClient() ) else: cls.client = Client( CONFIG["openid_configuration"], CONFIG['client_id'], - client_secret=CONFIG.get('client_secret')) + client_secret=CONFIG.get('client_secret'), + http_client=DefaultHttpClient()) @unittest.skipIf( "token_endpoint" not in CONFIG.get("openid_configuration", {}), From 96988f86d81cd6f9d9d3ab50bb2caed520dd4b74 Mon Sep 17 00:00:00 2001 From: Abhidnya Patil Date: Thu, 19 Mar 2020 17:52:16 -0700 Subject: [PATCH 10/33] Changing one reference to new interface left in the previous one --- msal/mex.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/msal/mex.py b/msal/mex.py index 239d6aa1..25971a5d 100644 --- a/msal/mex.py +++ b/msal/mex.py @@ -41,9 +41,9 @@ def _xpath_of_root(route_to_leaf): def send_request(mex_endpoint, http_client, **kwargs): - resp = http_client.request("GET", mex_endpoint, headers={'Content-Type': 'application/soap+xml'}, + resp = http_client.get(mex_endpoint, headers={'Content-Type': 'application/soap+xml'}, **kwargs) - mex_document = resp.content + mex_document = resp.text return Mex(mex_document).get_wstrust_username_password_endpoint() From 1d05615c1655c2e630bdc68fc8c85fdad7577731 Mon Sep 17 00:00:00 2001 From: Abhidnya Patil Date: Thu, 19 Mar 2020 17:58:30 -0700 Subject: [PATCH 11/33] Modified one more missed change --- msal/wstrust_request.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/msal/wstrust_request.py b/msal/wstrust_request.py index 14695ab5..2b0fcd37 100644 --- a/msal/wstrust_request.py +++ b/msal/wstrust_request.py @@ -57,7 +57,7 @@ def send_request( logger.debug("Unsuccessful WsTrust request receives: %s", resp.text) # It turns out ADFS uses 5xx status code even with client-side incorrect password error # resp.raise_for_status() - return parse_response(resp.content) + return parse_response(resp.text) def escape_password(password): From ccafcf96a15f8d2e9d43dc7d27a3f12cca9125df Mon Sep 17 00:00:00 2001 From: Abhidnya Patil Date: Thu, 19 Mar 2020 18:31:59 -0700 Subject: [PATCH 12/33] Few more changes and refactor --- msal/application.py | 12 ++++++------ msal/oauth2cli/default_http_client.py | 22 ++++++++++------------ 2 files changed, 16 insertions(+), 18 deletions(-) diff --git a/msal/application.py b/msal/application.py index 29de032a..912345ec 100644 --- a/msal/application.py +++ b/msal/application.py @@ -89,8 +89,8 @@ class ClientApplication(object): def __init__( self, client_id, client_credential=None, authority=None, validate_authority=True, - token_cache=None, - verify=True, proxies=None, http_client=None, timeout=None, + token_cache=None, http_client=None, + verify=True, proxies=None, timeout=None, client_claims=None, app_name=None, app_version=None): """Create an instance of application. @@ -211,6 +211,7 @@ def _build_client(self, client_credential, authority): return Client( server_configuration, self.client_id, + http_client=self.http_client, default_headers=default_headers, default_body=default_body, client_assertion=client_assertion, @@ -218,7 +219,7 @@ def _build_client(self, client_credential, authority): on_obtaining_tokens=self.token_cache.add, on_removing_rt=self.token_cache.remove_rt, on_updating_rt=self.token_cache.update_rt, - http_client=self.http_client, timeout=self.timeout) + timeout=self.timeout) def get_authorization_request_url( self, @@ -269,7 +270,7 @@ def get_authorization_request_url( client = Client( {"authorization_endpoint": the_authority.authorization_endpoint}, - self.client_id) + self.client_id, self.http_client) return client.build_auth_request_uri( response_type=response_type, redirect_uri=redirect_uri, state=state, login_hint=login_hint, @@ -369,8 +370,7 @@ def _get_authority_aliases(self, instance): resp = self.http_client.get("https://login.microsoftonline.com/common/discovery/instance?api-version=1.1&authorization_endpoint=https://login.microsoftonline.com/common/oauth2/authorize", headers={'Accept': 'application/json'}, timeout=self.timeout) - if resp.status_code >= 500: - resp.raise_for_status() + resp.raise_for_status() resp = json.loads(resp.text) self.authority_groups = [ set(group['aliases']) for group in resp['metadata']] diff --git a/msal/oauth2cli/default_http_client.py b/msal/oauth2cli/default_http_client.py index 63350fd0..41d96fe4 100644 --- a/msal/oauth2cli/default_http_client.py +++ b/msal/oauth2cli/default_http_client.py @@ -7,7 +7,7 @@ class DefaultHttpClient(HttpClient): """ Default HTTP Client """ - def __init__(self, verify=True, proxies=None, timeout=None): + def __init__(self, verify=True, proxies=None): """ Constructor for the DefaultHttpClient @@ -19,28 +19,26 @@ def __init__(self, verify=True, proxies=None, timeout=None): self.session.verify = verify if proxies: self.session.proxies = proxies - if timeout: - self.session.timeout = timeout def post(self, url, params=None, data=None, headers=None, **kwargs): response = self.session.post(url=url, params=params, headers=headers, data=data, **kwargs) - return Response(response.status_code, response.text) + return Response(response) def get(self, url, params=None, headers=None, **kwargs): response = self.session.get(url=url, params=params, headers=headers, **kwargs) - return Response(response.status_code, response.text) + return Response(response) class Response(Response): - def __init__(self, status_code, text): - """HTTP Response object - :param int status_code: Status code from HTTP response - :param str text: HTTP response in string format + def __init__(self, response): + """Constructor for DefaultResponseObject + response: Raw http response from requests """ - self.status_code = status_code - self.text = text + self.status_code = response.status_code + self.text = response.text + self.response = response def raise_for_status(self): - self.text.raise_for_status() \ No newline at end of file + self.response.raise_for_status() From 2670e254511a58f6f6b104355a80a21672f4893c Mon Sep 17 00:00:00 2001 From: Abhidnya Patil Date: Thu, 19 Mar 2020 19:02:55 -0700 Subject: [PATCH 13/33] Adding raw response to response object --- msal/oauth2cli/default_http_client.py | 14 ++++++++------ tests/test_application.py | 14 +++++++------- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/msal/oauth2cli/default_http_client.py b/msal/oauth2cli/default_http_client.py index 41d96fe4..571e2604 100644 --- a/msal/oauth2cli/default_http_client.py +++ b/msal/oauth2cli/default_http_client.py @@ -23,21 +23,23 @@ def __init__(self, verify=True, proxies=None): def post(self, url, params=None, data=None, headers=None, **kwargs): response = self.session.post(url=url, params=params, headers=headers, data=data, **kwargs) - return Response(response) + return Response(response.status_code, response.text, response) def get(self, url, params=None, headers=None, **kwargs): response = self.session.get(url=url, params=params, headers=headers, **kwargs) - return Response(response) + return Response(response.status_code, response.text, response) class Response(Response): - def __init__(self, response): + def __init__(self, status_code, text, response): """Constructor for DefaultResponseObject - response: Raw http response from requests + status, # type: int + text, # type: str response in string format + response, # type: Raw response from requests """ - self.status_code = response.status_code - self.text = response.text + self.status_code = status_code + self.text = text self.response = response def raise_for_status(self): diff --git a/tests/test_application.py b/tests/test_application.py index dd556fd0..5307869a 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -77,21 +77,21 @@ def test_cache_empty_will_be_returned_as_None(self): def test_acquire_token_silent_will_suppress_error(self): error_response = '{"error": "invalid_grant", "suberror": "xyz"}' def tester(url, **kwargs): - return Response(400, error_response) + return Response(400, error_response, '') self.assertEqual(None, self.app.acquire_token_silent( self.scopes, self.account, post=tester)) def test_acquire_token_silent_with_error_will_return_error(self): error_response = '{"error": "invalid_grant", "error_description": "xyz"}' def tester(url, **kwargs): - return Response(400, error_response) + return Response(400, error_response, '') self.assertEqual(json.loads(error_response), self.app.acquire_token_silent_with_error( self.scopes, self.account, post=tester)) def test_atswe_will_map_some_suberror_to_classification_as_is(self): error_response = '{"error": "invalid_grant", "suberror": "basic_action"}' def tester(url, **kwargs): - return Response(400, error_response) + return Response(400, error_response, '') result = self.app.acquire_token_silent_with_error( self.scopes, self.account, post=tester) self.assertEqual("basic_action", result.get("classification")) @@ -99,7 +99,7 @@ def tester(url, **kwargs): def test_atswe_will_map_some_suberror_to_classification_to_empty_string(self): error_response = '{"error": "invalid_grant", "suberror": "client_mismatch"}' def tester(url, **kwargs): - return Response(400, error_response) + return Response(400, error_response, '') result = self.app.acquire_token_silent_with_error( self.scopes, self.account, post=tester) self.assertEqual("", result.get("classification")) @@ -133,7 +133,7 @@ def test_unknown_orphan_app_will_attempt_frt_and_not_remove_it(self): error_response = '{"error": "invalid_grant","error_description": "Was issued to another client"}' def tester(url, data=None, **kwargs): self.assertEqual(self.frt, data.get("refresh_token"), "Should attempt the FRT") - return Response(400, error_response) + return Response(400, error_response, '') app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( self.authority, self.scopes, self.account, post=tester) self.assertNotEqual([], app.token_cache.find( @@ -154,7 +154,7 @@ def test_known_orphan_app_will_skip_frt_and_only_use_its_own_rt(self): logger.debug("%s.cache = %s", self.id(), self.cache.serialize()) def tester(url, data=None, **kwargs): self.assertEqual(rt, data.get("refresh_token"), "Should attempt the RT") - return Response(200, '{}') + return Response(200, '{}', '') app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( self.authority, self.scopes, self.account, post=tester) @@ -164,7 +164,7 @@ def tester(url, data=None, **kwargs): self.frt, data.get("refresh_token"), "Should attempt the FRT") return Response( 200, json.dumps(TokenCacheTestCase.build_response( - uid=self.uid, utid=self.utid, foci="1", access_token="at"))) + uid=self.uid, utid=self.utid, foci="1", access_token="at")), '') app = ClientApplication( "unknown_family_app", authority=self.authority_url, token_cache=self.cache) at = app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( From cb70a37cb99f4ffbe3b0958cda0a2a9f689e013c Mon Sep 17 00:00:00 2001 From: Abhidnya Patil Date: Mon, 23 Mar 2020 13:49:32 -0700 Subject: [PATCH 14/33] Cleaning None values --- msal/oauth2cli/oauth2.py | 6 ++++- sample/httpx.py | 49 ++++++++++++++++++++++++++++++++++++++++ sample/httpx_client.py | 48 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 102 insertions(+), 1 deletion(-) create mode 100644 sample/httpx.py create mode 100644 sample/httpx_client.py diff --git a/msal/oauth2cli/oauth2.py b/msal/oauth2cli/oauth2.py index 0fe46971..3f0ad8d4 100644 --- a/msal/oauth2cli/oauth2.py +++ b/msal/oauth2cli/oauth2.py @@ -123,7 +123,11 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749 _data.update(self.default_body) # It may contain authen parameters _data.update(data or {}) # So the content in data param prevails - # We don't have to clean up None values here, because requests lib will. + filtered = {k: v for k, v in _data.items() if v is not None} + _data.clear() + _data.update(filtered) + # We will have to clean up None values here, + # because we can have some libraries not supporting cleaning of None values. if _data.get('scope'): _data['scope'] = self._stringify(_data['scope']) diff --git a/sample/httpx.py b/sample/httpx.py new file mode 100644 index 00000000..83e95144 --- /dev/null +++ b/sample/httpx.py @@ -0,0 +1,49 @@ +import httpx + +from .oauth2cli import HttpClient, Response + + +class DefaultHttpClient(HttpClient): + """ + Default HTTP Client + """ + def __init__(self): + """ + Constructor for the DefaultHttpClient + + verify=True, # type: Union[str, True, False, None] + proxies=None, # type: Optional[dict] + """ + self.session = httpx.Client() + + def post(self, url, params=None, data=None, headers=None, **kwargs): + if params is None: + params = {} + params.update(**kwargs) + + response = self.session.post(url=url, params=params, headers=headers, data=data) + return Response(response.status_code, response.text, response) + + def get(self, url, params=None, headers=None, **kwargs): + if params is None: + params = {} + params.update(kwargs) + response = self.session.get(url=url, params=params, headers=headers) + return Response(response.status_code, response.text, response) + + +class Response(Response): + + def __init__(self, status_code, text, response): + """Constructor for DefaultResponseObject + status, # type: int + text, # type: str response in string format + response, # type: Raw response from requests + """ + self.status_code = status_code + self.text = text + self.response = response + + + def raise_for_status(self): + self.response.raise_for_status() diff --git a/sample/httpx_client.py b/sample/httpx_client.py new file mode 100644 index 00000000..38b33211 --- /dev/null +++ b/sample/httpx_client.py @@ -0,0 +1,48 @@ +import httpx +from httpx import * +from msal.oauth2cli import HttpClient, Response + + +class HttpxClient(HttpClient): + """ + Default HTTP Client + """ + def __init__(self): + """ + Constructor for the DefaultHttpClient + + verify=True, # type: Union[str, True, False, None] + proxies=None, # type: Optional[dict] + """ + + def post(self, url, params=None, data=None, headers=None, **kwargs): + if params is None: + params = {} + params.update(**kwargs) + + response = httpx.post(url=url, params=params, headers=headers, data=data) + return Response(response.status_code, response.text, response) + + def get(self, url, params=None, headers=None, **kwargs): + if params is None: + params = {} + params.update(kwargs) + response = httpx.get(url=url, params=params, headers=headers) + return Response(response.status_code, response.text, response) + + +class Response(Response): + + def __init__(self, status_code, text, response): + """Constructor for DefaultResponseObject + status, # type: int + text, # type: str response in string format + response, # type: Raw response from requests + """ + self.status_code = status_code + self.text = text + self.response = response + + + def raise_for_status(self): + self.response.raise_for_status() From 53520fd12fa4eb568c5bfa02f44fbefb484f998d Mon Sep 17 00:00:00 2001 From: Abhidnya Patil Date: Thu, 26 Mar 2020 00:21:29 -0700 Subject: [PATCH 15/33] PR review iteration --- msal/authority.py | 3 -- msal/oauth2cli/__init__.py | 2 +- msal/oauth2cli/default_http_client.py | 22 ++---------- msal/oauth2cli/oauth2.py | 33 +++++++++++------- sample/httpx.py | 49 --------------------------- sample/httpx_client.py | 48 -------------------------- tests/test_application.py | 18 +++++----- 7 files changed, 32 insertions(+), 143 deletions(-) delete mode 100644 sample/httpx.py delete mode 100644 sample/httpx_client.py diff --git a/msal/authority.py b/msal/authority.py index 6a844c0f..85d75cc7 100644 --- a/msal/authority.py +++ b/msal/authority.py @@ -1,7 +1,4 @@ import json - -from .oauth2cli.default_http_client import DefaultHttpClient - try: from urllib.parse import urlparse except ImportError: # Fall back to Python 2 diff --git a/msal/oauth2cli/__init__.py b/msal/oauth2cli/__init__.py index 1c8bdd1d..60a0aea7 100644 --- a/msal/oauth2cli/__init__.py +++ b/msal/oauth2cli/__init__.py @@ -3,5 +3,5 @@ from .oidc import Client from .assertion import JwtAssertionCreator from .assertion import JwtSigner # Obsolete. For backward compatibility. -from .http import HttpClient, Response +from .http import HttpClient from .default_http_client import DefaultHttpClient diff --git a/msal/oauth2cli/default_http_client.py b/msal/oauth2cli/default_http_client.py index 571e2604..a873b2fb 100644 --- a/msal/oauth2cli/default_http_client.py +++ b/msal/oauth2cli/default_http_client.py @@ -21,26 +21,8 @@ def __init__(self, verify=True, proxies=None): self.session.proxies = proxies def post(self, url, params=None, data=None, headers=None, **kwargs): - - response = self.session.post(url=url, params=params, headers=headers, data=data, **kwargs) - return Response(response.status_code, response.text, response) + return self.session.post(url=url, params=params, headers=headers, data=data, **kwargs) def get(self, url, params=None, headers=None, **kwargs): - response = self.session.get(url=url, params=params, headers=headers, **kwargs) - return Response(response.status_code, response.text, response) - - -class Response(Response): - - def __init__(self, status_code, text, response): - """Constructor for DefaultResponseObject - status, # type: int - text, # type: str response in string format - response, # type: Raw response from requests - """ - self.status_code = status_code - self.text = text - self.response = response + return self.session.get(url=url, params=params, headers=headers, **kwargs) - def raise_for_status(self): - self.response.raise_for_status() diff --git a/msal/oauth2cli/oauth2.py b/msal/oauth2cli/oauth2.py index 3f0ad8d4..73e452fc 100644 --- a/msal/oauth2cli/oauth2.py +++ b/msal/oauth2cli/oauth2.py @@ -1,7 +1,7 @@ """This OAuth2 client implementation aims to be spec-compliant, and generic.""" # OAuth2 spec https://tools.ietf.org/html/rfc6749 -import json +import json try: from urllib.parse import urlencode, parse_qs except ImportError: @@ -12,7 +12,6 @@ import time import base64 import sys -from .default_http_client import DefaultHttpClient string_types = (str,) if sys.version_info[0] >= 3 else (basestring, ) @@ -132,6 +131,10 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749 if _data.get('scope'): _data['scope'] = self._stringify(_data['scope']) + _headers = {'Accept': 'application/json'} + _headers.update(self.default_headers) + _headers.update(headers or {}) + # Quoted from https://tools.ietf.org/html/rfc6749#section-2.3.1 # Clients in possession of a client password MAY use the HTTP Basic # authentication. @@ -139,22 +142,27 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749 # the authorization server MAY support including the # client credentials in the request-body using the following # parameters: client_id, client_secret. - auth = None if self.client_secret and self.client_id: - auth = (self.client_id, self.client_secret) # for HTTP Basic Auth + _headers["Authorization"] = "Basic " + base64.b64encode( + "{}:{}".format(self.client_id, self.client_secret) + .encode("ascii")).decode("ascii") if "token_endpoint" not in self.configuration: raise ValueError("token_endpoint not found in configuration") - _headers = {'Accept': 'application/json'} - _headers.update(self.default_headers) - _headers.update(headers or {}) resp = (post or self.http_client.post)(self.configuration["token_endpoint"], - headers=_headers, params=params, data=_data, auth=auth, - timeout=timeout or self.timeout, + headers=_headers, params=params, data=_data, timeout=timeout or self.timeout, **kwargs) if resp.status_code >= 500: resp.raise_for_status() - return json.loads(resp.text) + try: + # The spec (https://tools.ietf.org/html/rfc6749#section-5.2) says + # even an error response will be a valid json structure, + # so we simply return it here, without needing to invent an exception. + return json.loads(resp.text) + except ValueError: + self.logger.exception( + "Token response is not in json format: %s", resp.text) + raise def obtain_token_by_refresh_token(self, refresh_token, scope=None, **kwargs): # type: (str, Union[str, list, set, tuple]) -> dict @@ -210,9 +218,8 @@ def initiate_device_flow(self, scope=None, timeout=None, **kwargs): if not self.configuration.get(DAE): raise ValueError("You need to provide device authorization endpoint") resp = self.http_client.post(self.configuration[DAE], - data={"client_id": self.client_id, "scope": self._stringify(scope or [])}, - timeout=timeout or self.timeout, - **kwargs) + data={"client_id": self.client_id, "scope": self._stringify(scope or [])}, + timeout=timeout or self.timeout, **kwargs) flow = json.loads(resp.text) flow["interval"] = int(flow.get("interval", 5)) # Some IdP returns string flow["expires_in"] = int(flow.get("expires_in", 1800)) diff --git a/sample/httpx.py b/sample/httpx.py deleted file mode 100644 index 83e95144..00000000 --- a/sample/httpx.py +++ /dev/null @@ -1,49 +0,0 @@ -import httpx - -from .oauth2cli import HttpClient, Response - - -class DefaultHttpClient(HttpClient): - """ - Default HTTP Client - """ - def __init__(self): - """ - Constructor for the DefaultHttpClient - - verify=True, # type: Union[str, True, False, None] - proxies=None, # type: Optional[dict] - """ - self.session = httpx.Client() - - def post(self, url, params=None, data=None, headers=None, **kwargs): - if params is None: - params = {} - params.update(**kwargs) - - response = self.session.post(url=url, params=params, headers=headers, data=data) - return Response(response.status_code, response.text, response) - - def get(self, url, params=None, headers=None, **kwargs): - if params is None: - params = {} - params.update(kwargs) - response = self.session.get(url=url, params=params, headers=headers) - return Response(response.status_code, response.text, response) - - -class Response(Response): - - def __init__(self, status_code, text, response): - """Constructor for DefaultResponseObject - status, # type: int - text, # type: str response in string format - response, # type: Raw response from requests - """ - self.status_code = status_code - self.text = text - self.response = response - - - def raise_for_status(self): - self.response.raise_for_status() diff --git a/sample/httpx_client.py b/sample/httpx_client.py deleted file mode 100644 index 38b33211..00000000 --- a/sample/httpx_client.py +++ /dev/null @@ -1,48 +0,0 @@ -import httpx -from httpx import * -from msal.oauth2cli import HttpClient, Response - - -class HttpxClient(HttpClient): - """ - Default HTTP Client - """ - def __init__(self): - """ - Constructor for the DefaultHttpClient - - verify=True, # type: Union[str, True, False, None] - proxies=None, # type: Optional[dict] - """ - - def post(self, url, params=None, data=None, headers=None, **kwargs): - if params is None: - params = {} - params.update(**kwargs) - - response = httpx.post(url=url, params=params, headers=headers, data=data) - return Response(response.status_code, response.text, response) - - def get(self, url, params=None, headers=None, **kwargs): - if params is None: - params = {} - params.update(kwargs) - response = httpx.get(url=url, params=params, headers=headers) - return Response(response.status_code, response.text, response) - - -class Response(Response): - - def __init__(self, status_code, text, response): - """Constructor for DefaultResponseObject - status, # type: int - text, # type: str response in string format - response, # type: Raw response from requests - """ - self.status_code = status_code - self.text = text - self.response = response - - - def raise_for_status(self): - self.response.raise_for_status() diff --git a/tests/test_application.py b/tests/test_application.py index 5307869a..66de3ce8 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -77,21 +77,21 @@ def test_cache_empty_will_be_returned_as_None(self): def test_acquire_token_silent_will_suppress_error(self): error_response = '{"error": "invalid_grant", "suberror": "xyz"}' def tester(url, **kwargs): - return Response(400, error_response, '') + return Mock(status_code=400, text=error_response) self.assertEqual(None, self.app.acquire_token_silent( self.scopes, self.account, post=tester)) def test_acquire_token_silent_with_error_will_return_error(self): error_response = '{"error": "invalid_grant", "error_description": "xyz"}' def tester(url, **kwargs): - return Response(400, error_response, '') + return Mock(status_code=400, text=error_response) self.assertEqual(json.loads(error_response), self.app.acquire_token_silent_with_error( self.scopes, self.account, post=tester)) def test_atswe_will_map_some_suberror_to_classification_as_is(self): error_response = '{"error": "invalid_grant", "suberror": "basic_action"}' def tester(url, **kwargs): - return Response(400, error_response, '') + return Mock(status_code=400, text=error_response) result = self.app.acquire_token_silent_with_error( self.scopes, self.account, post=tester) self.assertEqual("basic_action", result.get("classification")) @@ -99,7 +99,7 @@ def tester(url, **kwargs): def test_atswe_will_map_some_suberror_to_classification_to_empty_string(self): error_response = '{"error": "invalid_grant", "suberror": "client_mismatch"}' def tester(url, **kwargs): - return Response(400, error_response, '') + return Mock(status_code=400, text=error_response) result = self.app.acquire_token_silent_with_error( self.scopes, self.account, post=tester) self.assertEqual("", result.get("classification")) @@ -133,7 +133,7 @@ def test_unknown_orphan_app_will_attempt_frt_and_not_remove_it(self): error_response = '{"error": "invalid_grant","error_description": "Was issued to another client"}' def tester(url, data=None, **kwargs): self.assertEqual(self.frt, data.get("refresh_token"), "Should attempt the FRT") - return Response(400, error_response, '') + return Mock(status_code=400, text=error_response) app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( self.authority, self.scopes, self.account, post=tester) self.assertNotEqual([], app.token_cache.find( @@ -154,7 +154,7 @@ def test_known_orphan_app_will_skip_frt_and_only_use_its_own_rt(self): logger.debug("%s.cache = %s", self.id(), self.cache.serialize()) def tester(url, data=None, **kwargs): self.assertEqual(rt, data.get("refresh_token"), "Should attempt the RT") - return Response(200, '{}', '') + return Mock(status_code=200, text='{}') app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( self.authority, self.scopes, self.account, post=tester) @@ -162,9 +162,9 @@ def test_unknown_family_app_will_attempt_frt_and_join_family(self): def tester(url, data=None, **kwargs): self.assertEqual( self.frt, data.get("refresh_token"), "Should attempt the FRT") - return Response( - 200, json.dumps(TokenCacheTestCase.build_response( - uid=self.uid, utid=self.utid, foci="1", access_token="at")), '') + return Mock( + status_code=200, text=json.dumps(TokenCacheTestCase.build_response( + uid=self.uid, utid=self.utid, foci="1", access_token="at"))) app = ClientApplication( "unknown_family_app", authority=self.authority_url, token_cache=self.cache) at = app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( From 293b081810f9dc52516e2e06e446aa2e670a7129 Mon Sep 17 00:00:00 2001 From: Abhidnya Patil Date: Mon, 30 Mar 2020 12:59:31 -0700 Subject: [PATCH 16/33] Removing default http client --- msal/application.py | 10 ++++++++-- msal/oauth2cli/__init__.py | 1 - msal/oauth2cli/default_http_client.py | 28 --------------------------- tests/test_application.py | 10 ++++------ tests/test_authority.py | 8 ++++---- tests/test_client.py | 5 ++--- 6 files changed, 18 insertions(+), 44 deletions(-) delete mode 100644 msal/oauth2cli/default_http_client.py diff --git a/msal/application.py b/msal/application.py index 912345ec..a21bfae0 100644 --- a/msal/application.py +++ b/msal/application.py @@ -1,4 +1,5 @@ import json +import requests import time try: # Python 2 from urlparse import urljoin @@ -9,7 +10,7 @@ import warnings import uuid -from .oauth2cli import Client, JwtAssertionCreator, DefaultHttpClient +from .oauth2cli import Client, JwtAssertionCreator from .authority import Authority from .mex import send_request as mex_send_request from .wstrust_request import send_request as wst_send_request @@ -162,7 +163,12 @@ def __init__( self.client_id = client_id self.client_credential = client_credential self.client_claims = client_claims - self.http_client = http_client if http_client else DefaultHttpClient(verify=verify, proxies=proxies) + if http_client: + self.http_client = http_client + else: + self.http_client = requests.Session() + self.http_client.verify = verify + self.http_client.proxies = proxies self.timeout = timeout self.app_name = app_name self.app_version = app_version diff --git a/msal/oauth2cli/__init__.py b/msal/oauth2cli/__init__.py index 60a0aea7..d923eb68 100644 --- a/msal/oauth2cli/__init__.py +++ b/msal/oauth2cli/__init__.py @@ -4,4 +4,3 @@ from .assertion import JwtAssertionCreator from .assertion import JwtSigner # Obsolete. For backward compatibility. from .http import HttpClient -from .default_http_client import DefaultHttpClient diff --git a/msal/oauth2cli/default_http_client.py b/msal/oauth2cli/default_http_client.py deleted file mode 100644 index a873b2fb..00000000 --- a/msal/oauth2cli/default_http_client.py +++ /dev/null @@ -1,28 +0,0 @@ -import requests - -from .http import HttpClient, Response - - -class DefaultHttpClient(HttpClient): - """ - Default HTTP Client - """ - def __init__(self, verify=True, proxies=None): - """ - Constructor for the DefaultHttpClient - - verify=True, # type: Union[str, True, False, None] - proxies=None, # type: Optional[dict] - """ - self.session = requests.Session() - if verify: - self.session.verify = verify - if proxies: - self.session.proxies = proxies - - def post(self, url, params=None, data=None, headers=None, **kwargs): - return self.session.post(url=url, params=params, headers=headers, data=data, **kwargs) - - def get(self, url, params=None, headers=None, **kwargs): - return self.session.get(url=url, params=params, headers=headers, **kwargs) - diff --git a/tests/test_application.py b/tests/test_application.py index 66de3ce8..0e105c24 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -7,11 +7,10 @@ from mock import * # Need an external mock package from msal.application import * -from msal.oauth2cli.default_http_client import Response import msal from tests import unittest from tests.test_token_cache import TokenCacheTestCase - +import requests logger = logging.getLogger(__name__) logging.basicConfig(level=logging.DEBUG) @@ -47,9 +46,8 @@ def test_extract_multiple_tag_enclosed_certs(self): class TestClientApplicationAcquireTokenSilentErrorBehaviors(unittest.TestCase): def setUp(self): - self.http_client = DefaultHttpClient() self.authority_url = "https://login.microsoftonline.com/common" - self.authority = msal.authority.Authority(self.authority_url, http_client= self.http_client) + self.authority = msal.authority.Authority(self.authority_url, http_client= requests.Session()) self.scopes = ["s1", "s2"] self.uid = "my_uid" self.utid = "my_utid" @@ -66,7 +64,7 @@ def setUp(self): uid=self.uid, utid=self.utid, refresh_token=self.rt), }) # The add(...) helper populates correct home_account_id for future searching self.app = ClientApplication( - self.client_id, authority=self.authority_url, token_cache=self.cache, http_client=self.http_client) + self.client_id, authority=self.authority_url, token_cache=self.cache) def test_cache_empty_will_be_returned_as_None(self): self.assertEqual( @@ -108,7 +106,7 @@ class TestClientApplicationAcquireTokenSilentFociBehaviors(unittest.TestCase): def setUp(self): self.authority_url = "https://login.microsoftonline.com/common" - self.authority = msal.authority.Authority(self.authority_url, DefaultHttpClient()) + self.authority = msal.authority.Authority(self.authority_url, requests.Session()) self.scopes = ["s1", "s2"] self.uid = "my_uid" self.utid = "my_utid" diff --git a/tests/test_authority.py b/tests/test_authority.py index 9767dc34..79e7bc0b 100644 --- a/tests/test_authority.py +++ b/tests/test_authority.py @@ -2,7 +2,7 @@ from msal.authority import * from tests import unittest -from msal.oauth2cli.default_http_client import DefaultHttpClient +import requests @unittest.skipIf(os.getenv("TRAVIS_TAG"), "Skip network io during tagged release") @@ -11,7 +11,7 @@ class TestAuthority(unittest.TestCase): def test_wellknown_host_and_tenant(self): # Assert all well known authority hosts are using their own "common" tenant for host in WELL_KNOWN_AUTHORITY_HOSTS: - a = Authority('https://{}/common'.format(host), DefaultHttpClient()) + a = Authority('https://{}/common'.format(host), requests.Session()) self.assertEqual( a.authorization_endpoint, 'https://%s/common/oauth2/v2.0/authorize' % host) @@ -24,14 +24,14 @@ def test_lessknown_host_will_return_a_set_of_v1_endpoints(self): # It is probably not a strict API contract. I simply mention it here. less_known = 'login.windows.net' # less.known.host/ v1_token_endpoint = 'https://{}/common/oauth2/token'.format(less_known) - a = Authority('https://{}/common'.format(less_known), DefaultHttpClient()) + a = Authority('https://{}/common'.format(less_known), requests.Session()) self.assertEqual(a.token_endpoint, v1_token_endpoint) self.assertNotIn('v2.0', a.token_endpoint) def test_unknown_host_wont_pass_instance_discovery(self): _assert = getattr(self, "assertRaisesRegex", self.assertRaisesRegexp) # Hack with _assert(ValueError, "invalid_instance"): - Authority('https://example.com/tenant_doesnt_matter_in_this_case', DefaultHttpClient()) + Authority('https://example.com/tenant_doesnt_matter_in_this_case', requests.Session()) def test_invalid_host_skipping_validation_can_be_turned_off(self): try: diff --git a/tests/test_client.py b/tests/test_client.py index 6faf4969..ac7833b6 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -11,7 +11,6 @@ from msal.oauth2cli import Client, JwtSigner from msal.oauth2cli.authcode import obtain_auth_code -from msal.oauth2cli.default_http_client import DefaultHttpClient from tests import unittest, Oauth2TestCase @@ -100,13 +99,13 @@ def setUpClass(cls): issuer=CONFIG["client_id"], ), client_assertion_type=Client.CLIENT_ASSERTION_TYPE_JWT, - http_client=DefaultHttpClient() + http_client=requests.Session() ) else: cls.client = Client( CONFIG["openid_configuration"], CONFIG['client_id'], client_secret=CONFIG.get('client_secret'), - http_client=DefaultHttpClient()) + http_client=requests.Session()) @unittest.skipIf( "token_endpoint" not in CONFIG.get("openid_configuration", {}), From fcac05eda5355ba7cd6a21ab197172c3ba4a0bb4 Mon Sep 17 00:00:00 2001 From: Abhidnya Patil Date: Mon, 30 Mar 2020 13:04:06 -0700 Subject: [PATCH 17/33] cleaning up --- msal/oauth2cli/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/msal/oauth2cli/__init__.py b/msal/oauth2cli/__init__.py index d923eb68..63db107d 100644 --- a/msal/oauth2cli/__init__.py +++ b/msal/oauth2cli/__init__.py @@ -3,4 +3,3 @@ from .oidc import Client from .assertion import JwtAssertionCreator from .assertion import JwtSigner # Obsolete. For backward compatibility. -from .http import HttpClient From 229ad26a2ee4634f40bedc427199b233bc860856 Mon Sep 17 00:00:00 2001 From: Abhidnya Patil Date: Mon, 30 Mar 2020 13:08:13 -0700 Subject: [PATCH 18/33] Adding deleted single empty line back --- msal/oauth2cli/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/msal/oauth2cli/__init__.py b/msal/oauth2cli/__init__.py index 63db107d..b8941361 100644 --- a/msal/oauth2cli/__init__.py +++ b/msal/oauth2cli/__init__.py @@ -3,3 +3,4 @@ from .oidc import Client from .assertion import JwtAssertionCreator from .assertion import JwtSigner # Obsolete. For backward compatibility. + From 340210ff588d97b6bbed912a7bbf64c6d049f133 Mon Sep 17 00:00:00 2001 From: Abhidnya Patil Date: Mon, 30 Mar 2020 13:11:48 -0700 Subject: [PATCH 19/33] Updating filtering of non values from dictionary --- msal/oauth2cli/oauth2.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/msal/oauth2cli/oauth2.py b/msal/oauth2cli/oauth2.py index 73e452fc..f7545f15 100644 --- a/msal/oauth2cli/oauth2.py +++ b/msal/oauth2cli/oauth2.py @@ -122,9 +122,7 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749 _data.update(self.default_body) # It may contain authen parameters _data.update(data or {}) # So the content in data param prevails - filtered = {k: v for k, v in _data.items() if v is not None} - _data.clear() - _data.update(filtered) + _data = {k: v for k, v in _data.items() if v} # We will have to clean up None values here, # because we can have some libraries not supporting cleaning of None values. @@ -153,7 +151,7 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749 headers=_headers, params=params, data=_data, timeout=timeout or self.timeout, **kwargs) if resp.status_code >= 500: - resp.raise_for_status() + resp.raise_for_status() # TODO: Will probably retry here try: # The spec (https://tools.ietf.org/html/rfc6749#section-5.2) says # even an error response will be a valid json structure, From d993950cf0b7b66ff530902fdfc60e30dbdd97c9 Mon Sep 17 00:00:00 2001 From: Abhidnya Patil Date: Mon, 13 Apr 2020 16:55:59 -0700 Subject: [PATCH 20/33] Capturing editorial changes --- msal/application.py | 14 +++++++++----- msal/mex.py | 6 +++--- msal/wstrust_request.py | 3 +-- 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/msal/application.py b/msal/application.py index 45d27df3..b6bbe3ef 100644 --- a/msal/application.py +++ b/msal/application.py @@ -1,5 +1,4 @@ import json -import requests import time try: # Python 2 from urlparse import urljoin @@ -10,6 +9,8 @@ import warnings import uuid +import requests + from .oauth2cli import Client, JwtAssertionCreator from .authority import Authority from .mex import send_request as mex_send_request @@ -90,7 +91,8 @@ class ClientApplication(object): def __init__( self, client_id, client_credential=None, authority=None, validate_authority=True, - token_cache=None, http_client=None, + token_cache=None, + http_client=None, verify=True, proxies=None, timeout=None, client_claims=None, app_name=None, app_version=None): """Create an instance of application. @@ -140,7 +142,7 @@ def __init__( By default, an in-memory cache will be created and used. :param http_client: (optional) Your implementation of abstract class HttpClient - Defaults to default http client implementation which uses requests + Defaults to a requests session instance :param verify: (optional) It will be passed to the `verify parameter in the underlying requests library @@ -385,7 +387,8 @@ def _find_msal_accounts(self, environment): def _get_authority_aliases(self, instance): if not self.authority_groups: - resp = self.http_client.get("https://login.microsoftonline.com/common/discovery/instance?api-version=1.1&authorization_endpoint=https://login.microsoftonline.com/common/oauth2/authorize", + resp = self.http_client.get( + "https://login.microsoftonline.com/common/discovery/instance?api-version=1.1&authorization_endpoint=https://login.microsoftonline.com/common/oauth2/authorize", headers={'Accept': 'application/json'}, timeout=self.timeout) resp.raise_for_status() @@ -769,7 +772,8 @@ def _acquire_token_by_username_password_federated( wstrust_endpoint = {} if user_realm_result.get("federation_metadata_url"): wstrust_endpoint = mex_send_request( - user_realm_result["federation_metadata_url"], http_client=self.http_client) + user_realm_result["federation_metadata_url"], + http_client=self.http_client) if wstrust_endpoint is None: raise ValueError("Unable to find wstrust endpoint from MEX. " "This typically happens when attempting MSA accounts. " diff --git a/msal/mex.py b/msal/mex.py index 25971a5d..684d50ed 100644 --- a/msal/mex.py +++ b/msal/mex.py @@ -41,9 +41,9 @@ def _xpath_of_root(route_to_leaf): def send_request(mex_endpoint, http_client, **kwargs): - resp = http_client.get(mex_endpoint, headers={'Content-Type': 'application/soap+xml'}, - **kwargs) - mex_document = resp.text + mex_document = http_client.get( + mex_endpoint, headers={'Content-Type': 'application/soap+xml'}, + **kwargs).text return Mex(mex_document).get_wstrust_username_password_endpoint() diff --git a/msal/wstrust_request.py b/msal/wstrust_request.py index 2b0fcd37..b2898f76 100644 --- a/msal/wstrust_request.py +++ b/msal/wstrust_request.py @@ -48,8 +48,7 @@ def send_request( "Unsupported soap action: %s" % soap_action) data = _build_rst( username, password, cloud_audience_urn, endpoint_address, soap_action) - resp = http_client.post(endpoint_address, data=data, - headers={ + resp = http_client.post(endpoint_address, data=data, headers={ 'Content-type':'application/soap+xml; charset=utf-8', 'SOAPAction': soap_action, }, **kwargs) From e5ebd2867a8e2a31703f218cd2bebb63ba4c986b Mon Sep 17 00:00:00 2001 From: Abhidnya Patil Date: Wed, 15 Apr 2020 22:29:08 -0700 Subject: [PATCH 21/33] Review changes 1 --- msal/application.py | 24 ++++++++------- msal/authority.py | 63 +++++++++++++++++++++------------------- msal/oauth2cli/oauth2.py | 14 ++++----- tests/http_client.py | 19 ++++++++++++ tests/test_authority.py | 26 +++++++++++++++-- tests/test_e2e.py | 29 +++++++++++------- 6 files changed, 112 insertions(+), 63 deletions(-) create mode 100644 tests/http_client.py diff --git a/msal/application.py b/msal/application.py index b6bbe3ef..6416007e 100644 --- a/msal/application.py +++ b/msal/application.py @@ -1,3 +1,4 @@ +import functools import json import time try: # Python 2 @@ -146,14 +147,17 @@ def __init__( :param verify: (optional) It will be passed to the `verify parameter in the underlying requests library + This does not apply if you have chosen to pass your own Http client `_ :param proxies: (optional) It will be passed to the `proxies parameter in the underlying requests library + This does not apply if you have chosen to pass your own Http client `_ :param timeout: (optional) It will be passed to the `timeout parameter in the underlying requests library + This does not apply if you have chosen to pass your own Http client `_ :param app_name: (optional) You can provide your application name for Microsoft telemetry purposes. @@ -171,12 +175,15 @@ def __init__( self.http_client = requests.Session() self.http_client.verify = verify self.http_client.proxies = proxies - self.timeout = timeout + # Requests, does not support session - wide timeout + # But you can patch that (https://github.com/psf/requests/issues/3341): + self.http_client.request = functools.partial( + self.http_client.request, timeout=timeout) self.app_name = app_name self.app_version = app_version self.authority = Authority( authority or "https://login.microsoftonline.com/common/", - http_client=self.http_client, validate_authority=validate_authority, timeout=timeout) + http_client=self.http_client, validate_authority=validate_authority) # Here the self.authority is not the same type as authority in input self.token_cache = token_cache or TokenCache() self.client = self._build_client(client_credential, self.authority) @@ -226,8 +233,7 @@ def _build_client(self, client_credential, authority): client_assertion_type=client_assertion_type, on_obtaining_tokens=self.token_cache.add, on_removing_rt=self.token_cache.remove_rt, - on_updating_rt=self.token_cache.update_rt, - timeout=self.timeout) + on_updating_rt=self.token_cache.update_rt) def get_authorization_request_url( self, @@ -277,8 +283,7 @@ def get_authorization_request_url( # The previous implementation is, it will use self.authority by default. # Multi-tenant app can use new authority on demand the_authority = Authority( - authority, http_client=self.http_client, timeout=self.timeout - ) if authority else self.authority + authority, http_client=self.http_client) if authority else self.authority client = Client( {"authorization_endpoint": the_authority.authorization_endpoint}, @@ -389,8 +394,7 @@ def _get_authority_aliases(self, instance): if not self.authority_groups: resp = self.http_client.get( "https://login.microsoftonline.com/common/discovery/instance?api-version=1.1&authorization_endpoint=https://login.microsoftonline.com/common/oauth2/authorize", - headers={'Accept': 'application/json'}, - timeout=self.timeout) + headers={'Accept': 'application/json'}) resp.raise_for_status() resp = json.loads(resp.text) self.authority_groups = [ @@ -513,7 +517,6 @@ def acquire_token_silent_with_error( warnings.warn("We haven't decided how/if this method will accept authority parameter") # the_authority = Authority( # authority, http_client=self.http_client, - # timeout=self.timeout # ) if authority else self.authority result = self._acquire_token_silent_from_cache_and_possibly_refresh_it( scopes, account, self.authority, force_refresh=force_refresh, @@ -525,8 +528,7 @@ def acquire_token_silent_with_error( for alias in self._get_authority_aliases(self.authority.instance): the_authority = Authority( "https://" + alias + "/" + self.authority.tenant, - http_client=self.http_client, validate_authority=False, - timeout=self.timeout) + http_client=self.http_client, validate_authority=False) result = self._acquire_token_silent_from_cache_and_possibly_refresh_it( scopes, account, the_authority, force_refresh=force_refresh, correlation_id=correlation_id, diff --git a/msal/authority.py b/msal/authority.py index 85d75cc7..b7452cca 100644 --- a/msal/authority.py +++ b/msal/authority.py @@ -33,9 +33,7 @@ class Authority(object): """ _domains_without_user_realm_discovery = set([]) - def __init__(self, authority_url, http_client, validate_authority=True, - timeout=None - ): + def __init__(self, authority_url, http_client, validate_authority=True): """Creates an authority instance, and also validates it. :param validate_authority: @@ -45,17 +43,16 @@ def __init__(self, authority_url, http_client, validate_authority=True, performed. """ self.http_client = http_client - self.timeout = timeout authority, self.instance, tenant = canonicalize(authority_url) parts = authority.path.split('/') self.is_b2c = any(self.instance.endswith("." + d) for d in WELL_KNOWN_B2C_HOSTS) or ( len(parts) == 3 and parts[2].lower().startswith("b2c_")) if (tenant != "adfs" and (not self.is_b2c) and validate_authority and self.instance not in WELL_KNOWN_AUTHORITY_HOSTS): - payload = self.instance_discovery( + payload = instance_discovery( "https://{}{}/oauth2/v2.0/authorize".format( self.instance, authority.path), - timeout=timeout) + self.http_client) if payload.get("error") == "invalid_instance": raise ValueError( "invalid_instance: " @@ -72,9 +69,9 @@ def __init__(self, authority_url, http_client, validate_authority=True, authority.path, # In B2C scenario, it is "/tenant/policy" "" if tenant == "adfs" else "/v2.0" # the AAD v2 endpoint )) - openid_config = self.tenant_discovery( + openid_config = tenant_discovery( tenant_discovery_endpoint, - timeout=timeout) + self.http_client) logger.debug("openid_config = %s", openid_config) self.authorization_endpoint = openid_config['authorization_endpoint'] self.token_endpoint = openid_config['token_endpoint'] @@ -85,28 +82,17 @@ def user_realm_discovery(self, username, correlation_id=None, response=None): # It will typically return a dict containing "ver", "account_type", # "federation_protocol", "cloud_audience_urn", # "federation_metadata_url", "federation_active_auth_url", etc. - resp = response or self.http_client.get("https://{netloc}/common/userrealm/{username}?api-version=1.0".format( - netloc=self.instance, username=username), - headers={'Accept':'application/json', 'client-request-id': correlation_id}, - timeout=self.timeout) - return json.loads(resp.text) - - def instance_discovery(self, url, **kwargs): - resp = self.http_client.get('https://{}/common/discovery/instance'.format( - WORLD_WIDE # Historically using WORLD_WIDE. Could use self.instance too - # See https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/4.0.0/src/Microsoft.Identity.Client/Instance/AadInstanceDiscovery.cs#L101-L103 - # and https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/4.0.0/src/Microsoft.Identity.Client/Instance/AadAuthority.cs#L19-L33 - ), params={'authorization_endpoint': url, 'api-version': '1.0'}, - **kwargs) - return json.loads(resp.text) - - def tenant_discovery(self, tenant_discovery_endpoint, **kwargs): - # Returns Openid Configuration - resp = self.http_client.get(tenant_discovery_endpoint, **kwargs) - payload = json.loads(resp.text) - if 'authorization_endpoint' in payload and 'token_endpoint' in payload: - return payload - raise MsalServiceError(status_code=resp.status_code, **payload) + if self.instance not in self.__class__._domains_without_user_realm_discovery: + resp = response or self.http_client.get( + "https://{netloc}/common/userrealm/{username}?api-version=1.0".format( + netloc=self.instance, username=username), + headers={'Accept': 'application/json', + 'client-request-id': correlation_id},) + if resp.status_code != 404: + resp.raise_for_status() + return json.loads(resp.text) + self.__class__._domains_without_user_realm_discovery.add(self.instance) + return {} # This can guide the caller to fall back normal ROPC flow def canonicalize(authority_url): @@ -121,3 +107,20 @@ def canonicalize(authority_url): "or https://.b2clogin.com/.onmicrosoft.com/policy" % authority_url) return authority, authority.hostname, parts[1] + +def instance_discovery(url, http_client, **kwargs): + resp = http_client.get('https://{}/common/discovery/instance'.format( + WORLD_WIDE # Historically using WORLD_WIDE. Could use self.instance too + # See https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/4.0.0/src/Microsoft.Identity.Client/Instance/AadInstanceDiscovery.cs#L101-L103 + # and https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/4.0.0/src/Microsoft.Identity.Client/Instance/AadAuthority.cs#L19-L33 + ), params={'authorization_endpoint': url, 'api-version': '1.0'}, + **kwargs) + return json.loads(resp.text) + +def tenant_discovery(tenant_discovery_endpoint, http_client, **kwargs): + # Returns Openid Configuration + resp = http_client.get(tenant_discovery_endpoint, **kwargs) + payload = json.loads(resp.text) + if 'authorization_endpoint' in payload and 'token_endpoint' in payload: + return payload + raise MsalServiceError(status_code=resp.status_code, **payload) diff --git a/msal/oauth2cli/oauth2.py b/msal/oauth2cli/oauth2.py index f7545f15..7536e0a2 100644 --- a/msal/oauth2cli/oauth2.py +++ b/msal/oauth2cli/oauth2.py @@ -39,7 +39,6 @@ def __init__( client_assertion_type=None, # type: Optional[str] default_headers=None, # type: Optional[dict] default_body=None, # type: Optional[dict] - timeout=None, # type: Union[tuple, float, None] ): """Initialize a client object to talk all the OAuth2 grants to the server. @@ -84,7 +83,6 @@ def __init__( self.default_body["client_assertion_type"] = client_assertion_type self.logger = logging.getLogger(__name__) self.http_client = http_client - self.timeout = timeout def _build_auth_request_params(self, response_type, **kwargs): # response_type is a string defined in @@ -104,7 +102,6 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749 params=None, # a dict to be sent as query string to the endpoint data=None, # All relevant data, which will go into the http body headers=None, # a dict to be sent as request headers - timeout=None, post=None, # A callable to replace requests.post(), for testing. # Such as: lambda url, **kwargs: # Mock(status_code=200, json=Mock(return_value={})) @@ -148,7 +145,7 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749 if "token_endpoint" not in self.configuration: raise ValueError("token_endpoint not found in configuration") resp = (post or self.http_client.post)(self.configuration["token_endpoint"], - headers=_headers, params=params, data=_data, timeout=timeout or self.timeout, + headers=_headers, params=params, data=_data, **kwargs) if resp.status_code >= 500: resp.raise_for_status() # TODO: Will probably retry here @@ -197,7 +194,7 @@ class Client(BaseClient): # We choose to implement all 4 grants in 1 class grant_assertion_encoders = {GRANT_TYPE_SAML2: BaseClient.encode_saml_assertion} - def initiate_device_flow(self, scope=None, timeout=None, **kwargs): + def initiate_device_flow(self, scope=None, **kwargs): # type: (list, **dict) -> dict # The naming of this method is following the wording of this specs # https://tools.ietf.org/html/draft-ietf-oauth-device-flow-12#section-3.1 @@ -217,7 +214,7 @@ def initiate_device_flow(self, scope=None, timeout=None, **kwargs): raise ValueError("You need to provide device authorization endpoint") resp = self.http_client.post(self.configuration[DAE], data={"client_id": self.client_id, "scope": self._stringify(scope or [])}, - timeout=timeout or self.timeout, **kwargs) + **kwargs) flow = json.loads(resp.text) flow["interval"] = int(flow.get("interval", 5)) # Some IdP returns string flow["expires_in"] = int(flow.get("expires_in", 1800)) @@ -378,12 +375,13 @@ class initialization. return self._obtain_token("client_credentials", data=data, **kwargs) def __init__(self, - server_configuration, client_id, + server_configuration, client_id, http_client, on_obtaining_tokens=lambda event: None, # event is defined in _obtain_token(...) on_removing_rt=lambda token_item: None, on_updating_rt=lambda token_item, new_rt: None, **kwargs): - super(Client, self).__init__(server_configuration, client_id, **kwargs) + super(Client, self).__init__( + server_configuration, client_id, http_client, **kwargs) self.on_obtaining_tokens = on_obtaining_tokens self.on_removing_rt = on_removing_rt self.on_updating_rt = on_updating_rt diff --git a/tests/http_client.py b/tests/http_client.py new file mode 100644 index 00000000..21890c21 --- /dev/null +++ b/tests/http_client.py @@ -0,0 +1,19 @@ +import requests + + +class MinimalHttpClient: + + def __init__(self, verify=True, proxies=None, timeout=None): + self.session = requests.Session() + self.session.verify = verify + self.session.proxies = proxies + self.timeout = timeout + + def post(self, url, params=None, data=None, headers=None, **kwargs): + return self.session.post( + url, params=params, data=data, headers=headers, + timeout=self.timeout) + + def get(self, url, params=None, headers=None, **kwargs): + return self.session.get( + url, params=params, headers=headers, timeout=self.timeout) diff --git a/tests/test_authority.py b/tests/test_authority.py index 79e7bc0b..561aaa5e 100644 --- a/tests/test_authority.py +++ b/tests/test_authority.py @@ -4,6 +4,8 @@ from tests import unittest import requests +from tests.http_client import MinimalHttpClient + @unittest.skipIf(os.getenv("TRAVIS_TAG"), "Skip network io during tagged release") class TestAuthority(unittest.TestCase): @@ -11,7 +13,7 @@ class TestAuthority(unittest.TestCase): def test_wellknown_host_and_tenant(self): # Assert all well known authority hosts are using their own "common" tenant for host in WELL_KNOWN_AUTHORITY_HOSTS: - a = Authority('https://{}/common'.format(host), requests.Session()) + a = Authority('https://{}/common'.format(host), MinimalHttpClient()) self.assertEqual( a.authorization_endpoint, 'https://%s/common/oauth2/v2.0/authorize' % host) @@ -24,14 +26,14 @@ def test_lessknown_host_will_return_a_set_of_v1_endpoints(self): # It is probably not a strict API contract. I simply mention it here. less_known = 'login.windows.net' # less.known.host/ v1_token_endpoint = 'https://{}/common/oauth2/token'.format(less_known) - a = Authority('https://{}/common'.format(less_known), requests.Session()) + a = Authority('https://{}/common'.format(less_known), MinimalHttpClient()) self.assertEqual(a.token_endpoint, v1_token_endpoint) self.assertNotIn('v2.0', a.token_endpoint) def test_unknown_host_wont_pass_instance_discovery(self): _assert = getattr(self, "assertRaisesRegex", self.assertRaisesRegexp) # Hack with _assert(ValueError, "invalid_instance"): - Authority('https://example.com/tenant_doesnt_matter_in_this_case', requests.Session()) + Authority('https://example.com/tenant_doesnt_matter_in_this_case', MinimalHttpClient()) def test_invalid_host_skipping_validation_can_be_turned_off(self): try: @@ -73,3 +75,21 @@ def test_canonicalize_rejects_tenantless_host_with_trailing_slash(self): canonicalize("https://no.tenant.example.com/") +@unittest.skipIf(os.getenv("TRAVIS_TAG"), "Skip network io during tagged release") +class TestAuthorityInternalHelperUserRealmDiscovery(unittest.TestCase): + def test_memorize(self): + # We use a real authority so the constructor can finish tenant discovery + authority = "https://login.microsoftonline.com/common" + self.assertNotIn(authority, Authority._domains_without_user_realm_discovery) + a = Authority(authority, validate_authority=False) + + # We now pretend this authority supports no User Realm Discovery + class MockResponse(object): + status_code = 404 + a.user_realm_discovery("john.doe@example.com", response=MockResponse()) + self.assertIn( + "login.microsoftonline.com", + Authority._domains_without_user_realm_discovery, + "user_realm_discovery() should memorize domains not supporting URD") + a.user_realm_discovery("john.doe@example.com", + response="This would cause exception if memorization did not work") \ No newline at end of file diff --git a/tests/test_e2e.py b/tests/test_e2e.py index 0d74eb1d..28383cd6 100644 --- a/tests/test_e2e.py +++ b/tests/test_e2e.py @@ -7,7 +7,7 @@ import requests import msal - +from tests.http_client import MinimalHttpClient logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) @@ -21,7 +21,8 @@ def _get_app_and_auth_code( scopes=["https://graph.microsoft.com/.default"], # Microsoft Graph **kwargs): from msal.oauth2cli.authcode import obtain_auth_code - app = msal.ClientApplication(client_id, client_secret, authority=authority) + app = msal.ClientApplication( + client_id, client_secret, authority=authority, http_client=MinimalHttpClient()) redirect_uri = "http://localhost:%d" % port ac = obtain_auth_code(port, auth_uri=app.get_authorization_request_url( scopes, redirect_uri=redirect_uri, **kwargs)) @@ -92,7 +93,8 @@ def _test_username_password(self, authority=None, client_id=None, username=None, password=None, scope=None, **ignored): assert authority and client_id and username and password and scope - self.app = msal.PublicClientApplication(client_id, authority=authority) + self.app = msal.PublicClientApplication( + client_id, authority=authority, http_client=MinimalHttpClient()) result = self.app.acquire_token_by_username_password( username, password, scopes=scope) self.assertLoosely(result) @@ -106,7 +108,7 @@ def _test_device_flow( self, client_id=None, authority=None, scope=None, **ignored): assert client_id and authority and scope self.app = msal.PublicClientApplication( - client_id, authority=authority) + client_id, authority=authority, http_client=MinimalHttpClient()) flow = self.app.initiate_device_flow(scopes=scope) assert "user_code" in flow, "DF does not seem to be provisioned: %s".format( json.dumps(flow, indent=4)) @@ -225,13 +227,13 @@ def test_ssh_cert(self): self.assertEqual(refreshed_ssh_cert["token_type"], "ssh-cert") self.assertNotEqual(result["access_token"], refreshed_ssh_cert['access_token']) - def test_client_secret(self): self.skipUnlessWithConfig(["client_id", "client_secret"]) self.app = msal.ConfidentialClientApplication( self.config["client_id"], client_credential=self.config.get("client_secret"), - authority=self.config.get("authority")) + authority=self.config.get("authority"), + http_client=MinimalHttpClient()) scope = self.config.get("scope", []) result = self.app.acquire_token_for_client(scope) self.assertIn('access_token', result) @@ -245,7 +247,8 @@ def test_client_certificate(self): private_key = f.read() # Should be in PEM format self.app = msal.ConfidentialClientApplication( self.config['client_id'], - {"private_key": private_key, "thumbprint": client_cert["thumbprint"]}) + {"private_key": private_key, "thumbprint": client_cert["thumbprint"]}, + http_client=MinimalHttpClient()) scope = self.config.get("scope", []) result = self.app.acquire_token_for_client(scope) self.assertIn('access_token', result) @@ -267,7 +270,8 @@ def test_subject_name_issuer_authentication(self): "private_key": private_key, "thumbprint": self.config["thumbprint"], "public_certificate": public_certificate, - }) + }, + http_client=MinimalHttpClient()) scope = self.config.get("scope", []) result = self.app.acquire_token_for_client(scope) self.assertIn('access_token', result) @@ -311,7 +315,7 @@ def get_lab_app( return msal.ConfidentialClientApplication(client_id, client_secret, authority="https://login.microsoftonline.com/" "72f988bf-86f1-41af-91ab-2d7cd011db47", # Microsoft tenant ID - ) + http_client=MinimalHttpClient()) def get_session(lab_app, scopes): # BTW, this infrastructure tests the confidential client flow logger.info("Creating session") @@ -398,7 +402,8 @@ def _test_acquire_token_by_auth_code( def _test_acquire_token_obo(self, config_pca, config_cca): # 1. An app obtains a token representing a user, for our mid-tier service pca = msal.PublicClientApplication( - config_pca["client_id"], authority=config_pca["authority"]) + config_pca["client_id"], authority=config_pca["authority"], + http_client=MinimalHttpClient()) pca_result = pca.acquire_token_by_username_password( config_pca["username"], config_pca["password"], @@ -413,6 +418,7 @@ def _test_acquire_token_obo(self, config_pca, config_cca): config_cca["client_id"], client_credential=config_cca["client_secret"], authority=config_cca["authority"], + http_client=MinimalHttpClient(), # token_cache= ..., # Default token cache is all-tokens-store-in-memory. # That's fine if OBO app uses short-lived msal instance per session. # Otherwise, the OBO app need to implement a one-cache-per-user setup. @@ -439,7 +445,8 @@ def _test_acquire_token_by_client_secret( **ignored): assert client_id and client_secret and authority and scope app = msal.ConfidentialClientApplication( - client_id, client_credential=client_secret, authority=authority) + client_id, client_credential=client_secret, authority=authority, + http_client=MinimalHttpClient()) result = app.acquire_token_for_client(scope) self.assertIsNotNone(result.get("access_token"), "Got %s instead" % result) From 8901269408c45ec57d72eabfa0d476c7c1e5394f Mon Sep 17 00:00:00 2001 From: Abhidnya Patil Date: Wed, 15 Apr 2020 22:34:31 -0700 Subject: [PATCH 22/33] Fixing broken test --- tests/test_authority.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_authority.py b/tests/test_authority.py index 561aaa5e..d8660987 100644 --- a/tests/test_authority.py +++ b/tests/test_authority.py @@ -81,7 +81,7 @@ def test_memorize(self): # We use a real authority so the constructor can finish tenant discovery authority = "https://login.microsoftonline.com/common" self.assertNotIn(authority, Authority._domains_without_user_realm_discovery) - a = Authority(authority, validate_authority=False) + a = Authority(authority, MinimalHttpClient(), validate_authority=False) # We now pretend this authority supports no User Realm Discovery class MockResponse(object): From 7774be83e160fff78e27df89710c31de68f38201 Mon Sep 17 00:00:00 2001 From: Abhidnya Patil Date: Wed, 15 Apr 2020 23:34:30 -0700 Subject: [PATCH 23/33] Cleaning up --- msal/application.py | 11 +++++++---- msal/authority.py | 4 ++-- tests/test_application.py | 9 ++++++--- tests/test_authority.py | 20 ++++++++++++-------- tests/test_client.py | 5 +++-- 5 files changed, 30 insertions(+), 19 deletions(-) diff --git a/msal/application.py b/msal/application.py index 6416007e..809cf0d4 100644 --- a/msal/application.py +++ b/msal/application.py @@ -19,6 +19,7 @@ from .wstrust_response import * from .token_cache import TokenCache + # The __init__.py will import this. Not the other way around. __version__ = "1.2.0" @@ -283,11 +284,12 @@ def get_authorization_request_url( # The previous implementation is, it will use self.authority by default. # Multi-tenant app can use new authority on demand the_authority = Authority( - authority, http_client=self.http_client) if authority else self.authority + authority, http_client=self.http_client + ) if authority else self.authority client = Client( {"authorization_endpoint": the_authority.authorization_endpoint}, - self.client_id, self.http_client) + self.client_id, http_client=self.http_client) return client.build_auth_request_uri( response_type=response_type, redirect_uri=redirect_uri, state=state, login_hint=login_hint, @@ -516,7 +518,8 @@ def acquire_token_silent_with_error( if authority: warnings.warn("We haven't decided how/if this method will accept authority parameter") # the_authority = Authority( - # authority, http_client=self.http_client, + # authority, + # http_client=self.http_client, # ) if authority else self.authority result = self._acquire_token_silent_from_cache_and_possibly_refresh_it( scopes, account, self.authority, force_refresh=force_refresh, @@ -757,7 +760,7 @@ def acquire_token_by_username_password( CLIENT_CURRENT_TELEMETRY: _build_current_telemetry_request_header( self.ACQUIRE_TOKEN_BY_USERNAME_PASSWORD_ID), } - if not self.authority.is_adfs and not self.authority.is_b2c: + if not self.authority.is_adfs: user_realm_result = self.authority.user_realm_discovery( username, correlation_id=headers[CLIENT_REQUEST_ID]) if user_realm_result.get("account_type") == "Federated": diff --git a/msal/authority.py b/msal/authority.py index b7452cca..c4ec979e 100644 --- a/msal/authority.py +++ b/msal/authority.py @@ -45,9 +45,9 @@ def __init__(self, authority_url, http_client, validate_authority=True): self.http_client = http_client authority, self.instance, tenant = canonicalize(authority_url) parts = authority.path.split('/') - self.is_b2c = any(self.instance.endswith("." + d) for d in WELL_KNOWN_B2C_HOSTS) or ( + is_b2c = any(self.instance.endswith("." + d) for d in WELL_KNOWN_B2C_HOSTS) or ( len(parts) == 3 and parts[2].lower().startswith("b2c_")) - if (tenant != "adfs" and (not self.is_b2c) and validate_authority + if (tenant != "adfs" and (not is_b2c) and validate_authority and self.instance not in WELL_KNOWN_AUTHORITY_HOSTS): payload = instance_discovery( "https://{}{}/oauth2/v2.0/authorize".format( diff --git a/tests/test_application.py b/tests/test_application.py index 0e105c24..95e7c0cf 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -10,7 +10,8 @@ import msal from tests import unittest from tests.test_token_cache import TokenCacheTestCase -import requests +from tests.http_client import MinimalHttpClient + logger = logging.getLogger(__name__) logging.basicConfig(level=logging.DEBUG) @@ -47,7 +48,8 @@ class TestClientApplicationAcquireTokenSilentErrorBehaviors(unittest.TestCase): def setUp(self): self.authority_url = "https://login.microsoftonline.com/common" - self.authority = msal.authority.Authority(self.authority_url, http_client= requests.Session()) + self.authority = msal.authority.Authority( + self.authority_url, http_client=MinimalHttpClient()) self.scopes = ["s1", "s2"] self.uid = "my_uid" self.utid = "my_utid" @@ -106,7 +108,8 @@ class TestClientApplicationAcquireTokenSilentFociBehaviors(unittest.TestCase): def setUp(self): self.authority_url = "https://login.microsoftonline.com/common" - self.authority = msal.authority.Authority(self.authority_url, requests.Session()) + self.authority = msal.authority.Authority( + self.authority_url, http_client=MinimalHttpClient()) self.scopes = ["s1", "s2"] self.uid = "my_uid" self.utid = "my_utid" diff --git a/tests/test_authority.py b/tests/test_authority.py index d8660987..3a2d3f74 100644 --- a/tests/test_authority.py +++ b/tests/test_authority.py @@ -2,8 +2,6 @@ from msal.authority import * from tests import unittest -import requests - from tests.http_client import MinimalHttpClient @@ -13,7 +11,8 @@ class TestAuthority(unittest.TestCase): def test_wellknown_host_and_tenant(self): # Assert all well known authority hosts are using their own "common" tenant for host in WELL_KNOWN_AUTHORITY_HOSTS: - a = Authority('https://{}/common'.format(host), MinimalHttpClient()) + a = Authority( + 'https://{}/common'.format(host), http_client=MinimalHttpClient()) self.assertEqual( a.authorization_endpoint, 'https://%s/common/oauth2/v2.0/authorize' % host) @@ -26,18 +25,22 @@ def test_lessknown_host_will_return_a_set_of_v1_endpoints(self): # It is probably not a strict API contract. I simply mention it here. less_known = 'login.windows.net' # less.known.host/ v1_token_endpoint = 'https://{}/common/oauth2/token'.format(less_known) - a = Authority('https://{}/common'.format(less_known), MinimalHttpClient()) + a = Authority( + 'https://{}/common'.format(less_known), http_client=MinimalHttpClient()) self.assertEqual(a.token_endpoint, v1_token_endpoint) self.assertNotIn('v2.0', a.token_endpoint) def test_unknown_host_wont_pass_instance_discovery(self): _assert = getattr(self, "assertRaisesRegex", self.assertRaisesRegexp) # Hack with _assert(ValueError, "invalid_instance"): - Authority('https://example.com/tenant_doesnt_matter_in_this_case', MinimalHttpClient()) + Authority('https://example.com/tenant_doesnt_matter_in_this_case', + http_client=MinimalHttpClient()) def test_invalid_host_skipping_validation_can_be_turned_off(self): try: - Authority('https://example.com/invalid', validate_authority=False) + Authority( + 'https://example.com/invalid', + http_client=MinimalHttpClient(), validate_authority=False) except ValueError as e: if "invalid_instance" in str(e): # Imprecise but good enough self.fail("validate_authority=False should turn off validation") @@ -81,7 +84,8 @@ def test_memorize(self): # We use a real authority so the constructor can finish tenant discovery authority = "https://login.microsoftonline.com/common" self.assertNotIn(authority, Authority._domains_without_user_realm_discovery) - a = Authority(authority, MinimalHttpClient(), validate_authority=False) + a = Authority(authority, http_client=MinimalHttpClient(), + validate_authority=False) # We now pretend this authority supports no User Realm Discovery class MockResponse(object): @@ -92,4 +96,4 @@ class MockResponse(object): Authority._domains_without_user_realm_discovery, "user_realm_discovery() should memorize domains not supporting URD") a.user_realm_discovery("john.doe@example.com", - response="This would cause exception if memorization did not work") \ No newline at end of file + response="This would cause exception if memorization did not work") diff --git a/tests/test_client.py b/tests/test_client.py index 52fc650d..28734bff 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -12,6 +12,7 @@ from msal.oauth2cli import Client, JwtSigner from msal.oauth2cli.authcode import obtain_auth_code from tests import unittest, Oauth2TestCase +from tests.http_client import MinimalHttpClient logging.basicConfig(level=logging.DEBUG) @@ -99,13 +100,13 @@ def setUpClass(cls): issuer=CONFIG["client_id"], ), client_assertion_type=Client.CLIENT_ASSERTION_TYPE_JWT, - http_client=requests.Session() + http_client=MinimalHttpClient() ) else: cls.client = Client( CONFIG["openid_configuration"], CONFIG['client_id'], client_secret=CONFIG.get('client_secret'), - http_client=requests.Session()) + http_client=MinimalHttpClient()) @unittest.skipIf( "token_endpoint" not in CONFIG.get("openid_configuration", {}), From 42ca7a2294804815ed8db35cb0673105d54debf2 Mon Sep 17 00:00:00 2001 From: Abhidnya Patil Date: Fri, 17 Apr 2020 13:18:43 -0700 Subject: [PATCH 24/33] PR review changes part 1 --- msal/application.py | 20 +++++++++++--------- msal/authority.py | 8 +++++--- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/msal/application.py b/msal/application.py index 809cf0d4..69c9b2ac 100644 --- a/msal/application.py +++ b/msal/application.py @@ -184,7 +184,7 @@ def __init__( self.app_version = app_version self.authority = Authority( authority or "https://login.microsoftonline.com/common/", - http_client=self.http_client, validate_authority=validate_authority) + self.http_client, validate_authority=validate_authority) # Here the self.authority is not the same type as authority in input self.token_cache = token_cache or TokenCache() self.client = self._build_client(client_credential, self.authority) @@ -284,12 +284,14 @@ def get_authorization_request_url( # The previous implementation is, it will use self.authority by default. # Multi-tenant app can use new authority on demand the_authority = Authority( - authority, http_client=self.http_client + authority, + self.http_client ) if authority else self.authority client = Client( {"authorization_endpoint": the_authority.authorization_endpoint}, - self.client_id, http_client=self.http_client) + self.client_id, + self.http_client) return client.build_auth_request_uri( response_type=response_type, redirect_uri=redirect_uri, state=state, login_hint=login_hint, @@ -398,9 +400,8 @@ def _get_authority_aliases(self, instance): "https://login.microsoftonline.com/common/discovery/instance?api-version=1.1&authorization_endpoint=https://login.microsoftonline.com/common/oauth2/authorize", headers={'Accept': 'application/json'}) resp.raise_for_status() - resp = json.loads(resp.text) self.authority_groups = [ - set(group['aliases']) for group in resp['metadata']] + set(group['aliases']) for group in json.loads(resp.text)['metadata']] for group in self.authority_groups: if instance in group: return [alias for alias in group if alias != instance] @@ -519,7 +520,7 @@ def acquire_token_silent_with_error( warnings.warn("We haven't decided how/if this method will accept authority parameter") # the_authority = Authority( # authority, - # http_client=self.http_client, + # self.http_client, # ) if authority else self.authority result = self._acquire_token_silent_from_cache_and_possibly_refresh_it( scopes, account, self.authority, force_refresh=force_refresh, @@ -531,7 +532,8 @@ def acquire_token_silent_with_error( for alias in self._get_authority_aliases(self.authority.instance): the_authority = Authority( "https://" + alias + "/" + self.authority.tenant, - http_client=self.http_client, validate_authority=False) + self.http_client, + validate_authority=False) result = self._acquire_token_silent_from_cache_and_possibly_refresh_it( scopes, account, the_authority, force_refresh=force_refresh, correlation_id=correlation_id, @@ -778,7 +780,7 @@ def _acquire_token_by_username_password_federated( if user_realm_result.get("federation_metadata_url"): wstrust_endpoint = mex_send_request( user_realm_result["federation_metadata_url"], - http_client=self.http_client) + self.http_client) if wstrust_endpoint is None: raise ValueError("Unable to find wstrust endpoint from MEX. " "This typically happens when attempting MSA accounts. " @@ -790,7 +792,7 @@ def _acquire_token_by_username_password_federated( wstrust_endpoint.get("address", # Fallback to an AAD supplied endpoint user_realm_result.get("federation_active_auth_url")), - wstrust_endpoint.get("action"), http_client=self.http_client) + wstrust_endpoint.get("action"), self.http_client) if not ("token" in wstrust_result and "type" in wstrust_result): raise RuntimeError("Unsuccessful RSTR. %s" % wstrust_result) GRANT_TYPE_SAML1_1 = 'urn:ietf:params:oauth:grant-type:saml1_1-bearer' diff --git a/msal/authority.py b/msal/authority.py index c4ec979e..79b2f6af 100644 --- a/msal/authority.py +++ b/msal/authority.py @@ -109,12 +109,14 @@ def canonicalize(authority_url): return authority, authority.hostname, parts[1] def instance_discovery(url, http_client, **kwargs): - resp = http_client.get('https://{}/common/discovery/instance'.format( + resp = http_client.get( # Note: This URL seemingly returns V1 endpoint only + 'https://{}/common/discovery/instance'.format( WORLD_WIDE # Historically using WORLD_WIDE. Could use self.instance too # See https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/4.0.0/src/Microsoft.Identity.Client/Instance/AadInstanceDiscovery.cs#L101-L103 # and https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/4.0.0/src/Microsoft.Identity.Client/Instance/AadAuthority.cs#L19-L33 - ), params={'authorization_endpoint': url, 'api-version': '1.0'}, - **kwargs) + ), + params={'authorization_endpoint': url, 'api-version': '1.0'}, + **kwargs) return json.loads(resp.text) def tenant_discovery(tenant_discovery_endpoint, http_client, **kwargs): From d38d38d302a9f5ad14aa337b6dd4e4355b6e9d29 Mon Sep 17 00:00:00 2001 From: Abhidnya Patil Date: Fri, 17 Apr 2020 13:59:43 -0700 Subject: [PATCH 25/33] PR review changes part 2 --- tests/http_client.py | 11 +++++++++++ tests/test_application.py | 26 ++++++++++---------------- tests/test_authority.py | 11 +++++------ tests/test_client.py | 6 +++--- 4 files changed, 29 insertions(+), 25 deletions(-) diff --git a/tests/http_client.py b/tests/http_client.py index 21890c21..82fad692 100644 --- a/tests/http_client.py +++ b/tests/http_client.py @@ -17,3 +17,14 @@ def post(self, url, params=None, data=None, headers=None, **kwargs): def get(self, url, params=None, headers=None, **kwargs): return self.session.get( url, params=params, headers=headers, timeout=self.timeout) + + +class MinimalResponse(object): # Not for production use + def __init__(self, status_code=None, text=None, requests_resp=None): + self.status_code = status_code or requests_resp.status_code + self.text = text or requests_resp.text + self._raw_resp = requests_resp + + def raise_for_status(self): + if self._raw_resp: + self._raw_resp.raise_for_status() diff --git a/tests/test_application.py b/tests/test_application.py index 95e7c0cf..ff09ac42 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -1,16 +1,10 @@ # Note: Since Aug 2019 we move all e2e tests into test_e2e.py, # so this test_application file contains only unit tests without dependency. - -try: - from unittest.mock import * # Python 3 -except: - from mock import * # Need an external mock package - from msal.application import * import msal from tests import unittest from tests.test_token_cache import TokenCacheTestCase -from tests.http_client import MinimalHttpClient +from tests.http_client import MinimalHttpClient, MinimalResponse logger = logging.getLogger(__name__) @@ -49,7 +43,7 @@ class TestClientApplicationAcquireTokenSilentErrorBehaviors(unittest.TestCase): def setUp(self): self.authority_url = "https://login.microsoftonline.com/common" self.authority = msal.authority.Authority( - self.authority_url, http_client=MinimalHttpClient()) + self.authority_url, MinimalHttpClient()) self.scopes = ["s1", "s2"] self.uid = "my_uid" self.utid = "my_utid" @@ -77,21 +71,21 @@ def test_cache_empty_will_be_returned_as_None(self): def test_acquire_token_silent_will_suppress_error(self): error_response = '{"error": "invalid_grant", "suberror": "xyz"}' def tester(url, **kwargs): - return Mock(status_code=400, text=error_response) + return MinimalResponse(status_code=400, text=error_response) self.assertEqual(None, self.app.acquire_token_silent( self.scopes, self.account, post=tester)) def test_acquire_token_silent_with_error_will_return_error(self): error_response = '{"error": "invalid_grant", "error_description": "xyz"}' def tester(url, **kwargs): - return Mock(status_code=400, text=error_response) + return MinimalResponse(status_code=400, text=error_response) self.assertEqual(json.loads(error_response), self.app.acquire_token_silent_with_error( self.scopes, self.account, post=tester)) def test_atswe_will_map_some_suberror_to_classification_as_is(self): error_response = '{"error": "invalid_grant", "suberror": "basic_action"}' def tester(url, **kwargs): - return Mock(status_code=400, text=error_response) + return MinimalResponse(status_code=400, text=error_response) result = self.app.acquire_token_silent_with_error( self.scopes, self.account, post=tester) self.assertEqual("basic_action", result.get("classification")) @@ -99,7 +93,7 @@ def tester(url, **kwargs): def test_atswe_will_map_some_suberror_to_classification_to_empty_string(self): error_response = '{"error": "invalid_grant", "suberror": "client_mismatch"}' def tester(url, **kwargs): - return Mock(status_code=400, text=error_response) + return MinimalResponse(status_code=400, text=error_response) result = self.app.acquire_token_silent_with_error( self.scopes, self.account, post=tester) self.assertEqual("", result.get("classification")) @@ -109,7 +103,7 @@ class TestClientApplicationAcquireTokenSilentFociBehaviors(unittest.TestCase): def setUp(self): self.authority_url = "https://login.microsoftonline.com/common" self.authority = msal.authority.Authority( - self.authority_url, http_client=MinimalHttpClient()) + self.authority_url, MinimalHttpClient()) self.scopes = ["s1", "s2"] self.uid = "my_uid" self.utid = "my_utid" @@ -134,7 +128,7 @@ def test_unknown_orphan_app_will_attempt_frt_and_not_remove_it(self): error_response = '{"error": "invalid_grant","error_description": "Was issued to another client"}' def tester(url, data=None, **kwargs): self.assertEqual(self.frt, data.get("refresh_token"), "Should attempt the FRT") - return Mock(status_code=400, text=error_response) + return MinimalResponse(status_code=400, text=error_response) app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( self.authority, self.scopes, self.account, post=tester) self.assertNotEqual([], app.token_cache.find( @@ -155,7 +149,7 @@ def test_known_orphan_app_will_skip_frt_and_only_use_its_own_rt(self): logger.debug("%s.cache = %s", self.id(), self.cache.serialize()) def tester(url, data=None, **kwargs): self.assertEqual(rt, data.get("refresh_token"), "Should attempt the RT") - return Mock(status_code=200, text='{}') + return MinimalResponse(status_code=200, text='{}') app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( self.authority, self.scopes, self.account, post=tester) @@ -163,7 +157,7 @@ def test_unknown_family_app_will_attempt_frt_and_join_family(self): def tester(url, data=None, **kwargs): self.assertEqual( self.frt, data.get("refresh_token"), "Should attempt the FRT") - return Mock( + return MinimalResponse( status_code=200, text=json.dumps(TokenCacheTestCase.build_response( uid=self.uid, utid=self.utid, foci="1", access_token="at"))) app = ClientApplication( diff --git a/tests/test_authority.py b/tests/test_authority.py index 3a2d3f74..15a0eb52 100644 --- a/tests/test_authority.py +++ b/tests/test_authority.py @@ -12,7 +12,7 @@ def test_wellknown_host_and_tenant(self): # Assert all well known authority hosts are using their own "common" tenant for host in WELL_KNOWN_AUTHORITY_HOSTS: a = Authority( - 'https://{}/common'.format(host), http_client=MinimalHttpClient()) + 'https://{}/common'.format(host), MinimalHttpClient()) self.assertEqual( a.authorization_endpoint, 'https://%s/common/oauth2/v2.0/authorize' % host) @@ -26,7 +26,7 @@ def test_lessknown_host_will_return_a_set_of_v1_endpoints(self): less_known = 'login.windows.net' # less.known.host/ v1_token_endpoint = 'https://{}/common/oauth2/token'.format(less_known) a = Authority( - 'https://{}/common'.format(less_known), http_client=MinimalHttpClient()) + 'https://{}/common'.format(less_known), MinimalHttpClient()) self.assertEqual(a.token_endpoint, v1_token_endpoint) self.assertNotIn('v2.0', a.token_endpoint) @@ -34,13 +34,13 @@ def test_unknown_host_wont_pass_instance_discovery(self): _assert = getattr(self, "assertRaisesRegex", self.assertRaisesRegexp) # Hack with _assert(ValueError, "invalid_instance"): Authority('https://example.com/tenant_doesnt_matter_in_this_case', - http_client=MinimalHttpClient()) + MinimalHttpClient()) def test_invalid_host_skipping_validation_can_be_turned_off(self): try: Authority( 'https://example.com/invalid', - http_client=MinimalHttpClient(), validate_authority=False) + MinimalHttpClient(), validate_authority=False) except ValueError as e: if "invalid_instance" in str(e): # Imprecise but good enough self.fail("validate_authority=False should turn off validation") @@ -84,8 +84,7 @@ def test_memorize(self): # We use a real authority so the constructor can finish tenant discovery authority = "https://login.microsoftonline.com/common" self.assertNotIn(authority, Authority._domains_without_user_realm_discovery) - a = Authority(authority, http_client=MinimalHttpClient(), - validate_authority=False) + a = Authority(authority, MinimalHttpClient(), validate_authority=False) # We now pretend this authority supports no User Realm Discovery class MockResponse(object): diff --git a/tests/test_client.py b/tests/test_client.py index 28734bff..4851c79d 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -91,6 +91,7 @@ def setUpClass(cls): cls.client = Client( CONFIG["openid_configuration"], CONFIG['client_id'], + MinimalHttpClient(), client_assertion=JwtSigner( private_key, algorithm="RS256", @@ -100,13 +101,12 @@ def setUpClass(cls): issuer=CONFIG["client_id"], ), client_assertion_type=Client.CLIENT_ASSERTION_TYPE_JWT, - http_client=MinimalHttpClient() ) else: cls.client = Client( CONFIG["openid_configuration"], CONFIG['client_id'], - client_secret=CONFIG.get('client_secret'), - http_client=MinimalHttpClient()) + http_client=MinimalHttpClient(), + client_secret=CONFIG.get('client_secret')) @unittest.skipIf( "token_endpoint" not in CONFIG.get("openid_configuration", {}), From 3226bc451be54c216592e4bd60a9eb9ce86b7473 Mon Sep 17 00:00:00 2001 From: Abhidnya Patil Date: Fri, 17 Apr 2020 14:35:09 -0700 Subject: [PATCH 26/33] PR review changes part 3 --- tests/http_client.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/http_client.py b/tests/http_client.py index 82fad692..a12f740b 100644 --- a/tests/http_client.py +++ b/tests/http_client.py @@ -10,17 +10,17 @@ def __init__(self, verify=True, proxies=None, timeout=None): self.timeout = timeout def post(self, url, params=None, data=None, headers=None, **kwargs): - return self.session.post( + return MinimalResponse(self.session.post( url, params=params, data=data, headers=headers, - timeout=self.timeout) + timeout=self.timeout)) def get(self, url, params=None, headers=None, **kwargs): - return self.session.get( - url, params=params, headers=headers, timeout=self.timeout) + return MinimalResponse(self.session.get( + url, params=params, headers=headers, timeout=self.timeout)) class MinimalResponse(object): # Not for production use - def __init__(self, status_code=None, text=None, requests_resp=None): + def __init__(self, requests_resp=None, status_code=None, text=None): self.status_code = status_code or requests_resp.status_code self.text = text or requests_resp.text self._raw_resp = requests_resp From 90afb941ccc409aab5e33eee52c8f4a556d1a617 Mon Sep 17 00:00:00 2001 From: Abhidnya Patil Date: Fri, 17 Apr 2020 14:44:10 -0700 Subject: [PATCH 27/33] Minor indent change --- msal/authority.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/msal/authority.py b/msal/authority.py index 79b2f6af..183176a2 100644 --- a/msal/authority.py +++ b/msal/authority.py @@ -111,10 +111,10 @@ def canonicalize(authority_url): def instance_discovery(url, http_client, **kwargs): resp = http_client.get( # Note: This URL seemingly returns V1 endpoint only 'https://{}/common/discovery/instance'.format( - WORLD_WIDE # Historically using WORLD_WIDE. Could use self.instance too - # See https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/4.0.0/src/Microsoft.Identity.Client/Instance/AadInstanceDiscovery.cs#L101-L103 - # and https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/4.0.0/src/Microsoft.Identity.Client/Instance/AadAuthority.cs#L19-L33 - ), + WORLD_WIDE # Historically using WORLD_WIDE. Could use self.instance too + # See https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/4.0.0/src/Microsoft.Identity.Client/Instance/AadInstanceDiscovery.cs#L101-L103 + # and https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/4.0.0/src/Microsoft.Identity.Client/Instance/AadAuthority.cs#L19-L33 + ), params={'authorization_endpoint': url, 'api-version': '1.0'}, **kwargs) return json.loads(resp.text) From a1ce0d1b9cc5c22f8f6da9a76aa98f44073e3e54 Mon Sep 17 00:00:00 2001 From: Abhidnya Patil Date: Mon, 20 Apr 2020 10:58:04 -0700 Subject: [PATCH 28/33] PR review changes part 4 --- msal/oauth2cli/oauth2.py | 24 ++++++++++++++++++------ requirements.txt | 1 - tests/http_client.py | 4 ++-- tests/test_application.py | 2 +- tests/test_client.py | 6 +++--- 5 files changed, 24 insertions(+), 13 deletions(-) diff --git a/msal/oauth2cli/oauth2.py b/msal/oauth2cli/oauth2.py index 7536e0a2..f515ad93 100644 --- a/msal/oauth2cli/oauth2.py +++ b/msal/oauth2cli/oauth2.py @@ -33,7 +33,7 @@ def __init__( self, server_configuration, # type: dict client_id, # type: str - http_client, # type: HttpClient + http_client, # type: http.HttpClient client_secret=None, # type: Optional[str] client_assertion=None, # type: Union[bytes, callable, None] client_assertion_type=None, # type: Optional[str] @@ -53,6 +53,8 @@ def __init__( or https://example.com/.../.well-known/openid-configuration client_id (str): The client's id, issued by the authorization server + http_client (object): + An http.HttpClient-like object, e.g. requests.Session client_secret (str): Triggers HTTP AUTH for Confidential Client client_assertion (bytes, callable): The client assertion to authenticate this client, per RFC 7521. @@ -72,6 +74,16 @@ def __init__( you could choose to set this as {"client_secret": "your secret"} if your authorization server wants it to be in the request body (rather than in the request header). + + There is no session-wide `timeout` parameter defined here. + The timeout behavior is determined by the actual http client you use. + If you happen to use Requests, it chose to not support session-wide timeout + (https://github.com/psf/requests/issues/3341), but you can patch that by: + + s = requests.Session() + s.request = functools.partial(s.request, timeout=3) + + and then feed that patched session instance to this class. """ self.configuration = server_configuration self.client_id = client_id @@ -90,6 +102,7 @@ def _build_auth_request_params(self, response_type, **kwargs): # or it can be a space-delimited string as defined in # https://tools.ietf.org/html/rfc6749#section-8.4 response_type = self._stringify(response_type) + params = {'client_id': self.client_id, 'response_type': response_type} params.update(kwargs) # Note: None values will override params params = {k: v for k, v in params.items() if v is not None} # clean up @@ -119,9 +132,7 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749 _data.update(self.default_body) # It may contain authen parameters _data.update(data or {}) # So the content in data param prevails - _data = {k: v for k, v in _data.items() if v} - # We will have to clean up None values here, - # because we can have some libraries not supporting cleaning of None values. + _data = {k: v for k, v in _data.items() if v} # Clean up None values if _data.get('scope'): _data['scope'] = self._stringify(_data['scope']) @@ -144,7 +155,8 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749 if "token_endpoint" not in self.configuration: raise ValueError("token_endpoint not found in configuration") - resp = (post or self.http_client.post)(self.configuration["token_endpoint"], + resp = (post or self.http_client.post)( + self.configuration["token_endpoint"], headers=_headers, params=params, data=_data, **kwargs) if resp.status_code >= 500: @@ -214,6 +226,7 @@ def initiate_device_flow(self, scope=None, **kwargs): raise ValueError("You need to provide device authorization endpoint") resp = self.http_client.post(self.configuration[DAE], data={"client_id": self.client_id, "scope": self._stringify(scope or [])}, + headers=dict(self.default_headers, **kwargs.pop("headers", {})), **kwargs) flow = json.loads(resp.text) flow["interval"] = int(flow.get("interval", 5)) # Some IdP returns string @@ -464,4 +477,3 @@ def obtain_token_by_assertion( data = kwargs.pop("data", {}) data.update(scope=scope, assertion=encoder(assertion)) return self._obtain_token(grant_type, data=data, **kwargs) - diff --git a/requirements.txt b/requirements.txt index 61a6510d..9c558e35 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1 @@ . -mock; python_version < '3.3' diff --git a/tests/http_client.py b/tests/http_client.py index a12f740b..4bff9b45 100644 --- a/tests/http_client.py +++ b/tests/http_client.py @@ -10,12 +10,12 @@ def __init__(self, verify=True, proxies=None, timeout=None): self.timeout = timeout def post(self, url, params=None, data=None, headers=None, **kwargs): - return MinimalResponse(self.session.post( + return MinimalResponse(requests_resp=self.session.post( url, params=params, data=data, headers=headers, timeout=self.timeout)) def get(self, url, params=None, headers=None, **kwargs): - return MinimalResponse(self.session.get( + return MinimalResponse(requests_resp=self.session.get( url, params=params, headers=headers, timeout=self.timeout)) diff --git a/tests/test_application.py b/tests/test_application.py index ff09ac42..39becd5a 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -159,7 +159,7 @@ def tester(url, data=None, **kwargs): self.frt, data.get("refresh_token"), "Should attempt the FRT") return MinimalResponse( status_code=200, text=json.dumps(TokenCacheTestCase.build_response( - uid=self.uid, utid=self.utid, foci="1", access_token="at"))) + uid=self.uid, utid=self.utid, foci="1", access_token="at"))) app = ClientApplication( "unknown_family_app", authority=self.authority_url, token_cache=self.cache) at = app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( diff --git a/tests/test_client.py b/tests/test_client.py index 4851c79d..44236524 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -84,6 +84,7 @@ class TestClient(Oauth2TestCase): @classmethod def setUpClass(cls): + http_client = MinimalHttpClient() if "client_certificate" in CONFIG: private_key_path = CONFIG["client_certificate"]["private_key_path"] with open(os.path.join(THIS_FOLDER, private_key_path)) as f: @@ -91,7 +92,7 @@ def setUpClass(cls): cls.client = Client( CONFIG["openid_configuration"], CONFIG['client_id'], - MinimalHttpClient(), + http_client, client_assertion=JwtSigner( private_key, algorithm="RS256", @@ -104,8 +105,7 @@ def setUpClass(cls): ) else: cls.client = Client( - CONFIG["openid_configuration"], CONFIG['client_id'], - http_client=MinimalHttpClient(), + CONFIG["openid_configuration"], CONFIG['client_id'], http_client, client_secret=CONFIG.get('client_secret')) @unittest.skipIf( From ad245a70480a8fd86e7893a5706a1459eab0557e Mon Sep 17 00:00:00 2001 From: Abhidnya Date: Mon, 20 Apr 2020 11:07:11 -0700 Subject: [PATCH 29/33] Adding back accidentally deleted blank line --- msal/oauth2cli/oauth2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/msal/oauth2cli/oauth2.py b/msal/oauth2cli/oauth2.py index f515ad93..ed13771f 100644 --- a/msal/oauth2cli/oauth2.py +++ b/msal/oauth2cli/oauth2.py @@ -477,3 +477,4 @@ def obtain_token_by_assertion( data = kwargs.pop("data", {}) data.update(scope=scope, assertion=encoder(assertion)) return self._obtain_token(grant_type, data=data, **kwargs) + From 34bb127fdca511084ec14ad63e29d75901127667 Mon Sep 17 00:00:00 2001 From: Abhidnya Date: Mon, 20 Apr 2020 11:25:56 -0700 Subject: [PATCH 30/33] Removing extra indent --- msal/oauth2cli/oauth2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/msal/oauth2cli/oauth2.py b/msal/oauth2cli/oauth2.py index ed13771f..0cf4c280 100644 --- a/msal/oauth2cli/oauth2.py +++ b/msal/oauth2cli/oauth2.py @@ -477,4 +477,4 @@ def obtain_token_by_assertion( data = kwargs.pop("data", {}) data.update(scope=scope, assertion=encoder(assertion)) return self._obtain_token(grant_type, data=data, **kwargs) - + From a7459495e63c6659bf434466de158e7853fdcece Mon Sep 17 00:00:00 2001 From: Abhidnya Date: Mon, 20 Apr 2020 11:30:36 -0700 Subject: [PATCH 31/33] Minor line add in authority.py --- msal/authority.py | 1 + 1 file changed, 1 insertion(+) diff --git a/msal/authority.py b/msal/authority.py index 183176a2..94caaab4 100644 --- a/msal/authority.py +++ b/msal/authority.py @@ -126,3 +126,4 @@ def tenant_discovery(tenant_discovery_endpoint, http_client, **kwargs): if 'authorization_endpoint' in payload and 'token_endpoint' in payload: return payload raise MsalServiceError(status_code=resp.status_code, **payload) + From b636baac953e731774237e224b8a10778f5a3cfe Mon Sep 17 00:00:00 2001 From: Abhidnya Patil Date: Mon, 20 Apr 2020 16:52:45 -0700 Subject: [PATCH 32/33] Making changes for backward compatibility --- msal/oauth2cli/oauth2.py | 45 +++++++++++++++++++++++++++++++++------- tests/test_client.py | 5 +++-- 2 files changed, 41 insertions(+), 9 deletions(-) diff --git a/msal/oauth2cli/oauth2.py b/msal/oauth2cli/oauth2.py index 0cf4c280..fac35f1b 100644 --- a/msal/oauth2cli/oauth2.py +++ b/msal/oauth2cli/oauth2.py @@ -12,6 +12,10 @@ import time import base64 import sys +import functools + +import requests + string_types = (str,) if sys.version_info[0] >= 3 else (basestring, ) @@ -33,12 +37,15 @@ def __init__( self, server_configuration, # type: dict client_id, # type: str - http_client, # type: http.HttpClient + http_client=None, # We insert it here to match the upcoming async API client_secret=None, # type: Optional[str] client_assertion=None, # type: Union[bytes, callable, None] client_assertion_type=None, # type: Optional[str] default_headers=None, # type: Optional[dict] default_body=None, # type: Optional[dict] + verify=True, # type: Union[str, True, False, None] + proxies=None, # type: Optional[dict] + timeout=None, # type: Union[tuple, float, None] ): """Initialize a client object to talk all the OAuth2 grants to the server. @@ -53,8 +60,9 @@ def __init__( or https://example.com/.../.well-known/openid-configuration client_id (str): The client's id, issued by the authorization server - http_client (object): - An http.HttpClient-like object, e.g. requests.Session + http_client (http.HttpClient): + Your implementation of abstract class :class:`http.HttpClient`. + Defaults to a requests session instance. client_secret (str): Triggers HTTP AUTH for Confidential Client client_assertion (bytes, callable): The client assertion to authenticate this client, per RFC 7521. @@ -75,6 +83,22 @@ def __init__( if your authorization server wants it to be in the request body (rather than in the request header). + verify (boolean): + It will be passed to the + `verify parameter in the underlying requests library + `_ + This does not apply if you have chosen to pass your own Http client. + proxies (dict): + It will be passed to the + `proxies parameter in the underlying requests library + `_ + This does not apply if you have chosen to pass your own Http client. + timeout (object): + It will be passed to the + `timeout parameter in the underlying requests library + `_ + This does not apply if you have chosen to pass your own Http client. + There is no session-wide `timeout` parameter defined here. The timeout behavior is determined by the actual http client you use. If you happen to use Requests, it chose to not support session-wide timeout @@ -94,7 +118,15 @@ def __init__( if client_assertion_type is not None: self.default_body["client_assertion_type"] = client_assertion_type self.logger = logging.getLogger(__name__) - self.http_client = http_client + if http_client: + self.http_client = http_client + else: + self.http_client = requests.Session() + self.http_client.verify = verify + self.http_client.proxies = proxies + self.http_client.request = functools.partial( + # A workaround for requests not supporting session-wide timeout + self.http_client.request, timeout=timeout) def _build_auth_request_params(self, response_type, **kwargs): # response_type is a string defined in @@ -388,13 +420,12 @@ class initialization. return self._obtain_token("client_credentials", data=data, **kwargs) def __init__(self, - server_configuration, client_id, http_client, + server_configuration, client_id, on_obtaining_tokens=lambda event: None, # event is defined in _obtain_token(...) on_removing_rt=lambda token_item: None, on_updating_rt=lambda token_item, new_rt: None, **kwargs): - super(Client, self).__init__( - server_configuration, client_id, http_client, **kwargs) + super(Client, self).__init__(server_configuration, client_id, **kwargs) self.on_obtaining_tokens = on_obtaining_tokens self.on_removing_rt = on_removing_rt self.on_updating_rt = on_updating_rt diff --git a/tests/test_client.py b/tests/test_client.py index 44236524..75cdfc9c 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -92,7 +92,7 @@ def setUpClass(cls): cls.client = Client( CONFIG["openid_configuration"], CONFIG['client_id'], - http_client, + http_client=http_client, client_assertion=JwtSigner( private_key, algorithm="RS256", @@ -105,7 +105,8 @@ def setUpClass(cls): ) else: cls.client = Client( - CONFIG["openid_configuration"], CONFIG['client_id'], http_client, + CONFIG["openid_configuration"], CONFIG['client_id'], + http_client=http_client, client_secret=CONFIG.get('client_secret')) @unittest.skipIf( From d24d575dfcc2558607e471705803e2ce888b42ab Mon Sep 17 00:00:00 2001 From: Abhidnya Patil Date: Mon, 20 Apr 2020 17:05:38 -0700 Subject: [PATCH 33/33] Some more editorial changes --- msal/application.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/msal/application.py b/msal/application.py index 69c9b2ac..63d24f17 100644 --- a/msal/application.py +++ b/msal/application.py @@ -148,18 +148,18 @@ def __init__( :param verify: (optional) It will be passed to the `verify parameter in the underlying requests library - This does not apply if you have chosen to pass your own Http client `_ + This does not apply if you have chosen to pass your own Http client :param proxies: (optional) It will be passed to the `proxies parameter in the underlying requests library - This does not apply if you have chosen to pass your own Http client `_ + This does not apply if you have chosen to pass your own Http client :param timeout: (optional) It will be passed to the `timeout parameter in the underlying requests library - This does not apply if you have chosen to pass your own Http client `_ + This does not apply if you have chosen to pass your own Http client :param app_name: (optional) You can provide your application name for Microsoft telemetry purposes. Default value is None, means it will not be passed to Microsoft. @@ -291,7 +291,7 @@ def get_authorization_request_url( client = Client( {"authorization_endpoint": the_authority.authorization_endpoint}, self.client_id, - self.http_client) + http_client=self.http_client) return client.build_auth_request_uri( response_type=response_type, redirect_uri=redirect_uri, state=state, login_hint=login_hint,