diff --git a/msal/oauth2cli/oauth2.py b/msal/oauth2cli/oauth2.py index 9a947390..404f75d3 100644 --- a/msal/oauth2cli/oauth2.py +++ b/msal/oauth2cli/oauth2.py @@ -16,6 +16,7 @@ string_types = (str,) if sys.version_info[0] >= 3 else (basestring, ) +DAE = "device_authorization_endpoint" class BaseClient(object): @@ -148,11 +149,8 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749 raise ValueError("token_endpoint not found in configuration") _headers = {'Accept': 'application/json'} _headers.update(headers or {}) - resp = (post or self.session.post)( - self.configuration["token_endpoint"], - headers=_headers, params=params, data=_data, auth=auth, - timeout=timeout or self.timeout, - **kwargs) + resp = self.send_obtain_token(headers=_headers, params=params, data=_data, + auth=auth, timeout=timeout or self.timeout, post=post, **kwargs) if resp.status_code >= 500: resp.raise_for_status() # TODO: Will probably retry here try: @@ -165,6 +163,13 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749 "Token response is not in json format: %s", resp.text) raise + def send_obtain_token(self, headers, params, data, auth, timeout, post, **kwargs): + return (post or self.session.post)( + self.configuration["token_endpoint"], + headers=headers, params=params, data=data, auth=auth, + timeout=timeout, **kwargs + ) + def obtain_token_by_refresh_token(self, refresh_token, scope=None, **kwargs): # type: (str, Union[str, list, set, tuple]) -> dict """Obtain an access token via a refresh token. @@ -215,18 +220,29 @@ def initiate_device_flow(self, scope=None, timeout=None, **kwargs): And possibly here https://tools.ietf.org/html/draft-ietf-oauth-device-flow-12#section-3.3.1 """ - DAE = "device_authorization_endpoint" if not self.configuration.get(DAE): raise ValueError("You need to provide device authorization endpoint") - flow = self.session.post(self.configuration[DAE], - data={"client_id": self.client_id, "scope": self._stringify(scope or [])}, - timeout=timeout or self.timeout, - **kwargs).json() + flow = self.get_dae_json( + self.send_dae_request( + data={"client_id": self.client_id, "scope": self._stringify(scope or [])}, + timeout=timeout or self.timeout, + **kwargs + ) + ) flow["interval"] = int(flow.get("interval", 5)) # Some IdP returns string flow["expires_in"] = int(flow.get("expires_in", 1800)) flow["expires_at"] = time.time() + flow["expires_in"] # We invent this return flow + def send_dae_request(self, data, timeout, **kwargs): + return self.session.post(self.configuration[DAE], + data=data, + timeout=timeout, + **kwargs) + + def get_dae_json(self, resp): + return resp.json() + def _obtain_token_by_device_flow(self, flow, **kwargs): # type: (dict, **dict) -> dict # This method updates flow during each run. And it is non-blocking.