diff --git a/msal/oauth2cli/authcode.py b/msal/oauth2cli/authcode.py index 24e3f642..ec7ac93c 100644 --- a/msal/oauth2cli/authcode.py +++ b/msal/oauth2cli/authcode.py @@ -8,6 +8,8 @@ import logging import socket from string import Template +import threading +import time try: # Python 3 from http.server import HTTPServer, BaseHTTPRequestHandler @@ -149,11 +151,7 @@ def get_port(self): # https://docs.python.org/2.7/library/socketserver.html#SocketServer.BaseServer.server_address return self._server.server_address[1] - def get_auth_response(self, auth_uri=None, timeout=None, state=None, - welcome_template=None, success_template=None, error_template=None, - auth_uri_callback=None, - browser_name=None, - ): + def get_auth_response(self, timeout=None, **kwargs): """Wait and return the auth response. Raise RuntimeError when timeout. :param str auth_uri: @@ -192,6 +190,37 @@ def get_auth_response(self, auth_uri=None, timeout=None, state=None, and https://openid.net/specs/openid-connect-core-1_0.html#AuthResponse Returns None when the state was mismatched, or when timeout occurred. """ + # Historically, the _get_auth_response() uses HTTPServer.handle_request(), + # because its handle-and-retry logic is conceptually as easy as a while loop. + # Also, handle_request() honors server.timeout setting, and CTRL+C simply works. + # All those are true when running on Linux. + # + # However, the behaviors on Windows turns out to be different. + # A socket server waiting for request would freeze the current thread. + # Neither timeout nor CTRL+C would work. End user would have to do CTRL+BREAK. + # https://stackoverflow.com/questions/1364173/stopping-python-using-ctrlc + # + # The solution would need to somehow put the http server into its own thread. + # This could be done by the pattern of ``http.server.test()`` which internally + # use ``ThreadingHTTPServer.serve_forever()`` (only available in Python 3.7). + # Or create our own thread to wrap the HTTPServer.handle_request() inside. + result = {} # A mutable object to be filled with thread's return value + t = threading.Thread( + target=self._get_auth_response, args=(result,), kwargs=kwargs) + t.daemon = True # So that it won't prevent the main thread from exiting + t.start() + begin = time.time() + while (time.time() - begin < timeout) if timeout else True: + time.sleep(1) # Short detection interval to make happy path responsive + if not t.is_alive(): # Then the thread has finished its job and exited + break + return result or None + + def _get_auth_response(self, result, auth_uri=None, timeout=None, state=None, + welcome_template=None, success_template=None, error_template=None, + auth_uri_callback=None, + browser_name=None, + ): welcome_uri = "http://localhost:{p}".format(p=self.get_port()) abort_uri = "{loc}?error=abort".format(loc=welcome_uri) logger.debug("Abort by visit %s", abort_uri) @@ -238,7 +267,7 @@ def get_auth_response(self, auth_uri=None, timeout=None, state=None, logger.debug("State mismatch. Ignoring this noise.") else: break - return self._server.auth_response + result.update(self._server.auth_response) # Return via writable result param def close(self): """Either call this eventually; or use the entire class as context manager""" diff --git a/msal/oauth2cli/oauth2.py b/msal/oauth2cli/oauth2.py index 305061cf..8d337bb9 100644 --- a/msal/oauth2cli/oauth2.py +++ b/msal/oauth2cli/oauth2.py @@ -199,7 +199,7 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749 _data["client_assertion"] = encoder( self.client_assertion() # Do lazy on-the-fly computation if callable(self.client_assertion) else self.client_assertion - ) # The type is bytes, which is preferrable. See also: + ) # The type is bytes, which is preferable. See also: # https://github.com/psf/requests/issues/4503#issuecomment-455001070 _data.update(self.default_body) # It may contain authen parameters