Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 26 additions & 10 deletions msal/oauth2cli/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@


string_types = (str,) if sys.version_info[0] >= 3 else (basestring, )
DAE = "device_authorization_endpoint"


class BaseClient(object):
Expand Down Expand Up @@ -148,11 +149,8 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749
raise ValueError("token_endpoint not found in configuration")
_headers = {'Accept': 'application/json'}
_headers.update(headers or {})
resp = (post or self.session.post)(
self.configuration["token_endpoint"],
headers=_headers, params=params, data=_data, auth=auth,
timeout=timeout or self.timeout,
**kwargs)
resp = self.send_obtain_token(headers=_headers, params=params, data=_data,
auth=auth, timeout=timeout or self.timeout, post=post, **kwargs)
if resp.status_code >= 500:
resp.raise_for_status() # TODO: Will probably retry here
try:
Expand All @@ -165,6 +163,13 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749
"Token response is not in json format: %s", resp.text)
raise

def send_obtain_token(self, headers, params, data, auth, timeout, post, **kwargs):
return (post or self.session.post)(
self.configuration["token_endpoint"],
headers=headers, params=params, data=data, auth=auth,
timeout=timeout, **kwargs
)

def obtain_token_by_refresh_token(self, refresh_token, scope=None, **kwargs):
# type: (str, Union[str, list, set, tuple]) -> dict
"""Obtain an access token via a refresh token.
Expand Down Expand Up @@ -215,18 +220,29 @@ def initiate_device_flow(self, scope=None, timeout=None, **kwargs):
And possibly here
https://tools.ietf.org/html/draft-ietf-oauth-device-flow-12#section-3.3.1
"""
DAE = "device_authorization_endpoint"
if not self.configuration.get(DAE):
raise ValueError("You need to provide device authorization endpoint")
flow = self.session.post(self.configuration[DAE],
data={"client_id": self.client_id, "scope": self._stringify(scope or [])},
timeout=timeout or self.timeout,
**kwargs).json()
flow = self.get_dae_json(
self.send_dae_request(
data={"client_id": self.client_id, "scope": self._stringify(scope or [])},
timeout=timeout or self.timeout,
**kwargs
)
)
flow["interval"] = int(flow.get("interval", 5)) # Some IdP returns string
flow["expires_in"] = int(flow.get("expires_in", 1800))
flow["expires_at"] = time.time() + flow["expires_in"] # We invent this
return flow

def send_dae_request(self, data, timeout, **kwargs):
return self.session.post(self.configuration[DAE],
data=data,
timeout=timeout,
**kwargs)

def get_dae_json(self, resp):
return resp.json()

def _obtain_token_by_device_flow(self, flow, **kwargs):
# type: (dict, **dict) -> dict
# This method updates flow during each run. And it is non-blocking.
Expand Down