-
Notifications
You must be signed in to change notification settings - Fork 114
OAuth implementation #15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 54 commits
Commits
Show all changes
58 commits
Select commit
Hold shift + click to select a range
5581ad6
Reformat changelog (#11)
68a1903
oauth implementation initial work
moderakh 20e888b
responde to code review comments
moderakh cff24d9
Update src/databricks/sql/auth/authenticators.py
moderakh d99af2d
Update src/databricks/sql/auth/authenticators.py
moderakh 09986ab
Update src/databricks/sql/auth/authenticators.py
moderakh a703f58
responded to review comments
moderakh 70975fe
added unit tests for legacy auth providers (PAT, User/Pass)
moderakh bccf869
added more tests
moderakh bedc274
replaced click with logging
moderakh 0bde8fe
responded to code review comments
moderakh a1ffcec
made use of quotes consitent on oauth.py
moderakh 681a64b
addressed f string related comments
moderakh f8bb7e9
removed client method
moderakh 774c882
responded to review comments
moderakh 4a2ca6f
added requests as an explicit dependency as that's required by oauth
moderakh 61f2ec0
responded to review comments
moderakh adce3f3
responded to review comments
moderakh dcec176
cleanup
moderakh 9642692
cleanup
moderakh 964ddd2
added support for persistence
moderakh b062674
Add e2e tests (#12)
susodapop 1f30744
Indicate that Python 3.10 is not supported (#27)
e851ea4
Add Developer Certificate of Origin requirement (#13)
63894d7
Retry attempts that fail due to a connection timeout (#24)
36c6f4d
Bump to v2.0.3 (#28)
78015e5
Bump version to 2.0.4-dev (#29)
cdef1a5
[PECO-197] Support Python 3.10 (#31)
dbaxa 39efd0a
Update changelog and bump to v2.0.4 (#34)
643425f
Bump to 2.0.5-dev on main (#35)
69b7b8c
On Pypi, display the "Project Links" sidebar. (#36)
23cd570
[ES-402013] Close cursors before closing connection (#38)
353f413
Bump version to 2.0.5 and improve CHANGELOG (#40)
737fd7e
minor fixes
moderakh 2274c3f
Merge remote-tracking branch 'databricks/main' into PECO-188
moderakh 9e40f18
fixed token refresh and persistent
moderakh f9cb9f9
updated comment
moderakh b6e1fde
added type annotation
moderakh 2a69715
addressed code review comment (use python3 api)
moderakh beb1f64
made http request handler an independent class, removed global arg, r…
moderakh 930cea8
added pull request ci trigger
moderakh 8af7b5f
restructured the code as class
moderakh dac3c1e
fixed test_thrift_backend.py tests
moderakh 0fc1684
cleanup
moderakh 4fe5a58
Update src/databricks/sql/experimental/oauth_persistence.py
moderakh 8ab4e3f
cleanup
moderakh fb012b0
cleanup
moderakh 6881226
cleanup
moderakh 3e537d5
cleanup
moderakh 00c403d
cleanup
moderakh e64df63
Update src/databricks/sql/auth/thrift_http_client.py
moderakh ef385e9
cleanup
moderakh 27fb3b5
cleanup
moderakh 0a6c455
cleanup
moderakh 16aa44e
moved access_token out of kwargs
moderakh be19297
added hostname to the persitence api
moderakh 367a3ee
responded to review comments
moderakh 8476673
fix lint
moderakh File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
from enum import Enum | ||
from typing import List | ||
|
||
from databricks.sql.auth.authenticators import ( | ||
CredentialsProvider, | ||
AccessTokenAuthProvider, | ||
BasicAuthProvider, | ||
DatabricksOAuthProvider, | ||
) | ||
from databricks.sql.experimental.oauth_persistence import OAuthPersistence | ||
|
||
|
||
class AuthType(Enum): | ||
DATABRICKS_OAUTH = "databricks-oauth" | ||
# other supported types (access_token, user/pass) can be inferred | ||
# we can add more types as needed later | ||
|
||
|
||
class ClientContext: | ||
def __init__( | ||
self, | ||
hostname: str, | ||
username: str = None, | ||
password: str = None, | ||
access_token: str = None, | ||
auth_type: str = None, | ||
oauth_scopes: List[str] = None, | ||
oauth_client_id: str = None, | ||
oauth_redirect_port_range: List[int] = None, | ||
use_cert_as_auth: str = None, | ||
tls_client_cert_file: str = None, | ||
oauth_persistence=None, | ||
): | ||
self.hostname = hostname | ||
self.username = username | ||
self.password = password | ||
self.access_token = access_token | ||
self.auth_type = auth_type | ||
self.oauth_scopes = oauth_scopes | ||
self.oauth_client_id = oauth_client_id | ||
self.oauth_redirect_port_range = oauth_redirect_port_range | ||
self.use_cert_as_auth = use_cert_as_auth | ||
self.tls_client_cert_file = tls_client_cert_file | ||
self.oauth_persistence = oauth_persistence | ||
|
||
|
||
def get_auth_provider(cfg: ClientContext): | ||
if cfg.auth_type == AuthType.DATABRICKS_OAUTH.value: | ||
assert cfg.oauth_redirect_port_range is not None | ||
assert cfg.oauth_client_id is not None | ||
assert cfg.oauth_scopes is not None | ||
|
||
return DatabricksOAuthProvider( | ||
cfg.hostname, | ||
cfg.oauth_persistence, | ||
cfg.oauth_redirect_port_range, | ||
cfg.oauth_client_id, | ||
cfg.oauth_scopes, | ||
) | ||
elif cfg.access_token is not None: | ||
return AccessTokenAuthProvider(cfg.access_token) | ||
elif cfg.username is not None and cfg.password is not None: | ||
return BasicAuthProvider(cfg.username, cfg.password) | ||
elif cfg.use_cert_as_auth and cfg.tls_client_cert_file: | ||
# no op authenticator. authentication is performed using ssl certificate outside of headers | ||
return CredentialsProvider() | ||
else: | ||
raise RuntimeError("No valid authentication settings!") | ||
|
||
|
||
PYSQL_OAUTH_SCOPES = ["sql", "offline_access"] | ||
PYSQL_OAUTH_CLIENT_ID = "databricks-sql-python" | ||
PYSQL_OAUTH_REDIRECT_PORT_RANGE = list(range(8020, 8025)) | ||
|
||
|
||
def normalize_host_name(hostname: str): | ||
moderakh marked this conversation as resolved.
Show resolved
Hide resolved
|
||
maybe_scheme = "https://" if not hostname.startswith("https://") else "" | ||
maybe_trailing_slash = "/" if not hostname.endswith("/") else "" | ||
return f"{maybe_scheme}{hostname}{maybe_trailing_slash}" | ||
|
||
|
||
def get_python_sql_connector_auth_provider(hostname: str, **kwargs): | ||
cfg = ClientContext( | ||
hostname=normalize_host_name(hostname), | ||
auth_type=kwargs.get("auth_type"), | ||
access_token=kwargs.get("access_token"), | ||
username=kwargs.get("_username"), | ||
password=kwargs.get("_password"), | ||
use_cert_as_auth=kwargs.get("_use_cert_as_auth"), | ||
tls_client_cert_file=kwargs.get("_tls_client_cert_file"), | ||
oauth_scopes=PYSQL_OAUTH_SCOPES, | ||
oauth_client_id=PYSQL_OAUTH_CLIENT_ID, | ||
oauth_redirect_port_range=PYSQL_OAUTH_REDIRECT_PORT_RANGE, | ||
oauth_persistence=kwargs.get("experimental_oauth_persistence"), | ||
) | ||
return get_auth_provider(cfg) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
import base64 | ||
import logging | ||
from typing import Dict, List | ||
|
||
from databricks.sql.auth.oauth import OAuthManager | ||
|
||
# Private API: this is an evolving interface and it will change in the future. | ||
# Please must not depend on it in your applications. | ||
from databricks.sql.experimental.oauth_persistence import OAuthToken, OAuthPersistence | ||
|
||
|
||
class CredentialsProvider: | ||
moderakh marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def add_headers(self, request_headers: Dict[str, str]): | ||
pass | ||
moderakh marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
# Private API: this is an evolving interface and it will change in the future. | ||
# Please must not depend on it in your applications. | ||
class AccessTokenAuthProvider(CredentialsProvider): | ||
def __init__(self, access_token: str): | ||
self.__authorization_header_value = "Bearer {}".format(access_token) | ||
|
||
def add_headers(self, request_headers: Dict[str, str]): | ||
request_headers["Authorization"] = self.__authorization_header_value | ||
|
||
|
||
# Private API: this is an evolving interface and it will change in the future. | ||
# Please must not depend on it in your applications. | ||
class BasicAuthProvider(CredentialsProvider): | ||
def __init__(self, username: str, password: str): | ||
auth_credentials = f"{username}:{password}".encode("UTF-8") | ||
auth_credentials_base64 = base64.standard_b64encode(auth_credentials).decode( | ||
"UTF-8" | ||
) | ||
|
||
self.__authorization_header_value = f"Basic {auth_credentials_base64}" | ||
|
||
def add_headers(self, request_headers: Dict[str, str]): | ||
request_headers["Authorization"] = self.__authorization_header_value | ||
|
||
|
||
# Private API: this is an evolving interface and it will change in the future. | ||
# Please must not depend on it in your applications. | ||
class DatabricksOAuthProvider(CredentialsProvider): | ||
SCOPE_DELIM = " " | ||
|
||
def __init__( | ||
self, | ||
hostname: str, | ||
oauth_persistence: OAuthPersistence, | ||
redirect_port_range: List[int], | ||
client_id: str, | ||
scopes: List[str], | ||
): | ||
try: | ||
self.oauth_manager = OAuthManager( | ||
port_range=redirect_port_range, client_id=client_id | ||
) | ||
self._hostname = hostname | ||
self._scopes_as_str = DatabricksOAuthProvider.SCOPE_DELIM.join(scopes) | ||
self._oauth_persistence = oauth_persistence | ||
self._client_id = client_id | ||
self._access_token = None | ||
self._refresh_token = None | ||
self._initial_get_token() | ||
except Exception as e: | ||
logging.error(f"unexpected error", e, exc_info=True) | ||
raise e | ||
|
||
def add_headers(self, request_headers: Dict[str, str]): | ||
self._update_token_if_expired() | ||
request_headers["Authorization"] = f"Bearer {self._access_token}" | ||
|
||
def _initial_get_token(self): | ||
try: | ||
if self._access_token is None or self._refresh_token is None: | ||
if self._oauth_persistence: | ||
token = self._oauth_persistence.read() | ||
if token: | ||
self._access_token = token.access_token | ||
self._refresh_token = token.refresh_token | ||
|
||
if self._access_token and self._refresh_token: | ||
self._update_token_if_expired() | ||
else: | ||
(access_token, refresh_token) = self.oauth_manager.get_tokens( | ||
hostname=self._hostname, scope=self._scopes_as_str | ||
) | ||
self._access_token = access_token | ||
self._refresh_token = refresh_token | ||
self._oauth_persistence.persist(OAuthToken(access_token, refresh_token)) | ||
except Exception as e: | ||
logging.error(f"unexpected error in oauth initialization", e, exc_info=True) | ||
raise e | ||
|
||
def _update_token_if_expired(self): | ||
try: | ||
( | ||
fresh_access_token, | ||
fresh_refresh_token, | ||
is_refreshed, | ||
) = self.oauth_manager.check_and_refresh_access_token( | ||
hostname=self._hostname, | ||
access_token=self._access_token, | ||
refresh_token=self._refresh_token, | ||
) | ||
if not is_refreshed: | ||
return | ||
else: | ||
self._access_token = fresh_access_token | ||
self._refresh_token = fresh_refresh_token | ||
|
||
if self._oauth_persistence: | ||
token = OAuthToken(self._access_token, self._refresh_token) | ||
self._oauth_persistence.persist(token) | ||
except Exception as e: | ||
logging.error(f"unexpected error in oauth token update", e, exc_info=True) | ||
raise e |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.