From 5581ad6c4ca61f893f81516d412bdf9fcd00a9c6 Mon Sep 17 00:00:00 2001 From: Jesse Date: Wed, 29 Jun 2022 16:08:08 -0500 Subject: [PATCH 01/57] Reformat changelog (#11) Signed-off-by: Moe Derakhshani --- CHANGELOG.md | 36 ++++++++++++++++++++++-------------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index dd31de421..2730415a0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,51 +1,59 @@ -v2.0.2 - May 4, 2022 +# Release History + +## 2.x.x (Unreleased) + +- Reorganised code to use Poetry for dependency management. +## 2.0.2 (2022-05-04) - Better exception handling in automatic connection close -v2.0.1 - April 21, 2022 +## 2.0.1 (2022-04-21) - Fixed Pandas dependency in setup.cfg to be >= 1.2.0 -v2.0.0 - April 19, 2022 +## 2.0.0 (2022-04-19) - Initial stable release of V2 - Added better support for complex types, so that in Databricks runtime 10.3+, Arrays, Maps and Structs will get deserialized as lists, lists of tuples and dicts, respectively. - Changed the name of the metadata arg to http_headers -v2.0.b2 - April 4, 2022 +## 2.0.b2 (2022-04-04) - Change import of collections.Iterable to collections.abc.Iterable to make the library compatible with Python 3.10 - Fixed bug with .tables method so that .tables works as expected with Unity-Catalog enabled endpoints -v2.0.0b1 - March 4, 2022 +## 2.0.0b1 (2022-03-04) - Fix packaging issue (dependencies were not being installed properly) - Fetching timestamp results will now return aware instead of naive timestamps - The client will now default to using simplified error messages -v2.0.0b - February 8, 2022 +## 2.0.0b (2022-02-08) - Initial beta release of V2. V2 is an internal re-write of large parts of the connector to use Databricks edge features. All public APIs from V1 remain. - Added Unity Catalog support (pass catalog and / or schema key word args to the .connect method to select initial schema and catalog) -**Note**: The code for versions prior to `v2.0.0b` is not contained in this repository. +--- + +**Note**: The code for versions prior to `v2.0.0b` is not contained in this repository. The below entries are included for reference only. -v1.0.0 - January 20, 2022 +--- +## 1.0.0 (2022-01-20) - Add operations for retrieving metadata - Add the ability to access columns by name on result rows - Add the ability to provide configuration settings on connect -v0.9.4 - January 10, 2022 +## 0.9.4 (2022-01-10) - Improved logging and error messages. -v0.9.3 - December 8, 2021 +## 0.9.3 (2021-12-08) - Add retries for 429 and 503 HTTP responses. -v0.9.2 - December 2, 2021 +## 0.9.2 (2021-12-02) - (Bug fix) Increased Thrift requirement from 0.10.0 to 0.13.0 as 0.10.0 was in fact incompatible -- (Bug fix) Fixed error message after query execution failed - SQLSTATE and Error message were misplaced +- (Bug fix) Fixed error message after query execution failed -SQLSTATE and Error message were misplaced -v0.9.1 - Sept 1, 2021 +## 0.9.1 (2021-09-01) - Public Preview release, Experimental tag removed - minor updates in internal build/packaging - no functional changes -v0.9.0 - Aug 4, 2021 +## 0.9.0 (2021-08-04) - initial (Experimental) release of pyhive-forked connector - Python DBAPI 2.0 (PEP-0249), thrift based - see docs for more info: https://docs.databricks.com/dev-tools/python-sql-connector.html From 68a19038b9f949fc09f049652125a15a452ddf93 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Tue, 12 Jul 2022 15:03:03 -0700 Subject: [PATCH 02/57] oauth implementation initial work Signed-off-by: Moe Derakhshani --- poetry.lock | 311 +++++------------- pyproject.toml | 3 + src/databricks/sql/__init__.py | 5 +- src/databricks/sql/auth/__init__.py | 22 ++ src/databricks/sql/auth/auth.py | 103 ++++++ src/databricks/sql/auth/authenticators.py | 78 +++++ src/databricks/sql/auth/oauth.py | 286 ++++++++++++++++ src/databricks/sql/auth/thrift_http_client.py | 48 +++ src/databricks/sql/client.py | 33 +- src/databricks/sql/thrift_backend.py | 14 +- 10 files changed, 634 insertions(+), 269 deletions(-) create mode 100644 src/databricks/sql/auth/__init__.py create mode 100644 src/databricks/sql/auth/auth.py create mode 100644 src/databricks/sql/auth/authenticators.py create mode 100644 src/databricks/sql/auth/oauth.py create mode 100644 src/databricks/sql/auth/thrift_http_client.py diff --git a/poetry.lock b/poetry.lock index 9bc3ae8ae..782e82735 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,6 +1,6 @@ [[package]] name = "atomicwrites" -version = "1.4.0" +version = "1.4.1" description = "Atomic file writes." category = "dev" optional = false @@ -22,7 +22,7 @@ tests_no_zope = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (> [[package]] name = "black" -version = "22.3.0" +version = "22.6.0" description = "The uncompromising code formatter." category = "dev" optional = false @@ -33,7 +33,7 @@ click = ">=8.0.0" mypy-extensions = ">=0.4.3" pathspec = ">=0.9.0" platformdirs = ">=2" -tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} +tomli = {version = ">=1.1.0", markers = "python_full_version < \"3.11.0a7\""} typed-ast = {version = ">=1.4.2", markers = "python_version < \"3.8\" and implementation_name == \"cpython\""} typing-extensions = {version = ">=3.10.0.0", markers = "python_version < \"3.10\""} @@ -47,7 +47,7 @@ uvloop = ["uvloop (>=0.15.2)"] name = "click" version = "8.1.3" description = "Composable command line interface toolkit" -category = "dev" +category = "main" optional = false python-versions = ">=3.7" @@ -57,17 +57,17 @@ importlib-metadata = {version = "*", markers = "python_version < \"3.8\""} [[package]] name = "colorama" -version = "0.4.4" +version = "0.4.5" description = "Cross-platform colored terminal text." -category = "dev" +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" [[package]] name = "importlib-metadata" -version = "4.11.3" +version = "4.12.0" description = "Read metadata from Python packages" -category = "dev" +category = "main" optional = false python-versions = ">=3.7" @@ -78,7 +78,7 @@ zipp = ">=0.5" [package.extras] docs = ["sphinx", "jaraco.packaging (>=9)", "rst.linker (>=1.9)"] perf = ["ipython"] -testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-cov", "pytest-enabler (>=1.0.1)", "packaging", "pyfakefs", "flufl.flake8", "pytest-perf (>=0.9.2)", "pytest-black (>=0.3.7)", "pytest-mypy (>=0.9.1)", "importlib-resources (>=1.3)"] +testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-cov", "pytest-enabler (>=1.3)", "packaging", "pyfakefs", "flufl.flake8", "pytest-perf (>=0.9.2)", "pytest-black (>=0.3.7)", "pytest-mypy (>=0.9.1)", "importlib-resources (>=1.3)"] [[package]] name = "iniconfig" @@ -117,11 +117,32 @@ python-versions = "*" [[package]] name = "numpy" -version = "1.21.1" +version = "1.21.6" description = "NumPy is the fundamental package for array computing with Python." category = "main" optional = false -python-versions = ">=3.7" +python-versions = ">=3.7,<3.11" + +[[package]] +name = "numpy" +version = "1.23.0" +description = "NumPy is the fundamental package for array computing with Python." +category = "main" +optional = false +python-versions = ">=3.8" + +[[package]] +name = "oauthlib" +version = "3.2.0" +description = "A generic, spec-compliant, thorough implementation of the OAuth request-signing logic" +category = "main" +optional = false +python-versions = ">=3.6" + +[package.extras] +rsa = ["cryptography (>=3.0.0)"] +signals = ["blinker (>=1.4.0)"] +signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"] [[package]] name = "packaging" @@ -209,6 +230,20 @@ python-versions = ">=3.6" [package.dependencies] numpy = ">=1.16.6" +[[package]] +name = "pyjwt" +version = "2.4.0" +description = "JSON Web Token implementation in Python" +category = "main" +optional = false +python-versions = ">=3.6" + +[package.extras] +crypto = ["cryptography (>=3.3.1)"] +dev = ["sphinx", "sphinx-rtd-theme", "zope.interface", "cryptography (>=3.3.1)", "pytest (>=6.0.0,<7.0.0)", "coverage[toml] (==5.0.4)", "mypy", "pre-commit"] +docs = ["sphinx", "sphinx-rtd-theme", "zope.interface"] +tests = ["pytest (>=6.0.0,<7.0.0)", "coverage[toml] (==5.0.4)"] + [[package]] name = "pyparsing" version = "3.0.9" @@ -295,7 +330,7 @@ python-versions = ">=3.7" [[package]] name = "typed-ast" -version = "1.5.3" +version = "1.5.4" description = "a fork of Python 2 and 3 ast modules with type comment support" category = "dev" optional = false @@ -303,9 +338,9 @@ python-versions = ">=3.6" [[package]] name = "typing-extensions" -version = "4.2.0" +version = "4.3.0" description = "Backported and Experimental Type Hints for Python 3.7+" -category = "dev" +category = "main" optional = false python-versions = ">=3.7" @@ -313,7 +348,7 @@ python-versions = ">=3.7" name = "zipp" version = "3.8.0" description = "Backport of pathlib-compatible object wrapper for zip files" -category = "dev" +category = "main" optional = false python-versions = ">=3.7" @@ -324,58 +359,16 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest- [metadata] lock-version = "1.1" python-versions = "^3.7.1" -content-hash = "9a8934a880c7e31bf7dc9673ee9a9eafe4111ec26ef98298cbe20aa2b7533b52" +content-hash = "677882d16ec4ae384857a2ca2e9e0f55add6cb9691fc0763647051006f692e7a" [metadata.files] -atomicwrites = [ - {file = "atomicwrites-1.4.0-py2.py3-none-any.whl", hash = "sha256:6d1784dea7c0c8d4a5172b6c620f40b6e4cbfdf96d783691f2e1302a7b88e197"}, - {file = "atomicwrites-1.4.0.tar.gz", hash = "sha256:ae70396ad1a434f9c7046fd2dd196fc04b12f9e91ffb859164193be8b6168a7a"}, -] -attrs = [ - {file = "attrs-21.4.0-py2.py3-none-any.whl", hash = "sha256:2d27e3784d7a565d36ab851fe94887c5eccd6a463168875832a1be79c82828b4"}, - {file = "attrs-21.4.0.tar.gz", hash = "sha256:626ba8234211db98e869df76230a137c4c40a12d72445c45d5f5b716f076e2fd"}, -] -black = [ - {file = "black-22.3.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:2497f9c2386572e28921fa8bec7be3e51de6801f7459dffd6e62492531c47e09"}, - {file = "black-22.3.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5795a0375eb87bfe902e80e0c8cfaedf8af4d49694d69161e5bd3206c18618bb"}, - {file = "black-22.3.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e3556168e2e5c49629f7b0f377070240bd5511e45e25a4497bb0073d9dda776a"}, - {file = "black-22.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:67c8301ec94e3bcc8906740fe071391bce40a862b7be0b86fb5382beefecd968"}, - {file = "black-22.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:fd57160949179ec517d32ac2ac898b5f20d68ed1a9c977346efbac9c2f1e779d"}, - {file = "black-22.3.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:cc1e1de68c8e5444e8f94c3670bb48a2beef0e91dddfd4fcc29595ebd90bb9ce"}, - {file = "black-22.3.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d2fc92002d44746d3e7db7cf9313cf4452f43e9ea77a2c939defce3b10b5c82"}, - {file = "black-22.3.0-cp36-cp36m-win_amd64.whl", hash = "sha256:a6342964b43a99dbc72f72812bf88cad8f0217ae9acb47c0d4f141a6416d2d7b"}, - {file = "black-22.3.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:328efc0cc70ccb23429d6be184a15ce613f676bdfc85e5fe8ea2a9354b4e9015"}, - {file = "black-22.3.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:06f9d8846f2340dfac80ceb20200ea5d1b3f181dd0556b47af4e8e0b24fa0a6b"}, - {file = "black-22.3.0-cp37-cp37m-win_amd64.whl", hash = "sha256:ad4efa5fad66b903b4a5f96d91461d90b9507a812b3c5de657d544215bb7877a"}, - {file = "black-22.3.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:e8477ec6bbfe0312c128e74644ac8a02ca06bcdb8982d4ee06f209be28cdf163"}, - {file = "black-22.3.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:637a4014c63fbf42a692d22b55d8ad6968a946b4a6ebc385c5505d9625b6a464"}, - {file = "black-22.3.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:863714200ada56cbc366dc9ae5291ceb936573155f8bf8e9de92aef51f3ad0f0"}, - {file = "black-22.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:10dbe6e6d2988049b4655b2b739f98785a884d4d6b85bc35133a8fb9a2233176"}, - {file = "black-22.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:cee3e11161dde1b2a33a904b850b0899e0424cc331b7295f2a9698e79f9a69a0"}, - {file = "black-22.3.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:5891ef8abc06576985de8fa88e95ab70641de6c1fca97e2a15820a9b69e51b20"}, - {file = "black-22.3.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:30d78ba6bf080eeaf0b7b875d924b15cd46fec5fd044ddfbad38c8ea9171043a"}, - {file = "black-22.3.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ee8f1f7228cce7dffc2b464f07ce769f478968bfb3dd1254a4c2eeed84928aad"}, - {file = "black-22.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6ee227b696ca60dd1c507be80a6bc849a5a6ab57ac7352aad1ffec9e8b805f21"}, - {file = "black-22.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:9b542ced1ec0ceeff5b37d69838106a6348e60db7b8fdd245294dc1d26136265"}, - {file = "black-22.3.0-py3-none-any.whl", hash = "sha256:bc58025940a896d7e5356952228b68f793cf5fcb342be703c3a2669a1488cb72"}, - {file = "black-22.3.0.tar.gz", hash = "sha256:35020b8886c022ced9282b51b5a875b6d1ab0c387b31a065b84db7c33085ca79"}, -] -click = [ - {file = "click-8.1.3-py3-none-any.whl", hash = "sha256:bb4d8133cb15a609f44e8213d9b391b0809795062913b383c62be0ee95b1db48"}, - {file = "click-8.1.3.tar.gz", hash = "sha256:7682dc8afb30297001674575ea00d1814d808d6a36af415a82bd481d37ba7b8e"}, -] -colorama = [ - {file = "colorama-0.4.4-py2.py3-none-any.whl", hash = "sha256:9f47eda37229f68eee03b24b9748937c7dc3868f906e8ba69fbcbdd3bc5dc3e2"}, - {file = "colorama-0.4.4.tar.gz", hash = "sha256:5941b2b48a20143d2267e95b1c2a7603ce057ee39fd88e7329b0c292aa16869b"}, -] -importlib-metadata = [ - {file = "importlib_metadata-4.11.3-py3-none-any.whl", hash = "sha256:1208431ca90a8cca1a6b8af391bb53c1a2db74e5d1cef6ddced95d4b2062edc6"}, - {file = "importlib_metadata-4.11.3.tar.gz", hash = "sha256:ea4c597ebf37142f827b8f39299579e31685c31d3a438b59f469406afd0f2539"}, -] -iniconfig = [ - {file = "iniconfig-1.1.1-py2.py3-none-any.whl", hash = "sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3"}, - {file = "iniconfig-1.1.1.tar.gz", hash = "sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32"}, -] +atomicwrites = [] +attrs = [] +black = [] +click = [] +colorama = [] +importlib-metadata = [] +iniconfig = [] mypy = [ {file = "mypy-0.950-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:cf9c261958a769a3bd38c3e133801ebcd284ffb734ea12d01457cb09eacf7d7b"}, {file = "mypy-0.950-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b5b5bd0ffb11b4aba2bb6d31b8643902c48f990cc92fda4e21afac658044f0c0"}, @@ -401,175 +394,27 @@ mypy = [ {file = "mypy-0.950-py3-none-any.whl", hash = "sha256:a4d9898f46446bfb6405383b57b96737dcfd0a7f25b748e78ef3e8c576bba3cb"}, {file = "mypy-0.950.tar.gz", hash = "sha256:1b333cfbca1762ff15808a0ef4f71b5d3eed8528b23ea1c3fb50543c867d68de"}, ] -mypy-extensions = [ - {file = "mypy_extensions-0.4.3-py2.py3-none-any.whl", hash = "sha256:090fedd75945a69ae91ce1303b5824f428daf5a028d2f6ab8a299250a846f15d"}, - {file = "mypy_extensions-0.4.3.tar.gz", hash = "sha256:2d82818f5bb3e369420cb3c4060a7970edba416647068eb4c5343488a6c604a8"}, -] -numpy = [ - {file = "numpy-1.21.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:38e8648f9449a549a7dfe8d8755a5979b45b3538520d1e735637ef28e8c2dc50"}, - {file = "numpy-1.21.1-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:fd7d7409fa643a91d0a05c7554dd68aa9c9bb16e186f6ccfe40d6e003156e33a"}, - {file = "numpy-1.21.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:a75b4498b1e93d8b700282dc8e655b8bd559c0904b3910b144646dbbbc03e062"}, - {file = "numpy-1.21.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1412aa0aec3e00bc23fbb8664d76552b4efde98fb71f60737c83efbac24112f1"}, - {file = "numpy-1.21.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e46ceaff65609b5399163de5893d8f2a82d3c77d5e56d976c8b5fb01faa6b671"}, - {file = "numpy-1.21.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:c6a2324085dd52f96498419ba95b5777e40b6bcbc20088fddb9e8cbb58885e8e"}, - {file = "numpy-1.21.1-cp37-cp37m-win32.whl", hash = "sha256:73101b2a1fef16602696d133db402a7e7586654682244344b8329cdcbbb82172"}, - {file = "numpy-1.21.1-cp37-cp37m-win_amd64.whl", hash = "sha256:7a708a79c9a9d26904d1cca8d383bf869edf6f8e7650d85dbc77b041e8c5a0f8"}, - {file = "numpy-1.21.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:95b995d0c413f5d0428b3f880e8fe1660ff9396dcd1f9eedbc311f37b5652e16"}, - {file = "numpy-1.21.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:635e6bd31c9fb3d475c8f44a089569070d10a9ef18ed13738b03049280281267"}, - {file = "numpy-1.21.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4a3d5fb89bfe21be2ef47c0614b9c9c707b7362386c9a3ff1feae63e0267ccb6"}, - {file = "numpy-1.21.1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:8a326af80e86d0e9ce92bcc1e65c8ff88297de4fa14ee936cb2293d414c9ec63"}, - {file = "numpy-1.21.1-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:791492091744b0fe390a6ce85cc1bf5149968ac7d5f0477288f78c89b385d9af"}, - {file = "numpy-1.21.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0318c465786c1f63ac05d7c4dbcecd4d2d7e13f0959b01b534ea1e92202235c5"}, - {file = "numpy-1.21.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:9a513bd9c1551894ee3d31369f9b07460ef223694098cf27d399513415855b68"}, - {file = "numpy-1.21.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:91c6f5fc58df1e0a3cc0c3a717bb3308ff850abdaa6d2d802573ee2b11f674a8"}, - {file = "numpy-1.21.1-cp38-cp38-win32.whl", hash = "sha256:978010b68e17150db8765355d1ccdd450f9fc916824e8c4e35ee620590e234cd"}, - {file = "numpy-1.21.1-cp38-cp38-win_amd64.whl", hash = "sha256:9749a40a5b22333467f02fe11edc98f022133ee1bfa8ab99bda5e5437b831214"}, - {file = "numpy-1.21.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:d7a4aeac3b94af92a9373d6e77b37691b86411f9745190d2c351f410ab3a791f"}, - {file = "numpy-1.21.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:d9e7912a56108aba9b31df688a4c4f5cb0d9d3787386b87d504762b6754fbb1b"}, - {file = "numpy-1.21.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:25b40b98ebdd272bc3020935427a4530b7d60dfbe1ab9381a39147834e985eac"}, - {file = "numpy-1.21.1-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:8a92c5aea763d14ba9d6475803fc7904bda7decc2a0a68153f587ad82941fec1"}, - {file = "numpy-1.21.1-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:05a0f648eb28bae4bcb204e6fd14603de2908de982e761a2fc78efe0f19e96e1"}, - {file = "numpy-1.21.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f01f28075a92eede918b965e86e8f0ba7b7797a95aa8d35e1cc8821f5fc3ad6a"}, - {file = "numpy-1.21.1-cp39-cp39-win32.whl", hash = "sha256:88c0b89ad1cc24a5efbb99ff9ab5db0f9a86e9cc50240177a571fbe9c2860ac2"}, - {file = "numpy-1.21.1-cp39-cp39-win_amd64.whl", hash = "sha256:01721eefe70544d548425a07c80be8377096a54118070b8a62476866d5208e33"}, - {file = "numpy-1.21.1-pp37-pypy37_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:2d4d1de6e6fb3d28781c73fbde702ac97f03d79e4ffd6598b880b2d95d62ead4"}, - {file = "numpy-1.21.1.zip", hash = "sha256:dff4af63638afcc57a3dfb9e4b26d434a7a602d225b42d746ea7fe2edf1342fd"}, -] -packaging = [ - {file = "packaging-21.3-py3-none-any.whl", hash = "sha256:ef103e05f519cdc783ae24ea4e2e0f508a9c99b2d4969652eed6a2e1ea5bd522"}, - {file = "packaging-21.3.tar.gz", hash = "sha256:dd47c42927d89ab911e606518907cc2d3a1f38bbd026385970643f9c5b8ecfeb"}, -] -pandas = [ - {file = "pandas-1.3.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:62d5b5ce965bae78f12c1c0df0d387899dd4211ec0bdc52822373f13a3a022b9"}, - {file = "pandas-1.3.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:adfeb11be2d54f275142c8ba9bf67acee771b7186a5745249c7d5a06c670136b"}, - {file = "pandas-1.3.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:60a8c055d58873ad81cae290d974d13dd479b82cbb975c3e1fa2cf1920715296"}, - {file = "pandas-1.3.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fd541ab09e1f80a2a1760032d665f6e032d8e44055d602d65eeea6e6e85498cb"}, - {file = "pandas-1.3.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2651d75b9a167cc8cc572cf787ab512d16e316ae00ba81874b560586fa1325e0"}, - {file = "pandas-1.3.5-cp310-cp310-win_amd64.whl", hash = "sha256:aaf183a615ad790801fa3cf2fa450e5b6d23a54684fe386f7e3208f8b9bfbef6"}, - {file = "pandas-1.3.5-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:344295811e67f8200de2390093aeb3c8309f5648951b684d8db7eee7d1c81fb7"}, - {file = "pandas-1.3.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:552020bf83b7f9033b57cbae65589c01e7ef1544416122da0c79140c93288f56"}, - {file = "pandas-1.3.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5cce0c6bbeb266b0e39e35176ee615ce3585233092f685b6a82362523e59e5b4"}, - {file = "pandas-1.3.5-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7d28a3c65463fd0d0ba8bbb7696b23073efee0510783340a44b08f5e96ffce0c"}, - {file = "pandas-1.3.5-cp37-cp37m-win32.whl", hash = "sha256:a62949c626dd0ef7de11de34b44c6475db76995c2064e2d99c6498c3dba7fe58"}, - {file = "pandas-1.3.5-cp37-cp37m-win_amd64.whl", hash = "sha256:8025750767e138320b15ca16d70d5cdc1886e8f9cc56652d89735c016cd8aea6"}, - {file = "pandas-1.3.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:fe95bae4e2d579812865db2212bb733144e34d0c6785c0685329e5b60fcb85dd"}, - {file = "pandas-1.3.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5f261553a1e9c65b7a310302b9dbac31cf0049a51695c14ebe04e4bfd4a96f02"}, - {file = "pandas-1.3.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8b6dbec5f3e6d5dc80dcfee250e0a2a652b3f28663492f7dab9a24416a48ac39"}, - {file = "pandas-1.3.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d3bc49af96cd6285030a64779de5b3688633a07eb75c124b0747134a63f4c05f"}, - {file = "pandas-1.3.5-cp38-cp38-win32.whl", hash = "sha256:b6b87b2fb39e6383ca28e2829cddef1d9fc9e27e55ad91ca9c435572cdba51bf"}, - {file = "pandas-1.3.5-cp38-cp38-win_amd64.whl", hash = "sha256:a395692046fd8ce1edb4c6295c35184ae0c2bbe787ecbe384251da609e27edcb"}, - {file = "pandas-1.3.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bd971a3f08b745a75a86c00b97f3007c2ea175951286cdda6abe543e687e5f2f"}, - {file = "pandas-1.3.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:37f06b59e5bc05711a518aa10beaec10942188dccb48918bb5ae602ccbc9f1a0"}, - {file = "pandas-1.3.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2c21778a688d3712d35710501f8001cdbf96eb70a7c587a3d5613573299fdca6"}, - {file = "pandas-1.3.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3345343206546545bc26a05b4602b6a24385b5ec7c75cb6059599e3d56831da2"}, - {file = "pandas-1.3.5-cp39-cp39-win32.whl", hash = "sha256:c69406a2808ba6cf580c2255bcf260b3f214d2664a3a4197d0e640f573b46fd3"}, - {file = "pandas-1.3.5-cp39-cp39-win_amd64.whl", hash = "sha256:32e1a26d5ade11b547721a72f9bfc4bd113396947606e00d5b4a5b79b3dcb006"}, - {file = "pandas-1.3.5.tar.gz", hash = "sha256:1e4285f5de1012de20ca46b188ccf33521bff61ba5c5ebd78b4fb28e5416a9f1"}, -] -pathspec = [ - {file = "pathspec-0.9.0-py2.py3-none-any.whl", hash = "sha256:7d15c4ddb0b5c802d161efc417ec1a2558ea2653c2e8ad9c19098201dc1c993a"}, - {file = "pathspec-0.9.0.tar.gz", hash = "sha256:e564499435a2673d586f6b2130bb5b95f04a3ba06f81b8f895b651a3c76aabb1"}, -] -platformdirs = [ - {file = "platformdirs-2.5.2-py3-none-any.whl", hash = "sha256:027d8e83a2d7de06bbac4e5ef7e023c02b863d7ea5d079477e722bb41ab25788"}, - {file = "platformdirs-2.5.2.tar.gz", hash = "sha256:58c8abb07dcb441e6ee4b11d8df0ac856038f944ab98b7be6b27b2a3c7feef19"}, -] -pluggy = [ - {file = "pluggy-1.0.0-py2.py3-none-any.whl", hash = "sha256:74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3"}, - {file = "pluggy-1.0.0.tar.gz", hash = "sha256:4224373bacce55f955a878bf9cfa763c1e360858e330072059e10bad68531159"}, -] -py = [ - {file = "py-1.11.0-py2.py3-none-any.whl", hash = "sha256:607c53218732647dff4acdfcd50cb62615cedf612e72d1724fb1a0cc6405b378"}, - {file = "py-1.11.0.tar.gz", hash = "sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719"}, -] -pyarrow = [ - {file = "pyarrow-5.0.0-cp36-cp36m-macosx_10_13_x86_64.whl", hash = "sha256:e9ec80f4a77057498cf4c5965389e42e7f6a618b6859e6dd615e57505c9167a6"}, - {file = "pyarrow-5.0.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:b1453c2411b5062ba6bf6832dbc4df211ad625f678c623a2ee177aee158f199b"}, - {file = "pyarrow-5.0.0-cp36-cp36m-manylinux2010_x86_64.whl", hash = "sha256:9e04d3621b9f2f23898eed0d044203f66c156d880f02c5534a7f9947ebb1a4af"}, - {file = "pyarrow-5.0.0-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:64f30aa6b28b666a925d11c239344741850eb97c29d3aa0f7187918cf82494f7"}, - {file = "pyarrow-5.0.0-cp36-cp36m-manylinux2014_x86_64.whl", hash = "sha256:99c8b0f7e2ce2541dd4c0c0101d9944bb8e592ae3295fe7a2f290ab99222666d"}, - {file = "pyarrow-5.0.0-cp36-cp36m-win_amd64.whl", hash = "sha256:456a4488ae810a0569d1adf87dbc522bcc9a0e4a8d1809b934ca28c163d8edce"}, - {file = "pyarrow-5.0.0-cp37-cp37m-macosx_10_13_x86_64.whl", hash = "sha256:c5493d2414d0d690a738aac8dd6d38518d1f9b870e52e24f89d8d7eb3afd4161"}, - {file = "pyarrow-5.0.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:1832709281efefa4f199c639e9f429678286329860188e53beeda71750775923"}, - {file = "pyarrow-5.0.0-cp37-cp37m-manylinux2010_x86_64.whl", hash = "sha256:b6387d2058d95fa48ccfedea810a768187affb62f4a3ef6595fa30bf9d1a65cf"}, - {file = "pyarrow-5.0.0-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:bbe2e439bec2618c74a3bb259700c8a7353dc2ea0c5a62686b6cf04a50ab1e0d"}, - {file = "pyarrow-5.0.0-cp37-cp37m-manylinux2014_x86_64.whl", hash = "sha256:5c0d1b68e67bb334a5af0cecdf9b6a702aaa4cc259c5cbb71b25bbed40fcedaf"}, - {file = "pyarrow-5.0.0-cp37-cp37m-win_amd64.whl", hash = "sha256:6e937ce4a40ea0cc7896faff96adecadd4485beb53fbf510b46858e29b2e75ae"}, - {file = "pyarrow-5.0.0-cp38-cp38-macosx_10_13_x86_64.whl", hash = "sha256:7560332e5846f0e7830b377c14c93624e24a17f91c98f0b25dafb0ca1ea6ba02"}, - {file = "pyarrow-5.0.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:53e550dec60d1ab86cba3afa1719dc179a8bc9632a0e50d9fe91499cf0a7f2bc"}, - {file = "pyarrow-5.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:2d26186ca9748a1fb89ae6c1fa04fb343a4279b53f118734ea8096f15d66c820"}, - {file = "pyarrow-5.0.0-cp38-cp38-manylinux2010_x86_64.whl", hash = "sha256:7c4edd2bacee3eea6c8c28bddb02347f9d41a55ec9692c71c6de6e47c62a7f0d"}, - {file = "pyarrow-5.0.0-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:601b0aabd6fb066429e706282934d4d8d38f53bdb8d82da9576be49f07eedf5c"}, - {file = "pyarrow-5.0.0-cp38-cp38-manylinux2014_x86_64.whl", hash = "sha256:ff21711f6ff3b0bc90abc8ca8169e676faeb2401ddc1a0bc1c7dc181708a3406"}, - {file = "pyarrow-5.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:ed135a99975380c27077f9d0e210aea8618ed9fadcec0e71f8a3190939557afe"}, - {file = "pyarrow-5.0.0-cp39-cp39-macosx_10_13_universal2.whl", hash = "sha256:6e1f0e4374061116f40e541408a8a170c170d0a070b788717e18165ebfdd2a54"}, - {file = "pyarrow-5.0.0-cp39-cp39-macosx_10_13_x86_64.whl", hash = "sha256:4341ac0f552dc04c450751e049976940c7f4f8f2dae03685cc465ebe0a61e231"}, - {file = "pyarrow-5.0.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c3fc856f107ca2fb3c9391d7ea33bbb33f3a1c2b4a0e2b41f7525c626214cc03"}, - {file = "pyarrow-5.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:357605665fbefb573d40939b13a684c2490b6ed1ab4a5de8dd246db4ab02e5a4"}, - {file = "pyarrow-5.0.0-cp39-cp39-manylinux2010_x86_64.whl", hash = "sha256:f4db312e9ba80e730cefcae0a05b63ea5befc7634c28df56682b628ad8e1c25c"}, - {file = "pyarrow-5.0.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:1d9485741e497ccc516cb0a0c8f56e22be55aea815be185c3f9a681323b0e614"}, - {file = "pyarrow-5.0.0-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:b3115df938b8d7a7372911a3cb3904196194bcea8bb48911b4b3eafee3ab8d90"}, - {file = "pyarrow-5.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:4d8adda1892ef4553c4804af7f67cce484f4d6371564e2d8374b8e2bc85293e2"}, - {file = "pyarrow-5.0.0.tar.gz", hash = "sha256:24e64ea33eed07441cc0e80c949e3a1b48211a1add8953268391d250f4d39922"}, -] -pyparsing = [ - {file = "pyparsing-3.0.9-py3-none-any.whl", hash = "sha256:5026bae9a10eeaefb61dab2f09052b9f4307d44aee4eda64b309723d8d206bbc"}, - {file = "pyparsing-3.0.9.tar.gz", hash = "sha256:2b020ecf7d21b687f219b71ecad3631f644a47f01403fa1d1036b0c6416d70fb"}, -] -pytest = [ - {file = "pytest-7.1.2-py3-none-any.whl", hash = "sha256:13d0e3ccfc2b6e26be000cb6568c832ba67ba32e719443bfe725814d3c42433c"}, - {file = "pytest-7.1.2.tar.gz", hash = "sha256:a06a0425453864a270bc45e71f783330a7428defb4230fb5e6a731fde06ecd45"}, -] +mypy-extensions = [] +numpy = [] +oauthlib = [] +packaging = [] +pandas = [] +pathspec = [] +platformdirs = [] +pluggy = [] +py = [] +pyarrow = [] +pyjwt = [] +pyparsing = [] +pytest = [] python-dateutil = [ {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"}, {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"}, ] -pytz = [ - {file = "pytz-2022.1-py2.py3-none-any.whl", hash = "sha256:e68985985296d9a66a881eb3193b0906246245294a881e7c8afe623866ac6a5c"}, - {file = "pytz-2022.1.tar.gz", hash = "sha256:1e760e2fe6a8163bc0b3d9a19c4f84342afa0a2affebfaa84b01b978a02ecaa7"}, -] -six = [ - {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, - {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, -] -thrift = [ - {file = "thrift-0.13.0.tar.gz", hash = "sha256:9af1c86bf73433afc6010ed376a6c6aca2b54099cc0d61895f640870a9ae7d89"}, -] -tomli = [ - {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, - {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, -] -typed-ast = [ - {file = "typed_ast-1.5.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9ad3b48cf2b487be140072fb86feff36801487d4abb7382bb1929aaac80638ea"}, - {file = "typed_ast-1.5.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:542cd732351ba8235f20faa0fc7398946fe1a57f2cdb289e5497e1e7f48cfedb"}, - {file = "typed_ast-1.5.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5dc2c11ae59003d4a26dda637222d9ae924387f96acae9492df663843aefad55"}, - {file = "typed_ast-1.5.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:fd5df1313915dbd70eaaa88c19030b441742e8b05e6103c631c83b75e0435ccc"}, - {file = "typed_ast-1.5.3-cp310-cp310-win_amd64.whl", hash = "sha256:e34f9b9e61333ecb0f7d79c21c28aa5cd63bec15cb7e1310d7d3da6ce886bc9b"}, - {file = "typed_ast-1.5.3-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:f818c5b81966d4728fec14caa338e30a70dfc3da577984d38f97816c4b3071ec"}, - {file = "typed_ast-1.5.3-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3042bfc9ca118712c9809201f55355479cfcdc17449f9f8db5e744e9625c6805"}, - {file = "typed_ast-1.5.3-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:4fff9fdcce59dc61ec1b317bdb319f8f4e6b69ebbe61193ae0a60c5f9333dc49"}, - {file = "typed_ast-1.5.3-cp36-cp36m-win_amd64.whl", hash = "sha256:8e0b8528838ffd426fea8d18bde4c73bcb4167218998cc8b9ee0a0f2bfe678a6"}, - {file = "typed_ast-1.5.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8ef1d96ad05a291f5c36895d86d1375c0ee70595b90f6bb5f5fdbee749b146db"}, - {file = "typed_ast-1.5.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ed44e81517364cb5ba367e4f68fca01fba42a7a4690d40c07886586ac267d9b9"}, - {file = "typed_ast-1.5.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:f60d9de0d087454c91b3999a296d0c4558c1666771e3460621875021bf899af9"}, - {file = "typed_ast-1.5.3-cp37-cp37m-win_amd64.whl", hash = "sha256:9e237e74fd321a55c90eee9bc5d44be976979ad38a29bbd734148295c1ce7617"}, - {file = "typed_ast-1.5.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ee852185964744987609b40aee1d2eb81502ae63ee8eef614558f96a56c1902d"}, - {file = "typed_ast-1.5.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:27e46cdd01d6c3a0dd8f728b6a938a6751f7bd324817501c15fb056307f918c6"}, - {file = "typed_ast-1.5.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d64dabc6336ddc10373922a146fa2256043b3b43e61f28961caec2a5207c56d5"}, - {file = "typed_ast-1.5.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:8cdf91b0c466a6c43f36c1964772918a2c04cfa83df8001ff32a89e357f8eb06"}, - {file = "typed_ast-1.5.3-cp38-cp38-win_amd64.whl", hash = "sha256:9cc9e1457e1feb06b075c8ef8aeb046a28ec351b1958b42c7c31c989c841403a"}, - {file = "typed_ast-1.5.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:e20d196815eeffb3d76b75223e8ffed124e65ee62097e4e73afb5fec6b993e7a"}, - {file = "typed_ast-1.5.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:37e5349d1d5de2f4763d534ccb26809d1c24b180a477659a12c4bde9dd677d74"}, - {file = "typed_ast-1.5.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c9f1a27592fac87daa4e3f16538713d705599b0a27dfe25518b80b6b017f0a6d"}, - {file = "typed_ast-1.5.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:8831479695eadc8b5ffed06fdfb3e424adc37962a75925668deeb503f446c0a3"}, - {file = "typed_ast-1.5.3-cp39-cp39-win_amd64.whl", hash = "sha256:20d5118e494478ef2d3a2702d964dae830aedd7b4d3b626d003eea526be18718"}, - {file = "typed_ast-1.5.3.tar.gz", hash = "sha256:27f25232e2dd0edfe1f019d6bfaaf11e86e657d9bdb7b0956db95f560cceb2b3"}, -] -typing-extensions = [ - {file = "typing_extensions-4.2.0-py3-none-any.whl", hash = "sha256:6657594ee297170d19f67d55c05852a874e7eb634f4f753dbd667855e07c1708"}, - {file = "typing_extensions-4.2.0.tar.gz", hash = "sha256:f1c24655a0da0d1b67f07e17a5e6b2a105894e6824b92096378bb3668ef02376"}, -] -zipp = [ - {file = "zipp-3.8.0-py3-none-any.whl", hash = "sha256:c4f6e5bbf48e74f7a38e7cc5b0480ff42b0ae5178957d564d18932525d5cf099"}, - {file = "zipp-3.8.0.tar.gz", hash = "sha256:56bf8aadb83c24db6c4b577e13de374ccfb67da2078beba1d037c17980bf43ad"}, -] +pytz = [] +six = [] +thrift = [] +tomli = [] +typed-ast = [] +typing-extensions = [] +zipp = [] diff --git a/pyproject.toml b/pyproject.toml index de9160deb..4a5e9e766 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,9 @@ python = "^3.7.1" thrift = "^0.13.0" pyarrow = "^5.0.0" pandas = "^1.3.0" +click=">=7.0" +pyjwt=">=1.7.0" +oauthlib=">=3.1.0" [tool.poetry.dev-dependencies] pytest = "^7.1.2" diff --git a/src/databricks/sql/__init__.py b/src/databricks/sql/__init__.py index 8f67d4659..ae67264f7 100644 --- a/src/databricks/sql/__init__.py +++ b/src/databricks/sql/__init__.py @@ -44,7 +44,6 @@ def TimestampFromTicks(ticks): return Timestamp(*time.localtime(ticks)[:6]) -def connect(server_hostname, http_path, access_token, **kwargs): +def connect(server_hostname, http_path, **kwargs): from .client import Connection - - return Connection(server_hostname, http_path, access_token, **kwargs) + return Connection(server_hostname, http_path, **kwargs) diff --git a/src/databricks/sql/auth/__init__.py b/src/databricks/sql/auth/__init__.py new file mode 100644 index 000000000..1ad9b7875 --- /dev/null +++ b/src/databricks/sql/auth/__init__.py @@ -0,0 +1,22 @@ +# Databricks CLI +# Copyright 2022 Databricks, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"), except +# that the use of services to which certain application programming +# interfaces (each, an "API") connect requires that the user first obtain +# a license for the use of the APIs from Databricks, Inc. ("Databricks"), +# by creating an account at www.databricks.com and agreeing to either (a) +# the Community Edition Terms of Service, (b) the Databricks Terms of +# Service, or (c) another written agreement between Licensee and Databricks +# for the use of the APIs. +# +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/databricks/sql/auth/auth.py b/src/databricks/sql/auth/auth.py new file mode 100644 index 000000000..9bd99ee1b --- /dev/null +++ b/src/databricks/sql/auth/auth.py @@ -0,0 +1,103 @@ +# Databricks CLI +# Copyright 2022 Databricks, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"), except +# that the use of services to which certain application programming +# interfaces (each, an "API") connect requires that the user first obtain +# a license for the use of the APIs from Databricks, Inc. ("Databricks"), +# by creating an account at www.databricks.com and agreeing to either (a) +# the Community Edition Terms of Service, (b) the Databricks Terms of +# Service, or (c) another written agreement between Licensee and Databricks +# for the use of the APIs. +# +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum +from databricks.sql.auth.authenticators import Authenticator, \ + AccessTokenAuthenticator, UserPassAuthenticator, OAuthAuthenticator + + +class AuthType(Enum): + DATABRICKS_OAUTH = "databricks-oauth" + # other supported types (access_token, user/pass) can be inferred + # we can add more types as needed later + + +class ClientContext: + def __init__(self, + hostname: str, + username: str = None, + password: str = None, + token: str = None, + auth_type: str = None, + oauth_scopes: str = None, + oauth_client_id: str = None, + use_cert_as_auth: str = None, + tls_client_cert_file: str = None): + self.hostname = hostname + self.username = username + self.password = password + self.token = token + self.auth_type = auth_type + self.oauth_scopes = oauth_scopes + self.oauth_client_id = oauth_client_id + self.use_cert_as_auth = use_cert_as_auth + self.tls_client_cert_file = tls_client_cert_file + + +class SqlConnectorClientContext(ClientContext): + def __init__(self, + hostname: str, + username: str = None, + password: str = None, + access_token: str = None, + auth_type: str = None, + use_cert_as_auth: str = None, + tls_client_cert_file: str = None): + super().__init__(oauth_scopes="sql offline_access", + # to be changed once registered on the service side + oauth_client_id="databricks-cli", + hostname=hostname, + username=username, + password=password, + token=access_token, + auth_type=auth_type, + use_cert_as_auth=use_cert_as_auth, + tls_client_cert_file=tls_client_cert_file) + + +def get_authenticator(cfg: ClientContext): + if cfg.auth_type == AuthType.DATABRICKS_OAUTH.value: + return OAuthAuthenticator(cfg.hostname, cfg.oauth_client_id, cfg.oauth_scopes) + elif cfg.token is not None: + return AccessTokenAuthenticator(cfg.token) + elif cfg.username is not None and cfg.password is not None: + return UserPassAuthenticator(cfg.username, cfg.password) + elif cfg.use_cert_as_auth and cfg.tls_client_cert_file: + # no op authenticator. authentication is performed using ssl certificate outside of headers + return Authenticator() + else: + raise RuntimeError("No valid authentication settings!") + + +def get_python_sql_connector_authenticator(hostname: str, **kwargs): + + cfg = SqlConnectorClientContext(hostname=hostname, + auth_type=kwargs.get("auth_type"), + access_token=kwargs.get("access_token"), + username=kwargs.get("_username"), + password=kwargs.get("_password"), + use_cert_as_auth=kwargs.get("_use_cert_as_auth"), + tls_client_cert_file=kwargs.get("_tls_client_cert_file")) + return get_authenticator(cfg) + + diff --git a/src/databricks/sql/auth/authenticators.py b/src/databricks/sql/auth/authenticators.py new file mode 100644 index 000000000..b47e46f9a --- /dev/null +++ b/src/databricks/sql/auth/authenticators.py @@ -0,0 +1,78 @@ +# Databricks CLI +# Copyright 2022 Databricks, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"), except +# that the use of services to which certain application programming +# interfaces (each, an "API") connect requires that the user first obtain +# a license for the use of the APIs from Databricks, Inc. ("Databricks"), +# by creating an account at www.databricks.com and agreeing to either (a) +# the Community Edition Terms of Service, (b) the Databricks Terms of +# Service, or (c) another written agreement between Licensee and Databricks +# for the use of the APIs. +# +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from databricks.sql.auth.oauth import get_tokens, check_and_refresh_access_token +import base64 + + +class Authenticator: + def add_auth_token(self, request_headers): + pass + + +class AccessTokenAuthenticator(Authenticator): + def __init__(self, access_token): + self.__authorization_header_value = "Bearer {}".format(access_token) + + def add_auth_token(self, request_headers): + request_headers['Authorization'] = self.__authorization_header_value + + +class UserPassAuthenticator(Authenticator): + def __init__(self, username, password): + auth_credentials = "{username}:{password}".format(username, password).encode("UTF-8") + auth_credentials_base64 = base64.standard_b64encode(auth_credentials).decode("UTF-8") + + self.__authorization_header_value = "Basic {}".format(auth_credentials_base64) + + def add_auth_token(self, request_headers): + request_headers['Authorization'] = self.__authorization_header_value + + +class OAuthAuthenticator(Authenticator): + # TODO: moderakh the refresh_token is only kept in memory. not saved on disk + # hence if application restarts the user may need to re-authenticate + # I will add support for this outside of the scope of current PR. + def __init__(self, hostname, client_id, scopes): + self._hostname = self._normalize_host_name(hostname=hostname) + self._scope = scopes + access_token, refresh_token = get_tokens(hostname=self._hostname, client_id=client_id, scope=self._scope) + self._access_token = access_token + self._refresh_token = refresh_token + + def add_auth_token(self, request_headers): + check_and_refresh_access_token(hostname=self._hostname, + access_token=self._access_token, + refresh_token=self._refresh_token) + request_headers['Authorization'] = "Bearer {}".format(self._access_token) + + @staticmethod + def _normalize_host_name(hostname): + maybe_scheme = "https://" if not hostname.startswith("https://") else "" + maybe_trailing_slash = "/" if not hostname.endswith("/") else "" + return "{scheme}{host}{trailing}".format( + scheme=maybe_scheme, host=hostname, trailing=maybe_trailing_slash) + + + diff --git a/src/databricks/sql/auth/oauth.py b/src/databricks/sql/auth/oauth.py new file mode 100644 index 000000000..4e474103b --- /dev/null +++ b/src/databricks/sql/auth/oauth.py @@ -0,0 +1,286 @@ +# Databricks CLI +# Copyright 2022 Databricks, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"), except +# that the use of services to which certain application programming +# interfaces (each, an "API") connect requires that the user first obtain +# a license for the use of the APIs from Databricks, Inc. ("Databricks"), +# by creating an account at www.databricks.com and agreeing to either (a) +# the Community Edition Terms of Service, (b) the Databricks Terms of +# Service, or (c) another written agreement between Licensee and Databricks +# for the use of the APIs. +# +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import base64 +import hashlib +import json +import os +import webbrowser + +from datetime import datetime, timedelta, tzinfo + +import click + +import jwt +from jwt import PyJWTError + +import oauthlib.oauth2 +from oauthlib.oauth2.rfc6749.errors import OAuth2Error + +import requests +from requests.exceptions import RequestException + +from databricks_cli.utils import error_and_quit + +try: + from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer +except ImportError: + from http.server import BaseHTTPRequestHandler, HTTPServer + + +# This could use 'import secrets' in Python 3 +def token_urlsafe(nbytes=32): + tok = os.urandom(nbytes) + return base64.urlsafe_b64encode(tok).rstrip(b'=').decode('ascii') + + +# This could be datetime.timezone.utc in Python 3 +class UTCTimeZone(tzinfo): + """UTC""" + def utcoffset(self, dt): + #pylint: disable=unused-argument + return timedelta(0) + + def tzname(self, dt): + #pylint: disable=unused-argument + return "UTC" + + def dst(self, dt): + #pylint: disable=unused-argument + return timedelta(0) + + +# Some constant values +OIDC_REDIRECTOR_PATH = "oidc" +REDIRECT_PORT = 8020 +UTC = UTCTimeZone() + + +def get_client(client_id): + return oauthlib.oauth2.WebApplicationClient(client_id) + + +def get_redirect_url(port=REDIRECT_PORT): + return "http://localhost:{port}".format(port=port) + + +def fetch_well_known_config(idp_url): + known_config_url = "{idp_url}/.well-known/oauth-authorization-server".format(idp_url=idp_url) + try: + response = requests.request(method="GET", url=known_config_url) + except RequestException: + error_and_quit("Unable to fetch OAuth configuration from {idp_url}.\n" + "Verify it is a valid workspace URL and that OAuth is " + "enabled on this account.".format(idp_url=idp_url)) + + if response.status_code != 200: + error_and_quit("Received status {status} OAuth configuration from " + "{idp_url}.\n Verify it is a valid workspace URL and " + "that OAuth is enabled on this account." + .format(status=response.status_code, idp_url=idp_url)) + try: + return json.loads(response.text) + except json.decoder.JSONDecodeError: + error_and_quit("Unable to decode OAuth configuration from {idp_url}.\n" + "Verify it is a valid workspace URL and that OAuth is " + "enabled on this account.".format(idp_url=idp_url)) + + +def get_idp_url(host): + maybe_scheme = "https://" if not host.startswith("https://") else "" + maybe_trailing_slash = "/" if not host.endswith("/") else "" + return "{scheme}{host}{trailing}{path}".format( + scheme=maybe_scheme, host=host, trailing=maybe_trailing_slash, path=OIDC_REDIRECTOR_PATH) + + +def get_challenge(verifier_string=token_urlsafe(32)): + digest = hashlib.sha256(verifier_string.encode('UTF-8')).digest() + challenge_string = base64.urlsafe_b64encode(digest).decode("UTF-8").replace('=', '') + return verifier_string, challenge_string + + +# This is a janky global that is used to store the path of the single request the HTTP server +# will receive. +global_request_path = None + + +def set_request_path(path): + global global_request_path + global_request_path = path + + +class SingleRequestHandler(BaseHTTPRequestHandler): + RESPONSE_BODY = """ + + Close this Tab + + + +

Please close this tab.

+

+ The Databricks Python Sql Connector received a response. You may close this tab. +

+ +""".encode("utf-8") + + def do_GET(self): # nopep8 + self.send_response(200, "Success") + self.send_header("Content-type", "text/html") + self.end_headers() + self.wfile.write(self.RESPONSE_BODY) + set_request_path(self.path) + + def log_message(self, format, *args): + #pylint: disable=redefined-builtin + #pylint: disable=unused-argument + return + + +def get_authorization_code(client, auth_url, redirect_url, scope, state, challenge, port): + #pylint: disable=unused-variable + (auth_req_uri, headers, body) = client.prepare_authorization_request( + authorization_url=auth_url, + redirect_url=redirect_url, + scope=scope, + state=state, + code_challenge=challenge, + code_challenge_method="S256") + click.echo("Opening {uri}".format(uri=auth_req_uri)) + + with HTTPServer(("", port), SingleRequestHandler) as httpd: + webbrowser.open_new(auth_req_uri) + click.echo("Listening for OAuth authorization callback at {uri}" + .format(uri=redirect_url)) + httpd.handle_request() + + if not global_request_path: + error_and_quit("No path parameters were returned to the callback at {uri}" + .format(uri=redirect_url)) + # This is a kludge because the parsing library expects https callbacks + # We should probably set it up using https + full_redirect_url = "https://localhost:{port}/{path}".format( + port=port, path=global_request_path) + try: + authorization_code_response = \ + client.parse_request_uri_response(full_redirect_url, state=state) + except OAuth2Error as err: + error_and_quit("OAuth Token Request error {error}".format(error=err.description)) + return authorization_code_response + + +def send_auth_code_token_request(client, token_request_url, redirect_url, code, verifier): + token_request_body = client.prepare_request_body(code=code, redirect_uri=redirect_url) + data = "{body}&code_verifier={verifier}".format(body=token_request_body, verifier=verifier) + return send_token_request(token_request_url, data) + + +def send_token_request(token_request_url, data): + headers = { + "Accept": "application/json", + "Content-Type": "application/x-www-form-urlencoded" + } + response = requests.request(method="POST", url=token_request_url, data=data, headers=headers) + oauth_response = json.loads(response.text) + return oauth_response + + +def send_refresh_token_request(hostname, client_id, refresh_token): + idp_url = get_idp_url(hostname) + oauth_config = fetch_well_known_config(idp_url) + token_request_url = oauth_config['token_endpoint'] + client = get_client(client_id=client_id) + token_request_body = client.prepare_refresh_body( + refresh_token=refresh_token, client_id=client.client_id) + return send_token_request(token_request_url, token_request_body) + + +def get_tokens_from_response(oauth_response): + access_token = oauth_response['access_token'] + refresh_token = oauth_response['refresh_token'] if 'refresh_token' in oauth_response else None + return access_token, refresh_token + + +def check_and_refresh_access_token(hostname, access_token, refresh_token): + now = datetime.now(tz=UTC) + # If we can't decode an expiration time, this will be expired by default. + expiration_time = now + try: + # This token has already been verified and we are just parsing it. + # If it has been tampered with, it will be rejected on the server side. + # This avoids having to fetch the public key from the issuer and perform + # an unnecessary signature verification. + decoded = jwt.decode(access_token, options={"verify_signature": False}) + expiration_time = datetime.fromtimestamp(decoded['exp'], tz=UTC) + except PyJWTError as err: + error_and_quit(err) + + if expiration_time > now: + # The access token is fine. Just return it. + return access_token, refresh_token, False + + if not refresh_token: + error_and_quit("OAuth access token expired on {expiration_time}." + .format(expiration_time=expiration_time)) + + # Try to refresh using the refresh token + click.echo("Attempting to refresh OAuth access token that expired on {expiration_time}" + .format(expiration_time=expiration_time)) + oauth_response = send_refresh_token_request(hostname, refresh_token) + fresh_access_token, fresh_refresh_token = get_tokens_from_response(oauth_response) + return fresh_access_token, fresh_refresh_token, True + + +def get_tokens(hostname, client_id, scope=None): + idp_url = get_idp_url(hostname) + oauth_config = fetch_well_known_config(idp_url) + # We are going to override oauth_config["authorization_endpoint"] use the + # /oidc redirector on the hostname, which may inject additional parameters. + auth_url = "{}oidc/v1/authorize".format(hostname) + state = token_urlsafe(16) + (verifier, challenge) = get_challenge() + client = get_client(client_id) + redirect_url = get_redirect_url() + try: + auth_response = get_authorization_code( + client, + auth_url, + redirect_url, + scope, + state, + challenge, + REDIRECT_PORT) + except OAuth2Error as err: + error_and_quit("OAuth Authorization Error: {error}".format(error=err.description)) + + token_request_url = oauth_config["token_endpoint"] + code = auth_response['code'] + oauth_response = \ + send_auth_code_token_request(client, token_request_url, redirect_url, code, verifier) + return get_tokens_from_response(oauth_response) diff --git a/src/databricks/sql/auth/thrift_http_client.py b/src/databricks/sql/auth/thrift_http_client.py new file mode 100644 index 000000000..f5d818b74 --- /dev/null +++ b/src/databricks/sql/auth/thrift_http_client.py @@ -0,0 +1,48 @@ +# Databricks CLI +# Copyright 2022 Databricks, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"), except +# that the use of services to which certain application programming +# interfaces (each, an "API") connect requires that the user first obtain +# a license for the use of the APIs from Databricks, Inc. ("Databricks"), +# by creating an account at www.databricks.com and agreeing to either (a) +# the Community Edition Terms of Service, (b) the Databricks Terms of +# Service, or (c) another written agreement between Licensee and Databricks +# for the use of the APIs. +# +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import thrift.transport.THttpClient + +logger = logging.getLogger(__name__) + + +class THttpClient(thrift.transport.THttpClient.THttpClient): + + def __init__(self, authenticator, uri_or_host, port=None, path=None, cafile=None, cert_file=None, key_file=None, ssl_context=None): + super().__init__(uri_or_host, port, path, cafile, cert_file, key_file, ssl_context) + self.__authenticator = authenticator + + def setCustomHeaders(self, headers): + self._headers = headers + super().setCustomHeaders(headers) + + def flush(self): + # TODO retry behaviour + # TODO locking for multiple concurrent refresh token + req = {} + self.__authenticator.add_auth_token(req) + self._headers['Authorization'] = req['Authorization'] + self.setCustomHeaders(self._headers) + super().flush() \ No newline at end of file diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 717980af4..28f620243 100644 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -1,19 +1,15 @@ -import base64 -import datetime -from decimal import Decimal -import logging -import re from typing import Dict, Tuple, List, Optional, Any, Union import pandas import pyarrow -from databricks.sql import USER_AGENT_NAME, __version__ +from databricks.sql import __version__ from databricks.sql import * from databricks.sql.exc import OperationalError from databricks.sql.thrift_backend import ThriftBackend from databricks.sql.utils import ExecuteResponse, ParamEscaper from databricks.sql.types import Row +from databricks.sql.auth.auth import get_python_sql_connector_authenticator logger = logging.getLogger(__name__) @@ -26,7 +22,6 @@ def __init__( self, server_hostname: str, http_path: str, - access_token: str, http_headers: Optional[List[Tuple[str, str]]] = None, session_configuration: Dict[str, Any] = None, catalog: Optional[str] = None, @@ -90,25 +85,7 @@ def __init__( self.port = kwargs.get("_port", 443) self.disable_pandas = kwargs.get("_disable_pandas", False) - authorization_header = [] - if kwargs.get("_username") and kwargs.get("_password"): - auth_credentials = "{username}:{password}".format( - username=kwargs.get("_username"), password=kwargs.get("_password") - ).encode("UTF-8") - auth_credentials_base64 = base64.standard_b64encode( - auth_credentials - ).decode("UTF-8") - authorization_header = [ - ("Authorization", "Basic {}".format(auth_credentials_base64)) - ] - elif access_token: - authorization_header = [("Authorization", "Bearer {}".format(access_token))] - elif not ( - kwargs.get("_use_cert_as_auth") and kwargs.get("_tls_client_cert_file") - ): - raise ValueError( - "No valid authentication settings. Please provide an access token." - ) + authenticator = get_python_sql_connector_authenticator(server_hostname, **kwargs) if not kwargs.get("_user_agent_entry"): useragent_header = "{}/{}".format(USER_AGENT_NAME, __version__) @@ -117,11 +94,13 @@ def __init__( USER_AGENT_NAME, __version__, kwargs.get("_user_agent_entry") ) - base_headers = [("User-Agent", useragent_header)] + authorization_header + base_headers = [("User-Agent", useragent_header)] + self.thrift_backend = ThriftBackend( self.host, self.port, http_path, + authenticator, (http_headers or []) + base_headers, **kwargs ) diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index d812f93b9..db1ca548d 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -1,18 +1,17 @@ from decimal import Decimal -import logging import math import time import threading -from uuid import uuid4 -from ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED, create_default_context +from ssl import CERT_NONE, CERT_REQUIRED, create_default_context import pyarrow import thrift.transport.THttpClient import thrift.protocol.TBinaryProtocol import thrift.transport.TSocket import thrift.transport.TTransport -from thrift.Thrift import TException +from databricks.sql.auth.thrift_http_client import THttpClient +from databricks.sql.auth.authenticators import Authenticator from databricks.sql.thrift_api.TCLIService import TCLIService, ttypes from databricks.sql import * from databricks.sql.utils import ( @@ -48,7 +47,7 @@ class ThriftBackend: BIT_MASKS = [1, 2, 4, 8, 16, 32, 64, 128] def __init__( - self, server_hostname: str, port, http_path: str, http_headers, **kwargs + self, server_hostname: str, port, http_path: str, authenticator: Authenticator, http_headers, **kwargs ): # Internal arguments in **kwargs: # _user_agent_entry @@ -126,7 +125,10 @@ def __init__( password=tls_client_cert_key_password, ) - self._transport = thrift.transport.THttpClient.THttpClient( + self._authenticator = authenticator + + self._transport = THttpClient( + authenticator=self._authenticator, uri_or_host=uri, ssl_context=ssl_context, ) From 20e888b0db112cadb4fd776327acbd80e21b7f1e Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Wed, 13 Jul 2022 15:45:58 -0700 Subject: [PATCH 03/57] responde to code review comments Signed-off-by: Moe Derakhshani --- src/databricks/sql/auth/__init__.py | 1 - src/databricks/sql/auth/auth.py | 12 ++++++++---- src/databricks/sql/auth/authenticators.py | 5 ++--- src/databricks/sql/auth/oauth.py | 1 - src/databricks/sql/auth/thrift_http_client.py | 1 - 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/databricks/sql/auth/__init__.py b/src/databricks/sql/auth/__init__.py index 1ad9b7875..2cb62b15f 100644 --- a/src/databricks/sql/auth/__init__.py +++ b/src/databricks/sql/auth/__init__.py @@ -1,4 +1,3 @@ -# Databricks CLI # Copyright 2022 Databricks, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"), except diff --git a/src/databricks/sql/auth/auth.py b/src/databricks/sql/auth/auth.py index 9bd99ee1b..a93d1a27b 100644 --- a/src/databricks/sql/auth/auth.py +++ b/src/databricks/sql/auth/auth.py @@ -1,4 +1,3 @@ -# Databricks CLI # Copyright 2022 Databricks, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"), except @@ -21,6 +20,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List from enum import Enum from databricks.sql.auth.authenticators import Authenticator, \ AccessTokenAuthenticator, UserPassAuthenticator, OAuthAuthenticator @@ -39,7 +39,7 @@ def __init__(self, password: str = None, token: str = None, auth_type: str = None, - oauth_scopes: str = None, + oauth_scopes: List[str] = None, oauth_client_id: str = None, use_cert_as_auth: str = None, tls_client_cert_file: str = None): @@ -55,6 +55,10 @@ def __init__(self, class SqlConnectorClientContext(ClientContext): + OAUTH_SCOPES = ["sql", "offline_access"] + # TODO: moderakh to be changed once registered on the service side + OAUTH_CLIENT_ID = "databricks-cli" + def __init__(self, hostname: str, username: str = None, @@ -63,9 +67,9 @@ def __init__(self, auth_type: str = None, use_cert_as_auth: str = None, tls_client_cert_file: str = None): - super().__init__(oauth_scopes="sql offline_access", + super().__init__(oauth_scopes=SqlConnectorClientContext.OAUTH_SCOPES, # to be changed once registered on the service side - oauth_client_id="databricks-cli", + oauth_client_id=SqlConnectorClientContext.OAUTH_CLIENT_ID, hostname=hostname, username=username, password=password, diff --git a/src/databricks/sql/auth/authenticators.py b/src/databricks/sql/auth/authenticators.py index b47e46f9a..9dc655237 100644 --- a/src/databricks/sql/auth/authenticators.py +++ b/src/databricks/sql/auth/authenticators.py @@ -1,4 +1,3 @@ -# Databricks CLI # Copyright 2022 Databricks, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"), except @@ -56,8 +55,8 @@ class OAuthAuthenticator(Authenticator): # I will add support for this outside of the scope of current PR. def __init__(self, hostname, client_id, scopes): self._hostname = self._normalize_host_name(hostname=hostname) - self._scope = scopes - access_token, refresh_token = get_tokens(hostname=self._hostname, client_id=client_id, scope=self._scope) + self._scopes_as_str = ''.join(scopes) + access_token, refresh_token = get_tokens(hostname=self._hostname, client_id=client_id, scope=self._scopes_as_str) self._access_token = access_token self._refresh_token = refresh_token diff --git a/src/databricks/sql/auth/oauth.py b/src/databricks/sql/auth/oauth.py index 4e474103b..d12fbbef0 100644 --- a/src/databricks/sql/auth/oauth.py +++ b/src/databricks/sql/auth/oauth.py @@ -1,4 +1,3 @@ -# Databricks CLI # Copyright 2022 Databricks, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"), except diff --git a/src/databricks/sql/auth/thrift_http_client.py b/src/databricks/sql/auth/thrift_http_client.py index f5d818b74..50aefe29b 100644 --- a/src/databricks/sql/auth/thrift_http_client.py +++ b/src/databricks/sql/auth/thrift_http_client.py @@ -1,4 +1,3 @@ -# Databricks CLI # Copyright 2022 Databricks, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"), except From cff24d9b28eab9d07996b3de12e860be0eb50551 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Wed, 13 Jul 2022 16:12:41 -0700 Subject: [PATCH 04/57] Update src/databricks/sql/auth/authenticators.py Co-authored-by: Serge Smertin <259697+nfx@users.noreply.github.com> Signed-off-by: Moe Derakhshani --- src/databricks/sql/auth/authenticators.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/databricks/sql/auth/authenticators.py b/src/databricks/sql/auth/authenticators.py index 9dc655237..fe1215ba3 100644 --- a/src/databricks/sql/auth/authenticators.py +++ b/src/databricks/sql/auth/authenticators.py @@ -30,6 +30,8 @@ def add_auth_token(self, request_headers): pass +# Private API: this is an evolving interface and it will change in the future. +# Please must not depend on it in your applications. class AccessTokenAuthenticator(Authenticator): def __init__(self, access_token): self.__authorization_header_value = "Bearer {}".format(access_token) From d99af2dd9ddc665c1c51f14e243bcfe772de3ad2 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Wed, 13 Jul 2022 16:14:13 -0700 Subject: [PATCH 05/57] Update src/databricks/sql/auth/authenticators.py Co-authored-by: Serge Smertin <259697+nfx@users.noreply.github.com> Signed-off-by: Moe Derakhshani --- src/databricks/sql/auth/authenticators.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/databricks/sql/auth/authenticators.py b/src/databricks/sql/auth/authenticators.py index fe1215ba3..062f3d0f6 100644 --- a/src/databricks/sql/auth/authenticators.py +++ b/src/databricks/sql/auth/authenticators.py @@ -40,7 +40,9 @@ def add_auth_token(self, request_headers): request_headers['Authorization'] = self.__authorization_header_value -class UserPassAuthenticator(Authenticator): +# Private API: this is an evolving interface and it will change in the future. +# Please must not depend on it in your applications. +class BasicAuthenticator(Authenticator): def __init__(self, username, password): auth_credentials = "{username}:{password}".format(username, password).encode("UTF-8") auth_credentials_base64 = base64.standard_b64encode(auth_credentials).decode("UTF-8") From 09986ab4f1957736fbf242dba98333ee59f314b5 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Wed, 13 Jul 2022 16:14:25 -0700 Subject: [PATCH 06/57] Update src/databricks/sql/auth/authenticators.py Co-authored-by: Serge Smertin <259697+nfx@users.noreply.github.com> Signed-off-by: Moe Derakhshani --- src/databricks/sql/auth/authenticators.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/databricks/sql/auth/authenticators.py b/src/databricks/sql/auth/authenticators.py index 062f3d0f6..723254cb4 100644 --- a/src/databricks/sql/auth/authenticators.py +++ b/src/databricks/sql/auth/authenticators.py @@ -53,7 +53,9 @@ def add_auth_token(self, request_headers): request_headers['Authorization'] = self.__authorization_header_value -class OAuthAuthenticator(Authenticator): +# Private API: this is an evolving interface and it will change in the future. +# Please must not depend on it in your applications. +class DatabricksOAuthAuthenticator(Authenticator): # TODO: moderakh the refresh_token is only kept in memory. not saved on disk # hence if application restarts the user may need to re-authenticate # I will add support for this outside of the scope of current PR. From a703f5861ac6bef06f048b823de5693372927bcb Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Wed, 13 Jul 2022 16:42:57 -0700 Subject: [PATCH 07/57] responded to review comments Signed-off-by: Moe Derakhshani --- src/databricks/sql/auth/auth.py | 12 ++++++------ src/databricks/sql/auth/authenticators.py | 10 ++++++---- src/databricks/sql/thrift_backend.py | 4 ++-- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/databricks/sql/auth/auth.py b/src/databricks/sql/auth/auth.py index a93d1a27b..bd76f63da 100644 --- a/src/databricks/sql/auth/auth.py +++ b/src/databricks/sql/auth/auth.py @@ -22,8 +22,8 @@ from typing import List from enum import Enum -from databricks.sql.auth.authenticators import Authenticator, \ - AccessTokenAuthenticator, UserPassAuthenticator, OAuthAuthenticator +from databricks.sql.auth.authenticators import CredentialsProvider, \ + AccessTokenAuthProvider, BasicAuthProvider, DatabricksOAuthProvider class AuthType(Enum): @@ -81,14 +81,14 @@ def __init__(self, def get_authenticator(cfg: ClientContext): if cfg.auth_type == AuthType.DATABRICKS_OAUTH.value: - return OAuthAuthenticator(cfg.hostname, cfg.oauth_client_id, cfg.oauth_scopes) + return DatabricksOAuthProvider(cfg.hostname, cfg.oauth_client_id, cfg.oauth_scopes) elif cfg.token is not None: - return AccessTokenAuthenticator(cfg.token) + return AccessTokenAuthProvider(cfg.token) elif cfg.username is not None and cfg.password is not None: - return UserPassAuthenticator(cfg.username, cfg.password) + return BasicAuthProvider(cfg.username, cfg.password) elif cfg.use_cert_as_auth and cfg.tls_client_cert_file: # no op authenticator. authentication is performed using ssl certificate outside of headers - return Authenticator() + return CredentialsProvider() else: raise RuntimeError("No valid authentication settings!") diff --git a/src/databricks/sql/auth/authenticators.py b/src/databricks/sql/auth/authenticators.py index 723254cb4..3efc801b5 100644 --- a/src/databricks/sql/auth/authenticators.py +++ b/src/databricks/sql/auth/authenticators.py @@ -25,14 +25,16 @@ import base64 -class Authenticator: +# Private API: this is an evolving interface and it will change in the future. +# Please must not depend on it in your applications. +class CredentialsProvider: def add_auth_token(self, request_headers): pass # Private API: this is an evolving interface and it will change in the future. # Please must not depend on it in your applications. -class AccessTokenAuthenticator(Authenticator): +class AccessTokenAuthProvider(CredentialsProvider): def __init__(self, access_token): self.__authorization_header_value = "Bearer {}".format(access_token) @@ -42,7 +44,7 @@ def add_auth_token(self, request_headers): # Private API: this is an evolving interface and it will change in the future. # Please must not depend on it in your applications. -class BasicAuthenticator(Authenticator): +class BasicAuthProvider(CredentialsProvider): def __init__(self, username, password): auth_credentials = "{username}:{password}".format(username, password).encode("UTF-8") auth_credentials_base64 = base64.standard_b64encode(auth_credentials).decode("UTF-8") @@ -55,7 +57,7 @@ def add_auth_token(self, request_headers): # Private API: this is an evolving interface and it will change in the future. # Please must not depend on it in your applications. -class DatabricksOAuthAuthenticator(Authenticator): +class DatabricksOAuthProvider(CredentialsProvider): # TODO: moderakh the refresh_token is only kept in memory. not saved on disk # hence if application restarts the user may need to re-authenticate # I will add support for this outside of the scope of current PR. diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index db1ca548d..49e20ab9d 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -11,7 +11,7 @@ import thrift.transport.TTransport from databricks.sql.auth.thrift_http_client import THttpClient -from databricks.sql.auth.authenticators import Authenticator +from databricks.sql.auth.authenticators import CredentialsProvider from databricks.sql.thrift_api.TCLIService import TCLIService, ttypes from databricks.sql import * from databricks.sql.utils import ( @@ -47,7 +47,7 @@ class ThriftBackend: BIT_MASKS = [1, 2, 4, 8, 16, 32, 64, 128] def __init__( - self, server_hostname: str, port, http_path: str, authenticator: Authenticator, http_headers, **kwargs + self, server_hostname: str, port, http_path: str, authenticator: CredentialsProvider, http_headers, **kwargs ): # Internal arguments in **kwargs: # _user_agent_entry From 70975fece45073ef63313fb09ce8830638438ec3 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Wed, 13 Jul 2022 17:13:04 -0700 Subject: [PATCH 08/57] added unit tests for legacy auth providers (PAT, User/Pass) Signed-off-by: Moe Derakhshani --- src/databricks/sql/auth/authenticators.py | 5 +-- tests/test_auth.py | 37 +++++++++++++++++++++++ 2 files changed, 38 insertions(+), 4 deletions(-) create mode 100644 tests/test_auth.py diff --git a/src/databricks/sql/auth/authenticators.py b/src/databricks/sql/auth/authenticators.py index 3efc801b5..9b98ad9e9 100644 --- a/src/databricks/sql/auth/authenticators.py +++ b/src/databricks/sql/auth/authenticators.py @@ -46,7 +46,7 @@ def add_auth_token(self, request_headers): # Please must not depend on it in your applications. class BasicAuthProvider(CredentialsProvider): def __init__(self, username, password): - auth_credentials = "{username}:{password}".format(username, password).encode("UTF-8") + auth_credentials = "{username}:{password}".format(username=username, password=password).encode("UTF-8") auth_credentials_base64 = base64.standard_b64encode(auth_credentials).decode("UTF-8") self.__authorization_header_value = "Basic {}".format(auth_credentials_base64) @@ -80,6 +80,3 @@ def _normalize_host_name(hostname): maybe_trailing_slash = "/" if not hostname.endswith("/") else "" return "{scheme}{host}{trailing}".format( scheme=maybe_scheme, host=hostname, trailing=maybe_trailing_slash) - - - diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 000000000..d9796d0bc --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,37 @@ +import unittest + +from databricks.sql.auth.auth import AccessTokenAuthProvider, BasicAuthProvider, CredentialsProvider + + +class Auth(unittest.TestCase): + + def test_access_token_provider(self): + access_token = "aBc2" + auth = AccessTokenAuthProvider(access_token=access_token) + + http_request = {'myKey': 'myVal'} + auth.add_auth_token(http_request) + self.assertEqual(http_request['Authorization'], 'Bearer aBc2') + self.assertEqual(len(http_request.keys()), 2) + self.assertEqual(http_request['myKey'], 'myVal') + + def test_basic_auth_provider(self): + username = "moderakh" + password = "Elevate Databricks 123!!!" + auth = BasicAuthProvider(username=username, password=password) + + http_request = {'myKey': 'myVal'} + auth.add_auth_token(http_request) + + self.assertEqual(http_request['Authorization'], 'Basic bW9kZXJha2g6RWxldmF0ZSBEYXRhYnJpY2tzIDEyMyEhIQ==') + self.assertEqual(len(http_request.keys()), 2) + self.assertEqual(http_request['myKey'], 'myVal') + + def test_noop_auth_provider(self): + auth = CredentialsProvider() + + http_request = {'myKey': 'myVal'} + auth.add_auth_token(http_request) + + self.assertEqual(len(http_request.keys()), 1) + self.assertEqual(http_request['myKey'], 'myVal') \ No newline at end of file From bccf86910b9d74f5e308528dcc09f952e49bc609 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Wed, 13 Jul 2022 17:47:33 -0700 Subject: [PATCH 09/57] added more tests Signed-off-by: Moe Derakhshani --- src/databricks/sql/auth/auth.py | 6 ++-- src/databricks/sql/client.py | 4 +-- tests/test_auth.py | 55 ++++++++++++++++++++++++++++++++- 3 files changed, 59 insertions(+), 6 deletions(-) diff --git a/src/databricks/sql/auth/auth.py b/src/databricks/sql/auth/auth.py index bd76f63da..87c32aabf 100644 --- a/src/databricks/sql/auth/auth.py +++ b/src/databricks/sql/auth/auth.py @@ -79,7 +79,7 @@ def __init__(self, tls_client_cert_file=tls_client_cert_file) -def get_authenticator(cfg: ClientContext): +def get_auth_provider(cfg: ClientContext): if cfg.auth_type == AuthType.DATABRICKS_OAUTH.value: return DatabricksOAuthProvider(cfg.hostname, cfg.oauth_client_id, cfg.oauth_scopes) elif cfg.token is not None: @@ -93,7 +93,7 @@ def get_authenticator(cfg: ClientContext): raise RuntimeError("No valid authentication settings!") -def get_python_sql_connector_authenticator(hostname: str, **kwargs): +def get_python_sql_connector_auth_provider(hostname: str, **kwargs): cfg = SqlConnectorClientContext(hostname=hostname, auth_type=kwargs.get("auth_type"), @@ -102,6 +102,6 @@ def get_python_sql_connector_authenticator(hostname: str, **kwargs): password=kwargs.get("_password"), use_cert_as_auth=kwargs.get("_use_cert_as_auth"), tls_client_cert_file=kwargs.get("_tls_client_cert_file")) - return get_authenticator(cfg) + return get_auth_provider(cfg) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 28f620243..83a2a10ec 100644 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -9,7 +9,7 @@ from databricks.sql.thrift_backend import ThriftBackend from databricks.sql.utils import ExecuteResponse, ParamEscaper from databricks.sql.types import Row -from databricks.sql.auth.auth import get_python_sql_connector_authenticator +from databricks.sql.auth.auth import get_python_sql_connector_auth_provider logger = logging.getLogger(__name__) @@ -85,7 +85,7 @@ def __init__( self.port = kwargs.get("_port", 443) self.disable_pandas = kwargs.get("_disable_pandas", False) - authenticator = get_python_sql_connector_authenticator(server_hostname, **kwargs) + authenticator = get_python_sql_connector_auth_provider(server_hostname, **kwargs) if not kwargs.get("_user_agent_entry"): useragent_header = "{}/{}".format(USER_AGENT_NAME, __version__) diff --git a/tests/test_auth.py b/tests/test_auth.py index d9796d0bc..b0d90f98b 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -1,6 +1,29 @@ +# Copyright 2022 Databricks, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"), except +# that the use of services to which certain application programming +# interfaces (each, an "API") connect requires that the user first obtain +# a license for the use of the APIs from Databricks, Inc. ("Databricks"), +# by creating an account at www.databricks.com and agreeing to either (a) +# the Community Edition Terms of Service, (b) the Databricks Terms of +# Service, or (c) another written agreement between Licensee and Databricks +# for the use of the APIs. +# +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import unittest from databricks.sql.auth.auth import AccessTokenAuthProvider, BasicAuthProvider, CredentialsProvider +from databricks.sql.auth.auth import get_python_sql_connector_auth_provider class Auth(unittest.TestCase): @@ -34,4 +57,34 @@ def test_noop_auth_provider(self): auth.add_auth_token(http_request) self.assertEqual(len(http_request.keys()), 1) - self.assertEqual(http_request['myKey'], 'myVal') \ No newline at end of file + self.assertEqual(http_request['myKey'], 'myVal') + + def test_get_python_sql_connector_auth_provider_access_token(self): + hostname = "moderakh-test.cloud.databricks.com" + kwargs = {'access_token': 'dpi123'} + auth_provider = get_python_sql_connector_auth_provider(hostname, **kwargs) + self.assertTrue(type(auth_provider).__name__, "AccessTokenAuthProvider") + + headers = {} + auth_provider.add_auth_token(headers) + self.assertEqual(headers['Authorization'], 'Bearer dpi123') + + def test_get_python_sql_connector_auth_provider_username_password(self): + username = "moderakh" + password = "Elevate Databricks 123!!!" + hostname = "moderakh-test.cloud.databricks.com" + kwargs = {'_username': username, '_password': password} + auth_provider = get_python_sql_connector_auth_provider(hostname, **kwargs) + self.assertTrue(type(auth_provider).__name__, "BasicAuthProvider") + + headers = {} + auth_provider.add_auth_token(headers) + self.assertEqual(headers['Authorization'], 'Basic bW9kZXJha2g6RWxldmF0ZSBEYXRhYnJpY2tzIDEyMyEhIQ==') + + def test_get_python_sql_connector_auth_provider_noop(self): + tls_client_cert_file = "fake.cert" + use_cert_as_auth = "abc" + hostname = "moderakh-test.cloud.databricks.com" + kwargs = {'_tls_client_cert_file': tls_client_cert_file, '_use_cert_as_auth': use_cert_as_auth } + auth_provider = get_python_sql_connector_auth_provider(hostname, **kwargs) + self.assertTrue(type(auth_provider).__name__, "CredentialProvider") From bedc274f4a5e2d3e3298a4aa5f302ca448eac85d Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Thu, 14 Jul 2022 08:39:59 -0700 Subject: [PATCH 10/57] replaced click with logging Signed-off-by: Moe Derakhshani --- src/databricks/sql/auth/authenticators.py | 3 +- src/databricks/sql/auth/oauth.py | 67 ++++++++++++++--------- 2 files changed, 42 insertions(+), 28 deletions(-) diff --git a/src/databricks/sql/auth/authenticators.py b/src/databricks/sql/auth/authenticators.py index 9b98ad9e9..8345a11c7 100644 --- a/src/databricks/sql/auth/authenticators.py +++ b/src/databricks/sql/auth/authenticators.py @@ -58,12 +58,13 @@ def add_auth_token(self, request_headers): # Private API: this is an evolving interface and it will change in the future. # Please must not depend on it in your applications. class DatabricksOAuthProvider(CredentialsProvider): + SCOPE_DELIM = ' ' # TODO: moderakh the refresh_token is only kept in memory. not saved on disk # hence if application restarts the user may need to re-authenticate # I will add support for this outside of the scope of current PR. def __init__(self, hostname, client_id, scopes): self._hostname = self._normalize_host_name(hostname=hostname) - self._scopes_as_str = ''.join(scopes) + self._scopes_as_str = DatabricksOAuthProvider.SCOPE_DELIM.join(scopes) access_token, refresh_token = get_tokens(hostname=self._hostname, client_id=client_id, scope=self._scopes_as_str) self._access_token = access_token self._refresh_token = refresh_token diff --git a/src/databricks/sql/auth/oauth.py b/src/databricks/sql/auth/oauth.py index d12fbbef0..ccf2c7ec3 100644 --- a/src/databricks/sql/auth/oauth.py +++ b/src/databricks/sql/auth/oauth.py @@ -29,7 +29,7 @@ from datetime import datetime, timedelta, tzinfo -import click +import logging import jwt from jwt import PyJWTError @@ -40,13 +40,14 @@ import requests from requests.exceptions import RequestException -from databricks_cli.utils import error_and_quit try: from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer except ImportError: from http.server import BaseHTTPRequestHandler, HTTPServer +logger = logging.getLogger(__name__) + # This could use 'import secrets' in Python 3 def token_urlsafe(nbytes=32): @@ -88,22 +89,26 @@ def fetch_well_known_config(idp_url): known_config_url = "{idp_url}/.well-known/oauth-authorization-server".format(idp_url=idp_url) try: response = requests.request(method="GET", url=known_config_url) - except RequestException: - error_and_quit("Unable to fetch OAuth configuration from {idp_url}.\n" - "Verify it is a valid workspace URL and that OAuth is " - "enabled on this account.".format(idp_url=idp_url)) + except RequestException as e: + logger.error("Unable to fetch OAuth configuration from {idp_url}.\n" + "Verify it is a valid workspace URL and that OAuth is " + "enabled on this account.".format(idp_url=idp_url)) + raise e if response.status_code != 200: - error_and_quit("Received status {status} OAuth configuration from " - "{idp_url}.\n Verify it is a valid workspace URL and " - "that OAuth is enabled on this account." - .format(status=response.status_code, idp_url=idp_url)) + msg = ("Received status {status} OAuth configuration from " + "{idp_url}.\n Verify it is a valid workspace URL and " + "that OAuth is enabled on this account." + .format(status=response.status_code, idp_url=idp_url)) + logger.error(msg) + raise RuntimeError(msg) try: return json.loads(response.text) - except json.decoder.JSONDecodeError: - error_and_quit("Unable to decode OAuth configuration from {idp_url}.\n" - "Verify it is a valid workspace URL and that OAuth is " - "enabled on this account.".format(idp_url=idp_url)) + except json.decoder.JSONDecodeError as e: + logger.error("Unable to decode OAuth configuration from {idp_url}.\n" + "Verify it is a valid workspace URL and that OAuth is " + "enabled on this account.".format(idp_url=idp_url)) + raise e def get_idp_url(host): @@ -171,17 +176,19 @@ def get_authorization_code(client, auth_url, redirect_url, scope, state, challen state=state, code_challenge=challenge, code_challenge_method="S256") - click.echo("Opening {uri}".format(uri=auth_req_uri)) + logger.info("Opening {uri}".format(uri=auth_req_uri)) with HTTPServer(("", port), SingleRequestHandler) as httpd: webbrowser.open_new(auth_req_uri) - click.echo("Listening for OAuth authorization callback at {uri}" + logger.info("Listening for OAuth authorization callback at {uri}" .format(uri=redirect_url)) httpd.handle_request() if not global_request_path: - error_and_quit("No path parameters were returned to the callback at {uri}" - .format(uri=redirect_url)) + msg = ("No path parameters were returned to the callback at {uri}" + .format(uri=redirect_url)) + logger.error(msg) + raise RuntimeError(msg) # This is a kludge because the parsing library expects https callbacks # We should probably set it up using https full_redirect_url = "https://localhost:{port}/{path}".format( @@ -189,8 +196,9 @@ def get_authorization_code(client, auth_url, redirect_url, scope, state, challen try: authorization_code_response = \ client.parse_request_uri_response(full_redirect_url, state=state) - except OAuth2Error as err: - error_and_quit("OAuth Token Request error {error}".format(error=err.description)) + except OAuth2Error as e: + logger.error("OAuth Token Request error {error}".format(error=err.description)) + raise e return authorization_code_response @@ -237,19 +245,22 @@ def check_and_refresh_access_token(hostname, access_token, refresh_token): # an unnecessary signature verification. decoded = jwt.decode(access_token, options={"verify_signature": False}) expiration_time = datetime.fromtimestamp(decoded['exp'], tz=UTC) - except PyJWTError as err: - error_and_quit(err) + except PyJWTError as e: + logger.error(e) + raise e if expiration_time > now: # The access token is fine. Just return it. return access_token, refresh_token, False if not refresh_token: - error_and_quit("OAuth access token expired on {expiration_time}." - .format(expiration_time=expiration_time)) + msg = ("OAuth access token expired on {expiration_time}." + .format(expiration_time=expiration_time)) + logger.error(msg) + raise RuntimeError(msg) # Try to refresh using the refresh token - click.echo("Attempting to refresh OAuth access token that expired on {expiration_time}" + logger.debug("Attempting to refresh OAuth access token that expired on {expiration_time}" .format(expiration_time=expiration_time)) oauth_response = send_refresh_token_request(hostname, refresh_token) fresh_access_token, fresh_refresh_token = get_tokens_from_response(oauth_response) @@ -275,8 +286,10 @@ def get_tokens(hostname, client_id, scope=None): state, challenge, REDIRECT_PORT) - except OAuth2Error as err: - error_and_quit("OAuth Authorization Error: {error}".format(error=err.description)) + except OAuth2Error as e: + msg = "OAuth Authorization Error: {error}".format(error=e.description) + logger.error(msg) + raise e token_request_url = oauth_config["token_endpoint"] code = auth_response['code'] From 0bde8fe2df4b8546daa8f8aa2bf9711b4c4dac33 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Thu, 14 Jul 2022 10:09:52 -0700 Subject: [PATCH 11/57] responded to code review comments Signed-off-by: Moe Derakhshani --- poetry.lock | 20 ++++++------ pyproject.toml | 1 - src/databricks/sql/auth/auth.py | 55 +++++++++++---------------------- 3 files changed, 28 insertions(+), 48 deletions(-) diff --git a/poetry.lock b/poetry.lock index 782e82735..f04374cbd 100644 --- a/poetry.lock +++ b/poetry.lock @@ -47,7 +47,7 @@ uvloop = ["uvloop (>=0.15.2)"] name = "click" version = "8.1.3" description = "Composable command line interface toolkit" -category = "main" +category = "dev" optional = false python-versions = ">=3.7" @@ -59,7 +59,7 @@ importlib-metadata = {version = "*", markers = "python_version < \"3.8\""} name = "colorama" version = "0.4.5" description = "Cross-platform colored terminal text." -category = "main" +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" @@ -67,7 +67,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" name = "importlib-metadata" version = "4.12.0" description = "Read metadata from Python packages" -category = "main" +category = "dev" optional = false python-versions = ">=3.7" @@ -125,7 +125,7 @@ python-versions = ">=3.7,<3.11" [[package]] name = "numpy" -version = "1.23.0" +version = "1.23.1" description = "NumPy is the fundamental package for array computing with Python." category = "main" optional = false @@ -340,26 +340,26 @@ python-versions = ">=3.6" name = "typing-extensions" version = "4.3.0" description = "Backported and Experimental Type Hints for Python 3.7+" -category = "main" +category = "dev" optional = false python-versions = ">=3.7" [[package]] name = "zipp" -version = "3.8.0" +version = "3.8.1" description = "Backport of pathlib-compatible object wrapper for zip files" -category = "main" +category = "dev" optional = false python-versions = ">=3.7" [package.extras] -docs = ["sphinx", "jaraco.packaging (>=9)", "rst.linker (>=1.9)"] -testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-cov", "pytest-enabler (>=1.0.1)", "jaraco.itertools", "func-timeout", "pytest-black (>=0.3.7)", "pytest-mypy (>=0.9.1)"] +docs = ["sphinx", "jaraco.packaging (>=9)", "rst.linker (>=1.9)", "jaraco.tidelift (>=1.4)"] +testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-cov", "pytest-enabler (>=1.3)", "jaraco.itertools", "func-timeout", "pytest-black (>=0.3.7)", "pytest-mypy (>=0.9.1)"] [metadata] lock-version = "1.1" python-versions = "^3.7.1" -content-hash = "677882d16ec4ae384857a2ca2e9e0f55add6cb9691fc0763647051006f692e7a" +content-hash = "eecff275015f002e123d81343378de753ae80f33e2a3bd34ea2ad11dbbffeec3" [metadata.files] atomicwrites = [] diff --git a/pyproject.toml b/pyproject.toml index 4a5e9e766..4226d5723 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,6 @@ python = "^3.7.1" thrift = "^0.13.0" pyarrow = "^5.0.0" pandas = "^1.3.0" -click=">=7.0" pyjwt=">=1.7.0" oauthlib=">=3.1.0" diff --git a/src/databricks/sql/auth/auth.py b/src/databricks/sql/auth/auth.py index 87c32aabf..0b0f8c5f9 100644 --- a/src/databricks/sql/auth/auth.py +++ b/src/databricks/sql/auth/auth.py @@ -37,7 +37,7 @@ def __init__(self, hostname: str, username: str = None, password: str = None, - token: str = None, + access_token: str = None, auth_type: str = None, oauth_scopes: List[str] = None, oauth_client_id: str = None, @@ -46,7 +46,7 @@ def __init__(self, self.hostname = hostname self.username = username self.password = password - self.token = token + self.access_token = access_token self.auth_type = auth_type self.oauth_scopes = oauth_scopes self.oauth_client_id = oauth_client_id @@ -54,36 +54,11 @@ def __init__(self, self.tls_client_cert_file = tls_client_cert_file -class SqlConnectorClientContext(ClientContext): - OAUTH_SCOPES = ["sql", "offline_access"] - # TODO: moderakh to be changed once registered on the service side - OAUTH_CLIENT_ID = "databricks-cli" - - def __init__(self, - hostname: str, - username: str = None, - password: str = None, - access_token: str = None, - auth_type: str = None, - use_cert_as_auth: str = None, - tls_client_cert_file: str = None): - super().__init__(oauth_scopes=SqlConnectorClientContext.OAUTH_SCOPES, - # to be changed once registered on the service side - oauth_client_id=SqlConnectorClientContext.OAUTH_CLIENT_ID, - hostname=hostname, - username=username, - password=password, - token=access_token, - auth_type=auth_type, - use_cert_as_auth=use_cert_as_auth, - tls_client_cert_file=tls_client_cert_file) - - def get_auth_provider(cfg: ClientContext): if cfg.auth_type == AuthType.DATABRICKS_OAUTH.value: return DatabricksOAuthProvider(cfg.hostname, cfg.oauth_client_id, cfg.oauth_scopes) - elif cfg.token is not None: - return AccessTokenAuthProvider(cfg.token) + elif cfg.access_token is not None: + return AccessTokenAuthProvider(cfg.access_token) elif cfg.username is not None and cfg.password is not None: return BasicAuthProvider(cfg.username, cfg.password) elif cfg.use_cert_as_auth and cfg.tls_client_cert_file: @@ -93,15 +68,21 @@ def get_auth_provider(cfg: ClientContext): raise RuntimeError("No valid authentication settings!") -def get_python_sql_connector_auth_provider(hostname: str, **kwargs): +OAUTH_SCOPES = ["sql", "offline_access"] +# TODO: moderakh to be changed once registered on the service side +OAUTH_CLIENT_ID = "databricks-cli" + - cfg = SqlConnectorClientContext(hostname=hostname, - auth_type=kwargs.get("auth_type"), - access_token=kwargs.get("access_token"), - username=kwargs.get("_username"), - password=kwargs.get("_password"), - use_cert_as_auth=kwargs.get("_use_cert_as_auth"), - tls_client_cert_file=kwargs.get("_tls_client_cert_file")) +def get_python_sql_connector_auth_provider(hostname: str, **kwargs): + cfg = ClientContext(hostname=hostname, + auth_type=kwargs.get("auth_type"), + access_token=kwargs.get("access_token"), + username=kwargs.get("_username"), + password=kwargs.get("_password"), + use_cert_as_auth=kwargs.get("_use_cert_as_auth"), + tls_client_cert_file=kwargs.get("_tls_client_cert_file"), + oauth_scopes=OAUTH_SCOPES, + oauth_client_id=OAUTH_CLIENT_ID) return get_auth_provider(cfg) From a1ffcec71be1d53e1af382601f05e01e145d0025 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Thu, 14 Jul 2022 12:08:24 -0700 Subject: [PATCH 12/57] made use of quotes consitent on oauth.py Signed-off-by: Moe Derakhshani --- src/databricks/sql/auth/oauth.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/databricks/sql/auth/oauth.py b/src/databricks/sql/auth/oauth.py index ccf2c7ec3..260bac66c 100644 --- a/src/databricks/sql/auth/oauth.py +++ b/src/databricks/sql/auth/oauth.py @@ -52,7 +52,7 @@ # This could use 'import secrets' in Python 3 def token_urlsafe(nbytes=32): tok = os.urandom(nbytes) - return base64.urlsafe_b64encode(tok).rstrip(b'=').decode('ascii') + return base64.urlsafe_b64encode(tok).rstrip(b"=").decode("ascii") # This could be datetime.timezone.utc in Python 3 @@ -119,8 +119,8 @@ def get_idp_url(host): def get_challenge(verifier_string=token_urlsafe(32)): - digest = hashlib.sha256(verifier_string.encode('UTF-8')).digest() - challenge_string = base64.urlsafe_b64encode(digest).decode("UTF-8").replace('=', '') + digest = hashlib.sha256(verifier_string.encode("UTF-8")).digest() + challenge_string = base64.urlsafe_b64encode(digest).decode("UTF-8").replace("=", "") return verifier_string, challenge_string @@ -221,7 +221,7 @@ def send_token_request(token_request_url, data): def send_refresh_token_request(hostname, client_id, refresh_token): idp_url = get_idp_url(hostname) oauth_config = fetch_well_known_config(idp_url) - token_request_url = oauth_config['token_endpoint'] + token_request_url = oauth_config["token_endpoint"] client = get_client(client_id=client_id) token_request_body = client.prepare_refresh_body( refresh_token=refresh_token, client_id=client.client_id) @@ -229,8 +229,8 @@ def send_refresh_token_request(hostname, client_id, refresh_token): def get_tokens_from_response(oauth_response): - access_token = oauth_response['access_token'] - refresh_token = oauth_response['refresh_token'] if 'refresh_token' in oauth_response else None + access_token = oauth_response["access_token"] + refresh_token = oauth_response["refresh_token"] if "refresh_token" in oauth_response else None return access_token, refresh_token @@ -244,7 +244,7 @@ def check_and_refresh_access_token(hostname, access_token, refresh_token): # This avoids having to fetch the public key from the issuer and perform # an unnecessary signature verification. decoded = jwt.decode(access_token, options={"verify_signature": False}) - expiration_time = datetime.fromtimestamp(decoded['exp'], tz=UTC) + expiration_time = datetime.fromtimestamp(decoded["exp"], tz=UTC) except PyJWTError as e: logger.error(e) raise e @@ -292,7 +292,7 @@ def get_tokens(hostname, client_id, scope=None): raise e token_request_url = oauth_config["token_endpoint"] - code = auth_response['code'] + code = auth_response["code"] oauth_response = \ send_auth_code_token_request(client, token_request_url, redirect_url, code, verifier) return get_tokens_from_response(oauth_response) From 681a64bb18e223d43d58b34b4505786077a6b122 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Thu, 14 Jul 2022 12:29:14 -0700 Subject: [PATCH 13/57] addressed f string related comments Signed-off-by: Moe Derakhshani --- src/databricks/sql/auth/oauth.py | 50 ++++++++++++++------------------ 1 file changed, 22 insertions(+), 28 deletions(-) diff --git a/src/databricks/sql/auth/oauth.py b/src/databricks/sql/auth/oauth.py index 260bac66c..4e947ed6f 100644 --- a/src/databricks/sql/auth/oauth.py +++ b/src/databricks/sql/auth/oauth.py @@ -82,40 +82,39 @@ def get_client(client_id): def get_redirect_url(port=REDIRECT_PORT): - return "http://localhost:{port}".format(port=port) + return f"http://localhost:{port}" def fetch_well_known_config(idp_url): - known_config_url = "{idp_url}/.well-known/oauth-authorization-server".format(idp_url=idp_url) + known_config_url = f"{idp_url}/.well-known/oauth-authorization-server" try: - response = requests.request(method="GET", url=known_config_url) + response = requests.get(url=known_config_url) except RequestException as e: - logger.error("Unable to fetch OAuth configuration from {idp_url}.\n" + logger.error(f"Unable to fetch OAuth configuration from {idp_url}.\n" "Verify it is a valid workspace URL and that OAuth is " - "enabled on this account.".format(idp_url=idp_url)) + "enabled on this account.") raise e if response.status_code != 200: - msg = ("Received status {status} OAuth configuration from " - "{idp_url}.\n Verify it is a valid workspace URL and " + msg = (f"Received status {response.status_code} OAuth configuration from " + f"{idp_url}.\n Verify it is a valid workspace URL and " "that OAuth is enabled on this account." - .format(status=response.status_code, idp_url=idp_url)) + ) logger.error(msg) raise RuntimeError(msg) try: return json.loads(response.text) except json.decoder.JSONDecodeError as e: - logger.error("Unable to decode OAuth configuration from {idp_url}.\n" + logger.error(f"Unable to decode OAuth configuration from {idp_url}.\n" "Verify it is a valid workspace URL and that OAuth is " - "enabled on this account.".format(idp_url=idp_url)) + "enabled on this account.") raise e def get_idp_url(host): maybe_scheme = "https://" if not host.startswith("https://") else "" maybe_trailing_slash = "/" if not host.endswith("/") else "" - return "{scheme}{host}{trailing}{path}".format( - scheme=maybe_scheme, host=host, trailing=maybe_trailing_slash, path=OIDC_REDIRECTOR_PATH) + return f"{maybe_scheme}{host}{maybe_trailing_slash}{OIDC_REDIRECTOR_PATH}" def get_challenge(verifier_string=token_urlsafe(32)): @@ -176,35 +175,32 @@ def get_authorization_code(client, auth_url, redirect_url, scope, state, challen state=state, code_challenge=challenge, code_challenge_method="S256") - logger.info("Opening {uri}".format(uri=auth_req_uri)) + logger.info(f"Opening {auth_req_uri}") with HTTPServer(("", port), SingleRequestHandler) as httpd: webbrowser.open_new(auth_req_uri) - logger.info("Listening for OAuth authorization callback at {uri}" - .format(uri=redirect_url)) + logger.info(f"Listening for OAuth authorization callback at {redirect_url}") httpd.handle_request() if not global_request_path: - msg = ("No path parameters were returned to the callback at {uri}" - .format(uri=redirect_url)) + msg = f"No path parameters were returned to the callback at {redirect_url}" logger.error(msg) raise RuntimeError(msg) # This is a kludge because the parsing library expects https callbacks # We should probably set it up using https - full_redirect_url = "https://localhost:{port}/{path}".format( - port=port, path=global_request_path) + full_redirect_url = f"https://localhost:{port}/{global_request_path}" try: authorization_code_response = \ client.parse_request_uri_response(full_redirect_url, state=state) except OAuth2Error as e: - logger.error("OAuth Token Request error {error}".format(error=err.description)) + logger.error(f"OAuth Token Request error {e.description}") raise e return authorization_code_response def send_auth_code_token_request(client, token_request_url, redirect_url, code, verifier): token_request_body = client.prepare_request_body(code=code, redirect_uri=redirect_url) - data = "{body}&code_verifier={verifier}".format(body=token_request_body, verifier=verifier) + data = f"{token_request_body}&code_verifier={verifier}" return send_token_request(token_request_url, data) @@ -213,7 +209,7 @@ def send_token_request(token_request_url, data): "Accept": "application/json", "Content-Type": "application/x-www-form-urlencoded" } - response = requests.request(method="POST", url=token_request_url, data=data, headers=headers) + response = requests.post(url=token_request_url, data=data, headers=headers) oauth_response = json.loads(response.text) return oauth_response @@ -254,14 +250,12 @@ def check_and_refresh_access_token(hostname, access_token, refresh_token): return access_token, refresh_token, False if not refresh_token: - msg = ("OAuth access token expired on {expiration_time}." - .format(expiration_time=expiration_time)) + msg = f"OAuth access token expired on {expiration_time}." logger.error(msg) raise RuntimeError(msg) # Try to refresh using the refresh token - logger.debug("Attempting to refresh OAuth access token that expired on {expiration_time}" - .format(expiration_time=expiration_time)) + logger.debug(f"Attempting to refresh OAuth access token that expired on {expiration_time}") oauth_response = send_refresh_token_request(hostname, refresh_token) fresh_access_token, fresh_refresh_token = get_tokens_from_response(oauth_response) return fresh_access_token, fresh_refresh_token, True @@ -272,7 +266,7 @@ def get_tokens(hostname, client_id, scope=None): oauth_config = fetch_well_known_config(idp_url) # We are going to override oauth_config["authorization_endpoint"] use the # /oidc redirector on the hostname, which may inject additional parameters. - auth_url = "{}oidc/v1/authorize".format(hostname) + auth_url = f"{hostname}oidc/v1/authorize" state = token_urlsafe(16) (verifier, challenge) = get_challenge() client = get_client(client_id) @@ -287,7 +281,7 @@ def get_tokens(hostname, client_id, scope=None): challenge, REDIRECT_PORT) except OAuth2Error as e: - msg = "OAuth Authorization Error: {error}".format(error=e.description) + msg = f"OAuth Authorization Error: {e.description}" logger.error(msg) raise e From f8bb7e971754a5d2c1e4c2dba376b91cc2ec5ad9 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Thu, 14 Jul 2022 12:39:10 -0700 Subject: [PATCH 14/57] removed client method Signed-off-by: Moe Derakhshani --- src/databricks/sql/auth/oauth.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/src/databricks/sql/auth/oauth.py b/src/databricks/sql/auth/oauth.py index 4e947ed6f..04f34db65 100644 --- a/src/databricks/sql/auth/oauth.py +++ b/src/databricks/sql/auth/oauth.py @@ -23,7 +23,6 @@ import base64 import hashlib -import json import os import webbrowser @@ -77,10 +76,6 @@ def dst(self, dt): UTC = UTCTimeZone() -def get_client(client_id): - return oauthlib.oauth2.WebApplicationClient(client_id) - - def get_redirect_url(port=REDIRECT_PORT): return f"http://localhost:{port}" @@ -103,8 +98,8 @@ def fetch_well_known_config(idp_url): logger.error(msg) raise RuntimeError(msg) try: - return json.loads(response.text) - except json.decoder.JSONDecodeError as e: + return response.json() + except requests.exceptions.JSONDecodeError as e: logger.error(f"Unable to decode OAuth configuration from {idp_url}.\n" "Verify it is a valid workspace URL and that OAuth is " "enabled on this account.") @@ -210,15 +205,14 @@ def send_token_request(token_request_url, data): "Content-Type": "application/x-www-form-urlencoded" } response = requests.post(url=token_request_url, data=data, headers=headers) - oauth_response = json.loads(response.text) - return oauth_response + return response.json() def send_refresh_token_request(hostname, client_id, refresh_token): idp_url = get_idp_url(hostname) oauth_config = fetch_well_known_config(idp_url) token_request_url = oauth_config["token_endpoint"] - client = get_client(client_id=client_id) + client = oauthlib.oauth2.WebApplicationClient(client_id) token_request_body = client.prepare_refresh_body( refresh_token=refresh_token, client_id=client.client_id) return send_token_request(token_request_url, token_request_body) @@ -269,7 +263,7 @@ def get_tokens(hostname, client_id, scope=None): auth_url = f"{hostname}oidc/v1/authorize" state = token_urlsafe(16) (verifier, challenge) = get_challenge() - client = get_client(client_id) + client = oauthlib.oauth2.WebApplicationClient(client_id) redirect_url = get_redirect_url() try: auth_response = get_authorization_code( From 774c882daed5168643e5c6caf7c3403831cb8a2a Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Thu, 14 Jul 2022 12:51:41 -0700 Subject: [PATCH 15/57] responded to review comments Signed-off-by: Moe Derakhshani --- poetry.lock | 17 +---------------- pyproject.toml | 1 - src/databricks/sql/auth/oauth.py | 9 +++------ src/databricks/sql/auth/thrift_http_client.py | 10 +++++----- src/databricks/sql/thrift_backend.py | 6 +++--- 5 files changed, 12 insertions(+), 31 deletions(-) diff --git a/poetry.lock b/poetry.lock index f04374cbd..5211b2edb 100644 --- a/poetry.lock +++ b/poetry.lock @@ -230,20 +230,6 @@ python-versions = ">=3.6" [package.dependencies] numpy = ">=1.16.6" -[[package]] -name = "pyjwt" -version = "2.4.0" -description = "JSON Web Token implementation in Python" -category = "main" -optional = false -python-versions = ">=3.6" - -[package.extras] -crypto = ["cryptography (>=3.3.1)"] -dev = ["sphinx", "sphinx-rtd-theme", "zope.interface", "cryptography (>=3.3.1)", "pytest (>=6.0.0,<7.0.0)", "coverage[toml] (==5.0.4)", "mypy", "pre-commit"] -docs = ["sphinx", "sphinx-rtd-theme", "zope.interface"] -tests = ["pytest (>=6.0.0,<7.0.0)", "coverage[toml] (==5.0.4)"] - [[package]] name = "pyparsing" version = "3.0.9" @@ -359,7 +345,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest- [metadata] lock-version = "1.1" python-versions = "^3.7.1" -content-hash = "eecff275015f002e123d81343378de753ae80f33e2a3bd34ea2ad11dbbffeec3" +content-hash = "282be1d6681fbcbe624d24eff5a4375ad05deeb80b86ac06e99dd5cfd8e71ee7" [metadata.files] atomicwrites = [] @@ -404,7 +390,6 @@ platformdirs = [] pluggy = [] py = [] pyarrow = [] -pyjwt = [] pyparsing = [] pytest = [] python-dateutil = [ diff --git a/pyproject.toml b/pyproject.toml index 4226d5723..d93cc8f93 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,6 @@ python = "^3.7.1" thrift = "^0.13.0" pyarrow = "^5.0.0" pandas = "^1.3.0" -pyjwt=">=1.7.0" oauthlib=">=3.1.0" [tool.poetry.dev-dependencies] diff --git a/src/databricks/sql/auth/oauth.py b/src/databricks/sql/auth/oauth.py index 04f34db65..9d16d1c8a 100644 --- a/src/databricks/sql/auth/oauth.py +++ b/src/databricks/sql/auth/oauth.py @@ -25,14 +25,11 @@ import hashlib import os import webbrowser - +import json from datetime import datetime, timedelta, tzinfo import logging -import jwt -from jwt import PyJWTError - import oauthlib.oauth2 from oauthlib.oauth2.rfc6749.errors import OAuth2Error @@ -233,9 +230,9 @@ def check_and_refresh_access_token(hostname, access_token, refresh_token): # If it has been tampered with, it will be rejected on the server side. # This avoids having to fetch the public key from the issuer and perform # an unnecessary signature verification. - decoded = jwt.decode(access_token, options={"verify_signature": False}) + decoded = json.loads(base64.standard_b64decode(access_token.split(".")[1])) expiration_time = datetime.fromtimestamp(decoded["exp"], tz=UTC) - except PyJWTError as e: + except Exception as e: logger.error(e) raise e diff --git a/src/databricks/sql/auth/thrift_http_client.py b/src/databricks/sql/auth/thrift_http_client.py index 50aefe29b..9e6f4ae1b 100644 --- a/src/databricks/sql/auth/thrift_http_client.py +++ b/src/databricks/sql/auth/thrift_http_client.py @@ -29,9 +29,9 @@ class THttpClient(thrift.transport.THttpClient.THttpClient): - def __init__(self, authenticator, uri_or_host, port=None, path=None, cafile=None, cert_file=None, key_file=None, ssl_context=None): + def __init__(self, auth_provider, uri_or_host, port=None, path=None, cafile=None, cert_file=None, key_file=None, ssl_context=None): super().__init__(uri_or_host, port, path, cafile, cert_file, key_file, ssl_context) - self.__authenticator = authenticator + self.__auth_provider = auth_provider def setCustomHeaders(self, headers): self._headers = headers @@ -40,8 +40,8 @@ def setCustomHeaders(self, headers): def flush(self): # TODO retry behaviour # TODO locking for multiple concurrent refresh token - req = {} - self.__authenticator.add_auth_token(req) - self._headers['Authorization'] = req['Authorization'] + headers = dict(self._headers) + self.__auth_provider.add_auth_token(headers) + self._headers = headers self.setCustomHeaders(self._headers) super().flush() \ No newline at end of file diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index 49e20ab9d..dfadc0716 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -47,7 +47,7 @@ class ThriftBackend: BIT_MASKS = [1, 2, 4, 8, 16, 32, 64, 128] def __init__( - self, server_hostname: str, port, http_path: str, authenticator: CredentialsProvider, http_headers, **kwargs + self, server_hostname: str, port, http_path: str, auth_provider: CredentialsProvider, http_headers, **kwargs ): # Internal arguments in **kwargs: # _user_agent_entry @@ -125,10 +125,10 @@ def __init__( password=tls_client_cert_key_password, ) - self._authenticator = authenticator + self._auth_provider = auth_provider self._transport = THttpClient( - authenticator=self._authenticator, + auth_provider=self._auth_provider, uri_or_host=uri, ssl_context=ssl_context, ) From 4a2ca6f98bf0485fa5c70369832bb0cb36eabea7 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Thu, 14 Jul 2022 13:50:11 -0700 Subject: [PATCH 16/57] added requests as an explicit dependency as that's required by oauth Signed-off-by: Moe Derakhshani --- poetry.lock | 80 +++++++++++++++++++++++++++++++++++++++++++++++++- pyproject.toml | 1 + 2 files changed, 80 insertions(+), 1 deletion(-) diff --git a/poetry.lock b/poetry.lock index 5211b2edb..231d71a73 100644 --- a/poetry.lock +++ b/poetry.lock @@ -43,6 +43,25 @@ d = ["aiohttp (>=3.7.4)"] jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] uvloop = ["uvloop (>=0.15.2)"] +[[package]] +name = "certifi" +version = "2022.6.15" +description = "Python package for providing Mozilla's CA Bundle." +category = "main" +optional = false +python-versions = ">=3.6" + +[[package]] +name = "charset-normalizer" +version = "2.1.0" +description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." +category = "main" +optional = false +python-versions = ">=3.6.0" + +[package.extras] +unicode_backport = ["unicodedata2"] + [[package]] name = "click" version = "8.1.3" @@ -63,6 +82,14 @@ category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +[[package]] +name = "idna" +version = "3.3" +description = "Internationalized Domain Names in Applications (IDNA)" +category = "main" +optional = false +python-versions = ">=3.5" + [[package]] name = "importlib-metadata" version = "4.12.0" @@ -282,6 +309,24 @@ category = "main" optional = false python-versions = "*" +[[package]] +name = "requests" +version = "2.28.1" +description = "Python HTTP for Humans." +category = "main" +optional = false +python-versions = ">=3.7, <4" + +[package.dependencies] +certifi = ">=2017.4.17" +charset-normalizer = ">=2,<3" +idna = ">=2.5,<4" +urllib3 = ">=1.21.1,<1.27" + +[package.extras] +socks = ["PySocks (>=1.5.6,!=1.5.7)"] +use_chardet_on_py3 = ["chardet (>=3.0.2,<6)"] + [[package]] name = "six" version = "1.16.0" @@ -330,6 +375,19 @@ category = "dev" optional = false python-versions = ">=3.7" +[[package]] +name = "urllib3" +version = "1.26.10" +description = "HTTP library with thread-safe connection pooling, file post, and more." +category = "main" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, <4" + +[package.extras] +brotli = ["brotlicffi (>=0.8.0)", "brotli (>=1.0.9)", "brotlipy (>=0.6.0)"] +secure = ["pyOpenSSL (>=0.14)", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "certifi", "ipaddress"] +socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] + [[package]] name = "zipp" version = "3.8.1" @@ -345,14 +403,26 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest- [metadata] lock-version = "1.1" python-versions = "^3.7.1" -content-hash = "282be1d6681fbcbe624d24eff5a4375ad05deeb80b86ac06e99dd5cfd8e71ee7" +content-hash = "56ae2d2b29d59b5eb54bf00fd1259d2d62e61892f237cae399ad608c627ab228" [metadata.files] atomicwrites = [] attrs = [] black = [] +certifi = [ + {file = "certifi-2022.6.15-py3-none-any.whl", hash = "sha256:fe86415d55e84719d75f8b69414f6438ac3547d2078ab91b67e779ef69378412"}, + {file = "certifi-2022.6.15.tar.gz", hash = "sha256:84c85a9078b11105f04f3036a9482ae10e4621616db313fe045dd24743a0820d"}, +] +charset-normalizer = [ + {file = "charset-normalizer-2.1.0.tar.gz", hash = "sha256:575e708016ff3a5e3681541cb9d79312c416835686d054a23accb873b254f413"}, + {file = "charset_normalizer-2.1.0-py3-none-any.whl", hash = "sha256:5189b6f22b01957427f35b6a08d9a0bc45b46d3788ef5a92e978433c7a35f8a5"}, +] click = [] colorama = [] +idna = [ + {file = "idna-3.3-py3-none-any.whl", hash = "sha256:84d9dd047ffa80596e0f246e2eab0b391788b0503584e8945f2368256d2735ff"}, + {file = "idna-3.3.tar.gz", hash = "sha256:9d643ff0a55b762d5cdb124b8eaa99c66322e2157b69160bc32796e824360e6d"}, +] importlib-metadata = [] iniconfig = [] mypy = [ @@ -397,9 +467,17 @@ python-dateutil = [ {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"}, ] pytz = [] +requests = [ + {file = "requests-2.28.1-py3-none-any.whl", hash = "sha256:8fefa2a1a1365bf5520aac41836fbee479da67864514bdb821f31ce07ce65349"}, + {file = "requests-2.28.1.tar.gz", hash = "sha256:7c5599b102feddaa661c826c56ab4fee28bfd17f5abca1ebbe3e7f19d7c97983"}, +] six = [] thrift = [] tomli = [] typed-ast = [] typing-extensions = [] +urllib3 = [ + {file = "urllib3-1.26.10-py2.py3-none-any.whl", hash = "sha256:8298d6d56d39be0e3bc13c1c97d133f9b45d797169a0e11cdd0e0489d786f7ec"}, + {file = "urllib3-1.26.10.tar.gz", hash = "sha256:879ba4d1e89654d9769ce13121e0f94310ea32e8d2f8cf587b77c08bbcdb30d6"}, +] zipp = [] diff --git a/pyproject.toml b/pyproject.toml index d93cc8f93..548ea8b2b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ python = "^3.7.1" thrift = "^0.13.0" pyarrow = "^5.0.0" pandas = "^1.3.0" +requests=">2.18.1" oauthlib=">=3.1.0" [tool.poetry.dev-dependencies] From 61f2ec00a869da20c41264e1d9c53451e02fe976 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Thu, 14 Jul 2022 13:53:33 -0700 Subject: [PATCH 17/57] responded to review comments Signed-off-by: Moe Derakhshani --- src/databricks/sql/auth/authenticators.py | 8 ++++---- src/databricks/sql/auth/thrift_http_client.py | 2 +- tests/test_auth.py | 10 +++++----- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/databricks/sql/auth/authenticators.py b/src/databricks/sql/auth/authenticators.py index 8345a11c7..763b2ef1a 100644 --- a/src/databricks/sql/auth/authenticators.py +++ b/src/databricks/sql/auth/authenticators.py @@ -28,7 +28,7 @@ # Private API: this is an evolving interface and it will change in the future. # Please must not depend on it in your applications. class CredentialsProvider: - def add_auth_token(self, request_headers): + def add_headers(self, request_headers): pass @@ -38,7 +38,7 @@ class AccessTokenAuthProvider(CredentialsProvider): def __init__(self, access_token): self.__authorization_header_value = "Bearer {}".format(access_token) - def add_auth_token(self, request_headers): + def add_headers(self, request_headers): request_headers['Authorization'] = self.__authorization_header_value @@ -51,7 +51,7 @@ def __init__(self, username, password): self.__authorization_header_value = "Basic {}".format(auth_credentials_base64) - def add_auth_token(self, request_headers): + def add_headers(self, request_headers): request_headers['Authorization'] = self.__authorization_header_value @@ -69,7 +69,7 @@ def __init__(self, hostname, client_id, scopes): self._access_token = access_token self._refresh_token = refresh_token - def add_auth_token(self, request_headers): + def add_headers(self, request_headers): check_and_refresh_access_token(hostname=self._hostname, access_token=self._access_token, refresh_token=self._refresh_token) diff --git a/src/databricks/sql/auth/thrift_http_client.py b/src/databricks/sql/auth/thrift_http_client.py index 9e6f4ae1b..c5a6859e1 100644 --- a/src/databricks/sql/auth/thrift_http_client.py +++ b/src/databricks/sql/auth/thrift_http_client.py @@ -41,7 +41,7 @@ def flush(self): # TODO retry behaviour # TODO locking for multiple concurrent refresh token headers = dict(self._headers) - self.__auth_provider.add_auth_token(headers) + self.__auth_provider.add_headers(headers) self._headers = headers self.setCustomHeaders(self._headers) super().flush() \ No newline at end of file diff --git a/tests/test_auth.py b/tests/test_auth.py index b0d90f98b..a39e97995 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -33,7 +33,7 @@ def test_access_token_provider(self): auth = AccessTokenAuthProvider(access_token=access_token) http_request = {'myKey': 'myVal'} - auth.add_auth_token(http_request) + auth.add_headers(http_request) self.assertEqual(http_request['Authorization'], 'Bearer aBc2') self.assertEqual(len(http_request.keys()), 2) self.assertEqual(http_request['myKey'], 'myVal') @@ -44,7 +44,7 @@ def test_basic_auth_provider(self): auth = BasicAuthProvider(username=username, password=password) http_request = {'myKey': 'myVal'} - auth.add_auth_token(http_request) + auth.add_headers(http_request) self.assertEqual(http_request['Authorization'], 'Basic bW9kZXJha2g6RWxldmF0ZSBEYXRhYnJpY2tzIDEyMyEhIQ==') self.assertEqual(len(http_request.keys()), 2) @@ -54,7 +54,7 @@ def test_noop_auth_provider(self): auth = CredentialsProvider() http_request = {'myKey': 'myVal'} - auth.add_auth_token(http_request) + auth.add_headers(http_request) self.assertEqual(len(http_request.keys()), 1) self.assertEqual(http_request['myKey'], 'myVal') @@ -66,7 +66,7 @@ def test_get_python_sql_connector_auth_provider_access_token(self): self.assertTrue(type(auth_provider).__name__, "AccessTokenAuthProvider") headers = {} - auth_provider.add_auth_token(headers) + auth_provider.add_headers(headers) self.assertEqual(headers['Authorization'], 'Bearer dpi123') def test_get_python_sql_connector_auth_provider_username_password(self): @@ -78,7 +78,7 @@ def test_get_python_sql_connector_auth_provider_username_password(self): self.assertTrue(type(auth_provider).__name__, "BasicAuthProvider") headers = {} - auth_provider.add_auth_token(headers) + auth_provider.add_headers(headers) self.assertEqual(headers['Authorization'], 'Basic bW9kZXJha2g6RWxldmF0ZSBEYXRhYnJpY2tzIDEyMyEhIQ==') def test_get_python_sql_connector_auth_provider_noop(self): From adce3f3821e21758c3afd5229c4c670ae3986240 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Thu, 14 Jul 2022 13:55:19 -0700 Subject: [PATCH 18/57] responded to review comments Signed-off-by: Moe Derakhshani --- src/databricks/sql/auth/thrift_http_client.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/databricks/sql/auth/thrift_http_client.py b/src/databricks/sql/auth/thrift_http_client.py index c5a6859e1..daf0fec49 100644 --- a/src/databricks/sql/auth/thrift_http_client.py +++ b/src/databricks/sql/auth/thrift_http_client.py @@ -39,7 +39,6 @@ def setCustomHeaders(self, headers): def flush(self): # TODO retry behaviour - # TODO locking for multiple concurrent refresh token headers = dict(self._headers) self.__auth_provider.add_headers(headers) self._headers = headers From dcec176f74e407a77a68e62a96ea13636863a8ac Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Thu, 14 Jul 2022 15:00:26 -0700 Subject: [PATCH 19/57] cleanup Signed-off-by: Moe Derakhshani --- src/databricks/sql/auth/oauth.py | 3 +-- src/databricks/sql/client.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/databricks/sql/auth/oauth.py b/src/databricks/sql/auth/oauth.py index 9d16d1c8a..d2803c41c 100644 --- a/src/databricks/sql/auth/oauth.py +++ b/src/databricks/sql/auth/oauth.py @@ -159,8 +159,7 @@ def log_message(self, format, *args): def get_authorization_code(client, auth_url, redirect_url, scope, state, challenge, port): - #pylint: disable=unused-variable - (auth_req_uri, headers, body) = client.prepare_authorization_request( + (auth_req_uri, _, _) = client.prepare_authorization_request( authorization_url=auth_url, redirect_url=redirect_url, scope=scope, diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 83a2a10ec..99d48ea26 100644 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -85,7 +85,7 @@ def __init__( self.port = kwargs.get("_port", 443) self.disable_pandas = kwargs.get("_disable_pandas", False) - authenticator = get_python_sql_connector_auth_provider(server_hostname, **kwargs) + auth_provider = get_python_sql_connector_auth_provider(server_hostname, **kwargs) if not kwargs.get("_user_agent_entry"): useragent_header = "{}/{}".format(USER_AGENT_NAME, __version__) @@ -100,7 +100,7 @@ def __init__( self.host, self.port, http_path, - authenticator, + auth_provider, (http_headers or []) + base_headers, **kwargs ) From 964269275a0c267a669b1c5ac2d158e8b628ae50 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Thu, 14 Jul 2022 22:47:41 -0700 Subject: [PATCH 20/57] cleanup Signed-off-by: Moe Derakhshani --- src/databricks/sql/auth/authenticators.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/databricks/sql/auth/authenticators.py b/src/databricks/sql/auth/authenticators.py index 763b2ef1a..b9db7415a 100644 --- a/src/databricks/sql/auth/authenticators.py +++ b/src/databricks/sql/auth/authenticators.py @@ -46,10 +46,10 @@ def add_headers(self, request_headers): # Please must not depend on it in your applications. class BasicAuthProvider(CredentialsProvider): def __init__(self, username, password): - auth_credentials = "{username}:{password}".format(username=username, password=password).encode("UTF-8") + auth_credentials = f"{username}:{password}".encode("UTF-8") auth_credentials_base64 = base64.standard_b64encode(auth_credentials).decode("UTF-8") - self.__authorization_header_value = "Basic {}".format(auth_credentials_base64) + self.__authorization_header_value = f"Basic {auth_credentials_base64}" def add_headers(self, request_headers): request_headers['Authorization'] = self.__authorization_header_value @@ -73,11 +73,10 @@ def add_headers(self, request_headers): check_and_refresh_access_token(hostname=self._hostname, access_token=self._access_token, refresh_token=self._refresh_token) - request_headers['Authorization'] = "Bearer {}".format(self._access_token) + request_headers['Authorization'] = f"Bearer {self._access_token}" @staticmethod def _normalize_host_name(hostname): maybe_scheme = "https://" if not hostname.startswith("https://") else "" maybe_trailing_slash = "/" if not hostname.endswith("/") else "" - return "{scheme}{host}{trailing}".format( - scheme=maybe_scheme, host=hostname, trailing=maybe_trailing_slash) + return f"{maybe_scheme}{hostname}{maybe_trailing_slash}" From 964ddd2a6f85c59eb790a29fe8c90a67a6df307d Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Wed, 24 Aug 2022 12:41:11 -0700 Subject: [PATCH 21/57] added support for persistence Signed-off-by: Moe Derakhshani --- src/databricks/sql/__init__.py | 4 +- src/databricks/sql/auth/auth.py | 13 +++-- src/databricks/sql/auth/authenticators.py | 48 +++++++++++++--- src/databricks/sql/auth/oauth.py | 4 +- src/databricks/sql/client.py | 5 +- src/databricks/sql/experimental/__init__.py | 0 .../sql/experimental/oauth_persistence.py | 56 +++++++++++++++++++ 7 files changed, 113 insertions(+), 17 deletions(-) create mode 100644 src/databricks/sql/experimental/__init__.py create mode 100644 src/databricks/sql/experimental/oauth_persistence.py diff --git a/src/databricks/sql/__init__.py b/src/databricks/sql/__init__.py index ae67264f7..c94cbc31e 100644 --- a/src/databricks/sql/__init__.py +++ b/src/databricks/sql/__init__.py @@ -44,6 +44,6 @@ def TimestampFromTicks(ticks): return Timestamp(*time.localtime(ticks)[:6]) -def connect(server_hostname, http_path, **kwargs): +def connect(server_hostname, http_path, experimental_oauth_persistence=None, **kwargs): from .client import Connection - return Connection(server_hostname, http_path, **kwargs) + return Connection(server_hostname, http_path, experimental_oauth_persistence, **kwargs) diff --git a/src/databricks/sql/auth/auth.py b/src/databricks/sql/auth/auth.py index 0b0f8c5f9..0c0e5938f 100644 --- a/src/databricks/sql/auth/auth.py +++ b/src/databricks/sql/auth/auth.py @@ -24,6 +24,7 @@ from enum import Enum from databricks.sql.auth.authenticators import CredentialsProvider, \ AccessTokenAuthProvider, BasicAuthProvider, DatabricksOAuthProvider +from databricks.sql.experimental.oauth_persistence import OAuthPersistence class AuthType(Enum): @@ -42,7 +43,9 @@ def __init__(self, oauth_scopes: List[str] = None, oauth_client_id: str = None, use_cert_as_auth: str = None, - tls_client_cert_file: str = None): + tls_client_cert_file: str = None, + oauth_persistence=None + ): self.hostname = hostname self.username = username self.password = password @@ -52,11 +55,12 @@ def __init__(self, self.oauth_client_id = oauth_client_id self.use_cert_as_auth = use_cert_as_auth self.tls_client_cert_file = tls_client_cert_file + self.oauth_persistence = oauth_persistence def get_auth_provider(cfg: ClientContext): if cfg.auth_type == AuthType.DATABRICKS_OAUTH.value: - return DatabricksOAuthProvider(cfg.hostname, cfg.oauth_client_id, cfg.oauth_scopes) + return DatabricksOAuthProvider(cfg.hostname, cfg.oauth_persistence, cfg.oauth_client_id, cfg.oauth_scopes) elif cfg.access_token is not None: return AccessTokenAuthProvider(cfg.access_token) elif cfg.username is not None and cfg.password is not None: @@ -73,7 +77,7 @@ def get_auth_provider(cfg: ClientContext): OAUTH_CLIENT_ID = "databricks-cli" -def get_python_sql_connector_auth_provider(hostname: str, **kwargs): +def get_python_sql_connector_auth_provider(hostname: str, oauth_persistence: OAuthPersistence, **kwargs): cfg = ClientContext(hostname=hostname, auth_type=kwargs.get("auth_type"), access_token=kwargs.get("access_token"), @@ -82,7 +86,8 @@ def get_python_sql_connector_auth_provider(hostname: str, **kwargs): use_cert_as_auth=kwargs.get("_use_cert_as_auth"), tls_client_cert_file=kwargs.get("_tls_client_cert_file"), oauth_scopes=OAUTH_SCOPES, - oauth_client_id=OAUTH_CLIENT_ID) + oauth_client_id=OAUTH_CLIENT_ID, + oauth_persistence=oauth_persistence) return get_auth_provider(cfg) diff --git a/src/databricks/sql/auth/authenticators.py b/src/databricks/sql/auth/authenticators.py index b9db7415a..09fde1b9e 100644 --- a/src/databricks/sql/auth/authenticators.py +++ b/src/databricks/sql/auth/authenticators.py @@ -27,6 +27,9 @@ # Private API: this is an evolving interface and it will change in the future. # Please must not depend on it in your applications. +from databricks.sql.experimental.oauth_persistence import OAuthToken + + class CredentialsProvider: def add_headers(self, request_headers): pass @@ -59,18 +62,17 @@ def add_headers(self, request_headers): # Please must not depend on it in your applications. class DatabricksOAuthProvider(CredentialsProvider): SCOPE_DELIM = ' ' - # TODO: moderakh the refresh_token is only kept in memory. not saved on disk - # hence if application restarts the user may need to re-authenticate - # I will add support for this outside of the scope of current PR. - def __init__(self, hostname, client_id, scopes): + + def __init__(self, hostname, oauth_persistence, client_id, scopes): self._hostname = self._normalize_host_name(hostname=hostname) self._scopes_as_str = DatabricksOAuthProvider.SCOPE_DELIM.join(scopes) - access_token, refresh_token = get_tokens(hostname=self._hostname, client_id=client_id, scope=self._scopes_as_str) - self._access_token = access_token - self._refresh_token = refresh_token + self._oauth_persistence = oauth_persistence + self._client_id = client_id + self._get_tokens() def add_headers(self, request_headers): check_and_refresh_access_token(hostname=self._hostname, + client_id=self._client_id, access_token=self._access_token, refresh_token=self._refresh_token) request_headers['Authorization'] = f"Bearer {self._access_token}" @@ -80,3 +82,35 @@ def _normalize_host_name(hostname): maybe_scheme = "https://" if not hostname.startswith("https://") else "" maybe_trailing_slash = "/" if not hostname.endswith("/") else "" return f"{maybe_scheme}{hostname}{maybe_trailing_slash}" + + def _get_tokens(self): + if self._oauth_persistence: + token = self._oauth_persistence.read() + if token: + self._access_token = token.get_access_token() + self._refresh_token = token.get_refresh_token() + self._update_token_if_expired() + else: + (access_token, refresh_token) = get_tokens(hostname=self._hostname, + client_id=self._client_id, + scope=self._scopes_as_str) + self._access_token = access_token + self._refresh_token = refresh_token + self._oauth_persistence.persist(OAuthToken(access_token, refresh_token)) + + def _update_token_if_expired(self): + (fresh_access_token, fresh_refresh_token, is_refreshed) = check_and_refresh_access_token( + hostname=self._hostname, + client_id=self._client_id, + access_token=self._access_token, + refresh_token=self._refresh_token) + + if not is_refreshed: + return + else: + self._access_token = fresh_access_token + self._refresh_token = fresh_refresh_token + + if self._oauth_persistence: + token = OAuthToken(self._access_token, self._refresh_token) + self._oauth_persistence.persist(token) diff --git a/src/databricks/sql/auth/oauth.py b/src/databricks/sql/auth/oauth.py index d2803c41c..f89fcdbce 100644 --- a/src/databricks/sql/auth/oauth.py +++ b/src/databricks/sql/auth/oauth.py @@ -220,7 +220,7 @@ def get_tokens_from_response(oauth_response): return access_token, refresh_token -def check_and_refresh_access_token(hostname, access_token, refresh_token): +def check_and_refresh_access_token(hostname, client_id, access_token, refresh_token): now = datetime.now(tz=UTC) # If we can't decode an expiration time, this will be expired by default. expiration_time = now @@ -246,7 +246,7 @@ def check_and_refresh_access_token(hostname, access_token, refresh_token): # Try to refresh using the refresh token logger.debug(f"Attempting to refresh OAuth access token that expired on {expiration_time}") - oauth_response = send_refresh_token_request(hostname, refresh_token) + oauth_response = send_refresh_token_request(hostname, client_id, refresh_token) fresh_access_token, fresh_refresh_token = get_tokens_from_response(oauth_response) return fresh_access_token, fresh_refresh_token, True diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 99d48ea26..a07cd3c50 100644 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -10,7 +10,7 @@ from databricks.sql.utils import ExecuteResponse, ParamEscaper from databricks.sql.types import Row from databricks.sql.auth.auth import get_python_sql_connector_auth_provider - +from databricks.sql.experimental.oauth_persistence import OAuthPersistence logger = logging.getLogger(__name__) DEFAULT_RESULT_BUFFER_SIZE_BYTES = 10485760 @@ -22,6 +22,7 @@ def __init__( self, server_hostname: str, http_path: str, + oauth_persistence: OAuthPersistence = None, http_headers: Optional[List[Tuple[str, str]]] = None, session_configuration: Dict[str, Any] = None, catalog: Optional[str] = None, @@ -85,7 +86,7 @@ def __init__( self.port = kwargs.get("_port", 443) self.disable_pandas = kwargs.get("_disable_pandas", False) - auth_provider = get_python_sql_connector_auth_provider(server_hostname, **kwargs) + auth_provider = get_python_sql_connector_auth_provider(server_hostname, oauth_persistence, **kwargs) if not kwargs.get("_user_agent_entry"): useragent_header = "{}/{}".format(USER_AGENT_NAME, __version__) diff --git a/src/databricks/sql/experimental/__init__.py b/src/databricks/sql/experimental/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/databricks/sql/experimental/oauth_persistence.py b/src/databricks/sql/experimental/oauth_persistence.py new file mode 100644 index 000000000..ceccdde11 --- /dev/null +++ b/src/databricks/sql/experimental/oauth_persistence.py @@ -0,0 +1,56 @@ +import logging +import json +logger = logging.getLogger(__name__) + + +class OAuthToken: + def __init__(self, access_token, refresh_token): + self._access_token = access_token + self._refresh_token = refresh_token + + def get_access_token(self) -> str: + return self._access_token + + def get_refresh_token(self) -> str: + return self._refresh_token + + +class OAuthPersistence: + def persist(self, oauth_token: OAuthToken): + pass + + def read(self) -> OAuthToken: + pass + + +# Note this is only intended to be used for development +class DevOnlyFilePersistence(OAuthPersistence): + + def __init__(self, file_path): + self._file_path = file_path + + def persist(self, token: OAuthToken): + logger.info(f"persisting token in {self._file_path}") + + # Data to be written + dictionary = { + "refresh_token": token.get_refresh_token(), + "access_token": token.get_access_token() + } + + # Serializing json + json_object = json.dumps(dictionary, indent=4) + + with open(self._file_path, "w") as outfile: + outfile.write(json_object) + + def read(self) -> OAuthToken: + # TODO: validate the + try: + with open(self._file_path, "r") as infile: + json_as_string = infile.read() + + token_as_json = json.loads(json_as_string) + return OAuthToken(token_as_json['access_token'], token_as_json['refresh_token']) + except Exception as e: + return None From b062674039146b17a9ce6a45669a79e0d80440f9 Mon Sep 17 00:00:00 2001 From: Jesse Date: Fri, 15 Jul 2022 14:30:09 -0500 Subject: [PATCH 22/57] Add e2e tests (#12) Signed-off-by: Jesse Whitehouse Signed-off-by: Moe Derakhshani --- .github/workflows/code-quality-checks.yml | 4 +- CONTRIBUTING.md | 46 +- tests/__init__.py | 0 tests/e2e/common/__init__.py | 0 tests/e2e/common/core_tests.py | 131 +++++ tests/e2e/common/decimal_tests.py | 48 ++ tests/e2e/common/large_queries_mixin.py | 100 ++++ tests/e2e/common/predicates.py | 101 ++++ tests/e2e/common/retry_test_mixins.py | 38 ++ tests/e2e/common/timestamp_tests.py | 74 +++ tests/e2e/driver_tests.py | 589 ++++++++++++++++++++++ tests/{ => unit}/test_arrow_queue.py | 0 tests/{ => unit}/test_fetches.py | 0 tests/{ => unit}/test_fetches_bench.py | 0 tests/{ => unit}/test_thrift_backend.py | 0 tests/{ => unit}/tests.py | 0 16 files changed, 1121 insertions(+), 10 deletions(-) create mode 100644 tests/__init__.py create mode 100644 tests/e2e/common/__init__.py create mode 100644 tests/e2e/common/core_tests.py create mode 100644 tests/e2e/common/decimal_tests.py create mode 100644 tests/e2e/common/large_queries_mixin.py create mode 100644 tests/e2e/common/predicates.py create mode 100644 tests/e2e/common/retry_test_mixins.py create mode 100644 tests/e2e/common/timestamp_tests.py create mode 100644 tests/e2e/driver_tests.py rename tests/{ => unit}/test_arrow_queue.py (100%) rename tests/{ => unit}/test_fetches.py (100%) rename tests/{ => unit}/test_fetches_bench.py (100%) rename tests/{ => unit}/test_thrift_backend.py (100%) rename tests/{ => unit}/tests.py (100%) diff --git a/.github/workflows/code-quality-checks.yml b/.github/workflows/code-quality-checks.yml index b48a2faf2..a6f44144c 100644 --- a/.github/workflows/code-quality-checks.yml +++ b/.github/workflows/code-quality-checks.yml @@ -1,7 +1,7 @@ name: Code Quality Checks on: [push] jobs: - run-tests: + run-unit-tests: runs-on: ubuntu-latest steps: #---------------------------------------------- @@ -48,7 +48,7 @@ jobs: # run test suite #---------------------------------------------- - name: Run tests - run: poetry run pytest tests/ + run: poetry run python -m pytest tests/unit check-linting: runs-on: ubuntu-latest steps: diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index bd4886df2..cfc34a320 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -9,34 +9,64 @@ This project uses [Poetry](https://python-poetry.org/) for dependency management 1. Clone this respository 2. Run `poetry install` -### Unit Tests +### Run tests -We use [Pytest](https://docs.pytest.org/en/7.1.x/) as our test runner. Invoke it with `poetry run pytest`, all other arguments are passed directly to `pytest`. +We use [Pytest](https://docs.pytest.org/en/7.1.x/) as our test runner. Invoke it with `poetry run python -m pytest`, all other arguments are passed directly to `pytest`. + +#### Unit tests + +Unit tests do not require a Databricks account. -#### All tests ```bash -poetry run pytest tests +poetry run python -m pytest tests/unit ``` - #### Only a specific test file ```bash -poetry run pytest tests/tests.py +poetry run python -m pytest tests/unit/tests.py ``` #### Only a specific method ```bash -poetry run pytest tests/tests.py::ClientTestSuite::test_closing_connection_closes_commands +poetry run python -m pytest tests/unit/tests.py::ClientTestSuite::test_closing_connection_closes_commands +``` + +#### e2e Tests + +End-to-end tests require a Databricks account. Before you can run them, you must set connection details for a Databricks SQL endpoint in your environment: + +```bash +export host="" +export http_path="" +export access_token="" ``` +There are several e2e test suites available: +- `PySQLCoreTestSuite` +- `PySQLLargeQueriesSuite` +- `PySQLRetryTestSuite.HTTP503Suite` **[not documented]** +- `PySQLRetryTestSuite.HTTP429Suite` **[not documented]** +- `PySQLUnityCatalogTestSuite` **[not documented]** + +To execute the core test suite: + +```bash +poetry run python -m pytest tests/e2e/driver_tests.py::PySQLCoreTestSuite +``` + +The suites marked `[not documented]` require additional configuration which will be documented at a later time. ### Code formatting This project uses [Black](https://pypi.org/project/black/). ``` -poetry run black src +poetry run python3 -m black src --check ``` + +Remove the `--check` flag to write reformatted files to disk. + +To simplify reviews you can format your changes in a separate commit. ## Pull Request Process 1. Update the [CHANGELOG.md](README.md) or similar documentation with details of changes you wish to make, if applicable. diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/e2e/common/__init__.py b/tests/e2e/common/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/e2e/common/core_tests.py b/tests/e2e/common/core_tests.py new file mode 100644 index 000000000..cd325e8d0 --- /dev/null +++ b/tests/e2e/common/core_tests.py @@ -0,0 +1,131 @@ +import decimal +import datetime +from collections import namedtuple + +TypeFailure = namedtuple( + "TypeFailure", "query,columnType,resultType,resultValue," + "actualValue,actualType,description,conf") +ResultFailure = namedtuple( + "ResultFailure", "query,columnType,resultType,resultValue," + "actualValue,actualType,description,conf") +ExecFailure = namedtuple( + "ExecFailure", "query,columnType,resultType,resultValue," + "actualValue,actualType,description,conf,error") + + +class SmokeTestMixin: + def test_smoke_test(self): + with self.cursor() as cursor: + cursor.execute("select 0") + rows = cursor.fetchall() + self.assertEqual(len(rows), 1) + self.assertEqual(rows[0][0], 0) + + +class CoreTestMixin: + """ + This mixin expects to be mixed with a CursorTest-like class with the following extra attributes: + validate_row_value_type: bool + validate_result: bool + """ + + # A list of (subquery, column_type, python_type, expected_result) + # To be executed as "SELECT {} FROM RANGE(...)" and "SELECT {}" + range_queries = [ + ("TRUE", 'boolean', bool, True), + ("cast(1 AS TINYINT)", 'byte', int, 1), + ("cast(1000 AS SMALLINT)", 'short', int, 1000), + ("cast(100000 AS INTEGER)", 'integer', int, 100000), + ("cast(10000000000000 AS BIGINT)", 'long', int, 10000000000000), + ("cast(100.001 AS DECIMAL(6, 3))", 'decimal', decimal.Decimal, 100.001), + ("date '2020-02-20'", 'date', datetime.date, datetime.date(2020, 2, 20)), + ("unhex('f000')", 'binary', bytes, b'\xf0\x00'), # pyodbc internal mismatch + ("'foo'", 'string', str, 'foo'), + # SPARK-32130: 6.x: "4 weeks 2 days" vs 7.x: "30 days" + # ("interval 30 days", str, str, "interval 4 weeks 2 days"), + # ("interval 3 days", str, str, "interval 3 days"), + ("CAST(NULL AS DOUBLE)", 'double', type(None), None), + ] + + # Full queries, only the first column of the first row is checked + queries = [("NULL UNION (SELECT 1) order by 1", 'integer', type(None), None)] + + def run_tests_on_queries(self, default_conf): + failures = [] + for (query, columnType, rowValueType, answer) in self.range_queries: + with self.cursor(default_conf) as cursor: + failures.extend( + self.run_query(cursor, query, columnType, rowValueType, answer, default_conf)) + failures.extend( + self.run_range_query(cursor, query, columnType, rowValueType, answer, + default_conf)) + + for (query, columnType, rowValueType, answer) in self.queries: + with self.cursor(default_conf) as cursor: + failures.extend( + self.run_query(cursor, query, columnType, rowValueType, answer, default_conf)) + + if failures: + self.fail("Failed testing result set with Arrow. " + "Failed queries: {}".format("\n\n".join([str(f) for f in failures]))) + + def run_query(self, cursor, query, columnType, rowValueType, answer, conf): + full_query = "SELECT {}".format(query) + expected_column_types = self.expected_column_types(columnType) + try: + cursor.execute(full_query) + (result, ) = cursor.fetchone() + if not all(cursor.description[0][1] == type for type in expected_column_types): + return [ + TypeFailure(full_query, expected_column_types, rowValueType, answer, result, + type(result), cursor.description, conf) + ] + if self.validate_row_value_type and type(result) is not rowValueType: + return [ + TypeFailure(full_query, expected_column_types, rowValueType, answer, result, + type(result), cursor.description, conf) + ] + if self.validate_result and str(answer) != str(result): + return [ + ResultFailure(full_query, query, expected_column_types, rowValueType, answer, + result, type(result), cursor.description, conf) + ] + return [] + except Exception as e: + return [ + ExecFailure(full_query, columnType, rowValueType, None, None, None, + cursor.description, conf, e) + ] + + def run_range_query(self, cursor, query, columnType, rowValueType, expected, conf): + full_query = "SELECT {}, id FROM RANGE({})".format(query, 5000) + expected_column_types = self.expected_column_types(columnType) + try: + cursor.execute(full_query) + while True: + rows = cursor.fetchmany(1000) + if len(rows) <= 0: + break + for index, (result, id) in enumerate(rows): + if not all(cursor.description[0][1] == type for type in expected_column_types): + return [ + TypeFailure(full_query, expected_column_types, rowValueType, expected, + result, type(result), cursor.description, conf) + ] + if self.validate_row_value_type and type(result) \ + is not rowValueType: + return [ + TypeFailure(full_query, expected_column_types, rowValueType, expected, + result, type(result), cursor.description, conf) + ] + if self.validate_result and str(expected) != str(result): + return [ + ResultFailure(full_query, expected_column_types, rowValueType, expected, + result, type(result), cursor.description, conf) + ] + return [] + except Exception as e: + return [ + ExecFailure(full_query, columnType, rowValueType, None, None, None, + cursor.description, conf, e) + ] diff --git a/tests/e2e/common/decimal_tests.py b/tests/e2e/common/decimal_tests.py new file mode 100644 index 000000000..8051d2a18 --- /dev/null +++ b/tests/e2e/common/decimal_tests.py @@ -0,0 +1,48 @@ +from decimal import Decimal + +import pyarrow + + +class DecimalTestsMixin: + decimal_and_expected_results = [ + ("100.001 AS DECIMAL(6, 3)", Decimal("100.001"), pyarrow.decimal128(6, 3)), + ("1000000.0000 AS DECIMAL(11, 4)", Decimal("1000000.0000"), pyarrow.decimal128(11, 4)), + ("-10.2343 AS DECIMAL(10, 6)", Decimal("-10.234300"), pyarrow.decimal128(10, 6)), + # TODO(SC-90767): Re-enable this test after we have a way of passing `ansi_mode` = False + #("-13872347.2343 AS DECIMAL(10, 10)", None, pyarrow.decimal128(10, 10)), + ("NULL AS DECIMAL(1, 1)", None, pyarrow.decimal128(1, 1)), + ("1 AS DECIMAL(1, 0)", Decimal("1"), pyarrow.decimal128(1, 0)), + ("0.00000 AS DECIMAL(5, 3)", Decimal("0.000"), pyarrow.decimal128(5, 3)), + ("1e-3 AS DECIMAL(38, 3)", Decimal("0.001"), pyarrow.decimal128(38, 3)), + ] + + multi_decimals_and_expected_results = [ + (["1 AS DECIMAL(6, 3)", "100.001 AS DECIMAL(6, 3)", "NULL AS DECIMAL(6, 3)"], + [Decimal("1.00"), Decimal("100.001"), None], pyarrow.decimal128(6, 3)), + (["1 AS DECIMAL(6, 3)", "2 AS DECIMAL(5, 2)"], [Decimal('1.000'), + Decimal('2.000')], pyarrow.decimal128(6, + 3)), + ] + + def test_decimals(self): + with self.cursor({}) as cursor: + for (decimal, expected_value, expected_type) in self.decimal_and_expected_results: + query = "SELECT CAST ({})".format(decimal) + with self.subTest(query=query): + cursor.execute(query) + table = cursor.fetchmany_arrow(1) + self.assertEqual(table.field(0).type, expected_type) + self.assertEqual(table.to_pydict().popitem()[1][0], expected_value) + + def test_multi_decimals(self): + with self.cursor({}) as cursor: + for (decimals, expected_values, + expected_type) in self.multi_decimals_and_expected_results: + union_str = " UNION ".join(["(SELECT CAST ({}))".format(dec) for dec in decimals]) + query = "SELECT * FROM ({}) ORDER BY 1 NULLS LAST".format(union_str) + + with self.subTest(query=query): + cursor.execute(query) + table = cursor.fetchall_arrow() + self.assertEqual(table.field(0).type, expected_type) + self.assertEqual(table.to_pydict().popitem()[1], expected_values) diff --git a/tests/e2e/common/large_queries_mixin.py b/tests/e2e/common/large_queries_mixin.py new file mode 100644 index 000000000..d59e0a9fe --- /dev/null +++ b/tests/e2e/common/large_queries_mixin.py @@ -0,0 +1,100 @@ +import logging +import math +import time + +log = logging.getLogger(__name__) + + +class LargeQueriesMixin: + """ + This mixin expects to be mixed with a CursorTest-like class + """ + + def fetch_rows(self, cursor, row_count, fetchmany_size): + """ + A generator for rows. Fetches until the end or up to 5 minutes. + """ + # TODO: Remove fetchmany_size when we have fixed the performance issues with fetchone + # in the Python client + max_fetch_time = 5 * 60 # Fetch for at most 5 minutes + + rows = self.get_some_rows(cursor, fetchmany_size) + start_time = time.time() + n = 0 + while rows: + for row in rows: + n += 1 + yield row + if time.time() - start_time >= max_fetch_time: + log.warning("Fetching rows timed out") + break + rows = self.get_some_rows(cursor, fetchmany_size) + if not rows: + # Read all the rows, row_count should match + self.assertEqual(n, row_count) + + num_fetches = max(math.ceil(n / 10000), 1) + latency_ms = int((time.time() - start_time) * 1000 / num_fetches), 1 + print('Fetched {} rows with an avg latency of {} per fetch, '.format(n, latency_ms) + + 'assuming 10K fetch size.') + + def test_query_with_large_wide_result_set(self): + resultSize = 300 * 1000 * 1000 # 300 MB + width = 8192 # B + rows = resultSize // width + cols = width // 36 + + # Set the fetchmany_size to get 10MB of data a go + fetchmany_size = 10 * 1024 * 1024 // width + # This is used by PyHive tests to determine the buffer size + self.arraysize = 1000 + with self.cursor() as cursor: + uuids = ", ".join(["uuid() uuid{}".format(i) for i in range(cols)]) + cursor.execute("SELECT id, {uuids} FROM RANGE({rows})".format(uuids=uuids, rows=rows)) + for row_id, row in enumerate(self.fetch_rows(cursor, rows, fetchmany_size)): + self.assertEqual(row[0], row_id) # Verify no rows are dropped in the middle. + self.assertEqual(len(row[1]), 36) + + def test_query_with_large_narrow_result_set(self): + resultSize = 300 * 1000 * 1000 # 300 MB + width = 8 # sizeof(long) + rows = resultSize / width + + # Set the fetchmany_size to get 10MB of data a go + fetchmany_size = 10 * 1024 * 1024 // width + # This is used by PyHive tests to determine the buffer size + self.arraysize = 10000000 + with self.cursor() as cursor: + cursor.execute("SELECT * FROM RANGE({rows})".format(rows=rows)) + for row_id, row in enumerate(self.fetch_rows(cursor, rows, fetchmany_size)): + self.assertEqual(row[0], row_id) + + def test_long_running_query(self): + """ Incrementally increase query size until it takes at least 5 minutes, + and asserts that the query completes successfully. + """ + minutes = 60 + min_duration = 5 * minutes + + duration = -1 + scale0 = 10000 + scale_factor = 1 + with self.cursor() as cursor: + while duration < min_duration: + self.assertLess(scale_factor, 512, msg="Detected infinite loop") + start = time.time() + + cursor.execute("""SELECT count(*) + FROM RANGE({scale}) x + JOIN RANGE({scale0}) y + ON from_unixtime(x.id * y.id, "yyyy-MM-dd") LIKE "%not%a%date%" + """.format(scale=scale_factor * scale0, scale0=scale0)) + + n, = cursor.fetchone() + self.assertEqual(n, 0) + + duration = time.time() - start + current_fraction = duration / min_duration + print('Took {} s with scale factor={}'.format(duration, scale_factor)) + # Extrapolate linearly to reach 5 min and add 50% padding to push over the limit + scale_factor = math.ceil(1.5 * scale_factor / current_fraction) diff --git a/tests/e2e/common/predicates.py b/tests/e2e/common/predicates.py new file mode 100644 index 000000000..3450087f5 --- /dev/null +++ b/tests/e2e/common/predicates.py @@ -0,0 +1,101 @@ +import functools +from packaging.version import parse as parse_version +import unittest + +MAJOR_DBR_V_KEY = "major_dbr_version" +MINOR_DBR_V_KEY = "minor_dbr_version" +ENDPOINT_TEST_KEY = "is_endpoint_test" + + +def pysql_supports_arrow(): + """Import databricks.sql and test whether Cursor has fetchall_arrow.""" + from databricks.sql.client import Cursor + return hasattr(Cursor, 'fetchall_arrow') + + +def pysql_has_version(compare, version): + """Import databricks.sql, and return compare_module_version(...). + + Expected use: + from common.predicates import pysql_has_version + from databricks import sql as pysql + ... + @unittest.skipIf(pysql_has_version('<', '2')) + def test_some_pyhive_v1_stuff(): + ... + """ + from databricks import sql + return compare_module_version(sql, compare, version) + + +def is_endpoint_test(cli_args=None): + + # Currently only supporting tests against DBSQL Endpoints + # So we don't read `is_endpoint_test` from the CLI args + return True + + +def compare_dbr_versions(cli_args, compare, major_version, minor_version): + if MAJOR_DBR_V_KEY in cli_args and MINOR_DBR_V_KEY in cli_args: + if cli_args[MINOR_DBR_V_KEY] == "x": + actual_minor_v = float('inf') + else: + actual_minor_v = int(cli_args[MINOR_DBR_V_KEY]) + dbr_version = (int(cli_args[MAJOR_DBR_V_KEY]), actual_minor_v) + req_version = (major_version, minor_version) + return compare_versions(compare, dbr_version, req_version) + + if not is_endpoint_test(): + raise ValueError( + "DBR version not provided for non-endpoint test. Please pass the {} and {} params". + format(MAJOR_DBR_V_KEY, MINOR_DBR_V_KEY)) + + +def is_thrift_v5_plus(cli_args): + return compare_dbr_versions(cli_args, ">=", 10, 2) or is_endpoint_test(cli_args) + + +_compare_fns = { + '<': '__lt__', + '<=': '__le__', + '>': '__gt__', + '>=': '__ge__', + '==': '__eq__', + '!=': '__ne__', +} + + +def compare_versions(compare, v1_tuple, v2_tuple): + compare_fn_name = _compare_fns.get(compare) + assert compare_fn_name, 'Received invalid compare string: ' + compare + return getattr(v1_tuple, compare_fn_name)(v2_tuple) + + +def compare_module_version(module, compare, version): + """Compare `module`'s version as specified, returning True/False. + + @unittest.skipIf(compare_module_version(sql, '<', '2')) + def test_some_pyhive_v1_stuff(): + ... + + `module`: the module whose version will be compared + `compare`: one of '<', '<=', '>', '>=', '==', '!=' + `version`: a version string, of the form 'x[.y[.z]] + + Asserts module and compare to be truthy, and casts version to string. + + NOTE: This comparison leverages packaging.version.parse, and compares _release_ versions, + thus ignoring pre/post release tags (eg -rc1, -dev, etc). + """ + assert module, 'Received invalid module: ' + module + assert getattr(module, '__version__'), 'Received module with no version: ' + module + + def validate_version(version): + v = parse_version(str(version)) + # assert that we get a PEP-440 Version back -- LegacyVersion doesn't have major/minor. + assert hasattr(v, 'major'), 'Module has incompatible "Legacy" version: ' + version + return (v.major, v.minor, v.micro) + + mod_version = validate_version(module.__version__) + req_version = validate_version(version) + return compare_versions(compare, mod_version, req_version) diff --git a/tests/e2e/common/retry_test_mixins.py b/tests/e2e/common/retry_test_mixins.py new file mode 100644 index 000000000..a088ba1e3 --- /dev/null +++ b/tests/e2e/common/retry_test_mixins.py @@ -0,0 +1,38 @@ +class Client429ResponseMixin: + def test_client_should_retry_automatically_when_getting_429(self): + with self.cursor() as cursor: + for _ in range(10): + cursor.execute("SELECT 1") + rows = cursor.fetchall() + self.assertEqual(len(rows), 1) + self.assertEqual(rows[0][0], 1) + + def test_client_should_not_retry_429_if_RateLimitRetry_is_0(self): + with self.assertRaises(self.error_type) as cm: + with self.cursor(self.conf_to_disable_rate_limit_retries) as cursor: + for _ in range(10): + cursor.execute("SELECT 1") + rows = cursor.fetchall() + self.assertEqual(len(rows), 1) + self.assertEqual(rows[0][0], 1) + expected = "Maximum rate of 1 requests per SECOND has been exceeded. " \ + "Please reduce the rate of requests and try again after 1 seconds." + exception_str = str(cm.exception) + + # FIXME (Ali Smesseim, 7-Jul-2020): ODBC driver does not always return the + # X-Thriftserver-Error-Message as-is. Re-enable once Simba resolves this flakiness. + # Simba support ticket: https://magnitudesoftware.force.com/support/5001S000018RlaD + # self.assertIn(expected, exception_str) + + +class Client503ResponseMixin: + def test_wait_cluster_startup(self): + with self.cursor() as cursor: + cursor.execute("SELECT 1") + cursor.fetchall() + + def _test_retry_disabled_with_message(self, error_msg_substring, exception_type): + with self.assertRaises(exception_type) as cm: + with self.connection(self.conf_to_disable_temporarily_unavailable_retries): + pass + self.assertIn(error_msg_substring, str(cm.exception)) diff --git a/tests/e2e/common/timestamp_tests.py b/tests/e2e/common/timestamp_tests.py new file mode 100644 index 000000000..38b14e9e8 --- /dev/null +++ b/tests/e2e/common/timestamp_tests.py @@ -0,0 +1,74 @@ +import datetime + +from .predicates import compare_dbr_versions, is_thrift_v5_plus, pysql_has_version + + +class TimestampTestsMixin: + timestamp_and_expected_results = [ + ('2021-09-30 11:27:35.123+04:00', datetime.datetime(2021, 9, 30, 7, 27, 35, 123000)), + ('2021-09-30 11:27:35+04:00', datetime.datetime(2021, 9, 30, 7, 27, 35)), + ('2021-09-30 11:27:35.123', datetime.datetime(2021, 9, 30, 11, 27, 35, 123000)), + ('2021-09-30 11:27:35', datetime.datetime(2021, 9, 30, 11, 27, 35)), + ('2021-09-30 11:27', datetime.datetime(2021, 9, 30, 11, 27)), + ('2021-09-30 11', datetime.datetime(2021, 9, 30, 11)), + ('2021-09-30', datetime.datetime(2021, 9, 30)), + ('2021-09', datetime.datetime(2021, 9, 1)), + ('2021', datetime.datetime(2021, 1, 1)), + ('9999-12-31T15:59:59', datetime.datetime(9999, 12, 31, 15, 59, 59)), + ('9999-99-31T15:59:59', None), + ] + + date_and_expected_results = [ + ('2021-09-30', datetime.date(2021, 9, 30)), + ('2021-09', datetime.date(2021, 9, 1)), + ('2021', datetime.date(2021, 1, 1)), + ('9999-12-31', datetime.date(9999, 12, 31)), + ('9999-99-31', None), + ] + + def should_add_timezone(self): + return pysql_has_version(">=", 2) and is_thrift_v5_plus(self.arguments) + + def maybe_add_timezone_to_timestamp(self, ts): + """If we're using DBR >= 10.2, then we expect back aware timestamps, so add timezone to `ts` + Otherwise we have naive timestamps, so no change is needed + """ + if ts and self.should_add_timezone(): + return ts.replace(tzinfo=datetime.timezone.utc) + else: + return ts + + def assertTimestampsEqual(self, result, expected): + self.assertEqual(result, self.maybe_add_timezone_to_timestamp(expected)) + + def multi_query(self, n_rows=10): + row_sql = "SELECT " + ", ".join( + ["TIMESTAMP('{}')".format(ts) for (ts, _) in self.timestamp_and_expected_results]) + query = " UNION ALL ".join([row_sql for _ in range(n_rows)]) + expected_matrix = [[dt for (_, dt) in self.timestamp_and_expected_results] + for _ in range(n_rows)] + return query, expected_matrix + + def test_timestamps(self): + with self.cursor({"session_configuration": {"ansi_mode": False}}) as cursor: + for (timestamp, expected) in self.timestamp_and_expected_results: + cursor.execute("SELECT TIMESTAMP('{timestamp}')".format(timestamp=timestamp)) + result = cursor.fetchone()[0] + self.assertTimestampsEqual(result, expected) + + def test_multi_timestamps(self): + with self.cursor({"session_configuration": {"ansi_mode": False}}) as cursor: + query, expected = self.multi_query() + cursor.execute(query) + result = cursor.fetchall() + # We list-ify the rows because PyHive will return a tuple for a row + self.assertEqual([list(r) for r in result], + [[self.maybe_add_timezone_to_timestamp(ts) for ts in r] + for r in expected]) + + def test_dates(self): + with self.cursor({"session_configuration": {"ansi_mode": False}}) as cursor: + for (date, expected) in self.date_and_expected_results: + cursor.execute("SELECT DATE('{date}')".format(date=date)) + result = cursor.fetchone()[0] + self.assertEqual(result, expected) diff --git a/tests/e2e/driver_tests.py b/tests/e2e/driver_tests.py new file mode 100644 index 000000000..358f0b263 --- /dev/null +++ b/tests/e2e/driver_tests.py @@ -0,0 +1,589 @@ +from contextlib import contextmanager +from collections import OrderedDict +import datetime +import io +import logging +import os +import sys +import threading +import time +from unittest import loader, skipIf, skipUnless, TestCase +from uuid import uuid4 + +import numpy as np +import pyarrow +import pytz +import thrift + +import databricks.sql as sql +from databricks.sql import STRING, BINARY, NUMBER, DATETIME, DATE, DatabaseError, Error, OperationalError +from tests.e2e.common.predicates import pysql_has_version, pysql_supports_arrow, compare_dbr_versions, is_thrift_v5_plus +from tests.e2e.common.core_tests import CoreTestMixin, SmokeTestMixin +from tests.e2e.common.large_queries_mixin import LargeQueriesMixin +from tests.e2e.common.timestamp_tests import TimestampTestsMixin +from tests.e2e.common.decimal_tests import DecimalTestsMixin +from tests.e2e.common.retry_test_mixins import Client429ResponseMixin, Client503ResponseMixin + +log = logging.getLogger(__name__) + +# manually decorate DecimalTestsMixin to need arrow support +for name in loader.getTestCaseNames(DecimalTestsMixin, 'test_'): + fn = getattr(DecimalTestsMixin, name) + decorated = skipUnless(pysql_supports_arrow(), 'Decimal tests need arrow support')(fn) + setattr(DecimalTestsMixin, name, decorated) + +get_args_from_env = True + + +class PySQLTestCase(TestCase): + error_type = Error + conf_to_disable_rate_limit_retries = {"_retry_stop_after_attempts_count": 1} + conf_to_disable_temporarily_unavailable_retries = {"_retry_stop_after_attempts_count": 1} + + def __init__(self, method_name): + super().__init__(method_name) + # If running in local mode, just use environment variables for params. + self.arguments = os.environ if get_args_from_env else {} + self.arraysize = 1000 + + def connection_params(self, arguments): + params = { + "server_hostname": arguments["host"], + "http_path": arguments["http_path"], + **self.auth_params(arguments) + } + + return params + + def auth_params(self, arguments): + return { + "_username": arguments.get("rest_username"), + "_password": arguments.get("rest_password"), + "access_token": arguments.get("access_token") + } + + @contextmanager + def connection(self, extra_params=()): + connection_params = dict(self.connection_params(self.arguments), **dict(extra_params)) + + log.info("Connecting with args: {}".format(connection_params)) + conn = sql.connect(**connection_params) + + try: + yield conn + finally: + conn.close() + + @contextmanager + def cursor(self, extra_params=()): + with self.connection(extra_params) as conn: + cursor = conn.cursor(arraysize=self.arraysize) + try: + yield cursor + finally: + cursor.close() + + def assertEqualRowValues(self, actual, expected): + self.assertEqual(len(actual) if actual else 0, len(expected) if expected else 0) + for act, exp in zip(actual, expected): + self.assertSequenceEqual(act, exp) + + +class PySQLLargeQueriesSuite(PySQLTestCase, LargeQueriesMixin): + def get_some_rows(self, cursor, fetchmany_size): + row = cursor.fetchone() + if row: + return [row] + else: + return None + + +# Exclude Retry tests because they require specific setups, and LargeQueries too slow for core +# tests +class PySQLCoreTestSuite(SmokeTestMixin, CoreTestMixin, DecimalTestsMixin, TimestampTestsMixin, + PySQLTestCase): + validate_row_value_type = True + validate_result = True + + # An output column in description evaluates to equal to multiple types + # - type code returned by the client as string. + # - also potentially a PEP-249 object like NUMBER, DATETIME etc. + def expected_column_types(self, type_): + type_mappings = { + 'boolean': ['boolean', NUMBER], + 'byte': ['tinyint', NUMBER], + 'short': ['smallint', NUMBER], + 'integer': ['int', NUMBER], + 'long': ['bigint', NUMBER], + 'decimal': ['decimal', NUMBER], + 'timestamp': ['timestamp', DATETIME], + 'date': ['date', DATE], + 'binary': ['binary', BINARY], + 'string': ['string', STRING], + 'array': ['array'], + 'struct': ['struct'], + 'map': ['map'], + 'double': ['double', NUMBER], + 'null': ['null'] + } + return type_mappings[type_] + + def test_queries(self): + if not self._should_have_native_complex_types(): + array_type = str + array_val = "[1,2,3]" + struct_type = str + struct_val = "{\"a\":1,\"b\":2}" + map_type = str + map_val = "{1:2,3:4}" + else: + array_type = np.ndarray + array_val = np.array([1, 2, 3]) + struct_type = dict + struct_val = {"a": 1, "b": 2} + map_type = list + map_val = [(1, 2), (3, 4)] + + null_type = "null" if float(sql.__version__[0:2]) < 2.0 else "string" + self.range_queries = CoreTestMixin.range_queries + [ + ("NULL", null_type, type(None), None), + ("array(1, 2, 3)", 'array', array_type, array_val), + ("struct(1 as a, 2 as b)", 'struct', struct_type, struct_val), + ("map(1, 2, 3, 4)", 'map', map_type, map_val), + ] + + self.run_tests_on_queries({}) + + @skipIf(pysql_has_version('<', '2'), 'requires pysql v2') + def test_incorrect_query_throws_exception(self): + with self.cursor({}) as cursor: + # Syntax errors should contain the invalid SQL + with self.assertRaises(DatabaseError) as cm: + cursor.execute("^ FOO BAR") + self.assertIn("FOO BAR", str(cm.exception)) + + # Database error should contain the missing database + with self.assertRaises(DatabaseError) as cm: + cursor.execute("USE foo234823498ydfsiusdhf") + self.assertIn("foo234823498ydfsiusdhf", str(cm.exception)) + + # SQL with Extraneous input should send back the extraneous input + with self.assertRaises(DatabaseError) as cm: + cursor.execute("CREATE TABLE IF NOT EXISTS TABLE table_234234234") + self.assertIn("table_234234234", str(cm.exception)) + + def test_create_table_will_return_empty_result_set(self): + with self.cursor({}) as cursor: + table_name = 'table_{uuid}'.format(uuid=str(uuid4()).replace('-', '_')) + try: + cursor.execute( + "CREATE TABLE IF NOT EXISTS {} AS (SELECT 1 AS col_1, '2' AS col_2)".format( + table_name)) + self.assertEqual(cursor.fetchall(), []) + finally: + cursor.execute("DROP TABLE IF EXISTS {}".format(table_name)) + + def test_get_tables(self): + with self.cursor({}) as cursor: + table_name = 'table_{uuid}'.format(uuid=str(uuid4()).replace('-', '_')) + table_names = [table_name + '_1', table_name + '_2'] + + try: + for table in table_names: + cursor.execute( + "CREATE TABLE IF NOT EXISTS {} AS (SELECT 1 AS col_1, '2' AS col_2)".format( + table)) + cursor.tables(schema_name="defa%") + tables = cursor.fetchall() + tables_desc = cursor.description + + for table in table_names: + # Test only schema name and table name. + # From other columns, what is supported depends on DBR version. + self.assertIn(['default', table], [list(table[1:3]) for table in tables]) + self.assertEqual( + tables_desc, + [('TABLE_CAT', 'string', None, None, None, None, None), + ('TABLE_SCHEM', 'string', None, None, None, None, None), + ('TABLE_NAME', 'string', None, None, None, None, None), + ('TABLE_TYPE', 'string', None, None, None, None, None), + ('REMARKS', 'string', None, None, None, None, None), + ('TYPE_CAT', 'string', None, None, None, None, None), + ('TYPE_SCHEM', 'string', None, None, None, None, None), + ('TYPE_NAME', 'string', None, None, None, None, None), + ('SELF_REFERENCING_COL_NAME', 'string', None, None, None, None, None), + ('REF_GENERATION', 'string', None, None, None, None, None)]) + finally: + for table in table_names: + cursor.execute('DROP TABLE IF EXISTS {}'.format(table)) + + def test_get_columns(self): + with self.cursor({}) as cursor: + table_name = 'table_{uuid}'.format(uuid=str(uuid4()).replace('-', '_')) + table_names = [table_name + '_1', table_name + '_2'] + + try: + for table in table_names: + cursor.execute("CREATE TABLE IF NOT EXISTS {} AS (SELECT " + "1 AS col_1, " + "'2' AS col_2, " + "named_struct('name', 'alice', 'age', 28) as col_3, " + "map('items', 45, 'cost', 228) as col_4, " + "array('item1', 'item2', 'item3') as col_5)".format(table)) + + cursor.columns(schema_name="defa%", table_name=table_name + '%') + cols = cursor.fetchall() + cols_desc = cursor.description + + # Catalogue name not consistent across DBR versions, so we skip that + cleaned_response = [list(col[1:6]) for col in cols] + # We also replace ` as DBR changes how it represents struct names + for col in cleaned_response: + col[4] = col[4].replace("`", "") + + self.assertEqual(cleaned_response, [ + ['default', table_name + '_1', 'col_1', 4, 'INT'], + ['default', table_name + '_1', 'col_2', 12, 'STRING'], + ['default', table_name + '_1', 'col_3', 2002, 'STRUCT'], + ['default', table_name + '_1', 'col_4', 2000, 'MAP'], + ['default', table_name + '_1', 'col_5', 2003, 'ARRAY'], + ['default', table_name + '_2', 'col_1', 4, 'INT'], + ['default', table_name + '_2', 'col_2', 12, 'STRING'], + ['default', table_name + '_2', 'col_3', 2002, 'STRUCT'], + ['default', table_name + '_2', 'col_4', 2000, 'MAP'], + [ + 'default', + table_name + '_2', + 'col_5', + 2003, + 'ARRAY', + ] + ]) + + self.assertEqual(cols_desc, + [('TABLE_CAT', 'string', None, None, None, None, None), + ('TABLE_SCHEM', 'string', None, None, None, None, None), + ('TABLE_NAME', 'string', None, None, None, None, None), + ('COLUMN_NAME', 'string', None, None, None, None, None), + ('DATA_TYPE', 'int', None, None, None, None, None), + ('TYPE_NAME', 'string', None, None, None, None, None), + ('COLUMN_SIZE', 'int', None, None, None, None, None), + ('BUFFER_LENGTH', 'tinyint', None, None, None, None, None), + ('DECIMAL_DIGITS', 'int', None, None, None, None, None), + ('NUM_PREC_RADIX', 'int', None, None, None, None, None), + ('NULLABLE', 'int', None, None, None, None, None), + ('REMARKS', 'string', None, None, None, None, None), + ('COLUMN_DEF', 'string', None, None, None, None, None), + ('SQL_DATA_TYPE', 'int', None, None, None, None, None), + ('SQL_DATETIME_SUB', 'int', None, None, None, None, None), + ('CHAR_OCTET_LENGTH', 'int', None, None, None, None, None), + ('ORDINAL_POSITION', 'int', None, None, None, None, None), + ('IS_NULLABLE', 'string', None, None, None, None, None), + ('SCOPE_CATALOG', 'string', None, None, None, None, None), + ('SCOPE_SCHEMA', 'string', None, None, None, None, None), + ('SCOPE_TABLE', 'string', None, None, None, None, None), + ('SOURCE_DATA_TYPE', 'smallint', None, None, None, None, None), + ('IS_AUTO_INCREMENT', 'string', None, None, None, None, None)]) + finally: + for table in table_names: + cursor.execute('DROP TABLE IF EXISTS {}'.format(table)) + + def test_get_schemas(self): + with self.cursor({}) as cursor: + database_name = 'db_{uuid}'.format(uuid=str(uuid4()).replace('-', '_')) + try: + cursor.execute('CREATE DATABASE IF NOT EXISTS {}'.format(database_name)) + cursor.schemas() + schemas = cursor.fetchall() + schemas_desc = cursor.description + # Catalogue name not consistent across DBR versions, so we skip that + self.assertIn(database_name, [schema[0] for schema in schemas]) + self.assertEqual(schemas_desc, + [('TABLE_SCHEM', 'string', None, None, None, None, None), + ('TABLE_CATALOG', 'string', None, None, None, None, None)]) + finally: + cursor.execute('DROP DATABASE IF EXISTS {}'.format(database_name)) + + def test_get_catalogs(self): + with self.cursor({}) as cursor: + cursor.catalogs() + cursor.fetchall() + catalogs_desc = cursor.description + self.assertEqual(catalogs_desc, [('TABLE_CAT', 'string', None, None, None, None, None)]) + + @skipUnless(pysql_supports_arrow(), 'arrow test need arrow support') + def test_get_arrow(self): + # These tests are quite light weight as the arrow fetch methods are used internally + # by everything else + with self.cursor({}) as cursor: + cursor.execute("SELECT * FROM range(10)") + table_1 = cursor.fetchmany_arrow(1).to_pydict() + self.assertEqual(table_1, OrderedDict([("id", [0])])) + + table_2 = cursor.fetchall_arrow().to_pydict() + self.assertEqual(table_2, OrderedDict([("id", [1, 2, 3, 4, 5, 6, 7, 8, 9])])) + + def test_unicode(self): + unicode_str = "数据砖" + with self.cursor({}) as cursor: + cursor.execute("SELECT '{}'".format(unicode_str)) + results = cursor.fetchall() + self.assertTrue(len(results) == 1 and len(results[0]) == 1) + self.assertEqual(results[0][0], unicode_str) + + def test_cancel_during_execute(self): + with self.cursor({}) as cursor: + + def execute_really_long_query(): + cursor.execute("SELECT SUM(A.id - B.id) " + + "FROM range(1000000000) A CROSS JOIN range(100000000) B " + + "GROUP BY (A.id - B.id)") + + exec_thread = threading.Thread(target=execute_really_long_query) + + exec_thread.start() + # Make sure the query has started before cancelling + time.sleep(15) + cursor.cancel() + exec_thread.join(5) + self.assertFalse(exec_thread.is_alive()) + + # Fetching results should throw an exception + with self.assertRaises((Error, thrift.Thrift.TException)): + cursor.fetchall() + with self.assertRaises((Error, thrift.Thrift.TException)): + cursor.fetchone() + with self.assertRaises((Error, thrift.Thrift.TException)): + cursor.fetchmany(10) + + # We should be able to execute a new command on the cursor + cursor.execute("SELECT * FROM range(3)") + self.assertEqual(len(cursor.fetchall()), 3) + + @skipIf(pysql_has_version('<', '2'), 'requires pysql v2') + def test_can_execute_command_after_failure(self): + with self.cursor({}) as cursor: + with self.assertRaises(DatabaseError): + cursor.execute("this is a sytnax error") + + cursor.execute("SELECT 1;") + + res = cursor.fetchall() + self.assertEqualRowValues(res, [[1]]) + + @skipIf(pysql_has_version('<', '2'), 'requires pysql v2') + def test_can_execute_command_after_success(self): + with self.cursor({}) as cursor: + cursor.execute("SELECT 1;") + cursor.execute("SELECT 2;") + + res = cursor.fetchall() + self.assertEqualRowValues(res, [[2]]) + + def generate_multi_row_query(self): + query = "SELECT * FROM range(3);" + return query + + @skipIf(pysql_has_version('<', '2'), 'requires pysql v2') + def test_fetchone(self): + with self.cursor({}) as cursor: + query = self.generate_multi_row_query() + cursor.execute(query) + + self.assertSequenceEqual(cursor.fetchone(), [0]) + self.assertSequenceEqual(cursor.fetchone(), [1]) + self.assertSequenceEqual(cursor.fetchone(), [2]) + + self.assertEqual(cursor.fetchone(), None) + + @skipIf(pysql_has_version('<', '2'), 'requires pysql v2') + def test_fetchall(self): + with self.cursor({}) as cursor: + query = self.generate_multi_row_query() + cursor.execute(query) + + self.assertEqualRowValues(cursor.fetchall(), [[0], [1], [2]]) + + self.assertEqual(cursor.fetchone(), None) + + @skipIf(pysql_has_version('<', '2'), 'requires pysql v2') + def test_fetchmany_when_stride_fits(self): + with self.cursor({}) as cursor: + query = "SELECT * FROM range(4)" + cursor.execute(query) + + self.assertEqualRowValues(cursor.fetchmany(2), [[0], [1]]) + self.assertEqualRowValues(cursor.fetchmany(2), [[2], [3]]) + + @skipIf(pysql_has_version('<', '2'), 'requires pysql v2') + def test_fetchmany_in_excess(self): + with self.cursor({}) as cursor: + query = "SELECT * FROM range(4)" + cursor.execute(query) + + self.assertEqualRowValues(cursor.fetchmany(3), [[0], [1], [2]]) + self.assertEqualRowValues(cursor.fetchmany(3), [[3]]) + + @skipIf(pysql_has_version('<', '2'), 'requires pysql v2') + def test_iterator_api(self): + with self.cursor({}) as cursor: + query = "SELECT * FROM range(4)" + cursor.execute(query) + + expected_results = [[0], [1], [2], [3]] + for (i, row) in enumerate(cursor): + self.assertSequenceEqual(row, expected_results[i]) + + def test_temp_view_fetch(self): + with self.cursor({}) as cursor: + query = "create temporary view f as select * from range(10)" + cursor.execute(query) + # TODO assert on a result + # once what is being returned has stabilised + + @skipIf(pysql_has_version('<', '2'), 'requires pysql v2') + def test_socket_timeout(self): + # We we expect to see a BlockingIO error when the socket is opened + # in non-blocking mode, since no poll is done before the read + with self.assertRaises(OperationalError) as cm: + with self.cursor({"_socket_timeout": 0}): + pass + + self.assertIsInstance(cm.exception.args[1], io.BlockingIOError) + + def test_ssp_passthrough(self): + for enable_ansi in (True, False): + with self.cursor({"session_configuration": {"ansi_mode": enable_ansi}}) as cursor: + cursor.execute("SET ansi_mode") + self.assertEqual(list(cursor.fetchone()), ["ansi_mode", str(enable_ansi)]) + + @skipUnless(pysql_supports_arrow(), 'arrow test needs arrow support') + def test_timestamps_arrow(self): + with self.cursor({"session_configuration": {"ansi_mode": False}}) as cursor: + for (timestamp, expected) in self.timestamp_and_expected_results: + cursor.execute("SELECT TIMESTAMP('{timestamp}')".format(timestamp=timestamp)) + arrow_table = cursor.fetchmany_arrow(1) + if self.should_add_timezone(): + ts_type = pyarrow.timestamp("us", tz="Etc/UTC") + else: + ts_type = pyarrow.timestamp("us") + self.assertEqual(arrow_table.field(0).type, ts_type) + result_value = arrow_table.column(0).combine_chunks()[0].value + # To work consistently across different local timezones, we specify the timezone + # of the expected result to + # be UTC (what it should be by default on the server) + aware_timestamp = expected and expected.replace(tzinfo=datetime.timezone.utc) + self.assertEqual(result_value, aware_timestamp and + aware_timestamp.timestamp() * 1000000, + "timestamp {} did not match {}".format(timestamp, expected)) + + @skipUnless(pysql_supports_arrow(), 'arrow test needs arrow support') + def test_multi_timestamps_arrow(self): + with self.cursor({"session_configuration": {"ansi_mode": False}}) as cursor: + query, expected = self.multi_query() + expected = [[self.maybe_add_timezone_to_timestamp(ts) for ts in row] + for row in expected] + cursor.execute(query) + table = cursor.fetchall_arrow() + # Transpose columnar result to list of rows + list_of_cols = [c.to_pylist() for c in table] + result = [[col[row_index] for col in list_of_cols] + for row_index in range(table.num_rows)] + self.assertEqual(result, expected) + + @skipUnless(pysql_supports_arrow(), 'arrow test needs arrow support') + def test_timezone_with_timestamp(self): + if self.should_add_timezone(): + with self.cursor() as cursor: + cursor.execute("SET TIME ZONE 'Europe/Amsterdam'") + cursor.execute("select CAST('2022-03-02 12:54:56' as TIMESTAMP)") + amsterdam = pytz.timezone("Europe/Amsterdam") + expected = amsterdam.localize(datetime.datetime(2022, 3, 2, 12, 54, 56)) + result = cursor.fetchone()[0] + self.assertEqual(result, expected) + + cursor.execute("select CAST('2022-03-02 12:54:56' as TIMESTAMP)") + arrow_result_table = cursor.fetchmany_arrow(1) + arrow_result_value = arrow_result_table.column(0).combine_chunks()[0].value + ts_type = pyarrow.timestamp("us", tz="Europe/Amsterdam") + + self.assertEqual(arrow_result_table.field(0).type, ts_type) + self.assertEqual(arrow_result_value, expected.timestamp() * 1000000) + + def _should_have_native_complex_types(self): + return pysql_has_version(">=", 2) and is_thrift_v5_plus(self.arguments) + + @skipUnless(pysql_supports_arrow(), 'arrow test needs arrow support') + def test_arrays_are_not_returned_as_strings_arrow(self): + if self._should_have_native_complex_types(): + with self.cursor() as cursor: + cursor.execute("SELECT array(1,2,3,4)") + arrow_df = cursor.fetchall_arrow() + + list_type = arrow_df.field(0).type + self.assertTrue(pyarrow.types.is_list(list_type)) + self.assertTrue(pyarrow.types.is_integer(list_type.value_type)) + + @skipUnless(pysql_supports_arrow(), 'arrow test needs arrow support') + def test_structs_are_not_returned_as_strings_arrow(self): + if self._should_have_native_complex_types(): + with self.cursor() as cursor: + cursor.execute("SELECT named_struct('foo', 42, 'bar', 'baz')") + arrow_df = cursor.fetchall_arrow() + + struct_type = arrow_df.field(0).type + self.assertTrue(pyarrow.types.is_struct(struct_type)) + + @skipUnless(pysql_supports_arrow(), 'arrow test needs arrow support') + def test_decimal_not_returned_as_strings_arrow(self): + if self._should_have_native_complex_types(): + with self.cursor() as cursor: + cursor.execute("SELECT 5E3BD") + arrow_df = cursor.fetchall_arrow() + + decimal_type = arrow_df.field(0).type + self.assertTrue(pyarrow.types.is_decimal(decimal_type)) + + +# use a RetrySuite to encapsulate these tests which we'll typically want to run together; however keep +# the 429/503 subsuites separate since they execute under different circumstances. +class PySQLRetryTestSuite: + class HTTP429Suite(Client429ResponseMixin, PySQLTestCase): + pass # Mixin covers all + + class HTTP503Suite(Client503ResponseMixin, PySQLTestCase): + # 503Response suite gets custom error here vs PyODBC + def test_retry_disabled(self): + self._test_retry_disabled_with_message("TEMPORARILY_UNAVAILABLE", OperationalError) + + +class PySQLUnityCatalogTestSuite(PySQLTestCase): + """Simple namespace tests that should be run against a unity-catalog-enabled cluster""" + + @skipIf(pysql_has_version('<', '2'), 'requires pysql v2') + def test_initial_namespace(self): + table_name = 'table_{uuid}'.format(uuid=str(uuid4()).replace('-', '_')) + with self.cursor() as cursor: + cursor.execute("USE CATALOG {}".format(self.arguments["catA"])) + cursor.execute("CREATE TABLE table_{} (col1 int)".format(table_name)) + with self.connection({ + "catalog": self.arguments["catA"], + "schema": table_name + }) as connection: + cursor = connection.cursor() + cursor.execute("select current_catalog()") + self.assertEqual(cursor.fetchone()[0], self.arguments["catA"]) + cursor.execute("select current_database()") + self.assertEqual(cursor.fetchone()[0], table_name) + + +def main(cli_args): + global get_args_from_env + get_args_from_env = True + print(f"Running tests with version: {sql.__version__}") + logging.getLogger("databricks.sql").setLevel(logging.INFO) + unittest.main(module=__file__, argv=sys.argv[0:1] + cli_args) + + +if __name__ == "__main__": + main(sys.argv[1:]) \ No newline at end of file diff --git a/tests/test_arrow_queue.py b/tests/unit/test_arrow_queue.py similarity index 100% rename from tests/test_arrow_queue.py rename to tests/unit/test_arrow_queue.py diff --git a/tests/test_fetches.py b/tests/unit/test_fetches.py similarity index 100% rename from tests/test_fetches.py rename to tests/unit/test_fetches.py diff --git a/tests/test_fetches_bench.py b/tests/unit/test_fetches_bench.py similarity index 100% rename from tests/test_fetches_bench.py rename to tests/unit/test_fetches_bench.py diff --git a/tests/test_thrift_backend.py b/tests/unit/test_thrift_backend.py similarity index 100% rename from tests/test_thrift_backend.py rename to tests/unit/test_thrift_backend.py diff --git a/tests/tests.py b/tests/unit/tests.py similarity index 100% rename from tests/tests.py rename to tests/unit/tests.py From 1f30744a9befd2219eb7c73150c3aca63ed8e1cc Mon Sep 17 00:00:00 2001 From: Jesse Date: Mon, 1 Aug 2022 10:29:06 -0500 Subject: [PATCH 23/57] Indicate that Python 3.10 is not supported (#27) Advise developers to use Python 3.7, 3.8, or 3.9 until #26 is fixed. Signed-off-by: Moe Derakhshani --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 45251507f..d849a1544 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ You are welcome to file an issue here for general use cases. You can also contac ## Requirements -Python 3.7 or above is required. +A development machine running Python >=3.7, <3.10. ## Documentation From e851ea419bd7ac55056d9701089c8659b8365056 Mon Sep 17 00:00:00 2001 From: Jesse Date: Mon, 1 Aug 2022 11:14:39 -0500 Subject: [PATCH 24/57] Add Developer Certificate of Origin requirement (#13) Includes a GitHub Action which checks for a valid sign-off on every proposed commit Signed-off-by: Moe Derakhshani --- .github/.github/pull_request_template.md | 23 +++++++ .github/workflows/dco-check.yml | 25 ++++++++ CONTRIBUTING.md | 80 ++++++++++++++++++++---- 3 files changed, 117 insertions(+), 11 deletions(-) create mode 100644 .github/.github/pull_request_template.md create mode 100644 .github/workflows/dco-check.yml diff --git a/.github/.github/pull_request_template.md b/.github/.github/pull_request_template.md new file mode 100644 index 000000000..8ce224e83 --- /dev/null +++ b/.github/.github/pull_request_template.md @@ -0,0 +1,23 @@ + + + +## What type of PR is this? + + +- [ ] Refactor +- [ ] Feature +- [ ] Bug Fix +- [ ] Other + +## Description + +## How is this tested? + +- [ ] Unit tests +- [ ] E2E Tests +- [ ] Manually +- [ ] N/A + + + +## Related Tickets & Documents diff --git a/.github/workflows/dco-check.yml b/.github/workflows/dco-check.yml new file mode 100644 index 000000000..5a5c60607 --- /dev/null +++ b/.github/workflows/dco-check.yml @@ -0,0 +1,25 @@ +name: DCO Check + +on: [pull_request] + +jobs: + check: + runs-on: ubuntu-latest + steps: + - name: Check for DCO + id: dco-check + uses: tisonkun/actions-dco@v1.1 + - name: Comment about DCO status + uses: actions/github-script@v6 + if: ${{ failure() }} + with: + script: | + github.rest.issues.createComment({ + issue_number: context.issue.number, + owner: context.repo.owner, + repo: context.repo.repo, + body: `Thanks for your contribution! To satisfy the DCO policy in our \ + [contributing guide](https://github.com/databricks/databricks-sql-python/blob/main/CONTRIBUTING.md) \ + every commit message must include a sign-off message. One or more of your commits is missing this message. \ + You can reword previous commit messages with an interactive rebase (\`git rebase -i main\`).` + }) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index cfc34a320..386ba5da0 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,6 +1,73 @@ -# Contributing +# Contributing Guide -To contribute to this repository, fork it and send pull requests. +We happily welcome contributions to the `databricks-sql-connector` package. We use [GitHub Issues](https://github.com/databricks/databricks-sql-python/issues) to track community reported issues and [GitHub Pull Requests](https://github.com/databricks/databricks-sql-python/pulls) for accepting changes. + +Contributions are licensed on a license-in/license-out basis. + +## Communication +Before starting work on a major feature, please reach out to us via GitHub, Slack, email, etc. We will make sure no one else is already working on it and ask you to open a GitHub issue. +A "major feature" is defined as any change that is > 100 LOC altered (not including tests), or changes any user-facing behavior. +We will use the GitHub issue to discuss the feature and come to agreement. +This is to prevent your time being wasted, as well as ours. +The GitHub review process for major features is also important so that organizations with commit access can come to agreement on design. +If it is appropriate to write a design document, the document must be hosted either in the GitHub tracking issue, or linked to from the issue and hosted in a world-readable location. +Specifically, if the goal is to add a new extension, please read the extension policy. +Small patches and bug fixes don't need prior communication. + +## Coding Style +We follow [PEP 8](https://www.python.org/dev/peps/pep-0008/) with one exception: lines can be up to 100 characters in length, not 79. + +## Sign your work +The sign-off is a simple line at the end of the explanation for the patch. Your signature certifies that you wrote the patch or otherwise have the right to pass it on as an open-source patch. The rules are pretty simple: if you can certify the below (from developercertificate.org): + +``` +Developer Certificate of Origin +Version 1.1 + +Copyright (C) 2004, 2006 The Linux Foundation and its contributors. +1 Letterman Drive +Suite D4700 +San Francisco, CA, 94129 + +Everyone is permitted to copy and distribute verbatim copies of this +license document, but changing it is not allowed. + + +Developer's Certificate of Origin 1.1 + +By making a contribution to this project, I certify that: + +(a) The contribution was created in whole or in part by me and I + have the right to submit it under the open source license + indicated in the file; or + +(b) The contribution is based upon previous work that, to the best + of my knowledge, is covered under an appropriate open source + license and I have the right under that license to submit that + work with modifications, whether created in whole or in part + by me, under the same open source license (unless I am + permitted to submit under a different license), as indicated + in the file; or + +(c) The contribution was provided directly to me by some other + person who certified (a), (b) or (c) and I have not modified + it. + +(d) I understand and agree that this project and the contribution + are public and that a record of the contribution (including all + personal information I submit with it, including my sign-off) is + maintained indefinitely and may be redistributed consistent with + this project or the open source license(s) involved. +``` + +Then you just add a line to every git commit message: + +``` +Signed-off-by: Joe Smith +Use your real name (sorry, no pseudonyms or anonymous contributions.) +``` + +If you set your `user.name` and `user.email` git configs, you can sign your commit automatically with `git commit -s`. ## Set up your environment @@ -67,12 +134,3 @@ poetry run python3 -m black src --check Remove the `--check` flag to write reformatted files to disk. To simplify reviews you can format your changes in a separate commit. -## Pull Request Process - -1. Update the [CHANGELOG.md](README.md) or similar documentation with details of changes you wish to make, if applicable. -2. Add any appropriate tests. -3. Make your code or other changes. -4. Review guidelines such as - [How to write the perfect pull request][github-perfect-pr], thanks! - -[github-perfect-pr]: https://blog.github.com/2015-01-21-how-to-write-the-perfect-pull-request/ From 63894d77bfc722dc0d6c1715fd2585ad9a242d45 Mon Sep 17 00:00:00 2001 From: Jesse Date: Fri, 5 Aug 2022 16:23:17 -0500 Subject: [PATCH 25/57] Retry attempts that fail due to a connection timeout (#24) * Isolate delay bounding logic * Move error details scope up one-level. * Retry GetOperationStatus if an OSError was raised during execution. Add retry_delay_default to use in this case. * Log when a request is retried due to an OSError. Emit warnings for unexpected OSError codes * Update docstring for make_request * Nit: unit tests show the .warn message is deprecated. DeprecationWarning: The 'warn' function is deprecated, use 'warning' instead Signed-off-by: Jesse Whitehouse --- src/databricks/sql/thrift_backend.py | 78 ++++++++++++++++++++++------ tests/unit/test_thrift_backend.py | 57 ++++++++++++++++++++ 2 files changed, 120 insertions(+), 15 deletions(-) diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index dfadc0716..066396c37 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -1,4 +1,7 @@ from decimal import Decimal + +import errno +import logging import math import time import threading @@ -14,6 +17,9 @@ from databricks.sql.auth.authenticators import CredentialsProvider from databricks.sql.thrift_api.TCLIService import TCLIService, ttypes from databricks.sql import * +from databricks.sql.thrift_api.TCLIService.TCLIService import ( + Client as TCLIServiceClient, +) from databricks.sql.utils import ( ArrowQueue, ExecuteResponse, @@ -38,6 +44,7 @@ "_retry_delay_max": (float, 60, 5, 3600), "_retry_stop_after_attempts_count": (int, 30, 1, 60), "_retry_stop_after_attempts_duration": (float, 900, 1, 86400), + "_retry_delay_default": (float, 5, 1, 60), } @@ -70,6 +77,8 @@ def __init__( # _retry_delay_min (default: 1) # _retry_delay_max (default: 60) # {min,max} pre-retry delay bounds + # _retry_delay_default (default: 5) + # Only used when GetOperationStatus fails due to a TCP/OS Error. # _retry_stop_after_attempts_count (default: 30) # total max attempts during retry sequence # _retry_stop_after_attempts_duration (default: 900) @@ -160,7 +169,7 @@ def _initialize_retry_args(self, kwargs): "retry parameter: {} given_or_default {}".format(key, given_or_default) ) if bound != given_or_default: - logger.warn( + logger.warning( "Override out of policy retry parameter: " + "{} given {}, restricted to {}".format( key, given_or_default, bound @@ -245,7 +254,9 @@ def _handle_request_error(self, error_info, attempt, elapsed): # FUTURE: Consider moving to https://github.com/litl/backoff or # https://github.com/jd/tenacity for retry logic. def make_request(self, method, request): - """Execute given request, attempting retries when receiving HTTP 429/503. + """Execute given request, attempting retries when + 1. Receiving HTTP 429/503 from server + 2. OSError is raised during a GetOperationStatus For delay between attempts, honor the given Retry-After header, but with bounds. Use lower bound of expontial-backoff based on _retry_delay_min, @@ -262,6 +273,13 @@ def make_request(self, method, request): def get_elapsed(): return time.time() - t0 + def bound_retry_delay(attempt, proposed_delay): + """bound delay (seconds) by [min_delay*1.5^(attempt-1), max_delay]""" + delay = int(proposed_delay) + delay = max(delay, self._retry_delay_min * math.pow(1.5, attempt - 1)) + delay = min(delay, self._retry_delay_max) + return delay + def extract_retry_delay(attempt): # encapsulate retry checks, returns None || delay-in-secs # Retry IFF 429/503 code + Retry-After header set @@ -269,10 +287,7 @@ def extract_retry_delay(attempt): retry_after = getattr(self._transport, "headers", {}).get("Retry-After") if http_code in [429, 503] and retry_after: # bound delay (seconds) by [min_delay*1.5^(attempt-1), max_delay] - delay = int(retry_after) - delay = max(delay, self._retry_delay_min * math.pow(1.5, attempt - 1)) - delay = min(delay, self._retry_delay_max) - return delay + return bound_retry_delay(attempt, int(retry_after)) return None def attempt_request(attempt): @@ -281,24 +296,57 @@ def attempt_request(attempt): # - non-None method_return -> success, return and be done # - non-None retry_delay -> sleep delay before retry # - error, error_message always set when available + + error, error_message, retry_delay = None, None, None try: logger.debug("Sending request: {}".format(request)) response = method(request) logger.debug("Received response: {}".format(response)) return response - except Exception as error: + except OSError as err: + error = err + error_message = str(err) + + gos_name = TCLIServiceClient.GetOperationStatus.__name__ + if method.__name__ == gos_name: + retry_delay = bound_retry_delay(attempt, self._retry_delay_default) + + # fmt: off + # The built-in errno package encapsulates OSError codes, which are OS-specific. + # log.info for errors we believe are not unusual or unexpected. log.warn for + # for others like EEXIST, EBADF, ERANGE which are not expected in this context. + # + # I manually tested this retry behaviour using mitmweb and confirmed that + # GetOperationStatus requests are retried when I forced network connection + # interruptions / timeouts / reconnects. See #24 for more info. + # | Debian | Darwin | + info_errs = [ # |--------|--------| + errno.ESHUTDOWN, # | 32 | 32 | + errno.EAFNOSUPPORT, # | 97 | 47 | + errno.ECONNRESET, # | 104 | 54 | + errno.ETIMEDOUT, # | 110 | 60 | + ] + + # fmt: on + log_string = f"{gos_name} failed with code {err.errno} and will attempt to retry" + if err.errno in info_errs: + logger.info(log_string) + else: + logger.warning(log_string) + except Exception as err: + error = err retry_delay = extract_retry_delay(attempt) error_message = ThriftBackend._extract_error_message_from_headers( getattr(self._transport, "headers", {}) ) - return RequestErrorInfo( - error=error, - error_message=error_message, - retry_delay=retry_delay, - http_code=getattr(self._transport, "code", None), - method=method.__name__, - request=request, - ) + return RequestErrorInfo( + error=error, + error_message=error_message, + retry_delay=retry_delay, + http_code=getattr(self._transport, "code", None), + method=method.__name__, + request=request, + ) # The real work: # - for each available attempt: diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index d411df76d..e8c5a727f 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -19,6 +19,7 @@ def retry_policy_factory(): "_retry_delay_max": (float, 60, None, None), "_retry_stop_after_attempts_count": (int, 30, None, None), "_retry_stop_after_attempts_duration": (float, 900, None, None), + "_retry_delay_default": (float, 5, 1, 60) } @@ -968,6 +969,62 @@ def test_handle_execute_response_sets_active_op_handle(self): self.assertEqual(mock_resp.operationHandle, mock_cursor.active_op_handle) + @patch("thrift.transport.THttpClient.THttpClient") + @patch("databricks.sql.thrift_api.TCLIService.TCLIService.Client.GetOperationStatus") + @patch("databricks.sql.thrift_backend._retry_policy", new_callable=retry_policy_factory) + def test_make_request_will_retry_GetOperationStatus( + self, mock_retry_policy, mock_GetOperationStatus, t_transport_class): + + import thrift, errno + from databricks.sql.thrift_api.TCLIService.TCLIService import Client + from databricks.sql.exc import RequestError + from databricks.sql.utils import NoRetryReason + + this_gos_name = "GetOperationStatus" + mock_GetOperationStatus.__name__ = this_gos_name + mock_GetOperationStatus.side_effect = OSError(errno.ETIMEDOUT, "Connection timed out") + + protocol = thrift.protocol.TBinaryProtocol.TBinaryProtocol(t_transport_class) + client = Client(protocol) + + req = ttypes.TGetOperationStatusReq( + operationHandle=self.operation_handle, + getProgressUpdate=False, + ) + + EXPECTED_RETRIES = 2 + + thrift_backend = ThriftBackend( + "foobar", + 443, + "path", [], + _retry_stop_after_attempts_count=EXPECTED_RETRIES, + _retry_delay_default=1) + + + with self.assertRaises(RequestError) as cm: + thrift_backend.make_request(client.GetOperationStatus, req) + + self.assertEqual(NoRetryReason.OUT_OF_ATTEMPTS.value, cm.exception.context["no-retry-reason"]) + self.assertEqual(f'{EXPECTED_RETRIES}/{EXPECTED_RETRIES}', cm.exception.context["attempt"]) + + # Unusual OSError code + mock_GetOperationStatus.side_effect = OSError(errno.EEXIST, "File does not exist") + + with self.assertLogs("databricks.sql.thrift_backend", level=logging.WARNING) as cm: + with self.assertRaises(RequestError): + thrift_backend.make_request(client.GetOperationStatus, req) + + # There should be two warning log messages: one for each retry + self.assertEqual(len(cm.output), EXPECTED_RETRIES) + + # The warnings should be identical + self.assertEqual(cm.output[1], cm.output[0]) + + # The warnings should include this text + self.assertIn(f"{this_gos_name} failed with code {errno.EEXIST} and will attempt to retry", cm.output[0]) + + @patch("thrift.transport.THttpClient.THttpClient") def test_make_request_wont_retry_if_headers_not_present(self, t_transport_class): t_transport_instance = t_transport_class.return_value From 36c6f4d405685e5dc9a728f1ff598791b412398a Mon Sep 17 00:00:00 2001 From: Jesse Date: Fri, 5 Aug 2022 16:45:13 -0500 Subject: [PATCH 26/57] Bump to v2.0.3 (#28) Signed-off-by: Jesse Whitehouse Signed-off-by: Moe Derakhshani --- CHANGELOG.md | 5 ++++- pyproject.toml | 2 +- src/databricks/sql/__init__.py | 2 +- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2730415a0..aeaba1745 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,7 +1,10 @@ # Release History -## 2.x.x (Unreleased) +## 2.0.x (Unreleased) +## 2.0.3 (2022-08-05) + +- Add retry logic for `GetOperationStatus` requests that fail with an `OSError` - Reorganised code to use Poetry for dependency management. ## 2.0.2 (2022-05-04) - Better exception handling in automatic connection close diff --git a/pyproject.toml b/pyproject.toml index 548ea8b2b..8dad75569 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "databricks-sql-connector" -version = "2.0.2" +version = "2.0.3" description = "Databricks SQL Connector for Python" authors = ["Databricks "] license = "Apache-2.0" diff --git a/src/databricks/sql/__init__.py b/src/databricks/sql/__init__.py index c94cbc31e..34656625d 100644 --- a/src/databricks/sql/__init__.py +++ b/src/databricks/sql/__init__.py @@ -28,7 +28,7 @@ def __repr__(self): DATE = DBAPITypeObject("date") ROWID = DBAPITypeObject() -__version__ = "2.0.2" +__version__ = "2.0.3" USER_AGENT_NAME = "PyDatabricksSqlConnector" # These two functions are pyhive legacy From 78015e57b7d31e17ceb6fa9f117fc3a41962c78f Mon Sep 17 00:00:00 2001 From: Jesse Date: Wed, 10 Aug 2022 15:40:18 -0500 Subject: [PATCH 27/57] Bump version to 2.0.4-dev (#29) Signed-off-by: Jesse Whitehouse Signed-off-by: Moe Derakhshani --- pyproject.toml | 2 +- src/databricks/sql/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8dad75569..e178e6d5e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "databricks-sql-connector" -version = "2.0.3" +version = "2.0.4-dev" description = "Databricks SQL Connector for Python" authors = ["Databricks "] license = "Apache-2.0" diff --git a/src/databricks/sql/__init__.py b/src/databricks/sql/__init__.py index 34656625d..190d3d6ba 100644 --- a/src/databricks/sql/__init__.py +++ b/src/databricks/sql/__init__.py @@ -28,7 +28,7 @@ def __repr__(self): DATE = DBAPITypeObject("date") ROWID = DBAPITypeObject() -__version__ = "2.0.3" +__version__ = "2.0.4-dev" USER_AGENT_NAME = "PyDatabricksSqlConnector" # These two functions are pyhive legacy From cdef1a55b4eb199dbae4a2a149bbcae3c17dca5c Mon Sep 17 00:00:00 2001 From: David Black Date: Thu, 18 Aug 2022 05:02:37 +1000 Subject: [PATCH 28/57] [PECO-197] Support Python 3.10 (#31) * Test with multiple python versions. * Update pyarrow to version 9.0.0 to address issue in relation to python 3.10 & a specific version of numpy being pulled in by pyarrow. Closes #26 Signed-off-by: David Black --- .github/workflows/code-quality-checks.yml | 23 +++-- README.md | 2 +- poetry.lock | 114 ++++++++++++++++------ pyproject.toml | 4 +- 4 files changed, 103 insertions(+), 40 deletions(-) diff --git a/.github/workflows/code-quality-checks.yml b/.github/workflows/code-quality-checks.yml index a6f44144c..70e9de70b 100644 --- a/.github/workflows/code-quality-checks.yml +++ b/.github/workflows/code-quality-checks.yml @@ -3,17 +3,20 @@ on: [push] jobs: run-unit-tests: runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.7, 3.8, 3.9, "3.10"] steps: #---------------------------------------------- # check-out repo and set-up python #---------------------------------------------- - name: Check out repository uses: actions/checkout@v2 - - name: Set up python + - name: Set up python ${{ matrix.python-version }} id: setup-python uses: actions/setup-python@v2 with: - python-version: 3.7 + python-version: ${{ matrix.python-version }} #---------------------------------------------- # ----- install & configure poetry ----- #---------------------------------------------- @@ -51,17 +54,20 @@ jobs: run: poetry run python -m pytest tests/unit check-linting: runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.7, 3.8, 3.9, "3.10"] steps: #---------------------------------------------- # check-out repo and set-up python #---------------------------------------------- - name: Check out repository uses: actions/checkout@v2 - - name: Set up python + - name: Set up python ${{ matrix.python-version }} id: setup-python uses: actions/setup-python@v2 with: - python-version: 3.7 + python-version: ${{ matrix.python-version }} #---------------------------------------------- # ----- install & configure poetry ----- #---------------------------------------------- @@ -100,17 +106,20 @@ jobs: check-types: runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.7, 3.8, 3.9, "3.10"] steps: #---------------------------------------------- # check-out repo and set-up python #---------------------------------------------- - name: Check out repository uses: actions/checkout@v2 - - name: Set up python + - name: Set up python ${{ matrix.python-version }} id: setup-python uses: actions/setup-python@v2 with: - python-version: 3.7 + python-version: ${{ matrix.python-version }} #---------------------------------------------- # ----- install & configure poetry ----- #---------------------------------------------- @@ -145,4 +154,4 @@ jobs: # black the code #---------------------------------------------- - name: Mypy - run: poetry run mypy src \ No newline at end of file + run: poetry run mypy src diff --git a/README.md b/README.md index d849a1544..45251507f 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ You are welcome to file an issue here for general use cases. You can also contac ## Requirements -A development machine running Python >=3.7, <3.10. +Python 3.7 or above is required. ## Documentation diff --git a/poetry.lock b/poetry.lock index 231d71a73..49c722166 100644 --- a/poetry.lock +++ b/poetry.lock @@ -8,17 +8,17 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" [[package]] name = "attrs" -version = "21.4.0" +version = "22.1.0" description = "Classes Without Boilerplate" category = "dev" optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +python-versions = ">=3.5" [package.extras] -dev = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface", "furo", "sphinx", "sphinx-notfound-page", "pre-commit", "cloudpickle"] +dev = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "mypy (>=0.900,!=0.940)", "pytest-mypy-plugins", "zope.interface", "furo", "sphinx", "sphinx-notfound-page", "pre-commit", "cloudpickle"] docs = ["furo", "sphinx", "zope.interface", "sphinx-notfound-page"] -tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface", "cloudpickle"] -tests_no_zope = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "cloudpickle"] +tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "mypy (>=0.900,!=0.940)", "pytest-mypy-plugins", "zope.interface", "cloudpickle"] +tests_no_zope = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "mypy (>=0.900,!=0.940)", "pytest-mypy-plugins", "cloudpickle"] [[package]] name = "black" @@ -53,7 +53,7 @@ python-versions = ">=3.6" [[package]] name = "charset-normalizer" -version = "2.1.0" +version = "2.1.1" description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." category = "main" optional = false @@ -144,19 +144,11 @@ python-versions = "*" [[package]] name = "numpy" -version = "1.21.6" -description = "NumPy is the fundamental package for array computing with Python." -category = "main" -optional = false -python-versions = ">=3.7,<3.11" - -[[package]] -name = "numpy" -version = "1.23.1" +version = "1.21.1" description = "NumPy is the fundamental package for array computing with Python." category = "main" optional = false -python-versions = ">=3.8" +python-versions = ">=3.7" [[package]] name = "oauthlib" @@ -248,11 +240,11 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" [[package]] name = "pyarrow" -version = "5.0.0" +version = "9.0.0" description = "Python library for Apache Arrow" category = "main" optional = false -python-versions = ">=3.6" +python-versions = ">=3.7" [package.dependencies] numpy = ">=1.16.6" @@ -303,7 +295,7 @@ six = ">=1.5" [[package]] name = "pytz" -version = "2022.1" +version = "2022.2.1" description = "World timezone definitions, modern and historical" category = "main" optional = false @@ -377,7 +369,7 @@ python-versions = ">=3.7" [[package]] name = "urllib3" -version = "1.26.10" +version = "1.26.12" description = "HTTP library with thread-safe connection pooling, file post, and more." category = "main" optional = false @@ -385,7 +377,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, [package.extras] brotli = ["brotlicffi (>=0.8.0)", "brotli (>=1.0.9)", "brotlipy (>=0.6.0)"] -secure = ["pyOpenSSL (>=0.14)", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "certifi", "ipaddress"] +secure = ["pyOpenSSL (>=0.14)", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "certifi", "urllib3-secure-extra", "ipaddress"] socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] [[package]] @@ -403,19 +395,22 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest- [metadata] lock-version = "1.1" python-versions = "^3.7.1" -content-hash = "56ae2d2b29d59b5eb54bf00fd1259d2d62e61892f237cae399ad608c627ab228" +content-hash = "8ac5d8721267bc81ca501a41697d1b3883f334ac69cb09b297bbe33215c4e204" [metadata.files] atomicwrites = [] -attrs = [] +attrs = [ + {file = "attrs-22.1.0-py2.py3-none-any.whl", hash = "sha256:86efa402f67bf2df34f51a335487cf46b1ec130d02b8d39fd248abfd30da551c"}, + {file = "attrs-22.1.0.tar.gz", hash = "sha256:29adc2665447e5191d0e7c568fde78b21f9672d344281d0c6e1ab085429b22b6"}, +] black = [] certifi = [ {file = "certifi-2022.6.15-py3-none-any.whl", hash = "sha256:fe86415d55e84719d75f8b69414f6438ac3547d2078ab91b67e779ef69378412"}, {file = "certifi-2022.6.15.tar.gz", hash = "sha256:84c85a9078b11105f04f3036a9482ae10e4621616db313fe045dd24743a0820d"}, ] charset-normalizer = [ - {file = "charset-normalizer-2.1.0.tar.gz", hash = "sha256:575e708016ff3a5e3681541cb9d79312c416835686d054a23accb873b254f413"}, - {file = "charset_normalizer-2.1.0-py3-none-any.whl", hash = "sha256:5189b6f22b01957427f35b6a08d9a0bc45b46d3788ef5a92e978433c7a35f8a5"}, + {file = "charset-normalizer-2.1.1.tar.gz", hash = "sha256:5a3d016c7c547f69d6f81fb0db9449ce888b418b5b9952cc5e6e66843e9dd845"}, + {file = "charset_normalizer-2.1.1-py3-none-any.whl", hash = "sha256:83e9a75d1911279afd89352c68b45348559d1fc0506b054b346651b5e7fee29f"}, ] click = [] colorama = [] @@ -451,7 +446,36 @@ mypy = [ {file = "mypy-0.950.tar.gz", hash = "sha256:1b333cfbca1762ff15808a0ef4f71b5d3eed8528b23ea1c3fb50543c867d68de"}, ] mypy-extensions = [] -numpy = [] +numpy = [ + {file = "numpy-1.21.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:38e8648f9449a549a7dfe8d8755a5979b45b3538520d1e735637ef28e8c2dc50"}, + {file = "numpy-1.21.1-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:fd7d7409fa643a91d0a05c7554dd68aa9c9bb16e186f6ccfe40d6e003156e33a"}, + {file = "numpy-1.21.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:a75b4498b1e93d8b700282dc8e655b8bd559c0904b3910b144646dbbbc03e062"}, + {file = "numpy-1.21.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1412aa0aec3e00bc23fbb8664d76552b4efde98fb71f60737c83efbac24112f1"}, + {file = "numpy-1.21.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e46ceaff65609b5399163de5893d8f2a82d3c77d5e56d976c8b5fb01faa6b671"}, + {file = "numpy-1.21.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:c6a2324085dd52f96498419ba95b5777e40b6bcbc20088fddb9e8cbb58885e8e"}, + {file = "numpy-1.21.1-cp37-cp37m-win32.whl", hash = "sha256:73101b2a1fef16602696d133db402a7e7586654682244344b8329cdcbbb82172"}, + {file = "numpy-1.21.1-cp37-cp37m-win_amd64.whl", hash = "sha256:7a708a79c9a9d26904d1cca8d383bf869edf6f8e7650d85dbc77b041e8c5a0f8"}, + {file = "numpy-1.21.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:95b995d0c413f5d0428b3f880e8fe1660ff9396dcd1f9eedbc311f37b5652e16"}, + {file = "numpy-1.21.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:635e6bd31c9fb3d475c8f44a089569070d10a9ef18ed13738b03049280281267"}, + {file = "numpy-1.21.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4a3d5fb89bfe21be2ef47c0614b9c9c707b7362386c9a3ff1feae63e0267ccb6"}, + {file = "numpy-1.21.1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:8a326af80e86d0e9ce92bcc1e65c8ff88297de4fa14ee936cb2293d414c9ec63"}, + {file = "numpy-1.21.1-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:791492091744b0fe390a6ce85cc1bf5149968ac7d5f0477288f78c89b385d9af"}, + {file = "numpy-1.21.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0318c465786c1f63ac05d7c4dbcecd4d2d7e13f0959b01b534ea1e92202235c5"}, + {file = "numpy-1.21.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:9a513bd9c1551894ee3d31369f9b07460ef223694098cf27d399513415855b68"}, + {file = "numpy-1.21.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:91c6f5fc58df1e0a3cc0c3a717bb3308ff850abdaa6d2d802573ee2b11f674a8"}, + {file = "numpy-1.21.1-cp38-cp38-win32.whl", hash = "sha256:978010b68e17150db8765355d1ccdd450f9fc916824e8c4e35ee620590e234cd"}, + {file = "numpy-1.21.1-cp38-cp38-win_amd64.whl", hash = "sha256:9749a40a5b22333467f02fe11edc98f022133ee1bfa8ab99bda5e5437b831214"}, + {file = "numpy-1.21.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:d7a4aeac3b94af92a9373d6e77b37691b86411f9745190d2c351f410ab3a791f"}, + {file = "numpy-1.21.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:d9e7912a56108aba9b31df688a4c4f5cb0d9d3787386b87d504762b6754fbb1b"}, + {file = "numpy-1.21.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:25b40b98ebdd272bc3020935427a4530b7d60dfbe1ab9381a39147834e985eac"}, + {file = "numpy-1.21.1-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:8a92c5aea763d14ba9d6475803fc7904bda7decc2a0a68153f587ad82941fec1"}, + {file = "numpy-1.21.1-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:05a0f648eb28bae4bcb204e6fd14603de2908de982e761a2fc78efe0f19e96e1"}, + {file = "numpy-1.21.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f01f28075a92eede918b965e86e8f0ba7b7797a95aa8d35e1cc8821f5fc3ad6a"}, + {file = "numpy-1.21.1-cp39-cp39-win32.whl", hash = "sha256:88c0b89ad1cc24a5efbb99ff9ab5db0f9a86e9cc50240177a571fbe9c2860ac2"}, + {file = "numpy-1.21.1-cp39-cp39-win_amd64.whl", hash = "sha256:01721eefe70544d548425a07c80be8377096a54118070b8a62476866d5208e33"}, + {file = "numpy-1.21.1-pp37-pypy37_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:2d4d1de6e6fb3d28781c73fbde702ac97f03d79e4ffd6598b880b2d95d62ead4"}, + {file = "numpy-1.21.1.zip", hash = "sha256:dff4af63638afcc57a3dfb9e4b26d434a7a602d225b42d746ea7fe2edf1342fd"}, +] oauthlib = [] packaging = [] pandas = [] @@ -459,14 +483,44 @@ pathspec = [] platformdirs = [] pluggy = [] py = [] -pyarrow = [] +pyarrow = [ + {file = "pyarrow-9.0.0-cp310-cp310-macosx_10_13_universal2.whl", hash = "sha256:767cafb14278165ad539a2918c14c1b73cf20689747c21375c38e3fe62884902"}, + {file = "pyarrow-9.0.0-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:0238998dc692efcb4e41ae74738d7c1234723271ccf520bd8312dca07d49ef8d"}, + {file = "pyarrow-9.0.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:55328348b9139c2b47450d512d716c2248fd58e2f04e2fc23a65e18726666d42"}, + {file = "pyarrow-9.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fc856628acd8d281652c15b6268ec7f27ebcb015abbe99d9baad17f02adc51f1"}, + {file = "pyarrow-9.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:29eb3e086e2b26202f3a4678316b93cfb15d0e2ba20f3ec12db8fd9cc07cde63"}, + {file = "pyarrow-9.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2e753f8fcf07d8e3a0efa0c8bd51fef5c90281ffd4c5637c08ce42cd0ac297de"}, + {file = "pyarrow-9.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:3eef8a981f45d89de403e81fb83b8119c20824caddf1404274e41a5d66c73806"}, + {file = "pyarrow-9.0.0-cp37-cp37m-macosx_10_13_x86_64.whl", hash = "sha256:7fa56cbd415cef912677270b8e41baad70cde04c6d8a8336eeb2aba85aa93706"}, + {file = "pyarrow-9.0.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:f8c46bde1030d704e2796182286d1c56846552c50a39ad5bf5a20c0d8159fc35"}, + {file = "pyarrow-9.0.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8ad430cee28ebc4d6661fc7315747c7a18ae2a74e67498dcb039e1c762a2fb67"}, + {file = "pyarrow-9.0.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:81a60bb291a964f63b2717fb1b28f6615ffab7e8585322bfb8a6738e6b321282"}, + {file = "pyarrow-9.0.0-cp37-cp37m-win_amd64.whl", hash = "sha256:9cef618159567d5f62040f2b79b1c7b38e3885f4ffad0ec97cd2d86f88b67cef"}, + {file = "pyarrow-9.0.0-cp38-cp38-macosx_10_13_x86_64.whl", hash = "sha256:5526a3bfb404ff6d31d62ea582cf2466c7378a474a99ee04d1a9b05de5264541"}, + {file = "pyarrow-9.0.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:da3e0f319509a5881867effd7024099fb06950a0768dad0d6873668bb88cfaba"}, + {file = "pyarrow-9.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:2c715eca2092273dcccf6f08437371e04d112f9354245ba2fbe6c801879450b7"}, + {file = "pyarrow-9.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f11a645a41ee531c3a5edda45dea07c42267f52571f818d388971d33fc7e2d4a"}, + {file = "pyarrow-9.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a5b390bdcfb8c5b900ef543f911cdfec63e88524fafbcc15f83767202a4a2491"}, + {file = "pyarrow-9.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:d9eb04db626fa24fdfb83c00f76679ca0d98728cdbaa0481b6402bf793a290c0"}, + {file = "pyarrow-9.0.0-cp39-cp39-macosx_10_13_universal2.whl", hash = "sha256:4eebdab05afa23d5d5274b24c1cbeb1ba017d67c280f7d39fd8a8f18cbad2ec9"}, + {file = "pyarrow-9.0.0-cp39-cp39-macosx_10_13_x86_64.whl", hash = "sha256:02b820ecd1da02012092c180447de449fc688d0c3f9ff8526ca301cdd60dacd0"}, + {file = "pyarrow-9.0.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:92f3977e901db1ef5cba30d6cc1d7942b8d94b910c60f89013e8f7bb86a86eef"}, + {file = "pyarrow-9.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f241bd488c2705df930eedfe304ada71191dcf67d6b98ceda0cc934fd2a8388e"}, + {file = "pyarrow-9.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1c5a073a930c632058461547e0bc572da1e724b17b6b9eb31a97da13f50cb6e0"}, + {file = "pyarrow-9.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f59bcd5217a3ae1e17870792f82b2ff92df9f3862996e2c78e156c13e56ff62e"}, + {file = "pyarrow-9.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:fe2ce795fa1d95e4e940fe5661c3c58aee7181c730f65ac5dd8794a77228de59"}, + {file = "pyarrow-9.0.0.tar.gz", hash = "sha256:7fb02bebc13ab55573d1ae9bb5002a6d20ba767bf8569b52fce5301d42495ab7"}, +] pyparsing = [] pytest = [] python-dateutil = [ {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"}, {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"}, ] -pytz = [] +pytz = [ + {file = "pytz-2022.2.1-py2.py3-none-any.whl", hash = "sha256:220f481bdafa09c3955dfbdddb7b57780e9a94f5127e35456a48589b9e0c0197"}, + {file = "pytz-2022.2.1.tar.gz", hash = "sha256:cea221417204f2d1a2aa03ddae3e867921971d0d76f14d87abb4414415bbdcf5"}, +] requests = [ {file = "requests-2.28.1-py3-none-any.whl", hash = "sha256:8fefa2a1a1365bf5520aac41836fbee479da67864514bdb821f31ce07ce65349"}, {file = "requests-2.28.1.tar.gz", hash = "sha256:7c5599b102feddaa661c826c56ab4fee28bfd17f5abca1ebbe3e7f19d7c97983"}, @@ -477,7 +531,7 @@ tomli = [] typed-ast = [] typing-extensions = [] urllib3 = [ - {file = "urllib3-1.26.10-py2.py3-none-any.whl", hash = "sha256:8298d6d56d39be0e3bc13c1c97d133f9b45d797169a0e11cdd0e0489d786f7ec"}, - {file = "urllib3-1.26.10.tar.gz", hash = "sha256:879ba4d1e89654d9769ce13121e0f94310ea32e8d2f8cf587b77c08bbcdb30d6"}, + {file = "urllib3-1.26.12-py2.py3-none-any.whl", hash = "sha256:b930dd878d5a8afb066a637fbb35144fe7901e3b209d1cd4f524bd0e9deee997"}, + {file = "urllib3-1.26.12.tar.gz", hash = "sha256:3fa96cf423e6987997fc326ae8df396db2a8b7c667747d47ddd8ecba91f4a74e"}, ] zipp = [] diff --git a/pyproject.toml b/pyproject.toml index e178e6d5e..cc3084d91 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,8 +11,8 @@ include = ["CHANGELOG.md"] [tool.poetry.dependencies] python = "^3.7.1" thrift = "^0.13.0" -pyarrow = "^5.0.0" pandas = "^1.3.0" +pyarrow = "^9.0.0" requests=">2.18.1" oauthlib=">=3.1.0" @@ -30,4 +30,4 @@ ignore_missing_imports = "true" exclude = ['ttypes\.py$', 'TCLIService\.py$'] [tool.black] -exclude = '/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|\.svn|_build|buck-out|build|dist|thrift_api)/' \ No newline at end of file +exclude = '/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|\.svn|_build|buck-out|build|dist|thrift_api)/' From 39efd0a8a5ee0db92b5e2dce92af1f5dadb28841 Mon Sep 17 00:00:00 2001 From: Jesse Date: Wed, 17 Aug 2022 14:14:29 -0500 Subject: [PATCH 29/57] Update changelog and bump to v2.0.4 (#34) * Update changelog and bump to v2.0.4 * Specifically thank @dbaxa for this change. Signed-off-by: Jesse Whitehouse Signed-off-by: Moe Derakhshani --- CHANGELOG.md | 7 +++++++ pyproject.toml | 2 +- src/databricks/sql/__init__.py | 2 +- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index aeaba1745..953103a2c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,13 @@ ## 2.0.x (Unreleased) +## 2.0.4 (2022-08-17) + +- Add support for Python 3.10 +- Add unit test matrix for supported Python versions + +Huge thanks to @dbaxa for contributing this change! + ## 2.0.3 (2022-08-05) - Add retry logic for `GetOperationStatus` requests that fail with an `OSError` diff --git a/pyproject.toml b/pyproject.toml index cc3084d91..da7fe9ddc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "databricks-sql-connector" -version = "2.0.4-dev" +version = "2.0.4" description = "Databricks SQL Connector for Python" authors = ["Databricks "] license = "Apache-2.0" diff --git a/src/databricks/sql/__init__.py b/src/databricks/sql/__init__.py index 190d3d6ba..db6750479 100644 --- a/src/databricks/sql/__init__.py +++ b/src/databricks/sql/__init__.py @@ -28,7 +28,7 @@ def __repr__(self): DATE = DBAPITypeObject("date") ROWID = DBAPITypeObject() -__version__ = "2.0.4-dev" +__version__ = "2.0.4" USER_AGENT_NAME = "PyDatabricksSqlConnector" # These two functions are pyhive legacy From 643425fc5d27f218507d4d13305f011de1e7f129 Mon Sep 17 00:00:00 2001 From: Jesse Date: Fri, 19 Aug 2022 11:17:31 -0500 Subject: [PATCH 30/57] Bump to 2.0.5-dev on main (#35) Signed-off-by: Jesse Whitehouse Signed-off-by: Moe Derakhshani --- pyproject.toml | 2 +- src/databricks/sql/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index da7fe9ddc..fb4fc4eba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "databricks-sql-connector" -version = "2.0.4" +version = "2.0.5-dev" description = "Databricks SQL Connector for Python" authors = ["Databricks "] license = "Apache-2.0" diff --git a/src/databricks/sql/__init__.py b/src/databricks/sql/__init__.py index db6750479..9f0afbc81 100644 --- a/src/databricks/sql/__init__.py +++ b/src/databricks/sql/__init__.py @@ -28,7 +28,7 @@ def __repr__(self): DATE = DBAPITypeObject("date") ROWID = DBAPITypeObject() -__version__ = "2.0.4" +__version__ = "2.0.5-dev" USER_AGENT_NAME = "PyDatabricksSqlConnector" # These two functions are pyhive legacy From 69b7b8c22ecf4a94d97bc68463c392940862ce35 Mon Sep 17 00:00:00 2001 From: Jesse Date: Fri, 19 Aug 2022 11:22:00 -0500 Subject: [PATCH 31/57] On Pypi, display the "Project Links" sidebar. (#36) Signed-off-by: Jesse Whitehouse Signed-off-by: Moe Derakhshani --- pyproject.toml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index fb4fc4eba..e0813715d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,10 @@ pytest = "^7.1.2" mypy = "^0.950" black = "^22.3.0" +[tool.poetry.urls] +"Homepage" = "https://github.com/databricks/databricks-sql-python" +"Bug Tracker" = "https://github.com/databricks/databricks-sql-python/issues" + [build-system] requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" From 23cd570c347cc70d7fb387e21f31451ac97b7340 Mon Sep 17 00:00:00 2001 From: Jesse Date: Tue, 23 Aug 2022 12:29:28 -0500 Subject: [PATCH 32/57] [ES-402013] Close cursors before closing connection (#38) * Add test: cursors are closed when connection closes Signed-off-by: Jesse Whitehouse Signed-off-by: Moe Derakhshani --- src/databricks/sql/client.py | 5 ++--- tests/e2e/driver_tests.py | 26 ++++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index a07cd3c50..ef551daec 100644 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -161,12 +161,11 @@ def close(self) -> None: self._close() def _close(self, close_cursors=True) -> None: - self.thrift_backend.close_session(self._session_handle) - self.open = False - if close_cursors: for cursor in self._cursors: cursor.close() + self.thrift_backend.close_session(self._session_handle) + self.open = False def commit(self): """No-op because Databricks does not support transactions""" diff --git a/tests/e2e/driver_tests.py b/tests/e2e/driver_tests.py index 358f0b263..9e400770d 100644 --- a/tests/e2e/driver_tests.py +++ b/tests/e2e/driver_tests.py @@ -544,6 +544,32 @@ def test_decimal_not_returned_as_strings_arrow(self): decimal_type = arrow_df.field(0).type self.assertTrue(pyarrow.types.is_decimal(decimal_type)) + def test_close_connection_closes_cursors(self): + + from databricks.sql.thrift_api.TCLIService import ttypes + + with self.connection() as conn: + cursor = conn.cursor() + cursor.execute('SELECT id, id `id2`, id `id3` FROM RANGE(1000000) order by RANDOM()') + ars = cursor.active_result_set + + # We must manually run this check because thrift_backend always forces `has_been_closed_server_side` to True + + # Cursor op state should be open before connection is closed + status_request = ttypes.TGetOperationStatusReq(operationHandle=ars.command_id, getProgressUpdate=False) + op_status_at_server = ars.thrift_backend._client.GetOperationStatus(status_request) + assert op_status_at_server.operationState != ttypes.TOperationState.CLOSED_STATE + + conn.close() + + # When connection closes, any cursor operations should no longer exist at the server + with self.assertRaises(thrift.Thrift.TApplicationException) as cm: + op_status_at_server = ars.thrift_backend._client.GetOperationStatus(status_request) + if hasattr(cm, "exception"): + assert "RESOURCE_DOES_NOT_EXIST" in cm.exception.message + + + # use a RetrySuite to encapsulate these tests which we'll typically want to run together; however keep # the 429/503 subsuites separate since they execute under different circumstances. From 353f41305353914be74a36567d14dd4ccbb19403 Mon Sep 17 00:00:00 2001 From: Jesse Date: Tue, 23 Aug 2022 13:01:30 -0500 Subject: [PATCH 33/57] Bump version to 2.0.5 and improve CHANGELOG (#40) Signed-off-by: Jesse Whitehouse Signed-off-by: Moe Derakhshani --- CHANGELOG.md | 5 +++++ pyproject.toml | 2 +- src/databricks/sql/__init__.py | 2 +- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 953103a2c..aeb78efe9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,11 @@ ## 2.0.x (Unreleased) +## 2.0.5 (2022-08-23) + +- Fix: closing a connection now closes any open cursors from that connection at the server +- Other: Add project links to pyproject.toml (helpful for visitors from PyPi) + ## 2.0.4 (2022-08-17) - Add support for Python 3.10 diff --git a/pyproject.toml b/pyproject.toml index e0813715d..5b9d176e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "databricks-sql-connector" -version = "2.0.5-dev" +version = "2.0.5" description = "Databricks SQL Connector for Python" authors = ["Databricks "] license = "Apache-2.0" diff --git a/src/databricks/sql/__init__.py b/src/databricks/sql/__init__.py index 9f0afbc81..e457a4358 100644 --- a/src/databricks/sql/__init__.py +++ b/src/databricks/sql/__init__.py @@ -28,7 +28,7 @@ def __repr__(self): DATE = DBAPITypeObject("date") ROWID = DBAPITypeObject() -__version__ = "2.0.5-dev" +__version__ = "2.0.5" USER_AGENT_NAME = "PyDatabricksSqlConnector" # These two functions are pyhive legacy From 737fd7e13d777dfc6e5e734ce4ccaa46ee44e350 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Wed, 24 Aug 2022 15:16:20 -0700 Subject: [PATCH 34/57] minor fixes Signed-off-by: Moe Derakhshani --- src/databricks/sql/auth/auth.py | 2 +- tests/unit/__init__.py | 0 tests/{ => unit}/test_auth.py | 0 3 files changed, 1 insertion(+), 1 deletion(-) create mode 100644 tests/unit/__init__.py rename tests/{ => unit}/test_auth.py (100%) diff --git a/src/databricks/sql/auth/auth.py b/src/databricks/sql/auth/auth.py index 0c0e5938f..6dad102df 100644 --- a/src/databricks/sql/auth/auth.py +++ b/src/databricks/sql/auth/auth.py @@ -77,7 +77,7 @@ def get_auth_provider(cfg: ClientContext): OAUTH_CLIENT_ID = "databricks-cli" -def get_python_sql_connector_auth_provider(hostname: str, oauth_persistence: OAuthPersistence, **kwargs): +def get_python_sql_connector_auth_provider(hostname: str, oauth_persistence: OAuthPersistence = None, **kwargs): cfg = ClientContext(hostname=hostname, auth_type=kwargs.get("auth_type"), access_token=kwargs.get("access_token"), diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_auth.py b/tests/unit/test_auth.py similarity index 100% rename from tests/test_auth.py rename to tests/unit/test_auth.py From 9e40f1848b0b1a3a0a1befc628f030689ccee026 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Thu, 25 Aug 2022 11:27:04 -0700 Subject: [PATCH 35/57] fixed token refresh and persistent Signed-off-by: Moe Derakhshani --- src/databricks/sql/auth/authenticators.py | 75 +++++++++++++---------- src/databricks/sql/auth/oauth.py | 5 +- tests/unit/test_oauth_persistence.py | 55 +++++++++++++++++ 3 files changed, 103 insertions(+), 32 deletions(-) create mode 100644 tests/unit/test_oauth_persistence.py diff --git a/src/databricks/sql/auth/authenticators.py b/src/databricks/sql/auth/authenticators.py index 09fde1b9e..c27a96ab6 100644 --- a/src/databricks/sql/auth/authenticators.py +++ b/src/databricks/sql/auth/authenticators.py @@ -19,7 +19,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import logging from databricks.sql.auth.oauth import get_tokens, check_and_refresh_access_token import base64 @@ -64,17 +64,20 @@ class DatabricksOAuthProvider(CredentialsProvider): SCOPE_DELIM = ' ' def __init__(self, hostname, oauth_persistence, client_id, scopes): - self._hostname = self._normalize_host_name(hostname=hostname) - self._scopes_as_str = DatabricksOAuthProvider.SCOPE_DELIM.join(scopes) - self._oauth_persistence = oauth_persistence - self._client_id = client_id - self._get_tokens() + try: + self._hostname = self._normalize_host_name(hostname=hostname) + self._scopes_as_str = DatabricksOAuthProvider.SCOPE_DELIM.join(scopes) + self._oauth_persistence = oauth_persistence + self._client_id = client_id + self._access_token = None + self._refresh_token = None + self._initial_get_token() + except Exception as e: + logging.error(f"unexpected error", e, exc_info=True) + raise e def add_headers(self, request_headers): - check_and_refresh_access_token(hostname=self._hostname, - client_id=self._client_id, - access_token=self._access_token, - refresh_token=self._refresh_token) + self._update_token_if_expired() request_headers['Authorization'] = f"Bearer {self._access_token}" @staticmethod @@ -83,12 +86,16 @@ def _normalize_host_name(hostname): maybe_trailing_slash = "/" if not hostname.endswith("/") else "" return f"{maybe_scheme}{hostname}{maybe_trailing_slash}" - def _get_tokens(self): - if self._oauth_persistence: - token = self._oauth_persistence.read() - if token: - self._access_token = token.get_access_token() - self._refresh_token = token.get_refresh_token() + def _initial_get_token(self): + try: + if self._access_token is None or self._refresh_token is None: + if self._oauth_persistence: + token = self._oauth_persistence.read() + if token: + self._access_token = token.get_access_token() + self._refresh_token = token.get_refresh_token() + + if self._access_token and self._refresh_token: self._update_token_if_expired() else: (access_token, refresh_token) = get_tokens(hostname=self._hostname, @@ -97,20 +104,26 @@ def _get_tokens(self): self._access_token = access_token self._refresh_token = refresh_token self._oauth_persistence.persist(OAuthToken(access_token, refresh_token)) + except Exception as e: + logging.error(f"unexpected error in oauth initialization", e, exc_info=True) + raise e def _update_token_if_expired(self): - (fresh_access_token, fresh_refresh_token, is_refreshed) = check_and_refresh_access_token( - hostname=self._hostname, - client_id=self._client_id, - access_token=self._access_token, - refresh_token=self._refresh_token) - - if not is_refreshed: - return - else: - self._access_token = fresh_access_token - self._refresh_token = fresh_refresh_token - - if self._oauth_persistence: - token = OAuthToken(self._access_token, self._refresh_token) - self._oauth_persistence.persist(token) + try: + (fresh_access_token, fresh_refresh_token, is_refreshed) = check_and_refresh_access_token( + hostname=self._hostname, + client_id=self._client_id, + access_token=self._access_token, + refresh_token=self._refresh_token) + if not is_refreshed: + return + else: + self._access_token = fresh_access_token + self._refresh_token = fresh_refresh_token + + if self._oauth_persistence: + token = OAuthToken(self._access_token, self._refresh_token) + self._oauth_persistence.persist(token) + except Exception as e: + logging.error(f"unexpected error in oauth token update", e, exc_info=True) + raise e diff --git a/src/databricks/sql/auth/oauth.py b/src/databricks/sql/auth/oauth.py index f89fcdbce..63cea7a69 100644 --- a/src/databricks/sql/auth/oauth.py +++ b/src/databricks/sql/auth/oauth.py @@ -229,7 +229,10 @@ def check_and_refresh_access_token(hostname, client_id, access_token, refresh_to # If it has been tampered with, it will be rejected on the server side. # This avoids having to fetch the public key from the issuer and perform # an unnecessary signature verification. - decoded = json.loads(base64.standard_b64decode(access_token.split(".")[1])) + access_token_payload = access_token.split(".")[1] + # add padding + access_token_payload = access_token_payload + '=' * (-len(access_token_payload) % 4) + decoded = json.loads(base64.standard_b64decode(access_token_payload)) expiration_time = datetime.fromtimestamp(decoded["exp"], tz=UTC) except Exception as e: logger.error(e) diff --git a/tests/unit/test_oauth_persistence.py b/tests/unit/test_oauth_persistence.py new file mode 100644 index 000000000..18d7dc4ad --- /dev/null +++ b/tests/unit/test_oauth_persistence.py @@ -0,0 +1,55 @@ +# Copyright 2022 Databricks, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"), except +# that the use of services to which certain application programming +# interfaces (each, an "API") connect requires that the user first obtain +# a license for the use of the APIs from Databricks, Inc. ("Databricks"), +# by creating an account at www.databricks.com and agreeing to either (a) +# the Community Edition Terms of Service, (b) the Databricks Terms of +# Service, or (c) another written agreement between Licensee and Databricks +# for the use of the APIs. +# +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from databricks.sql.auth.auth import AccessTokenAuthProvider, BasicAuthProvider, CredentialsProvider +from databricks.sql.auth.auth import get_python_sql_connector_auth_provider +from databricks.sql.experimental.oauth_persistence import DevOnlyFilePersistence, OAuthToken +import tempfile +import os + + +class OAuthPersistenceTests(unittest.TestCase): + + def test_DevOnlyFilePersistence_read_my_write(self): + with tempfile.TemporaryDirectory() as tempdir: + test_json_file_path = os.path.join(tempdir, 'test.json') + persistence_manager = DevOnlyFilePersistence(test_json_file_path) + access_token = "abc#$%%^&^*&*()()_=-/" + refresh_token = "#$%%^^&**()+)_gter243]xyz" + token = OAuthToken(access_token=access_token, refresh_token=refresh_token) + persistence_manager.persist(token) + new_token = persistence_manager.read() + + self.assertEqual(new_token.get_access_token(), access_token) + self.assertEqual(new_token.get_refresh_token(), refresh_token) + + def test_DevOnlyFilePersistence_file_does_not_exist(self): + with tempfile.TemporaryDirectory() as tempdir: + test_json_file_path = os.path.join(tempdir, 'test.json') + persistence_manager = DevOnlyFilePersistence(test_json_file_path) + new_token = persistence_manager.read() + + self.assertEqual(new_token, None) + + # TODO moderakh add test for file with invalid format From f9cb9f9bc6dfb47a697062dd4c3f54c2ab1169aa Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Thu, 25 Aug 2022 12:27:45 -0700 Subject: [PATCH 36/57] updated comment Signed-off-by: Moe Derakhshani --- tests/unit/test_oauth_persistence.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_oauth_persistence.py b/tests/unit/test_oauth_persistence.py index 18d7dc4ad..bd44654cc 100644 --- a/tests/unit/test_oauth_persistence.py +++ b/tests/unit/test_oauth_persistence.py @@ -52,4 +52,4 @@ def test_DevOnlyFilePersistence_file_does_not_exist(self): self.assertEqual(new_token, None) - # TODO moderakh add test for file with invalid format + # TODO moderakh add test for file with invalid format (should return None) From b6e1fdeee64f023253a6646c8c496641ad4bc96e Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Thu, 25 Aug 2022 13:18:39 -0700 Subject: [PATCH 37/57] added type annotation Signed-off-by: Moe Derakhshani --- src/databricks/sql/auth/authenticators.py | 19 ++++++++++--------- src/databricks/sql/auth/thrift_http_client.py | 2 +- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/src/databricks/sql/auth/authenticators.py b/src/databricks/sql/auth/authenticators.py index c27a96ab6..d3c839d84 100644 --- a/src/databricks/sql/auth/authenticators.py +++ b/src/databricks/sql/auth/authenticators.py @@ -20,6 +20,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import TypedDict, List from databricks.sql.auth.oauth import get_tokens, check_and_refresh_access_token import base64 @@ -27,34 +28,34 @@ # Private API: this is an evolving interface and it will change in the future. # Please must not depend on it in your applications. -from databricks.sql.experimental.oauth_persistence import OAuthToken +from databricks.sql.experimental.oauth_persistence import OAuthToken, OAuthPersistence class CredentialsProvider: - def add_headers(self, request_headers): + def add_headers(self, request_headers: TypedDict): pass # Private API: this is an evolving interface and it will change in the future. # Please must not depend on it in your applications. class AccessTokenAuthProvider(CredentialsProvider): - def __init__(self, access_token): + def __init__(self, access_token: str): self.__authorization_header_value = "Bearer {}".format(access_token) - def add_headers(self, request_headers): + def add_headers(self, request_headers: TypedDict): request_headers['Authorization'] = self.__authorization_header_value # Private API: this is an evolving interface and it will change in the future. # Please must not depend on it in your applications. class BasicAuthProvider(CredentialsProvider): - def __init__(self, username, password): + def __init__(self, username: str, password: str): auth_credentials = f"{username}:{password}".encode("UTF-8") auth_credentials_base64 = base64.standard_b64encode(auth_credentials).decode("UTF-8") self.__authorization_header_value = f"Basic {auth_credentials_base64}" - def add_headers(self, request_headers): + def add_headers(self, request_headers: TypedDict): request_headers['Authorization'] = self.__authorization_header_value @@ -63,7 +64,7 @@ def add_headers(self, request_headers): class DatabricksOAuthProvider(CredentialsProvider): SCOPE_DELIM = ' ' - def __init__(self, hostname, oauth_persistence, client_id, scopes): + def __init__(self, hostname: str, oauth_persistence: OAuthPersistence, client_id: str, scopes: List[str]): try: self._hostname = self._normalize_host_name(hostname=hostname) self._scopes_as_str = DatabricksOAuthProvider.SCOPE_DELIM.join(scopes) @@ -76,12 +77,12 @@ def __init__(self, hostname, oauth_persistence, client_id, scopes): logging.error(f"unexpected error", e, exc_info=True) raise e - def add_headers(self, request_headers): + def add_headers(self, request_headers: TypedDict): self._update_token_if_expired() request_headers['Authorization'] = f"Bearer {self._access_token}" @staticmethod - def _normalize_host_name(hostname): + def _normalize_host_name(hostname: str): maybe_scheme = "https://" if not hostname.startswith("https://") else "" maybe_trailing_slash = "/" if not hostname.endswith("/") else "" return f"{maybe_scheme}{hostname}{maybe_trailing_slash}" diff --git a/src/databricks/sql/auth/thrift_http_client.py b/src/databricks/sql/auth/thrift_http_client.py index daf0fec49..acf1e2615 100644 --- a/src/databricks/sql/auth/thrift_http_client.py +++ b/src/databricks/sql/auth/thrift_http_client.py @@ -43,4 +43,4 @@ def flush(self): self.__auth_provider.add_headers(headers) self._headers = headers self.setCustomHeaders(self._headers) - super().flush() \ No newline at end of file + super().flush() \ No newline at end of file From 2a697153151a79cd7f56fae325a468aaf6aeed60 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Thu, 25 Aug 2022 13:53:57 -0700 Subject: [PATCH 38/57] addressed code review comment (use python3 api) Signed-off-by: Moe Derakhshani --- src/databricks/sql/auth/oauth.py | 25 +++---------------- src/databricks/sql/auth/thrift_http_client.py | 2 +- 2 files changed, 5 insertions(+), 22 deletions(-) diff --git a/src/databricks/sql/auth/oauth.py b/src/databricks/sql/auth/oauth.py index 63cea7a69..bdc78b43d 100644 --- a/src/databricks/sql/auth/oauth.py +++ b/src/databricks/sql/auth/oauth.py @@ -26,7 +26,8 @@ import os import webbrowser import json -from datetime import datetime, timedelta, tzinfo +import secrets +from datetime import datetime, timedelta, tzinfo, timezone import logging @@ -45,32 +46,14 @@ logger = logging.getLogger(__name__) -# This could use 'import secrets' in Python 3 def token_urlsafe(nbytes=32): - tok = os.urandom(nbytes) - return base64.urlsafe_b64encode(tok).rstrip(b"=").decode("ascii") - - -# This could be datetime.timezone.utc in Python 3 -class UTCTimeZone(tzinfo): - """UTC""" - def utcoffset(self, dt): - #pylint: disable=unused-argument - return timedelta(0) - - def tzname(self, dt): - #pylint: disable=unused-argument - return "UTC" - - def dst(self, dt): - #pylint: disable=unused-argument - return timedelta(0) + return secrets.token_urlsafe(nbytes) # Some constant values OIDC_REDIRECTOR_PATH = "oidc" REDIRECT_PORT = 8020 -UTC = UTCTimeZone() +UTC = timezone.utc def get_redirect_url(port=REDIRECT_PORT): diff --git a/src/databricks/sql/auth/thrift_http_client.py b/src/databricks/sql/auth/thrift_http_client.py index acf1e2615..daf0fec49 100644 --- a/src/databricks/sql/auth/thrift_http_client.py +++ b/src/databricks/sql/auth/thrift_http_client.py @@ -43,4 +43,4 @@ def flush(self): self.__auth_provider.add_headers(headers) self._headers = headers self.setCustomHeaders(self._headers) - super().flush() \ No newline at end of file + super().flush() \ No newline at end of file From beb1f64d0e0964d11cd7fc6fa4527b110955a3de Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Thu, 25 Aug 2022 16:08:57 -0700 Subject: [PATCH 39/57] made http request handler an independent class, removed global arg, responded to a code review comment Signed-off-by: Moe Derakhshani --- src/databricks/sql/auth/oauth.py | 52 +++---------------- src/databricks/sql/auth/oauth_http_handler.py | 42 +++++++++++++++ 2 files changed, 48 insertions(+), 46 deletions(-) create mode 100644 src/databricks/sql/auth/oauth_http_handler.py diff --git a/src/databricks/sql/auth/oauth.py b/src/databricks/sql/auth/oauth.py index bdc78b43d..068062336 100644 --- a/src/databricks/sql/auth/oauth.py +++ b/src/databricks/sql/auth/oauth.py @@ -37,6 +37,7 @@ import requests from requests.exceptions import RequestException +from databricks.sql.auth.oauth_http_handler import OAuthHttpSingleRequestHandler try: from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer @@ -98,49 +99,6 @@ def get_challenge(verifier_string=token_urlsafe(32)): return verifier_string, challenge_string -# This is a janky global that is used to store the path of the single request the HTTP server -# will receive. -global_request_path = None - - -def set_request_path(path): - global global_request_path - global_request_path = path - - -class SingleRequestHandler(BaseHTTPRequestHandler): - RESPONSE_BODY = """ - - Close this Tab - - - -

Please close this tab.

-

- The Databricks Python Sql Connector received a response. You may close this tab. -

- -""".encode("utf-8") - - def do_GET(self): # nopep8 - self.send_response(200, "Success") - self.send_header("Content-type", "text/html") - self.end_headers() - self.wfile.write(self.RESPONSE_BODY) - set_request_path(self.path) - - def log_message(self, format, *args): - #pylint: disable=redefined-builtin - #pylint: disable=unused-argument - return - - def get_authorization_code(client, auth_url, redirect_url, scope, state, challenge, port): (auth_req_uri, _, _) = client.prepare_authorization_request( authorization_url=auth_url, @@ -151,18 +109,20 @@ def get_authorization_code(client, auth_url, redirect_url, scope, state, challen code_challenge_method="S256") logger.info(f"Opening {auth_req_uri}") - with HTTPServer(("", port), SingleRequestHandler) as httpd: + handler = OAuthHttpSingleRequestHandler("Databricks Sql Connector") + + with HTTPServer(("", port), handler) as httpd: webbrowser.open_new(auth_req_uri) logger.info(f"Listening for OAuth authorization callback at {redirect_url}") httpd.handle_request() - if not global_request_path: + if not handler.request_path: msg = f"No path parameters were returned to the callback at {redirect_url}" logger.error(msg) raise RuntimeError(msg) # This is a kludge because the parsing library expects https callbacks # We should probably set it up using https - full_redirect_url = f"https://localhost:{port}/{global_request_path}" + full_redirect_url = f"https://localhost:{port}/{handler.request_path}" try: authorization_code_response = \ client.parse_request_uri_response(full_redirect_url, state=state) diff --git a/src/databricks/sql/auth/oauth_http_handler.py b/src/databricks/sql/auth/oauth_http_handler.py new file mode 100644 index 000000000..a3f6cc161 --- /dev/null +++ b/src/databricks/sql/auth/oauth_http_handler.py @@ -0,0 +1,42 @@ +from http.server import BaseHTTPRequestHandler + + +class OAuthHttpSingleRequestHandler(BaseHTTPRequestHandler): + RESPONSE_BODY_TEMPLATE = """ + + Close this Tab + + + +

Please close this tab.

+

+ The {!!!PLACE_HOLDER!!!} received a response. You may close this tab. +

+ +""" + + def __init__(self, tool_name): + self.response_body = self.RESPONSE_BODY_TEMPLATE.replace("{!!!PLACE_HOLDER!!!}", tool_name).encode("utf-8") + self.request_path = None + + def __call__(self, *args, **kwargs): + """Handle a request.""" + super().__init__(*args, **kwargs) + + def do_GET(self): # nopep8 + self.send_response(200, "Success") + self.send_header("Content-type", "text/html") + self.end_headers() + self.wfile.write(self.response_body) + self.request_path = self.path + + def log_message(self, format, *args): + #pylint: disable=redefined-builtin + #pylint: disable=unused-argument + return From 930cea8c74ac914ae4d834c8f75bd9a64b363579 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Fri, 26 Aug 2022 13:20:28 -0700 Subject: [PATCH 40/57] added pull request ci trigger Signed-off-by: Moe Derakhshani --- .github/workflows/code-quality-checks.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/code-quality-checks.yml b/.github/workflows/code-quality-checks.yml index 70e9de70b..b6462322a 100644 --- a/.github/workflows/code-quality-checks.yml +++ b/.github/workflows/code-quality-checks.yml @@ -1,5 +1,5 @@ name: Code Quality Checks -on: [push] +on: [pull_request, push] jobs: run-unit-tests: runs-on: ubuntu-latest From 8af7b5f7969711c1b42b56dade0f94398b374647 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Fri, 26 Aug 2022 13:36:08 -0700 Subject: [PATCH 41/57] restructured the code as class Signed-off-by: Moe Derakhshani --- src/databricks/sql/auth/auth.py | 7 +- src/databricks/sql/auth/authenticators.py | 24 +- src/databricks/sql/auth/oauth.py | 380 +++++++++++----------- 3 files changed, 212 insertions(+), 199 deletions(-) diff --git a/src/databricks/sql/auth/auth.py b/src/databricks/sql/auth/auth.py index 6dad102df..c09f823dc 100644 --- a/src/databricks/sql/auth/auth.py +++ b/src/databricks/sql/auth/auth.py @@ -42,6 +42,7 @@ def __init__(self, auth_type: str = None, oauth_scopes: List[str] = None, oauth_client_id: str = None, + oauth_redirect_port_range: List[int] = None, use_cert_as_auth: str = None, tls_client_cert_file: str = None, oauth_persistence=None @@ -53,6 +54,7 @@ def __init__(self, self.auth_type = auth_type self.oauth_scopes = oauth_scopes self.oauth_client_id = oauth_client_id + self.oauth_redirect_port_range = oauth_redirect_port_range self.use_cert_as_auth = use_cert_as_auth self.tls_client_cert_file = tls_client_cert_file self.oauth_persistence = oauth_persistence @@ -60,7 +62,7 @@ def __init__(self, def get_auth_provider(cfg: ClientContext): if cfg.auth_type == AuthType.DATABRICKS_OAUTH.value: - return DatabricksOAuthProvider(cfg.hostname, cfg.oauth_persistence, cfg.oauth_client_id, cfg.oauth_scopes) + return DatabricksOAuthProvider(cfg.hostname, cfg.oauth_persistence, cfg.oauth_redirect_port_range, cfg.oauth_client_id, cfg.oauth_scopes) elif cfg.access_token is not None: return AccessTokenAuthProvider(cfg.access_token) elif cfg.username is not None and cfg.password is not None: @@ -75,7 +77,7 @@ def get_auth_provider(cfg: ClientContext): OAUTH_SCOPES = ["sql", "offline_access"] # TODO: moderakh to be changed once registered on the service side OAUTH_CLIENT_ID = "databricks-cli" - +OAUTH_REDIRECT_PORT_RANGE = range(8020, 8025) def get_python_sql_connector_auth_provider(hostname: str, oauth_persistence: OAuthPersistence = None, **kwargs): cfg = ClientContext(hostname=hostname, @@ -87,6 +89,7 @@ def get_python_sql_connector_auth_provider(hostname: str, oauth_persistence: OAu tls_client_cert_file=kwargs.get("_tls_client_cert_file"), oauth_scopes=OAUTH_SCOPES, oauth_client_id=OAUTH_CLIENT_ID, + oauth_redirect_port_range=OAUTH_REDIRECT_PORT_RANGE, oauth_persistence=oauth_persistence) return get_auth_provider(cfg) diff --git a/src/databricks/sql/auth/authenticators.py b/src/databricks/sql/auth/authenticators.py index d3c839d84..bba1f7914 100644 --- a/src/databricks/sql/auth/authenticators.py +++ b/src/databricks/sql/auth/authenticators.py @@ -20,9 +20,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TypedDict, List +from typing import Dict, List -from databricks.sql.auth.oauth import get_tokens, check_and_refresh_access_token +from databricks.sql.auth.oauth import OAuthManager import base64 @@ -32,7 +32,7 @@ class CredentialsProvider: - def add_headers(self, request_headers: TypedDict): + def add_headers(self, request_headers: Dict[str, str]): pass @@ -42,7 +42,7 @@ class AccessTokenAuthProvider(CredentialsProvider): def __init__(self, access_token: str): self.__authorization_header_value = "Bearer {}".format(access_token) - def add_headers(self, request_headers: TypedDict): + def add_headers(self, request_headers: Dict[str, str]): request_headers['Authorization'] = self.__authorization_header_value @@ -55,7 +55,7 @@ def __init__(self, username: str, password: str): self.__authorization_header_value = f"Basic {auth_credentials_base64}" - def add_headers(self, request_headers: TypedDict): + def add_headers(self, request_headers: Dict[str, str]): request_headers['Authorization'] = self.__authorization_header_value @@ -64,8 +64,9 @@ def add_headers(self, request_headers: TypedDict): class DatabricksOAuthProvider(CredentialsProvider): SCOPE_DELIM = ' ' - def __init__(self, hostname: str, oauth_persistence: OAuthPersistence, client_id: str, scopes: List[str]): + def __init__(self, hostname: str, oauth_persistence: OAuthPersistence, redirect_port_range: List[int], client_id: str, scopes: List[str]): try: + self.oauth_manager = OAuthManager(port_range=redirect_port_range, client_id=client_id) self._hostname = self._normalize_host_name(hostname=hostname) self._scopes_as_str = DatabricksOAuthProvider.SCOPE_DELIM.join(scopes) self._oauth_persistence = oauth_persistence @@ -77,7 +78,7 @@ def __init__(self, hostname: str, oauth_persistence: OAuthPersistence, client_id logging.error(f"unexpected error", e, exc_info=True) raise e - def add_headers(self, request_headers: TypedDict): + def add_headers(self, request_headers: Dict[str, str]): self._update_token_if_expired() request_headers['Authorization'] = f"Bearer {self._access_token}" @@ -99,9 +100,9 @@ def _initial_get_token(self): if self._access_token and self._refresh_token: self._update_token_if_expired() else: - (access_token, refresh_token) = get_tokens(hostname=self._hostname, - client_id=self._client_id, - scope=self._scopes_as_str) + (access_token, refresh_token) = self.oauth_manager.get_tokens( + hostname=self._hostname, + scope=self._scopes_as_str) self._access_token = access_token self._refresh_token = refresh_token self._oauth_persistence.persist(OAuthToken(access_token, refresh_token)) @@ -111,9 +112,8 @@ def _initial_get_token(self): def _update_token_if_expired(self): try: - (fresh_access_token, fresh_refresh_token, is_refreshed) = check_and_refresh_access_token( + (fresh_access_token, fresh_refresh_token, is_refreshed) = self.oauth_manager.check_and_refresh_access_token( hostname=self._hostname, - client_id=self._client_id, access_token=self._access_token, refresh_token=self._refresh_token) if not is_refreshed: diff --git a/src/databricks/sql/auth/oauth.py b/src/databricks/sql/auth/oauth.py index 068062336..6dc630454 100644 --- a/src/databricks/sql/auth/oauth.py +++ b/src/databricks/sql/auth/oauth.py @@ -23,206 +23,216 @@ import base64 import hashlib -import os import webbrowser import json import secrets -from datetime import datetime, timedelta, tzinfo, timezone - +from datetime import datetime, timezone +from http.server import HTTPServer import logging import oauthlib.oauth2 from oauthlib.oauth2.rfc6749.errors import OAuth2Error +from typing import Dict, List import requests from requests.exceptions import RequestException from databricks.sql.auth.oauth_http_handler import OAuthHttpSingleRequestHandler -try: - from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer -except ImportError: - from http.server import BaseHTTPRequestHandler, HTTPServer - logger = logging.getLogger(__name__) -def token_urlsafe(nbytes=32): - return secrets.token_urlsafe(nbytes) - - # Some constant values -OIDC_REDIRECTOR_PATH = "oidc" -REDIRECT_PORT = 8020 -UTC = timezone.utc - - -def get_redirect_url(port=REDIRECT_PORT): - return f"http://localhost:{port}" - - -def fetch_well_known_config(idp_url): - known_config_url = f"{idp_url}/.well-known/oauth-authorization-server" - try: - response = requests.get(url=known_config_url) - except RequestException as e: - logger.error(f"Unable to fetch OAuth configuration from {idp_url}.\n" - "Verify it is a valid workspace URL and that OAuth is " - "enabled on this account.") - raise e - - if response.status_code != 200: - msg = (f"Received status {response.status_code} OAuth configuration from " - f"{idp_url}.\n Verify it is a valid workspace URL and " - "that OAuth is enabled on this account." - ) - logger.error(msg) - raise RuntimeError(msg) - try: + +class OAuthManager: + OIDC_REDIRECTOR_PATH = "oidc" + + def __init__(self, port_range: List[int], client_id: str): + self.port_range = port_range + self.client_id = client_id + self.redirect_port = None + + @staticmethod + def __token_urlsafe(nbytes=32): + return secrets.token_urlsafe(nbytes) + + @staticmethod + def __get_redirect_url(redirect_port: int): + return f"http://localhost:{redirect_port}" + + @staticmethod + def __fetch_well_known_config(idp_url: str): + known_config_url = f"{idp_url}/.well-known/oauth-authorization-server" + try: + response = requests.get(url=known_config_url) + except RequestException as e: + logger.error(f"Unable to fetch OAuth configuration from {idp_url}.\n" + "Verify it is a valid workspace URL and that OAuth is " + "enabled on this account.") + raise e + + if response.status_code != 200: + msg = (f"Received status {response.status_code} OAuth configuration from " + f"{idp_url}.\n Verify it is a valid workspace URL and " + "that OAuth is enabled on this account." + ) + logger.error(msg) + raise RuntimeError(msg) + try: + return response.json() + except requests.exceptions.JSONDecodeError as e: + logger.error(f"Unable to decode OAuth configuration from {idp_url}.\n" + "Verify it is a valid workspace URL and that OAuth is " + "enabled on this account.") + raise e + + @staticmethod + def __get_idp_url(host: str): + maybe_scheme = "https://" if not host.startswith("https://") else "" + maybe_trailing_slash = "/" if not host.endswith("/") else "" + return f"{maybe_scheme}{host}{maybe_trailing_slash}{OAuthManager.OIDC_REDIRECTOR_PATH}" + + @staticmethod + def __get_challenge(): + verifier_string = OAuthManager.__token_urlsafe(32) + digest = hashlib.sha256(verifier_string.encode("UTF-8")).digest() + challenge_string = base64.urlsafe_b64encode(digest).decode("UTF-8").replace("=", "") + return verifier_string, challenge_string + + def __get_authorization_code(self, client, auth_url, scope, state, challenge): + handler = OAuthHttpSingleRequestHandler("Databricks Sql Connector") + + last_error = None + for port in self.port_range: + try: + with HTTPServer(("", port), handler) as httpd: + redirect_url = OAuthManager.__get_redirect_url(port) + (auth_req_uri, _, _) = client.prepare_authorization_request( + authorization_url=auth_url, + redirect_url=redirect_url, + scope=scope, + state=state, + code_challenge=challenge, + code_challenge_method="S256") + logger.info(f"Opening {auth_req_uri}") + + webbrowser.open_new(auth_req_uri) + logger.info(f"Listening for OAuth authorization callback at {redirect_url}") + httpd.handle_request() + self.redirect_port = port + break + except OSError as e: + if e.errno == 48: + logger.info(f"Port {port} is in use") + last_error = e + except Exception as e: + logger.error("unexpected error", e) + if self.redirect_port is None: + logger.error(f"Tried all the ports {self.port_range} for oauth redirect, but can't find free port") + raise last_error + + if not handler.request_path: + msg = f"No path parameters were returned to the callback at {redirect_url}" + logger.error(msg) + raise RuntimeError(msg) + # This is a kludge because the parsing library expects https callbacks + # We should probably set it up using https + full_redirect_url = f"https://localhost:{self.redirect_port}/{handler.request_path}" + try: + authorization_code_response = \ + client.parse_request_uri_response(full_redirect_url, state=state) + except OAuth2Error as e: + logger.error(f"OAuth Token Request error {e.description}") + raise e + return authorization_code_response + + def __send_auth_code_token_request(self, client, token_request_url, redirect_url, code, verifier): + token_request_body = client.prepare_request_body(code=code, redirect_uri=redirect_url) + data = f"{token_request_body}&code_verifier={verifier}" + return self.__send_token_request(token_request_url, data) + + @staticmethod + def __send_token_request(token_request_url, data): + headers = { + "Accept": "application/json", + "Content-Type": "application/x-www-form-urlencoded" + } + response = requests.post(url=token_request_url, data=data, headers=headers) return response.json() - except requests.exceptions.JSONDecodeError as e: - logger.error(f"Unable to decode OAuth configuration from {idp_url}.\n" - "Verify it is a valid workspace URL and that OAuth is " - "enabled on this account.") - raise e - - -def get_idp_url(host): - maybe_scheme = "https://" if not host.startswith("https://") else "" - maybe_trailing_slash = "/" if not host.endswith("/") else "" - return f"{maybe_scheme}{host}{maybe_trailing_slash}{OIDC_REDIRECTOR_PATH}" - - -def get_challenge(verifier_string=token_urlsafe(32)): - digest = hashlib.sha256(verifier_string.encode("UTF-8")).digest() - challenge_string = base64.urlsafe_b64encode(digest).decode("UTF-8").replace("=", "") - return verifier_string, challenge_string - - -def get_authorization_code(client, auth_url, redirect_url, scope, state, challenge, port): - (auth_req_uri, _, _) = client.prepare_authorization_request( - authorization_url=auth_url, - redirect_url=redirect_url, - scope=scope, - state=state, - code_challenge=challenge, - code_challenge_method="S256") - logger.info(f"Opening {auth_req_uri}") - - handler = OAuthHttpSingleRequestHandler("Databricks Sql Connector") - - with HTTPServer(("", port), handler) as httpd: - webbrowser.open_new(auth_req_uri) - logger.info(f"Listening for OAuth authorization callback at {redirect_url}") - httpd.handle_request() - - if not handler.request_path: - msg = f"No path parameters were returned to the callback at {redirect_url}" - logger.error(msg) - raise RuntimeError(msg) - # This is a kludge because the parsing library expects https callbacks - # We should probably set it up using https - full_redirect_url = f"https://localhost:{port}/{handler.request_path}" - try: - authorization_code_response = \ - client.parse_request_uri_response(full_redirect_url, state=state) - except OAuth2Error as e: - logger.error(f"OAuth Token Request error {e.description}") - raise e - return authorization_code_response - - -def send_auth_code_token_request(client, token_request_url, redirect_url, code, verifier): - token_request_body = client.prepare_request_body(code=code, redirect_uri=redirect_url) - data = f"{token_request_body}&code_verifier={verifier}" - return send_token_request(token_request_url, data) - - -def send_token_request(token_request_url, data): - headers = { - "Accept": "application/json", - "Content-Type": "application/x-www-form-urlencoded" - } - response = requests.post(url=token_request_url, data=data, headers=headers) - return response.json() - - -def send_refresh_token_request(hostname, client_id, refresh_token): - idp_url = get_idp_url(hostname) - oauth_config = fetch_well_known_config(idp_url) - token_request_url = oauth_config["token_endpoint"] - client = oauthlib.oauth2.WebApplicationClient(client_id) - token_request_body = client.prepare_refresh_body( - refresh_token=refresh_token, client_id=client.client_id) - return send_token_request(token_request_url, token_request_body) - - -def get_tokens_from_response(oauth_response): - access_token = oauth_response["access_token"] - refresh_token = oauth_response["refresh_token"] if "refresh_token" in oauth_response else None - return access_token, refresh_token - - -def check_and_refresh_access_token(hostname, client_id, access_token, refresh_token): - now = datetime.now(tz=UTC) - # If we can't decode an expiration time, this will be expired by default. - expiration_time = now - try: - # This token has already been verified and we are just parsing it. - # If it has been tampered with, it will be rejected on the server side. - # This avoids having to fetch the public key from the issuer and perform - # an unnecessary signature verification. - access_token_payload = access_token.split(".")[1] - # add padding - access_token_payload = access_token_payload + '=' * (-len(access_token_payload) % 4) - decoded = json.loads(base64.standard_b64decode(access_token_payload)) - expiration_time = datetime.fromtimestamp(decoded["exp"], tz=UTC) - except Exception as e: - logger.error(e) - raise e - - if expiration_time > now: - # The access token is fine. Just return it. - return access_token, refresh_token, False - - if not refresh_token: - msg = f"OAuth access token expired on {expiration_time}." - logger.error(msg) - raise RuntimeError(msg) - - # Try to refresh using the refresh token - logger.debug(f"Attempting to refresh OAuth access token that expired on {expiration_time}") - oauth_response = send_refresh_token_request(hostname, client_id, refresh_token) - fresh_access_token, fresh_refresh_token = get_tokens_from_response(oauth_response) - return fresh_access_token, fresh_refresh_token, True - - -def get_tokens(hostname, client_id, scope=None): - idp_url = get_idp_url(hostname) - oauth_config = fetch_well_known_config(idp_url) - # We are going to override oauth_config["authorization_endpoint"] use the - # /oidc redirector on the hostname, which may inject additional parameters. - auth_url = f"{hostname}oidc/v1/authorize" - state = token_urlsafe(16) - (verifier, challenge) = get_challenge() - client = oauthlib.oauth2.WebApplicationClient(client_id) - redirect_url = get_redirect_url() - try: - auth_response = get_authorization_code( - client, - auth_url, - redirect_url, - scope, - state, - challenge, - REDIRECT_PORT) - except OAuth2Error as e: - msg = f"OAuth Authorization Error: {e.description}" - logger.error(msg) - raise e - - token_request_url = oauth_config["token_endpoint"] - code = auth_response["code"] - oauth_response = \ - send_auth_code_token_request(client, token_request_url, redirect_url, code, verifier) - return get_tokens_from_response(oauth_response) + + def __send_refresh_token_request(self, hostname, refresh_token): + idp_url = OAuthManager.__get_idp_url(hostname) + oauth_config = OAuthManager.__fetch_well_known_config(idp_url) + token_request_url = oauth_config["token_endpoint"] + client = oauthlib.oauth2.WebApplicationClient(self.client_id) + token_request_body = client.prepare_refresh_body( + refresh_token=refresh_token, client_id=client.client_id) + return OAuthManager.__send_token_request(token_request_url, token_request_body) + + @staticmethod + def __get_tokens_from_response(oauth_response): + access_token = oauth_response["access_token"] + refresh_token = oauth_response["refresh_token"] if "refresh_token" in oauth_response else None + return access_token, refresh_token + + def check_and_refresh_access_token(self, hostname: str, access_token: str, refresh_token: str): + now = datetime.now(tz=timezone.utc) + # If we can't decode an expiration time, this will be expired by default. + expiration_time = now + try: + # This token has already been verified and we are just parsing it. + # If it has been tampered with, it will be rejected on the server side. + # This avoids having to fetch the public key from the issuer and perform + # an unnecessary signature verification. + access_token_payload = access_token.split(".")[1] + # add padding + access_token_payload = access_token_payload + '=' * (-len(access_token_payload) % 4) + decoded = json.loads(base64.standard_b64decode(access_token_payload)) + expiration_time = datetime.fromtimestamp(decoded["exp"], tz=timezone.utc) + except Exception as e: + logger.error(e) + raise e + + if expiration_time > now: + # The access token is fine. Just return it. + return access_token, refresh_token, False + + if not refresh_token: + msg = f"OAuth access token expired on {expiration_time}." + logger.error(msg) + raise RuntimeError(msg) + + # Try to refresh using the refresh token + logger.debug(f"Attempting to refresh OAuth access token that expired on {expiration_time}") + oauth_response = self.__send_refresh_token_request(hostname, refresh_token) + fresh_access_token, fresh_refresh_token = self.__get_tokens_from_response(oauth_response) + return fresh_access_token, fresh_refresh_token, True + + def get_tokens(self, hostname: str, scope=None): + idp_url = self.__get_idp_url(hostname) + oauth_config = self.__fetch_well_known_config(idp_url) + # We are going to override oauth_config["authorization_endpoint"] use the + # /oidc redirector on the hostname, which may inject additional parameters. + auth_url = f"{hostname}oidc/v1/authorize" + state = OAuthManager.__token_urlsafe(16) + (verifier, challenge) = OAuthManager.__get_challenge() + client = oauthlib.oauth2.WebApplicationClient(self.client_id) + try: + auth_response = self.__get_authorization_code( + client, + auth_url, + scope, + state, + challenge) + except OAuth2Error as e: + msg = f"OAuth Authorization Error: {e.description}" + logger.error(msg) + raise e + + redirect_url = OAuthManager.__get_redirect_url(self.redirect_port) + + token_request_url = oauth_config["token_endpoint"] + code = auth_response["code"] + oauth_response = \ + self.__send_auth_code_token_request(client, token_request_url, redirect_url, code, verifier) + return self.__get_tokens_from_response(oauth_response) From dac3c1ef70d6e47b851b57dfd9c355d112b1b8e9 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Fri, 26 Aug 2022 16:17:03 -0700 Subject: [PATCH 42/57] fixed test_thrift_backend.py tests Signed-off-by: Moe Derakhshani --- src/databricks/sql/auth/auth.py | 2 +- src/databricks/sql/client.py | 2 +- src/databricks/sql/thrift_backend.py | 7 +- tests/unit/test_thrift_backend.py | 114 ++++++++++++++------------- 4 files changed, 65 insertions(+), 60 deletions(-) diff --git a/src/databricks/sql/auth/auth.py b/src/databricks/sql/auth/auth.py index c09f823dc..aef434de5 100644 --- a/src/databricks/sql/auth/auth.py +++ b/src/databricks/sql/auth/auth.py @@ -76,7 +76,7 @@ def get_auth_provider(cfg: ClientContext): OAUTH_SCOPES = ["sql", "offline_access"] # TODO: moderakh to be changed once registered on the service side -OAUTH_CLIENT_ID = "databricks-cli" +OAUTH_CLIENT_ID = "databricks-sql-python" OAUTH_REDIRECT_PORT_RANGE = range(8020, 8025) def get_python_sql_connector_auth_provider(hostname: str, oauth_persistence: OAuthPersistence = None, **kwargs): diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index ef551daec..0164e1869 100644 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -101,8 +101,8 @@ def __init__( self.host, self.port, http_path, - auth_provider, (http_headers or []) + base_headers, + auth_provider, **kwargs ) diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index a2434582c..e077518b7 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -12,13 +12,14 @@ import thrift.transport.TSocket import thrift.transport.TTransport -from databricks.sql.auth.thrift_http_client import THttpClient +import databricks.sql.auth.thrift_http_client from databricks.sql.auth.authenticators import CredentialsProvider from databricks.sql.thrift_api.TCLIService import TCLIService, ttypes from databricks.sql import * from databricks.sql.thrift_api.TCLIService.TCLIService import ( Client as TCLIServiceClient, ) + from databricks.sql.utils import ( ArrowQueue, ExecuteResponse, @@ -53,7 +54,7 @@ class ThriftBackend: BIT_MASKS = [1, 2, 4, 8, 16, 32, 64, 128] def __init__( - self, server_hostname: str, port, http_path: str, auth_provider: CredentialsProvider, http_headers, **kwargs + self, server_hostname: str, port, http_path: str, http_headers, auth_provider: CredentialsProvider, **kwargs ): # Internal arguments in **kwargs: # _user_agent_entry @@ -135,7 +136,7 @@ def __init__( self._auth_provider = auth_provider - self._transport = THttpClient( + self._transport = databricks.sql.auth.thrift_http_client.THttpClient( auth_provider=self._auth_provider, uri_or_host=uri, ssl_context=ssl_context, diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index e8c5a727f..1120c0525 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -10,6 +10,7 @@ import databricks.sql from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql import * +from databricks.sql.auth.authenticators import CredentialsProvider from databricks.sql.thrift_backend import ThriftBackend @@ -59,7 +60,7 @@ def test_make_request_checks_thrift_status_code(self): mock_method = Mock() mock_method.__name__ = "method name" mock_method.return_value = mock_response - thrift_backend = ThriftBackend("foo", 123, "bar", []) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) with self.assertRaises(DatabaseError): thrift_backend.make_request(mock_method, Mock()) @@ -67,7 +68,7 @@ def _make_type_desc(self, type): return ttypes.TTypeDesc(types=[ttypes.TTypeEntry(ttypes.TPrimitiveTypeEntry(type=type))]) def _make_fake_thrift_backend(self): - thrift_backend = ThriftBackend("foobar", 443, "path", []) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) thrift_backend._hive_schema_to_arrow_schema = Mock() thrift_backend._hive_schema_to_description = Mock() thrift_backend._create_arrow_table = MagicMock() @@ -137,12 +138,12 @@ def test_okay_protocol_versions_succeed(self, tcli_service_client_cass): thrift_backend = self._make_fake_thrift_backend() thrift_backend.open_session({}, None, None) - @patch("thrift.transport.THttpClient.THttpClient") + @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_headers_are_set(self, t_http_client_class): - ThriftBackend("foo", 123, "bar", [("header", "value")]) + ThriftBackend("foo", 123, "bar", [("header", "value")], auth_provider=CredentialsProvider()) t_http_client_class.return_value.setCustomHeaders.assert_called_with({"header": "value"}) - @patch("thrift.transport.THttpClient.THttpClient") + @patch("databricks.sql.auth.thrift_http_client.THttpClient") @patch("databricks.sql.thrift_backend.create_default_context") def test_tls_cert_args_are_propagated(self, mock_create_default_context, t_http_client_class): mock_cert_key_file = Mock() @@ -154,6 +155,7 @@ def test_tls_cert_args_are_propagated(self, mock_create_default_context, t_http_ "foo", 123, "bar", [], + auth_provider=CredentialsProvider(), _tls_client_cert_file=mock_cert_file, _tls_client_cert_key_file=mock_cert_key_file, _tls_client_cert_key_password=mock_cert_key_password, @@ -167,40 +169,40 @@ def test_tls_cert_args_are_propagated(self, mock_create_default_context, t_http_ self.assertEqual(mock_ssl_context.verify_mode, CERT_REQUIRED) self.assertEqual(t_http_client_class.call_args[1]["ssl_context"], mock_ssl_context) - @patch("thrift.transport.THttpClient.THttpClient") + @patch("databricks.sql.auth.thrift_http_client.THttpClient") @patch("databricks.sql.thrift_backend.create_default_context") def test_tls_no_verify_is_respected(self, mock_create_default_context, t_http_client_class): - ThriftBackend("foo", 123, "bar", [], _tls_no_verify=True) + ThriftBackend("foo", 123, "bar", [], auth_provider=CredentialsProvider(), _tls_no_verify=True) mock_ssl_context = mock_create_default_context.return_value self.assertFalse(mock_ssl_context.check_hostname) self.assertEqual(mock_ssl_context.verify_mode, CERT_NONE) self.assertEqual(t_http_client_class.call_args[1]["ssl_context"], mock_ssl_context) - @patch("thrift.transport.THttpClient.THttpClient") + @patch("databricks.sql.auth.thrift_http_client.THttpClient") @patch("databricks.sql.thrift_backend.create_default_context") def test_tls_verify_hostname_is_respected(self, mock_create_default_context, t_http_client_class): - ThriftBackend("foo", 123, "bar", [], _tls_verify_hostname=False) + ThriftBackend("foo", 123, "bar", [], auth_provider=CredentialsProvider(), _tls_verify_hostname=False) mock_ssl_context = mock_create_default_context.return_value self.assertFalse(mock_ssl_context.check_hostname) self.assertEqual(mock_ssl_context.verify_mode, CERT_REQUIRED) self.assertEqual(t_http_client_class.call_args[1]["ssl_context"], mock_ssl_context) - @patch("thrift.transport.THttpClient.THttpClient") + @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_port_and_host_are_respected(self, t_http_client_class): - ThriftBackend("hostname", 123, "path_value", []) + ThriftBackend("hostname", 123, "path_value", [], auth_provider=CredentialsProvider()) self.assertEqual(t_http_client_class.call_args[1]["uri_or_host"], "https://hostname:123/path_value") - @patch("thrift.transport.THttpClient.THttpClient") + @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_socket_timeout_is_propagated(self, t_http_client_class): - ThriftBackend("hostname", 123, "path_value", [], _socket_timeout=129) + ThriftBackend("hostname", 123, "path_value", [], auth_provider=CredentialsProvider(), _socket_timeout=129) self.assertEqual(t_http_client_class.return_value.setTimeout.call_args[0][0], 129 * 1000) - ThriftBackend("hostname", 123, "path_value", [], _socket_timeout=0) + ThriftBackend("hostname", 123, "path_value", [], auth_provider=CredentialsProvider(), _socket_timeout=0) self.assertEqual(t_http_client_class.return_value.setTimeout.call_args[0][0], 0) - ThriftBackend("hostname", 123, "path_value", [], _socket_timeout=None) + ThriftBackend("hostname", 123, "path_value", [], auth_provider=CredentialsProvider(), _socket_timeout=None) self.assertEqual(t_http_client_class.return_value.setTimeout.call_args[0][0], None) def test_non_primitive_types_raise_error(self): @@ -268,7 +270,7 @@ def test_hive_schema_to_description_preserves_scale_and_precision(self): def test_make_request_checks_status_code(self): error_codes = [ttypes.TStatusCode.ERROR_STATUS, ttypes.TStatusCode.INVALID_HANDLE_STATUS] - thrift_backend = ThriftBackend("foo", 123, "bar", []) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) for code in error_codes: mock_error_response = Mock() @@ -301,7 +303,7 @@ def test_handle_execute_response_checks_operation_state_in_direct_results(self): resultSetMetadata=None, resultSet=None, closeOperation=None)) - thrift_backend = ThriftBackend("foobar", 443, "path", []) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) with self.assertRaises(DatabaseError) as cm: thrift_backend._handle_execute_response(t_execute_resp, Mock()) @@ -329,7 +331,7 @@ def test_handle_execute_response_checks_operation_state_in_polls(self, tcli_serv operationHandle=self.operation_handle) tcli_service_instance.GetOperationStatus.return_value = op_state_resp - thrift_backend = ThriftBackend("foobar", 443, "path", []) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) with self.assertRaises(DatabaseError) as cm: thrift_backend._handle_execute_response(t_execute_resp, Mock()) @@ -354,7 +356,7 @@ def test_get_status_uses_display_message_if_available(self, tcli_service_class): tcli_service_instance.GetOperationStatus.return_value = t_get_operation_status_resp tcli_service_instance.ExecuteStatement.return_value = t_execute_resp - thrift_backend = ThriftBackend("foobar", 443, "path", []) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) with self.assertRaises(DatabaseError) as cm: thrift_backend.execute_command(Mock(), Mock(), 100, 100, Mock()) @@ -384,7 +386,7 @@ def test_direct_results_uses_display_message_if_available(self, tcli_service_cla tcli_service_instance.ExecuteStatement.return_value = t_execute_resp - thrift_backend = ThriftBackend("foobar", 443, "path", []) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) with self.assertRaises(DatabaseError) as cm: thrift_backend.execute_command(Mock(), Mock(), 100, 100, Mock()) @@ -427,7 +429,7 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): for error_resp in [resp_1, resp_2, resp_3, resp_4]: with self.subTest(error_resp=error_resp): - thrift_backend = ThriftBackend("foobar", 443, "path", []) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) with self.assertRaises(DatabaseError) as cm: thrift_backend._handle_execute_response(error_resp, Mock()) @@ -463,7 +465,7 @@ def test_handle_execute_response_can_handle_without_direct_results(self, tcli_se tcli_service_instance.GetOperationStatus.side_effect = [ op_state_1, op_state_2, op_state_3 ] - thrift_backend = ThriftBackend("foobar", 443, "path", []) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) results_message_response = thrift_backend._handle_execute_response( execute_resp, Mock()) self.assertEqual(results_message_response.status, @@ -488,7 +490,7 @@ def test_handle_execute_response_can_handle_with_direct_results(self): directResults=direct_results_message, operationHandle=self.operation_handle) - thrift_backend = ThriftBackend("foobar", 443, "path", []) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) thrift_backend._results_message_to_execute_response = Mock() thrift_backend._handle_execute_response(execute_resp, Mock()) @@ -642,7 +644,7 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class): pyarrow.field("column3", pyarrow.binary()) ]).serialize().to_pybytes() - thrift_backend = ThriftBackend("foobar", 443, "path", []) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) arrow_queue, has_more_results = thrift_backend.fetch_results( op_handle=Mock(), max_rows=1, @@ -658,7 +660,7 @@ def test_execute_statement_calls_client_and_handle_execute_response(self, tcli_s tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.ExecuteStatement.return_value = response - thrift_backend = ThriftBackend("foobar", 443, "path", []) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() @@ -676,7 +678,7 @@ def test_get_catalogs_calls_client_and_handle_execute_response(self, tcli_servic tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetCatalogs.return_value = response - thrift_backend = ThriftBackend("foobar", 443, "path", []) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() @@ -693,7 +695,7 @@ def test_get_schemas_calls_client_and_handle_execute_response(self, tcli_service tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetSchemas.return_value = response - thrift_backend = ThriftBackend("foobar", 443, "path", []) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() @@ -718,7 +720,7 @@ def test_get_tables_calls_client_and_handle_execute_response(self, tcli_service_ tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetTables.return_value = response - thrift_backend = ThriftBackend("foobar", 443, "path", []) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() @@ -747,7 +749,7 @@ def test_get_columns_calls_client_and_handle_execute_response(self, tcli_service tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetColumns.return_value = response - thrift_backend = ThriftBackend("foobar", 443, "path", []) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() @@ -776,14 +778,14 @@ def test_open_session_user_provided_session_id_optional(self, tcli_service_class tcli_service_instance = tcli_service_class.return_value tcli_service_instance.OpenSession.return_value = self.open_session_resp - thrift_backend = ThriftBackend("foobar", 443, "path", []) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) thrift_backend.open_session({}, None, None) self.assertEqual(len(tcli_service_instance.OpenSession.call_args_list), 1) @patch("databricks.sql.thrift_backend.TCLIService.Client") def test_op_handle_respected_in_close_command(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value - thrift_backend = ThriftBackend("foobar", 443, "path", []) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) thrift_backend.close_command(self.operation_handle) self.assertEqual(tcli_service_instance.CloseOperation.call_args[0][0].operationHandle, self.operation_handle) @@ -791,7 +793,7 @@ def test_op_handle_respected_in_close_command(self, tcli_service_class): @patch("databricks.sql.thrift_backend.TCLIService.Client") def test_session_handle_respected_in_close_session(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value - thrift_backend = ThriftBackend("foobar", 443, "path", []) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) thrift_backend.close_session(self.session_handle) self.assertEqual(tcli_service_instance.CloseSession.call_args[0][0].sessionHandle, self.session_handle) @@ -826,7 +828,7 @@ def test_non_arrow_non_column_based_set_triggers_exception(self, tcli_service_cl def test_create_arrow_table_raises_error_for_unsupported_type(self): t_row_set = ttypes.TRowSet() - thrift_backend = ThriftBackend("foobar", 443, "path", []) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) with self.assertRaises(OperationalError): thrift_backend._create_arrow_table(t_row_set, None, Mock()) @@ -834,7 +836,7 @@ def test_create_arrow_table_raises_error_for_unsupported_type(self): @patch.object(ThriftBackend, "_convert_column_based_set_to_arrow_table") def test_create_arrow_table_calls_correct_conversion_method(self, convert_col_mock, convert_arrow_mock): - thrift_backend = ThriftBackend("foobar", 443, "path", []) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) convert_arrow_mock.return_value = (MagicMock(), Mock()) convert_col_mock.return_value = (MagicMock(), Mock()) @@ -974,7 +976,7 @@ def test_handle_execute_response_sets_active_op_handle(self): @patch("databricks.sql.thrift_backend._retry_policy", new_callable=retry_policy_factory) def test_make_request_will_retry_GetOperationStatus( self, mock_retry_policy, mock_GetOperationStatus, t_transport_class): - + import thrift, errno from databricks.sql.thrift_api.TCLIService.TCLIService import Client from databricks.sql.exc import RequestError @@ -991,13 +993,14 @@ def test_make_request_will_retry_GetOperationStatus( operationHandle=self.operation_handle, getProgressUpdate=False, ) - + EXPECTED_RETRIES = 2 thrift_backend = ThriftBackend( "foobar", 443, "path", [], + auth_provider=CredentialsProvider(), _retry_stop_after_attempts_count=EXPECTED_RETRIES, _retry_delay_default=1) @@ -1009,7 +1012,7 @@ def test_make_request_will_retry_GetOperationStatus( self.assertEqual(f'{EXPECTED_RETRIES}/{EXPECTED_RETRIES}', cm.exception.context["attempt"]) # Unusual OSError code - mock_GetOperationStatus.side_effect = OSError(errno.EEXIST, "File does not exist") + mock_GetOperationStatus.side_effect = OSError(errno.EEXIST, "File does not exist") with self.assertLogs("databricks.sql.thrift_backend", level=logging.WARNING) as cm: with self.assertRaises(RequestError): @@ -1017,14 +1020,14 @@ def test_make_request_will_retry_GetOperationStatus( # There should be two warning log messages: one for each retry self.assertEqual(len(cm.output), EXPECTED_RETRIES) - + # The warnings should be identical self.assertEqual(cm.output[1], cm.output[0]) - + # The warnings should include this text self.assertIn(f"{this_gos_name} failed with code {errno.EEXIST} and will attempt to retry", cm.output[0]) - - + + @patch("thrift.transport.THttpClient.THttpClient") def test_make_request_wont_retry_if_headers_not_present(self, t_transport_class): t_transport_instance = t_transport_class.return_value @@ -1034,7 +1037,7 @@ def test_make_request_wont_retry_if_headers_not_present(self, t_transport_class) mock_method.__name__ = "method name" mock_method.side_effect = Exception("This method fails") - thrift_backend = ThriftBackend("foobar", 443, "path", []) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) with self.assertRaises(OperationalError) as cm: thrift_backend.make_request(mock_method, Mock()) @@ -1050,14 +1053,14 @@ def test_make_request_wont_retry_if_error_code_not_429_or_503(self, t_transport_ mock_method.__name__ = "method name" mock_method.side_effect = Exception("This method fails") - thrift_backend = ThriftBackend("foobar", 443, "path", []) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) with self.assertRaises(OperationalError) as cm: thrift_backend.make_request(mock_method, Mock()) self.assertIn("This method fails", str(cm.exception.message_with_context())) - @patch("thrift.transport.THttpClient.THttpClient") + @patch("databricks.sql.auth.thrift_http_client.THttpClient") @patch("databricks.sql.thrift_backend._retry_policy", new_callable=retry_policy_factory) def test_make_request_will_retry_stop_after_attempts_count_if_retryable( self, mock_retry_policy, t_transport_class): @@ -1072,6 +1075,7 @@ def test_make_request_will_retry_stop_after_attempts_count_if_retryable( "foobar", 443, "path", [], + auth_provider=CredentialsProvider(), _retry_stop_after_attempts_count=14, _retry_delay_max=0, _retry_delay_min=0) @@ -1084,14 +1088,14 @@ def test_make_request_will_retry_stop_after_attempts_count_if_retryable( self.assertEqual(mock_method.call_count, 14) - @patch("thrift.transport.THttpClient.THttpClient") + @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_make_request_will_read_error_message_headers_if_set(self, t_transport_class): t_transport_instance = t_transport_class.return_value mock_method = Mock() mock_method.__name__ = "method name" mock_method.side_effect = Exception("This method fails") - thrift_backend = ThriftBackend("foobar", 443, "path", []) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) error_headers = [[("x-thriftserver-error-message", "thrift server error message")], [("x-databricks-error-or-redirect-message", "databricks error message")], @@ -1173,7 +1177,7 @@ def test_retry_args_passthrough(self, mock_http_client): "_retry_stop_after_attempts_count": 1, "_retry_stop_after_attempts_duration": 100 } - backend = ThriftBackend("foobar", 443, "path", [], **retry_delay_args) + backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider(), **retry_delay_args) for (arg, val) in retry_delay_args.items(): self.assertEqual(getattr(backend, arg), val) @@ -1188,7 +1192,7 @@ def test_retry_args_bounding(self, mock_http_client): k: v[i][0] for (k, v) in retry_delay_test_args_and_expected_values.items() } - backend = ThriftBackend("foobar", 443, "path", [], **retry_delay_args) + backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider(), **retry_delay_args) retry_delay_expected_vals = { k: v[i][1] for (k, v) in retry_delay_test_args_and_expected_values.items() @@ -1208,7 +1212,7 @@ def test_configuration_passthrough(self, tcli_client_class): "42": "42" } - backend = ThriftBackend("foobar", 443, "path", []) + backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) backend.open_session(mock_config, None, None) open_session_req = tcli_client_class.return_value.OpenSession.call_args[0][0] @@ -1219,7 +1223,7 @@ def test_cant_set_timestamp_as_string_to_true(self, tcli_client_class): tcli_service_instance = tcli_client_class.return_value tcli_service_instance.OpenSession.return_value = self.open_session_resp mock_config = {"spark.thriftserver.arrowBasedRowSet.timestampAsString": True} - backend = ThriftBackend("foobar", 443, "path", []) + backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) with self.assertRaises(databricks.sql.Error) as cm: backend.open_session(mock_config, None, None) @@ -1237,7 +1241,7 @@ def _construct_open_session_with_namespace(self, can_use_multiple_cats, cat, sch def test_initial_namespace_passthrough_to_open_session(self, tcli_client_class): tcli_service_instance = tcli_client_class.return_value - backend = ThriftBackend("foobar", 443, "path", []) + backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) initial_cat_schem_args = [("cat", None), (None, "schem"), ("cat", "schem")] for cat, schem in initial_cat_schem_args: @@ -1256,7 +1260,7 @@ def test_can_use_multiple_catalogs_is_set_in_open_session_req(self, tcli_client_ tcli_service_instance = tcli_client_class.return_value tcli_service_instance.OpenSession.return_value = self.open_session_resp - backend = ThriftBackend("foobar", 443, "path", []) + backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) backend.open_session({}, None, None) open_session_req = tcli_client_class.return_value.OpenSession.call_args[0][0] @@ -1266,7 +1270,7 @@ def test_can_use_multiple_catalogs_is_set_in_open_session_req(self, tcli_client_ def test_can_use_multiple_catalogs_is_false_fails_with_initial_catalog(self, tcli_client_class): tcli_service_instance = tcli_client_class.return_value - backend = ThriftBackend("foobar", 443, "path", []) + backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) # If the initial catalog is set, but server returns canUseMultipleCatalogs=False, we # expect failure. If the initial catalog isn't set, then canUseMultipleCatalogs=False # is fine @@ -1300,7 +1304,7 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): initialNamespace=ttypes.TNamespace(catalogName="cat", schemaName="schem") ) - backend = ThriftBackend("foobar", 443, "path", []) + backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) with self.assertRaises(InvalidServerResponseError) as cm: backend.open_session({}, "cat", "schem") @@ -1324,7 +1328,7 @@ def test_execute_command_sets_complex_type_fields_correctly(self, mock_handle_ex if decimals is not None: complex_arg_types["_use_arrow_native_decimals"] = decimals - thrift_backend = ThriftBackend("foobar", 443, "path", [], **complex_arg_types) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider(), **complex_arg_types) thrift_backend.execute_command(Mock(), Mock(), 100, 100, Mock()) t_execute_statement_req = tcli_service_instance.ExecuteStatement.call_args[0][0] From 0fc1684aeaac6a02c6f5778dcb07319c57023b22 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Mon, 29 Aug 2022 12:50:32 -0700 Subject: [PATCH 43/57] cleanup Signed-off-by: Moe Derakhshani --- poetry.lock | 237 +++++++++++++++++++++++++++++-- pyproject.toml | 1 + src/databricks/sql/auth/auth.py | 14 +- src/databricks/sql/auth/oauth.py | 2 - 4 files changed, 231 insertions(+), 23 deletions(-) diff --git a/poetry.lock b/poetry.lock index 49c722166..391bbcb57 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,3 +1,17 @@ +[[package]] +name = "astroid" +version = "2.11.7" +description = "An abstract syntax tree for Python with inference support." +category = "dev" +optional = false +python-versions = ">=3.6.2" + +[package.dependencies] +lazy-object-proxy = ">=1.4.0" +typed-ast = {version = ">=1.4.0,<2.0", markers = "implementation_name == \"cpython\" and python_version < \"3.8\""} +typing-extensions = {version = ">=3.10", markers = "python_version < \"3.10\""} +wrapt = ">=1.11,<2" + [[package]] name = "atomicwrites" version = "1.4.1" @@ -15,10 +29,10 @@ optional = false python-versions = ">=3.5" [package.extras] -dev = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "mypy (>=0.900,!=0.940)", "pytest-mypy-plugins", "zope.interface", "furo", "sphinx", "sphinx-notfound-page", "pre-commit", "cloudpickle"] -docs = ["furo", "sphinx", "zope.interface", "sphinx-notfound-page"] -tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "mypy (>=0.900,!=0.940)", "pytest-mypy-plugins", "zope.interface", "cloudpickle"] -tests_no_zope = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "mypy (>=0.900,!=0.940)", "pytest-mypy-plugins", "cloudpickle"] +dev = ["cloudpickle", "coverage[toml] (>=5.0.2)", "furo", "hypothesis", "mypy (>=0.900,!=0.940)", "pre-commit", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "sphinx", "sphinx-notfound-page", "zope.interface"] +docs = ["furo", "sphinx", "sphinx-notfound-page", "zope.interface"] +tests = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900,!=0.940)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "zope.interface"] +tests_no_zope = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900,!=0.940)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins"] [[package]] name = "black" @@ -82,6 +96,17 @@ category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +[[package]] +name = "dill" +version = "0.3.5.1" +description = "serialize all of python" +category = "dev" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*" + +[package.extras] +graph = ["objgraph (>=1.7.2)"] + [[package]] name = "idna" version = "3.3" @@ -103,9 +128,9 @@ typing-extensions = {version = ">=3.6.4", markers = "python_version < \"3.8\""} zipp = ">=0.5" [package.extras] -docs = ["sphinx", "jaraco.packaging (>=9)", "rst.linker (>=1.9)"] +docs = ["jaraco.packaging (>=9)", "rst.linker (>=1.9)", "sphinx"] perf = ["ipython"] -testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-cov", "pytest-enabler (>=1.3)", "packaging", "pyfakefs", "flufl.flake8", "pytest-perf (>=0.9.2)", "pytest-black (>=0.3.7)", "pytest-mypy (>=0.9.1)", "importlib-resources (>=1.3)"] +testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)"] [[package]] name = "iniconfig" @@ -115,6 +140,36 @@ category = "dev" optional = false python-versions = "*" +[[package]] +name = "isort" +version = "5.10.1" +description = "A Python utility / library to sort Python imports." +category = "dev" +optional = false +python-versions = ">=3.6.1,<4.0" + +[package.extras] +colors = ["colorama (>=0.4.3,<0.5.0)"] +pipfile_deprecated_finder = ["pipreqs", "requirementslib"] +plugins = ["setuptools"] +requirements_deprecated_finder = ["pip-api", "pipreqs"] + +[[package]] +name = "lazy-object-proxy" +version = "1.7.1" +description = "A fast and thorough lazy object proxy." +category = "dev" +optional = false +python-versions = ">=3.6" + +[[package]] +name = "mccabe" +version = "0.7.0" +description = "McCabe checker, plugin for flake8" +category = "dev" +optional = false +python-versions = ">=3.6" + [[package]] name = "mypy" version = "0.950" @@ -212,8 +267,8 @@ optional = false python-versions = ">=3.7" [package.extras] -docs = ["furo (>=2021.7.5b38)", "proselint (>=0.10.2)", "sphinx-autodoc-typehints (>=1.12)", "sphinx (>=4)"] -test = ["appdirs (==1.4.4)", "pytest-cov (>=2.7)", "pytest-mock (>=3.6)", "pytest (>=6)"] +docs = ["furo (>=2021.7.5b38)", "proselint (>=0.10.2)", "sphinx (>=4)", "sphinx-autodoc-typehints (>=1.12)"] +test = ["appdirs (==1.4.4)", "pytest (>=6)", "pytest-cov (>=2.7)", "pytest-mock (>=3.6)"] [[package]] name = "pluggy" @@ -249,6 +304,27 @@ python-versions = ">=3.7" [package.dependencies] numpy = ">=1.16.6" +[[package]] +name = "pylint" +version = "2.13.9" +description = "python code static checker" +category = "dev" +optional = false +python-versions = ">=3.6.2" + +[package.dependencies] +astroid = ">=2.11.5,<=2.12.0-dev0" +colorama = {version = "*", markers = "sys_platform == \"win32\""} +dill = ">=0.2" +isort = ">=4.2.5,<6" +mccabe = ">=0.6,<0.8" +platformdirs = ">=2.2.0" +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} +typing-extensions = {version = ">=3.10.0", markers = "python_version < \"3.10\""} + +[package.extras] +testutil = ["gitpython (>3)"] + [[package]] name = "pyparsing" version = "3.0.9" @@ -258,7 +334,7 @@ optional = false python-versions = ">=3.6.8" [package.extras] -diagrams = ["railroad-diagrams", "jinja2"] +diagrams = ["jinja2", "railroad-diagrams"] [[package]] name = "pytest" @@ -376,10 +452,18 @@ optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, <4" [package.extras] -brotli = ["brotlicffi (>=0.8.0)", "brotli (>=1.0.9)", "brotlipy (>=0.6.0)"] -secure = ["pyOpenSSL (>=0.14)", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "certifi", "urllib3-secure-extra", "ipaddress"] +brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotlipy (>=0.6.0)"] +secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"] socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] +[[package]] +name = "wrapt" +version = "1.14.1" +description = "Module for decorators, wrappers and monkey patching." +category = "dev" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7" + [[package]] name = "zipp" version = "3.8.1" @@ -389,15 +473,19 @@ optional = false python-versions = ">=3.7" [package.extras] -docs = ["sphinx", "jaraco.packaging (>=9)", "rst.linker (>=1.9)", "jaraco.tidelift (>=1.4)"] -testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-cov", "pytest-enabler (>=1.3)", "jaraco.itertools", "func-timeout", "pytest-black (>=0.3.7)", "pytest-mypy (>=0.9.1)"] +docs = ["jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx"] +testing = ["func-timeout", "jaraco.itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"] [metadata] lock-version = "1.1" python-versions = "^3.7.1" -content-hash = "8ac5d8721267bc81ca501a41697d1b3883f334ac69cb09b297bbe33215c4e204" +content-hash = "f283eca35466a0294e09deb8535da2633219db696ad8bbc74dffd4592b0d66ad" [metadata.files] +astroid = [ + {file = "astroid-2.11.7-py3-none-any.whl", hash = "sha256:86b0a340a512c65abf4368b80252754cda17c02cdbbd3f587dddf98112233e7b"}, + {file = "astroid-2.11.7.tar.gz", hash = "sha256:bb24615c77f4837c707669d16907331374ae8a964650a66999da3f5ca68dc946"}, +] atomicwrites = [] attrs = [ {file = "attrs-22.1.0-py2.py3-none-any.whl", hash = "sha256:86efa402f67bf2df34f51a335487cf46b1ec130d02b8d39fd248abfd30da551c"}, @@ -414,12 +502,63 @@ charset-normalizer = [ ] click = [] colorama = [] +dill = [ + {file = "dill-0.3.5.1-py2.py3-none-any.whl", hash = "sha256:33501d03270bbe410c72639b350e941882a8b0fd55357580fbc873fba0c59302"}, + {file = "dill-0.3.5.1.tar.gz", hash = "sha256:d75e41f3eff1eee599d738e76ba8f4ad98ea229db8b085318aa2b3333a208c86"}, +] idna = [ {file = "idna-3.3-py3-none-any.whl", hash = "sha256:84d9dd047ffa80596e0f246e2eab0b391788b0503584e8945f2368256d2735ff"}, {file = "idna-3.3.tar.gz", hash = "sha256:9d643ff0a55b762d5cdb124b8eaa99c66322e2157b69160bc32796e824360e6d"}, ] importlib-metadata = [] iniconfig = [] +isort = [ + {file = "isort-5.10.1-py3-none-any.whl", hash = "sha256:6f62d78e2f89b4500b080fe3a81690850cd254227f27f75c3a0c491a1f351ba7"}, + {file = "isort-5.10.1.tar.gz", hash = "sha256:e8443a5e7a020e9d7f97f1d7d9cd17c88bcb3bc7e218bf9cf5095fe550be2951"}, +] +lazy-object-proxy = [ + {file = "lazy-object-proxy-1.7.1.tar.gz", hash = "sha256:d609c75b986def706743cdebe5e47553f4a5a1da9c5ff66d76013ef396b5a8a4"}, + {file = "lazy_object_proxy-1.7.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bb8c5fd1684d60a9902c60ebe276da1f2281a318ca16c1d0a96db28f62e9166b"}, + {file = "lazy_object_proxy-1.7.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a57d51ed2997e97f3b8e3500c984db50a554bb5db56c50b5dab1b41339b37e36"}, + {file = "lazy_object_proxy-1.7.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd45683c3caddf83abbb1249b653a266e7069a09f486daa8863fb0e7496a9fdb"}, + {file = "lazy_object_proxy-1.7.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:8561da8b3dd22d696244d6d0d5330618c993a215070f473b699e00cf1f3f6443"}, + {file = "lazy_object_proxy-1.7.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:fccdf7c2c5821a8cbd0a9440a456f5050492f2270bd54e94360cac663398739b"}, + {file = "lazy_object_proxy-1.7.1-cp310-cp310-win32.whl", hash = "sha256:898322f8d078f2654d275124a8dd19b079080ae977033b713f677afcfc88e2b9"}, + {file = "lazy_object_proxy-1.7.1-cp310-cp310-win_amd64.whl", hash = "sha256:85b232e791f2229a4f55840ed54706110c80c0a210d076eee093f2b2e33e1bfd"}, + {file = "lazy_object_proxy-1.7.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:46ff647e76f106bb444b4533bb4153c7370cdf52efc62ccfc1a28bdb3cc95442"}, + {file = "lazy_object_proxy-1.7.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:12f3bb77efe1367b2515f8cb4790a11cffae889148ad33adad07b9b55e0ab22c"}, + {file = "lazy_object_proxy-1.7.1-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c19814163728941bb871240d45c4c30d33b8a2e85972c44d4e63dd7107faba44"}, + {file = "lazy_object_proxy-1.7.1-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:e40f2013d96d30217a51eeb1db28c9ac41e9d0ee915ef9d00da639c5b63f01a1"}, + {file = "lazy_object_proxy-1.7.1-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:2052837718516a94940867e16b1bb10edb069ab475c3ad84fd1e1a6dd2c0fcfc"}, + {file = "lazy_object_proxy-1.7.1-cp36-cp36m-win32.whl", hash = "sha256:6a24357267aa976abab660b1d47a34aaf07259a0c3859a34e536f1ee6e76b5bb"}, + {file = "lazy_object_proxy-1.7.1-cp36-cp36m-win_amd64.whl", hash = "sha256:6aff3fe5de0831867092e017cf67e2750c6a1c7d88d84d2481bd84a2e019ec35"}, + {file = "lazy_object_proxy-1.7.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:6a6e94c7b02641d1311228a102607ecd576f70734dc3d5e22610111aeacba8a0"}, + {file = "lazy_object_proxy-1.7.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c4ce15276a1a14549d7e81c243b887293904ad2d94ad767f42df91e75fd7b5b6"}, + {file = "lazy_object_proxy-1.7.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e368b7f7eac182a59ff1f81d5f3802161932a41dc1b1cc45c1f757dc876b5d2c"}, + {file = "lazy_object_proxy-1.7.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:6ecbb350991d6434e1388bee761ece3260e5228952b1f0c46ffc800eb313ff42"}, + {file = "lazy_object_proxy-1.7.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:553b0f0d8dbf21890dd66edd771f9b1b5f51bd912fa5f26de4449bfc5af5e029"}, + {file = "lazy_object_proxy-1.7.1-cp37-cp37m-win32.whl", hash = "sha256:c7a683c37a8a24f6428c28c561c80d5f4fd316ddcf0c7cab999b15ab3f5c5c69"}, + {file = "lazy_object_proxy-1.7.1-cp37-cp37m-win_amd64.whl", hash = "sha256:df2631f9d67259dc9620d831384ed7732a198eb434eadf69aea95ad18c587a28"}, + {file = "lazy_object_proxy-1.7.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:07fa44286cda977bd4803b656ffc1c9b7e3bc7dff7d34263446aec8f8c96f88a"}, + {file = "lazy_object_proxy-1.7.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4dca6244e4121c74cc20542c2ca39e5c4a5027c81d112bfb893cf0790f96f57e"}, + {file = "lazy_object_proxy-1.7.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:91ba172fc5b03978764d1df5144b4ba4ab13290d7bab7a50f12d8117f8630c38"}, + {file = "lazy_object_proxy-1.7.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:043651b6cb706eee4f91854da4a089816a6606c1428fd391573ef8cb642ae4f7"}, + {file = "lazy_object_proxy-1.7.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b9e89b87c707dd769c4ea91f7a31538888aad05c116a59820f28d59b3ebfe25a"}, + {file = "lazy_object_proxy-1.7.1-cp38-cp38-win32.whl", hash = "sha256:9d166602b525bf54ac994cf833c385bfcc341b364e3ee71e3bf5a1336e677b55"}, + {file = "lazy_object_proxy-1.7.1-cp38-cp38-win_amd64.whl", hash = "sha256:8f3953eb575b45480db6568306893f0bd9d8dfeeebd46812aa09ca9579595148"}, + {file = "lazy_object_proxy-1.7.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:dd7ed7429dbb6c494aa9bc4e09d94b778a3579be699f9d67da7e6804c422d3de"}, + {file = "lazy_object_proxy-1.7.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70ed0c2b380eb6248abdef3cd425fc52f0abd92d2b07ce26359fcbc399f636ad"}, + {file = "lazy_object_proxy-1.7.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7096a5e0c1115ec82641afbdd70451a144558ea5cf564a896294e346eb611be1"}, + {file = "lazy_object_proxy-1.7.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:f769457a639403073968d118bc70110e7dce294688009f5c24ab78800ae56dc8"}, + {file = "lazy_object_proxy-1.7.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:39b0e26725c5023757fc1ab2a89ef9d7ab23b84f9251e28f9cc114d5b59c1b09"}, + {file = "lazy_object_proxy-1.7.1-cp39-cp39-win32.whl", hash = "sha256:2130db8ed69a48a3440103d4a520b89d8a9405f1b06e2cc81640509e8bf6548f"}, + {file = "lazy_object_proxy-1.7.1-cp39-cp39-win_amd64.whl", hash = "sha256:677ea950bef409b47e51e733283544ac3d660b709cfce7b187f5ace137960d61"}, + {file = "lazy_object_proxy-1.7.1-pp37.pp38-none-any.whl", hash = "sha256:d66906d5785da8e0be7360912e99c9188b70f52c422f9fc18223347235691a84"}, +] +mccabe = [ + {file = "mccabe-0.7.0-py2.py3-none-any.whl", hash = "sha256:6c2d30ab6be0e4a46919781807b4f0d834ebdd6c6e3dca0bda5a15f863427b6e"}, + {file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"}, +] mypy = [ {file = "mypy-0.950-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:cf9c261958a769a3bd38c3e133801ebcd284ffb734ea12d01457cb09eacf7d7b"}, {file = "mypy-0.950-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b5b5bd0ffb11b4aba2bb6d31b8643902c48f990cc92fda4e21afac658044f0c0"}, @@ -511,6 +650,10 @@ pyarrow = [ {file = "pyarrow-9.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:fe2ce795fa1d95e4e940fe5661c3c58aee7181c730f65ac5dd8794a77228de59"}, {file = "pyarrow-9.0.0.tar.gz", hash = "sha256:7fb02bebc13ab55573d1ae9bb5002a6d20ba767bf8569b52fce5301d42495ab7"}, ] +pylint = [ + {file = "pylint-2.13.9-py3-none-any.whl", hash = "sha256:705c620d388035bdd9ff8b44c5bcdd235bfb49d276d488dd2c8ff1736aa42526"}, + {file = "pylint-2.13.9.tar.gz", hash = "sha256:095567c96e19e6f57b5b907e67d265ff535e588fe26b12b5ebe1fc5645b2c731"}, +] pyparsing = [] pytest = [] python-dateutil = [ @@ -534,4 +677,70 @@ urllib3 = [ {file = "urllib3-1.26.12-py2.py3-none-any.whl", hash = "sha256:b930dd878d5a8afb066a637fbb35144fe7901e3b209d1cd4f524bd0e9deee997"}, {file = "urllib3-1.26.12.tar.gz", hash = "sha256:3fa96cf423e6987997fc326ae8df396db2a8b7c667747d47ddd8ecba91f4a74e"}, ] +wrapt = [ + {file = "wrapt-1.14.1-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:1b376b3f4896e7930f1f772ac4b064ac12598d1c38d04907e696cc4d794b43d3"}, + {file = "wrapt-1.14.1-cp27-cp27m-manylinux1_i686.whl", hash = "sha256:903500616422a40a98a5a3c4ff4ed9d0066f3b4c951fa286018ecdf0750194ef"}, + {file = "wrapt-1.14.1-cp27-cp27m-manylinux1_x86_64.whl", hash = "sha256:5a9a0d155deafd9448baff28c08e150d9b24ff010e899311ddd63c45c2445e28"}, + {file = "wrapt-1.14.1-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:ddaea91abf8b0d13443f6dac52e89051a5063c7d014710dcb4d4abb2ff811a59"}, + {file = "wrapt-1.14.1-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:36f582d0c6bc99d5f39cd3ac2a9062e57f3cf606ade29a0a0d6b323462f4dd87"}, + {file = "wrapt-1.14.1-cp27-cp27mu-manylinux1_i686.whl", hash = "sha256:7ef58fb89674095bfc57c4069e95d7a31cfdc0939e2a579882ac7d55aadfd2a1"}, + {file = "wrapt-1.14.1-cp27-cp27mu-manylinux1_x86_64.whl", hash = "sha256:e2f83e18fe2f4c9e7db597e988f72712c0c3676d337d8b101f6758107c42425b"}, + {file = "wrapt-1.14.1-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:ee2b1b1769f6707a8a445162ea16dddf74285c3964f605877a20e38545c3c462"}, + {file = "wrapt-1.14.1-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:833b58d5d0b7e5b9832869f039203389ac7cbf01765639c7309fd50ef619e0b1"}, + {file = "wrapt-1.14.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:80bb5c256f1415f747011dc3604b59bc1f91c6e7150bd7db03b19170ee06b320"}, + {file = "wrapt-1.14.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:07f7a7d0f388028b2df1d916e94bbb40624c59b48ecc6cbc232546706fac74c2"}, + {file = "wrapt-1.14.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:02b41b633c6261feff8ddd8d11c711df6842aba629fdd3da10249a53211a72c4"}, + {file = "wrapt-1.14.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2fe803deacd09a233e4762a1adcea5db5d31e6be577a43352936179d14d90069"}, + {file = "wrapt-1.14.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:257fd78c513e0fb5cdbe058c27a0624c9884e735bbd131935fd49e9fe719d310"}, + {file = "wrapt-1.14.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:4fcc4649dc762cddacd193e6b55bc02edca674067f5f98166d7713b193932b7f"}, + {file = "wrapt-1.14.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:11871514607b15cfeb87c547a49bca19fde402f32e2b1c24a632506c0a756656"}, + {file = "wrapt-1.14.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8ad85f7f4e20964db4daadcab70b47ab05c7c1cf2a7c1e51087bfaa83831854c"}, + {file = "wrapt-1.14.1-cp310-cp310-win32.whl", hash = "sha256:a9a52172be0b5aae932bef82a79ec0a0ce87288c7d132946d645eba03f0ad8a8"}, + {file = "wrapt-1.14.1-cp310-cp310-win_amd64.whl", hash = "sha256:6d323e1554b3d22cfc03cd3243b5bb815a51f5249fdcbb86fda4bf62bab9e164"}, + {file = "wrapt-1.14.1-cp35-cp35m-manylinux1_i686.whl", hash = "sha256:43ca3bbbe97af00f49efb06e352eae40434ca9d915906f77def219b88e85d907"}, + {file = "wrapt-1.14.1-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:6b1a564e6cb69922c7fe3a678b9f9a3c54e72b469875aa8018f18b4d1dd1adf3"}, + {file = "wrapt-1.14.1-cp35-cp35m-manylinux2010_i686.whl", hash = "sha256:00b6d4ea20a906c0ca56d84f93065b398ab74b927a7a3dbd470f6fc503f95dc3"}, + {file = "wrapt-1.14.1-cp35-cp35m-manylinux2010_x86_64.whl", hash = "sha256:a85d2b46be66a71bedde836d9e41859879cc54a2a04fad1191eb50c2066f6e9d"}, + {file = "wrapt-1.14.1-cp35-cp35m-win32.whl", hash = "sha256:dbcda74c67263139358f4d188ae5faae95c30929281bc6866d00573783c422b7"}, + {file = "wrapt-1.14.1-cp35-cp35m-win_amd64.whl", hash = "sha256:b21bb4c09ffabfa0e85e3a6b623e19b80e7acd709b9f91452b8297ace2a8ab00"}, + {file = "wrapt-1.14.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:9e0fd32e0148dd5dea6af5fee42beb949098564cc23211a88d799e434255a1f4"}, + {file = "wrapt-1.14.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9736af4641846491aedb3c3f56b9bc5568d92b0692303b5a305301a95dfd38b1"}, + {file = "wrapt-1.14.1-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5b02d65b9ccf0ef6c34cba6cf5bf2aab1bb2f49c6090bafeecc9cd81ad4ea1c1"}, + {file = "wrapt-1.14.1-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21ac0156c4b089b330b7666db40feee30a5d52634cc4560e1905d6529a3897ff"}, + {file = "wrapt-1.14.1-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:9f3e6f9e05148ff90002b884fbc2a86bd303ae847e472f44ecc06c2cd2fcdb2d"}, + {file = "wrapt-1.14.1-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:6e743de5e9c3d1b7185870f480587b75b1cb604832e380d64f9504a0535912d1"}, + {file = "wrapt-1.14.1-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:d79d7d5dc8a32b7093e81e97dad755127ff77bcc899e845f41bf71747af0c569"}, + {file = "wrapt-1.14.1-cp36-cp36m-win32.whl", hash = "sha256:81b19725065dcb43df02b37e03278c011a09e49757287dca60c5aecdd5a0b8ed"}, + {file = "wrapt-1.14.1-cp36-cp36m-win_amd64.whl", hash = "sha256:b014c23646a467558be7da3d6b9fa409b2c567d2110599b7cf9a0c5992b3b471"}, + {file = "wrapt-1.14.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:88bd7b6bd70a5b6803c1abf6bca012f7ed963e58c68d76ee20b9d751c74a3248"}, + {file = "wrapt-1.14.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b5901a312f4d14c59918c221323068fad0540e34324925c8475263841dbdfe68"}, + {file = "wrapt-1.14.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d77c85fedff92cf788face9bfa3ebaa364448ebb1d765302e9af11bf449ca36d"}, + {file = "wrapt-1.14.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d649d616e5c6a678b26d15ece345354f7c2286acd6db868e65fcc5ff7c24a77"}, + {file = "wrapt-1.14.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:7d2872609603cb35ca513d7404a94d6d608fc13211563571117046c9d2bcc3d7"}, + {file = "wrapt-1.14.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:ee6acae74a2b91865910eef5e7de37dc6895ad96fa23603d1d27ea69df545015"}, + {file = "wrapt-1.14.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:2b39d38039a1fdad98c87279b48bc5dce2c0ca0d73483b12cb72aa9609278e8a"}, + {file = "wrapt-1.14.1-cp37-cp37m-win32.whl", hash = "sha256:60db23fa423575eeb65ea430cee741acb7c26a1365d103f7b0f6ec412b893853"}, + {file = "wrapt-1.14.1-cp37-cp37m-win_amd64.whl", hash = "sha256:709fe01086a55cf79d20f741f39325018f4df051ef39fe921b1ebe780a66184c"}, + {file = "wrapt-1.14.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:8c0ce1e99116d5ab21355d8ebe53d9460366704ea38ae4d9f6933188f327b456"}, + {file = "wrapt-1.14.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e3fb1677c720409d5f671e39bac6c9e0e422584e5f518bfd50aa4cbbea02433f"}, + {file = "wrapt-1.14.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:642c2e7a804fcf18c222e1060df25fc210b9c58db7c91416fb055897fc27e8cc"}, + {file = "wrapt-1.14.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7b7c050ae976e286906dd3f26009e117eb000fb2cf3533398c5ad9ccc86867b1"}, + {file = "wrapt-1.14.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ef3f72c9666bba2bab70d2a8b79f2c6d2c1a42a7f7e2b0ec83bb2f9e383950af"}, + {file = "wrapt-1.14.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:01c205616a89d09827986bc4e859bcabd64f5a0662a7fe95e0d359424e0e071b"}, + {file = "wrapt-1.14.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:5a0f54ce2c092aaf439813735584b9537cad479575a09892b8352fea5e988dc0"}, + {file = "wrapt-1.14.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:2cf71233a0ed05ccdabe209c606fe0bac7379fdcf687f39b944420d2a09fdb57"}, + {file = "wrapt-1.14.1-cp38-cp38-win32.whl", hash = "sha256:aa31fdcc33fef9eb2552cbcbfee7773d5a6792c137b359e82879c101e98584c5"}, + {file = "wrapt-1.14.1-cp38-cp38-win_amd64.whl", hash = "sha256:d1967f46ea8f2db647c786e78d8cc7e4313dbd1b0aca360592d8027b8508e24d"}, + {file = "wrapt-1.14.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3232822c7d98d23895ccc443bbdf57c7412c5a65996c30442ebe6ed3df335383"}, + {file = "wrapt-1.14.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:988635d122aaf2bdcef9e795435662bcd65b02f4f4c1ae37fbee7401c440b3a7"}, + {file = "wrapt-1.14.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9cca3c2cdadb362116235fdbd411735de4328c61425b0aa9f872fd76d02c4e86"}, + {file = "wrapt-1.14.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d52a25136894c63de15a35bc0bdc5adb4b0e173b9c0d07a2be9d3ca64a332735"}, + {file = "wrapt-1.14.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:40e7bc81c9e2b2734ea4bc1aceb8a8f0ceaac7c5299bc5d69e37c44d9081d43b"}, + {file = "wrapt-1.14.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b9b7a708dd92306328117d8c4b62e2194d00c365f18eff11a9b53c6f923b01e3"}, + {file = "wrapt-1.14.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:6a9a25751acb379b466ff6be78a315e2b439d4c94c1e99cb7266d40a537995d3"}, + {file = "wrapt-1.14.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:34aa51c45f28ba7f12accd624225e2b1e5a3a45206aa191f6f9aac931d9d56fe"}, + {file = "wrapt-1.14.1-cp39-cp39-win32.whl", hash = "sha256:dee0ce50c6a2dd9056c20db781e9c1cfd33e77d2d569f5d1d9321c641bb903d5"}, + {file = "wrapt-1.14.1-cp39-cp39-win_amd64.whl", hash = "sha256:dee60e1de1898bde3b238f18340eec6148986da0455d8ba7848d50470a7a32fb"}, + {file = "wrapt-1.14.1.tar.gz", hash = "sha256:380a85cf89e0e69b7cfbe2ea9f765f004ff419f34194018a6827ac0e3edfed4d"}, +] zipp = [] diff --git a/pyproject.toml b/pyproject.toml index 5b9d176e1..04428e552 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ oauthlib=">=3.1.0" [tool.poetry.dev-dependencies] pytest = "^7.1.2" mypy = "^0.950" +pylint = ">=2.12.0" black = "^22.3.0" [tool.poetry.urls] diff --git a/src/databricks/sql/auth/auth.py b/src/databricks/sql/auth/auth.py index aef434de5..5ac9e7f84 100644 --- a/src/databricks/sql/auth/auth.py +++ b/src/databricks/sql/auth/auth.py @@ -74,10 +74,10 @@ def get_auth_provider(cfg: ClientContext): raise RuntimeError("No valid authentication settings!") -OAUTH_SCOPES = ["sql", "offline_access"] -# TODO: moderakh to be changed once registered on the service side -OAUTH_CLIENT_ID = "databricks-sql-python" -OAUTH_REDIRECT_PORT_RANGE = range(8020, 8025) +PYSQL_OAUTH_SCOPES = ["sql", "offline_access"] +PYSQL_OAUTH_CLIENT_ID = "databricks-sql-python" +PYSQL_OAUTH_REDIRECT_PORT_RANGE = range(8020, 8025) + def get_python_sql_connector_auth_provider(hostname: str, oauth_persistence: OAuthPersistence = None, **kwargs): cfg = ClientContext(hostname=hostname, @@ -87,9 +87,9 @@ def get_python_sql_connector_auth_provider(hostname: str, oauth_persistence: OAu password=kwargs.get("_password"), use_cert_as_auth=kwargs.get("_use_cert_as_auth"), tls_client_cert_file=kwargs.get("_tls_client_cert_file"), - oauth_scopes=OAUTH_SCOPES, - oauth_client_id=OAUTH_CLIENT_ID, - oauth_redirect_port_range=OAUTH_REDIRECT_PORT_RANGE, + oauth_scopes=PYSQL_OAUTH_SCOPES, + oauth_client_id=PYSQL_OAUTH_CLIENT_ID, + oauth_redirect_port_range=PYSQL_OAUTH_REDIRECT_PORT_RANGE, oauth_persistence=oauth_persistence) return get_auth_provider(cfg) diff --git a/src/databricks/sql/auth/oauth.py b/src/databricks/sql/auth/oauth.py index 6dc630454..6ca4803d6 100644 --- a/src/databricks/sql/auth/oauth.py +++ b/src/databricks/sql/auth/oauth.py @@ -42,8 +42,6 @@ logger = logging.getLogger(__name__) -# Some constant values - class OAuthManager: OIDC_REDIRECTOR_PATH = "oidc" From 4fe5a58bac6572de2db780152da9cf610a2c7657 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Mon, 29 Aug 2022 13:00:58 -0700 Subject: [PATCH 44/57] Update src/databricks/sql/experimental/oauth_persistence.py Co-authored-by: Jesse --- src/databricks/sql/experimental/oauth_persistence.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/databricks/sql/experimental/oauth_persistence.py b/src/databricks/sql/experimental/oauth_persistence.py index ceccdde11..2d6bdaff0 100644 --- a/src/databricks/sql/experimental/oauth_persistence.py +++ b/src/databricks/sql/experimental/oauth_persistence.py @@ -7,8 +7,8 @@ class OAuthToken: def __init__(self, access_token, refresh_token): self._access_token = access_token self._refresh_token = refresh_token - - def get_access_token(self) -> str: + @property + def access_token(self) -> str: return self._access_token def get_refresh_token(self) -> str: From 8ab4e3f9143a4657a6940f775ceab28f9b60359a Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Mon, 29 Aug 2022 13:53:32 -0700 Subject: [PATCH 45/57] cleanup Signed-off-by: Moe Derakhshani --- src/databricks/sql/auth/__init__.py | 21 --------------- src/databricks/sql/auth/auth.py | 22 ---------------- src/databricks/sql/auth/authenticators.py | 25 ++---------------- src/databricks/sql/auth/oauth.py | 23 ---------------- src/databricks/sql/auth/thrift_http_client.py | 26 ++----------------- .../sql/experimental/oauth_persistence.py | 8 +++--- tests/unit/test_oauth_persistence.py | 4 +-- 7 files changed, 11 insertions(+), 118 deletions(-) diff --git a/src/databricks/sql/auth/__init__.py b/src/databricks/sql/auth/__init__.py index 2cb62b15f..e69de29bb 100644 --- a/src/databricks/sql/auth/__init__.py +++ b/src/databricks/sql/auth/__init__.py @@ -1,21 +0,0 @@ -# Copyright 2022 Databricks, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"), except -# that the use of services to which certain application programming -# interfaces (each, an "API") connect requires that the user first obtain -# a license for the use of the APIs from Databricks, Inc. ("Databricks"), -# by creating an account at www.databricks.com and agreeing to either (a) -# the Community Edition Terms of Service, (b) the Databricks Terms of -# Service, or (c) another written agreement between Licensee and Databricks -# for the use of the APIs. -# -# You may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/src/databricks/sql/auth/auth.py b/src/databricks/sql/auth/auth.py index 5ac9e7f84..9be9ec5b6 100644 --- a/src/databricks/sql/auth/auth.py +++ b/src/databricks/sql/auth/auth.py @@ -1,25 +1,3 @@ -# Copyright 2022 Databricks, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"), except -# that the use of services to which certain application programming -# interfaces (each, an "API") connect requires that the user first obtain -# a license for the use of the APIs from Databricks, Inc. ("Databricks"), -# by creating an account at www.databricks.com and agreeing to either (a) -# the Community Edition Terms of Service, (b) the Databricks Terms of -# Service, or (c) another written agreement between Licensee and Databricks -# for the use of the APIs. -# -# You may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - from typing import List from enum import Enum from databricks.sql.auth.authenticators import CredentialsProvider, \ diff --git a/src/databricks/sql/auth/authenticators.py b/src/databricks/sql/auth/authenticators.py index bba1f7914..3ecb33264 100644 --- a/src/databricks/sql/auth/authenticators.py +++ b/src/databricks/sql/auth/authenticators.py @@ -1,24 +1,3 @@ -# Copyright 2022 Databricks, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"), except -# that the use of services to which certain application programming -# interfaces (each, an "API") connect requires that the user first obtain -# a license for the use of the APIs from Databricks, Inc. ("Databricks"), -# by creating an account at www.databricks.com and agreeing to either (a) -# the Community Edition Terms of Service, (b) the Databricks Terms of -# Service, or (c) another written agreement between Licensee and Databricks -# for the use of the APIs. -# -# You may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. import logging from typing import Dict, List @@ -94,8 +73,8 @@ def _initial_get_token(self): if self._oauth_persistence: token = self._oauth_persistence.read() if token: - self._access_token = token.get_access_token() - self._refresh_token = token.get_refresh_token() + self._access_token = token.access_token + self._refresh_token = token.refresh_token if self._access_token and self._refresh_token: self._update_token_if_expired() diff --git a/src/databricks/sql/auth/oauth.py b/src/databricks/sql/auth/oauth.py index 6ca4803d6..2f17d3025 100644 --- a/src/databricks/sql/auth/oauth.py +++ b/src/databricks/sql/auth/oauth.py @@ -1,26 +1,3 @@ -# Copyright 2022 Databricks, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"), except -# that the use of services to which certain application programming -# interfaces (each, an "API") connect requires that the user first obtain -# a license for the use of the APIs from Databricks, Inc. ("Databricks"), -# by creating an account at www.databricks.com and agreeing to either (a) -# the Community Edition Terms of Service, (b) the Databricks Terms of -# Service, or (c) another written agreement between Licensee and Databricks -# for the use of the APIs. -# -# You may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - import base64 import hashlib import webbrowser diff --git a/src/databricks/sql/auth/thrift_http_client.py b/src/databricks/sql/auth/thrift_http_client.py index daf0fec49..696897230 100644 --- a/src/databricks/sql/auth/thrift_http_client.py +++ b/src/databricks/sql/auth/thrift_http_client.py @@ -1,27 +1,6 @@ -# Copyright 2022 Databricks, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"), except -# that the use of services to which certain application programming -# interfaces (each, an "API") connect requires that the user first obtain -# a license for the use of the APIs from Databricks, Inc. ("Databricks"), -# by creating an account at www.databricks.com and agreeing to either (a) -# the Community Edition Terms of Service, (b) the Databricks Terms of -# Service, or (c) another written agreement between Licensee and Databricks -# for the use of the APIs. -# -# You may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - import logging +from typing import Dict, List import thrift.transport.THttpClient logger = logging.getLogger(__name__) @@ -33,12 +12,11 @@ def __init__(self, auth_provider, uri_or_host, port=None, path=None, cafile=None super().__init__(uri_or_host, port, path, cafile, cert_file, key_file, ssl_context) self.__auth_provider = auth_provider - def setCustomHeaders(self, headers): + def setCustomHeaders(self, headers: Dict[str, str]): self._headers = headers super().setCustomHeaders(headers) def flush(self): - # TODO retry behaviour headers = dict(self._headers) self.__auth_provider.add_headers(headers) self._headers = headers diff --git a/src/databricks/sql/experimental/oauth_persistence.py b/src/databricks/sql/experimental/oauth_persistence.py index 2d6bdaff0..f2dd5c429 100644 --- a/src/databricks/sql/experimental/oauth_persistence.py +++ b/src/databricks/sql/experimental/oauth_persistence.py @@ -7,11 +7,13 @@ class OAuthToken: def __init__(self, access_token, refresh_token): self._access_token = access_token self._refresh_token = refresh_token + @property def access_token(self) -> str: return self._access_token - def get_refresh_token(self) -> str: + @property + def refresh_token(self) -> str: return self._refresh_token @@ -34,8 +36,8 @@ def persist(self, token: OAuthToken): # Data to be written dictionary = { - "refresh_token": token.get_refresh_token(), - "access_token": token.get_access_token() + "refresh_token": token.refresh_token, + "access_token": token.access_token } # Serializing json diff --git a/tests/unit/test_oauth_persistence.py b/tests/unit/test_oauth_persistence.py index bd44654cc..01c199033 100644 --- a/tests/unit/test_oauth_persistence.py +++ b/tests/unit/test_oauth_persistence.py @@ -41,8 +41,8 @@ def test_DevOnlyFilePersistence_read_my_write(self): persistence_manager.persist(token) new_token = persistence_manager.read() - self.assertEqual(new_token.get_access_token(), access_token) - self.assertEqual(new_token.get_refresh_token(), refresh_token) + self.assertEqual(new_token.access_token, access_token) + self.assertEqual(new_token.refresh_token, refresh_token) def test_DevOnlyFilePersistence_file_does_not_exist(self): with tempfile.TemporaryDirectory() as tempdir: From fb012b0e9a2d9724f4f2241df70460af38c7c0d9 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Mon, 29 Aug 2022 14:14:04 -0700 Subject: [PATCH 46/57] cleanup Signed-off-by: Moe Derakhshani --- src/databricks/sql/auth/auth.py | 13 ++++++---- src/databricks/sql/auth/authenticators.py | 7 +++--- src/databricks/sql/auth/oauth.py | 9 ++++--- src/databricks/sql/auth/oauth_http_handler.py | 4 ++-- src/databricks/sql/auth/thrift_http_client.py | 7 +++--- tests/unit/test_auth.py | 24 +------------------ 6 files changed, 22 insertions(+), 42 deletions(-) diff --git a/src/databricks/sql/auth/auth.py b/src/databricks/sql/auth/auth.py index 9be9ec5b6..b61835a0f 100644 --- a/src/databricks/sql/auth/auth.py +++ b/src/databricks/sql/auth/auth.py @@ -1,5 +1,6 @@ -from typing import List from enum import Enum +from typing import List + from databricks.sql.auth.authenticators import CredentialsProvider, \ AccessTokenAuthProvider, BasicAuthProvider, DatabricksOAuthProvider from databricks.sql.experimental.oauth_persistence import OAuthPersistence @@ -40,7 +41,11 @@ def __init__(self, def get_auth_provider(cfg: ClientContext): if cfg.auth_type == AuthType.DATABRICKS_OAUTH.value: - return DatabricksOAuthProvider(cfg.hostname, cfg.oauth_persistence, cfg.oauth_redirect_port_range, cfg.oauth_client_id, cfg.oauth_scopes) + return DatabricksOAuthProvider(cfg.hostname, + cfg.oauth_persistence, + cfg.oauth_redirect_port_range, + cfg.oauth_client_id, + cfg.oauth_scopes) elif cfg.access_token is not None: return AccessTokenAuthProvider(cfg.access_token) elif cfg.username is not None and cfg.password is not None: @@ -54,7 +59,7 @@ def get_auth_provider(cfg: ClientContext): PYSQL_OAUTH_SCOPES = ["sql", "offline_access"] PYSQL_OAUTH_CLIENT_ID = "databricks-sql-python" -PYSQL_OAUTH_REDIRECT_PORT_RANGE = range(8020, 8025) +PYSQL_OAUTH_REDIRECT_PORT_RANGE = list(range(8020, 8025)) def get_python_sql_connector_auth_provider(hostname: str, oauth_persistence: OAuthPersistence = None, **kwargs): @@ -70,5 +75,3 @@ def get_python_sql_connector_auth_provider(hostname: str, oauth_persistence: OAu oauth_redirect_port_range=PYSQL_OAUTH_REDIRECT_PORT_RANGE, oauth_persistence=oauth_persistence) return get_auth_provider(cfg) - - diff --git a/src/databricks/sql/auth/authenticators.py b/src/databricks/sql/auth/authenticators.py index 3ecb33264..eac62fddd 100644 --- a/src/databricks/sql/auth/authenticators.py +++ b/src/databricks/sql/auth/authenticators.py @@ -1,10 +1,8 @@ +import base64 import logging from typing import Dict, List from databricks.sql.auth.oauth import OAuthManager -import base64 - - # Private API: this is an evolving interface and it will change in the future. # Please must not depend on it in your applications. from databricks.sql.experimental.oauth_persistence import OAuthToken, OAuthPersistence @@ -43,7 +41,8 @@ def add_headers(self, request_headers: Dict[str, str]): class DatabricksOAuthProvider(CredentialsProvider): SCOPE_DELIM = ' ' - def __init__(self, hostname: str, oauth_persistence: OAuthPersistence, redirect_port_range: List[int], client_id: str, scopes: List[str]): + def __init__(self, hostname: str, oauth_persistence: OAuthPersistence, redirect_port_range: List[int], + client_id: str, scopes: List[str]): try: self.oauth_manager = OAuthManager(port_range=redirect_port_range, client_id=client_id) self._hostname = self._normalize_host_name(hostname=hostname) diff --git a/src/databricks/sql/auth/oauth.py b/src/databricks/sql/auth/oauth.py index 2f17d3025..570494c5e 100644 --- a/src/databricks/sql/auth/oauth.py +++ b/src/databricks/sql/auth/oauth.py @@ -1,17 +1,16 @@ import base64 import hashlib -import webbrowser import json +import logging import secrets +import webbrowser from datetime import datetime, timezone from http.server import HTTPServer -import logging +from typing import List import oauthlib.oauth2 -from oauthlib.oauth2.rfc6749.errors import OAuth2Error -from typing import Dict, List - import requests +from oauthlib.oauth2.rfc6749.errors import OAuth2Error from requests.exceptions import RequestException from databricks.sql.auth.oauth_http_handler import OAuthHttpSingleRequestHandler diff --git a/src/databricks/sql/auth/oauth_http_handler.py b/src/databricks/sql/auth/oauth_http_handler.py index a3f6cc161..a6cb1e1ec 100644 --- a/src/databricks/sql/auth/oauth_http_handler.py +++ b/src/databricks/sql/auth/oauth_http_handler.py @@ -37,6 +37,6 @@ def do_GET(self): # nopep8 self.request_path = self.path def log_message(self, format, *args): - #pylint: disable=redefined-builtin - #pylint: disable=unused-argument + # pylint: disable=redefined-builtin + # pylint: disable=unused-argument return diff --git a/src/databricks/sql/auth/thrift_http_client.py b/src/databricks/sql/auth/thrift_http_client.py index 696897230..afffab5f7 100644 --- a/src/databricks/sql/auth/thrift_http_client.py +++ b/src/databricks/sql/auth/thrift_http_client.py @@ -1,6 +1,6 @@ import logging +from typing import Dict -from typing import Dict, List import thrift.transport.THttpClient logger = logging.getLogger(__name__) @@ -8,7 +8,8 @@ class THttpClient(thrift.transport.THttpClient.THttpClient): - def __init__(self, auth_provider, uri_or_host, port=None, path=None, cafile=None, cert_file=None, key_file=None, ssl_context=None): + def __init__(self, auth_provider, uri_or_host, port=None, path=None, cafile=None, cert_file=None, key_file=None, + ssl_context=None): super().__init__(uri_or_host, port, path, cafile, cert_file, key_file, ssl_context) self.__auth_provider = auth_provider @@ -21,4 +22,4 @@ def flush(self): self.__auth_provider.add_headers(headers) self._headers = headers self.setCustomHeaders(self._headers) - super().flush() \ No newline at end of file + super().flush() diff --git a/tests/unit/test_auth.py b/tests/unit/test_auth.py index a39e97995..22b4336bf 100644 --- a/tests/unit/test_auth.py +++ b/tests/unit/test_auth.py @@ -1,25 +1,3 @@ -# Copyright 2022 Databricks, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"), except -# that the use of services to which certain application programming -# interfaces (each, an "API") connect requires that the user first obtain -# a license for the use of the APIs from Databricks, Inc. ("Databricks"), -# by creating an account at www.databricks.com and agreeing to either (a) -# the Community Edition Terms of Service, (b) the Databricks Terms of -# Service, or (c) another written agreement between Licensee and Databricks -# for the use of the APIs. -# -# You may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - import unittest from databricks.sql.auth.auth import AccessTokenAuthProvider, BasicAuthProvider, CredentialsProvider @@ -85,6 +63,6 @@ def test_get_python_sql_connector_auth_provider_noop(self): tls_client_cert_file = "fake.cert" use_cert_as_auth = "abc" hostname = "moderakh-test.cloud.databricks.com" - kwargs = {'_tls_client_cert_file': tls_client_cert_file, '_use_cert_as_auth': use_cert_as_auth } + kwargs = {'_tls_client_cert_file': tls_client_cert_file, '_use_cert_as_auth': use_cert_as_auth} auth_provider = get_python_sql_connector_auth_provider(hostname, **kwargs) self.assertTrue(type(auth_provider).__name__, "CredentialProvider") From 68812262ac4dc220b063975072516cf0503e7a4b Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Mon, 29 Aug 2022 15:44:38 -0700 Subject: [PATCH 47/57] cleanup Signed-off-by: Moe Derakhshani --- .github/workflows/code-quality-checks.yml | 2 +- src/databricks/sql/auth/auth.py | 4 ++++ src/databricks/sql/auth/oauth.py | 1 + src/databricks/sql/experimental/oauth_persistence.py | 6 ++++-- 4 files changed, 10 insertions(+), 3 deletions(-) diff --git a/.github/workflows/code-quality-checks.yml b/.github/workflows/code-quality-checks.yml index b6462322a..6648242c1 100644 --- a/.github/workflows/code-quality-checks.yml +++ b/.github/workflows/code-quality-checks.yml @@ -154,4 +154,4 @@ jobs: # black the code #---------------------------------------------- - name: Mypy - run: poetry run mypy src + run: poetry run mypy --install-types --non-interactive src diff --git a/src/databricks/sql/auth/auth.py b/src/databricks/sql/auth/auth.py index b61835a0f..1ff6e3201 100644 --- a/src/databricks/sql/auth/auth.py +++ b/src/databricks/sql/auth/auth.py @@ -41,6 +41,10 @@ def __init__(self, def get_auth_provider(cfg: ClientContext): if cfg.auth_type == AuthType.DATABRICKS_OAUTH.value: + assert cfg.oauth_redirect_port_range is not None + assert cfg.oauth_client_id is not None + assert cfg.oauth_scopes is not None + return DatabricksOAuthProvider(cfg.hostname, cfg.oauth_persistence, cfg.oauth_redirect_port_range, diff --git a/src/databricks/sql/auth/oauth.py b/src/databricks/sql/auth/oauth.py index 570494c5e..1d2c113a0 100644 --- a/src/databricks/sql/auth/oauth.py +++ b/src/databricks/sql/auth/oauth.py @@ -203,6 +203,7 @@ def get_tokens(self, hostname: str, scope=None): logger.error(msg) raise e + assert self.redirect_port is not None redirect_url = OAuthManager.__get_redirect_url(self.redirect_port) token_request_url = oauth_config["token_endpoint"] diff --git a/src/databricks/sql/experimental/oauth_persistence.py b/src/databricks/sql/experimental/oauth_persistence.py index f2dd5c429..ea94098c8 100644 --- a/src/databricks/sql/experimental/oauth_persistence.py +++ b/src/databricks/sql/experimental/oauth_persistence.py @@ -1,5 +1,7 @@ import logging import json +from typing import Optional + logger = logging.getLogger(__name__) @@ -21,7 +23,7 @@ class OAuthPersistence: def persist(self, oauth_token: OAuthToken): pass - def read(self) -> OAuthToken: + def read(self) -> Optional[OAuthToken]: pass @@ -46,7 +48,7 @@ def persist(self, token: OAuthToken): with open(self._file_path, "w") as outfile: outfile.write(json_object) - def read(self) -> OAuthToken: + def read(self) -> Optional[OAuthToken]: # TODO: validate the try: with open(self._file_path, "r") as infile: From 3e537d57b0b1b30d4034eb00dca1d3c2edb448b8 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Mon, 29 Aug 2022 15:54:46 -0700 Subject: [PATCH 48/57] cleanup Signed-off-by: Moe Derakhshani --- src/databricks/sql/__init__.py | 5 +- src/databricks/sql/auth/auth.py | 75 ++++++++------- src/databricks/sql/auth/authenticators.py | 42 +++++--- src/databricks/sql/auth/oauth.py | 96 ++++++++++++------- src/databricks/sql/auth/oauth_http_handler.py | 4 +- src/databricks/sql/auth/thrift_http_client.py | 18 +++- src/databricks/sql/client.py | 5 +- .../sql/experimental/oauth_persistence.py | 7 +- src/databricks/sql/thrift_backend.py | 8 +- 9 files changed, 171 insertions(+), 89 deletions(-) diff --git a/src/databricks/sql/__init__.py b/src/databricks/sql/__init__.py index e457a4358..ffc99f4ba 100644 --- a/src/databricks/sql/__init__.py +++ b/src/databricks/sql/__init__.py @@ -46,4 +46,7 @@ def TimestampFromTicks(ticks): def connect(server_hostname, http_path, experimental_oauth_persistence=None, **kwargs): from .client import Connection - return Connection(server_hostname, http_path, experimental_oauth_persistence, **kwargs) + + return Connection( + server_hostname, http_path, experimental_oauth_persistence, **kwargs + ) diff --git a/src/databricks/sql/auth/auth.py b/src/databricks/sql/auth/auth.py index 1ff6e3201..af61f7ac2 100644 --- a/src/databricks/sql/auth/auth.py +++ b/src/databricks/sql/auth/auth.py @@ -1,8 +1,12 @@ from enum import Enum from typing import List -from databricks.sql.auth.authenticators import CredentialsProvider, \ - AccessTokenAuthProvider, BasicAuthProvider, DatabricksOAuthProvider +from databricks.sql.auth.authenticators import ( + CredentialsProvider, + AccessTokenAuthProvider, + BasicAuthProvider, + DatabricksOAuthProvider, +) from databricks.sql.experimental.oauth_persistence import OAuthPersistence @@ -13,19 +17,20 @@ class AuthType(Enum): class ClientContext: - def __init__(self, - hostname: str, - username: str = None, - password: str = None, - access_token: str = None, - auth_type: str = None, - oauth_scopes: List[str] = None, - oauth_client_id: str = None, - oauth_redirect_port_range: List[int] = None, - use_cert_as_auth: str = None, - tls_client_cert_file: str = None, - oauth_persistence=None - ): + def __init__( + self, + hostname: str, + username: str = None, + password: str = None, + access_token: str = None, + auth_type: str = None, + oauth_scopes: List[str] = None, + oauth_client_id: str = None, + oauth_redirect_port_range: List[int] = None, + use_cert_as_auth: str = None, + tls_client_cert_file: str = None, + oauth_persistence=None, + ): self.hostname = hostname self.username = username self.password = password @@ -45,11 +50,13 @@ def get_auth_provider(cfg: ClientContext): assert cfg.oauth_client_id is not None assert cfg.oauth_scopes is not None - return DatabricksOAuthProvider(cfg.hostname, - cfg.oauth_persistence, - cfg.oauth_redirect_port_range, - cfg.oauth_client_id, - cfg.oauth_scopes) + return DatabricksOAuthProvider( + cfg.hostname, + cfg.oauth_persistence, + cfg.oauth_redirect_port_range, + cfg.oauth_client_id, + cfg.oauth_scopes, + ) elif cfg.access_token is not None: return AccessTokenAuthProvider(cfg.access_token) elif cfg.username is not None and cfg.password is not None: @@ -66,16 +73,20 @@ def get_auth_provider(cfg: ClientContext): PYSQL_OAUTH_REDIRECT_PORT_RANGE = list(range(8020, 8025)) -def get_python_sql_connector_auth_provider(hostname: str, oauth_persistence: OAuthPersistence = None, **kwargs): - cfg = ClientContext(hostname=hostname, - auth_type=kwargs.get("auth_type"), - access_token=kwargs.get("access_token"), - username=kwargs.get("_username"), - password=kwargs.get("_password"), - use_cert_as_auth=kwargs.get("_use_cert_as_auth"), - tls_client_cert_file=kwargs.get("_tls_client_cert_file"), - oauth_scopes=PYSQL_OAUTH_SCOPES, - oauth_client_id=PYSQL_OAUTH_CLIENT_ID, - oauth_redirect_port_range=PYSQL_OAUTH_REDIRECT_PORT_RANGE, - oauth_persistence=oauth_persistence) +def get_python_sql_connector_auth_provider( + hostname: str, oauth_persistence: OAuthPersistence = None, **kwargs +): + cfg = ClientContext( + hostname=hostname, + auth_type=kwargs.get("auth_type"), + access_token=kwargs.get("access_token"), + username=kwargs.get("_username"), + password=kwargs.get("_password"), + use_cert_as_auth=kwargs.get("_use_cert_as_auth"), + tls_client_cert_file=kwargs.get("_tls_client_cert_file"), + oauth_scopes=PYSQL_OAUTH_SCOPES, + oauth_client_id=PYSQL_OAUTH_CLIENT_ID, + oauth_redirect_port_range=PYSQL_OAUTH_REDIRECT_PORT_RANGE, + oauth_persistence=oauth_persistence, + ) return get_auth_provider(cfg) diff --git a/src/databricks/sql/auth/authenticators.py b/src/databricks/sql/auth/authenticators.py index eac62fddd..02bf5816a 100644 --- a/src/databricks/sql/auth/authenticators.py +++ b/src/databricks/sql/auth/authenticators.py @@ -3,6 +3,7 @@ from typing import Dict, List from databricks.sql.auth.oauth import OAuthManager + # Private API: this is an evolving interface and it will change in the future. # Please must not depend on it in your applications. from databricks.sql.experimental.oauth_persistence import OAuthToken, OAuthPersistence @@ -20,7 +21,7 @@ def __init__(self, access_token: str): self.__authorization_header_value = "Bearer {}".format(access_token) def add_headers(self, request_headers: Dict[str, str]): - request_headers['Authorization'] = self.__authorization_header_value + request_headers["Authorization"] = self.__authorization_header_value # Private API: this is an evolving interface and it will change in the future. @@ -28,23 +29,33 @@ def add_headers(self, request_headers: Dict[str, str]): class BasicAuthProvider(CredentialsProvider): def __init__(self, username: str, password: str): auth_credentials = f"{username}:{password}".encode("UTF-8") - auth_credentials_base64 = base64.standard_b64encode(auth_credentials).decode("UTF-8") + auth_credentials_base64 = base64.standard_b64encode(auth_credentials).decode( + "UTF-8" + ) self.__authorization_header_value = f"Basic {auth_credentials_base64}" def add_headers(self, request_headers: Dict[str, str]): - request_headers['Authorization'] = self.__authorization_header_value + request_headers["Authorization"] = self.__authorization_header_value # Private API: this is an evolving interface and it will change in the future. # Please must not depend on it in your applications. class DatabricksOAuthProvider(CredentialsProvider): - SCOPE_DELIM = ' ' - - def __init__(self, hostname: str, oauth_persistence: OAuthPersistence, redirect_port_range: List[int], - client_id: str, scopes: List[str]): + SCOPE_DELIM = " " + + def __init__( + self, + hostname: str, + oauth_persistence: OAuthPersistence, + redirect_port_range: List[int], + client_id: str, + scopes: List[str], + ): try: - self.oauth_manager = OAuthManager(port_range=redirect_port_range, client_id=client_id) + self.oauth_manager = OAuthManager( + port_range=redirect_port_range, client_id=client_id + ) self._hostname = self._normalize_host_name(hostname=hostname) self._scopes_as_str = DatabricksOAuthProvider.SCOPE_DELIM.join(scopes) self._oauth_persistence = oauth_persistence @@ -58,7 +69,7 @@ def __init__(self, hostname: str, oauth_persistence: OAuthPersistence, redirect_ def add_headers(self, request_headers: Dict[str, str]): self._update_token_if_expired() - request_headers['Authorization'] = f"Bearer {self._access_token}" + request_headers["Authorization"] = f"Bearer {self._access_token}" @staticmethod def _normalize_host_name(hostname: str): @@ -79,8 +90,8 @@ def _initial_get_token(self): self._update_token_if_expired() else: (access_token, refresh_token) = self.oauth_manager.get_tokens( - hostname=self._hostname, - scope=self._scopes_as_str) + hostname=self._hostname, scope=self._scopes_as_str + ) self._access_token = access_token self._refresh_token = refresh_token self._oauth_persistence.persist(OAuthToken(access_token, refresh_token)) @@ -90,10 +101,15 @@ def _initial_get_token(self): def _update_token_if_expired(self): try: - (fresh_access_token, fresh_refresh_token, is_refreshed) = self.oauth_manager.check_and_refresh_access_token( + ( + fresh_access_token, + fresh_refresh_token, + is_refreshed, + ) = self.oauth_manager.check_and_refresh_access_token( hostname=self._hostname, access_token=self._access_token, - refresh_token=self._refresh_token) + refresh_token=self._refresh_token, + ) if not is_refreshed: return else: diff --git a/src/databricks/sql/auth/oauth.py b/src/databricks/sql/auth/oauth.py index 1d2c113a0..0f49aa88f 100644 --- a/src/databricks/sql/auth/oauth.py +++ b/src/databricks/sql/auth/oauth.py @@ -40,24 +40,29 @@ def __fetch_well_known_config(idp_url: str): try: response = requests.get(url=known_config_url) except RequestException as e: - logger.error(f"Unable to fetch OAuth configuration from {idp_url}.\n" - "Verify it is a valid workspace URL and that OAuth is " - "enabled on this account.") + logger.error( + f"Unable to fetch OAuth configuration from {idp_url}.\n" + "Verify it is a valid workspace URL and that OAuth is " + "enabled on this account." + ) raise e if response.status_code != 200: - msg = (f"Received status {response.status_code} OAuth configuration from " - f"{idp_url}.\n Verify it is a valid workspace URL and " - "that OAuth is enabled on this account." - ) + msg = ( + f"Received status {response.status_code} OAuth configuration from " + f"{idp_url}.\n Verify it is a valid workspace URL and " + "that OAuth is enabled on this account." + ) logger.error(msg) raise RuntimeError(msg) try: return response.json() except requests.exceptions.JSONDecodeError as e: - logger.error(f"Unable to decode OAuth configuration from {idp_url}.\n" - "Verify it is a valid workspace URL and that OAuth is " - "enabled on this account.") + logger.error( + f"Unable to decode OAuth configuration from {idp_url}.\n" + "Verify it is a valid workspace URL and that OAuth is " + "enabled on this account." + ) raise e @staticmethod @@ -70,7 +75,9 @@ def __get_idp_url(host: str): def __get_challenge(): verifier_string = OAuthManager.__token_urlsafe(32) digest = hashlib.sha256(verifier_string.encode("UTF-8")).digest() - challenge_string = base64.urlsafe_b64encode(digest).decode("UTF-8").replace("=", "") + challenge_string = ( + base64.urlsafe_b64encode(digest).decode("UTF-8").replace("=", "") + ) return verifier_string, challenge_string def __get_authorization_code(self, client, auth_url, scope, state, challenge): @@ -87,11 +94,14 @@ def __get_authorization_code(self, client, auth_url, scope, state, challenge): scope=scope, state=state, code_challenge=challenge, - code_challenge_method="S256") + code_challenge_method="S256", + ) logger.info(f"Opening {auth_req_uri}") webbrowser.open_new(auth_req_uri) - logger.info(f"Listening for OAuth authorization callback at {redirect_url}") + logger.info( + f"Listening for OAuth authorization callback at {redirect_url}" + ) httpd.handle_request() self.redirect_port = port break @@ -102,7 +112,9 @@ def __get_authorization_code(self, client, auth_url, scope, state, challenge): except Exception as e: logger.error("unexpected error", e) if self.redirect_port is None: - logger.error(f"Tried all the ports {self.port_range} for oauth redirect, but can't find free port") + logger.error( + f"Tried all the ports {self.port_range} for oauth redirect, but can't find free port" + ) raise last_error if not handler.request_path: @@ -111,17 +123,24 @@ def __get_authorization_code(self, client, auth_url, scope, state, challenge): raise RuntimeError(msg) # This is a kludge because the parsing library expects https callbacks # We should probably set it up using https - full_redirect_url = f"https://localhost:{self.redirect_port}/{handler.request_path}" + full_redirect_url = ( + f"https://localhost:{self.redirect_port}/{handler.request_path}" + ) try: - authorization_code_response = \ - client.parse_request_uri_response(full_redirect_url, state=state) + authorization_code_response = client.parse_request_uri_response( + full_redirect_url, state=state + ) except OAuth2Error as e: logger.error(f"OAuth Token Request error {e.description}") raise e return authorization_code_response - def __send_auth_code_token_request(self, client, token_request_url, redirect_url, code, verifier): - token_request_body = client.prepare_request_body(code=code, redirect_uri=redirect_url) + def __send_auth_code_token_request( + self, client, token_request_url, redirect_url, code, verifier + ): + token_request_body = client.prepare_request_body( + code=code, redirect_uri=redirect_url + ) data = f"{token_request_body}&code_verifier={verifier}" return self.__send_token_request(token_request_url, data) @@ -129,7 +148,7 @@ def __send_auth_code_token_request(self, client, token_request_url, redirect_url def __send_token_request(token_request_url, data): headers = { "Accept": "application/json", - "Content-Type": "application/x-www-form-urlencoded" + "Content-Type": "application/x-www-form-urlencoded", } response = requests.post(url=token_request_url, data=data, headers=headers) return response.json() @@ -140,16 +159,23 @@ def __send_refresh_token_request(self, hostname, refresh_token): token_request_url = oauth_config["token_endpoint"] client = oauthlib.oauth2.WebApplicationClient(self.client_id) token_request_body = client.prepare_refresh_body( - refresh_token=refresh_token, client_id=client.client_id) + refresh_token=refresh_token, client_id=client.client_id + ) return OAuthManager.__send_token_request(token_request_url, token_request_body) @staticmethod def __get_tokens_from_response(oauth_response): access_token = oauth_response["access_token"] - refresh_token = oauth_response["refresh_token"] if "refresh_token" in oauth_response else None + refresh_token = ( + oauth_response["refresh_token"] + if "refresh_token" in oauth_response + else None + ) return access_token, refresh_token - def check_and_refresh_access_token(self, hostname: str, access_token: str, refresh_token: str): + def check_and_refresh_access_token( + self, hostname: str, access_token: str, refresh_token: str + ): now = datetime.now(tz=timezone.utc) # If we can't decode an expiration time, this will be expired by default. expiration_time = now @@ -160,7 +186,9 @@ def check_and_refresh_access_token(self, hostname: str, access_token: str, refre # an unnecessary signature verification. access_token_payload = access_token.split(".")[1] # add padding - access_token_payload = access_token_payload + '=' * (-len(access_token_payload) % 4) + access_token_payload = access_token_payload + "=" * ( + -len(access_token_payload) % 4 + ) decoded = json.loads(base64.standard_b64decode(access_token_payload)) expiration_time = datetime.fromtimestamp(decoded["exp"], tz=timezone.utc) except Exception as e: @@ -177,9 +205,13 @@ def check_and_refresh_access_token(self, hostname: str, access_token: str, refre raise RuntimeError(msg) # Try to refresh using the refresh token - logger.debug(f"Attempting to refresh OAuth access token that expired on {expiration_time}") + logger.debug( + f"Attempting to refresh OAuth access token that expired on {expiration_time}" + ) oauth_response = self.__send_refresh_token_request(hostname, refresh_token) - fresh_access_token, fresh_refresh_token = self.__get_tokens_from_response(oauth_response) + fresh_access_token, fresh_refresh_token = self.__get_tokens_from_response( + oauth_response + ) return fresh_access_token, fresh_refresh_token, True def get_tokens(self, hostname: str, scope=None): @@ -193,11 +225,8 @@ def get_tokens(self, hostname: str, scope=None): client = oauthlib.oauth2.WebApplicationClient(self.client_id) try: auth_response = self.__get_authorization_code( - client, - auth_url, - scope, - state, - challenge) + client, auth_url, scope, state, challenge + ) except OAuth2Error as e: msg = f"OAuth Authorization Error: {e.description}" logger.error(msg) @@ -208,6 +237,7 @@ def get_tokens(self, hostname: str, scope=None): token_request_url = oauth_config["token_endpoint"] code = auth_response["code"] - oauth_response = \ - self.__send_auth_code_token_request(client, token_request_url, redirect_url, code, verifier) + oauth_response = self.__send_auth_code_token_request( + client, token_request_url, redirect_url, code, verifier + ) return self.__get_tokens_from_response(oauth_response) diff --git a/src/databricks/sql/auth/oauth_http_handler.py b/src/databricks/sql/auth/oauth_http_handler.py index a6cb1e1ec..72c6ce517 100644 --- a/src/databricks/sql/auth/oauth_http_handler.py +++ b/src/databricks/sql/auth/oauth_http_handler.py @@ -22,7 +22,9 @@ class OAuthHttpSingleRequestHandler(BaseHTTPRequestHandler): """ def __init__(self, tool_name): - self.response_body = self.RESPONSE_BODY_TEMPLATE.replace("{!!!PLACE_HOLDER!!!}", tool_name).encode("utf-8") + self.response_body = self.RESPONSE_BODY_TEMPLATE.replace( + "{!!!PLACE_HOLDER!!!}", tool_name + ).encode("utf-8") self.request_path = None def __call__(self, *args, **kwargs): diff --git a/src/databricks/sql/auth/thrift_http_client.py b/src/databricks/sql/auth/thrift_http_client.py index afffab5f7..6d68b620c 100644 --- a/src/databricks/sql/auth/thrift_http_client.py +++ b/src/databricks/sql/auth/thrift_http_client.py @@ -7,10 +7,20 @@ class THttpClient(thrift.transport.THttpClient.THttpClient): - - def __init__(self, auth_provider, uri_or_host, port=None, path=None, cafile=None, cert_file=None, key_file=None, - ssl_context=None): - super().__init__(uri_or_host, port, path, cafile, cert_file, key_file, ssl_context) + def __init__( + self, + auth_provider, + uri_or_host, + port=None, + path=None, + cafile=None, + cert_file=None, + key_file=None, + ssl_context=None, + ): + super().__init__( + uri_or_host, port, path, cafile, cert_file, key_file, ssl_context + ) self.__auth_provider = auth_provider def setCustomHeaders(self, headers: Dict[str, str]): diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 0164e1869..1d413b6a3 100644 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -11,6 +11,7 @@ from databricks.sql.types import Row from databricks.sql.auth.auth import get_python_sql_connector_auth_provider from databricks.sql.experimental.oauth_persistence import OAuthPersistence + logger = logging.getLogger(__name__) DEFAULT_RESULT_BUFFER_SIZE_BYTES = 10485760 @@ -86,7 +87,9 @@ def __init__( self.port = kwargs.get("_port", 443) self.disable_pandas = kwargs.get("_disable_pandas", False) - auth_provider = get_python_sql_connector_auth_provider(server_hostname, oauth_persistence, **kwargs) + auth_provider = get_python_sql_connector_auth_provider( + server_hostname, oauth_persistence, **kwargs + ) if not kwargs.get("_user_agent_entry"): useragent_header = "{}/{}".format(USER_AGENT_NAME, __version__) diff --git a/src/databricks/sql/experimental/oauth_persistence.py b/src/databricks/sql/experimental/oauth_persistence.py index ea94098c8..3149ea940 100644 --- a/src/databricks/sql/experimental/oauth_persistence.py +++ b/src/databricks/sql/experimental/oauth_persistence.py @@ -29,7 +29,6 @@ def read(self) -> Optional[OAuthToken]: # Note this is only intended to be used for development class DevOnlyFilePersistence(OAuthPersistence): - def __init__(self, file_path): self._file_path = file_path @@ -39,7 +38,7 @@ def persist(self, token: OAuthToken): # Data to be written dictionary = { "refresh_token": token.refresh_token, - "access_token": token.access_token + "access_token": token.access_token, } # Serializing json @@ -55,6 +54,8 @@ def read(self) -> Optional[OAuthToken]: json_as_string = infile.read() token_as_json = json.loads(json_as_string) - return OAuthToken(token_as_json['access_token'], token_as_json['refresh_token']) + return OAuthToken( + token_as_json["access_token"], token_as_json["refresh_token"] + ) except Exception as e: return None diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index e077518b7..ca983cbdc 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -54,7 +54,13 @@ class ThriftBackend: BIT_MASKS = [1, 2, 4, 8, 16, 32, 64, 128] def __init__( - self, server_hostname: str, port, http_path: str, http_headers, auth_provider: CredentialsProvider, **kwargs + self, + server_hostname: str, + port, + http_path: str, + http_headers, + auth_provider: CredentialsProvider, + **kwargs, ): # Internal arguments in **kwargs: # _user_agent_entry From 00c403d0f19911cb580b6d154d1067837c1865d2 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Tue, 30 Aug 2022 14:20:29 -0700 Subject: [PATCH 49/57] cleanup Signed-off-by: Moe Derakhshani --- src/databricks/sql/__init__.py | 6 +-- src/databricks/sql/auth/auth.py | 6 +-- src/databricks/sql/client.py | 85 ++++++++++++++++++++++++++++----- 3 files changed, 78 insertions(+), 19 deletions(-) diff --git a/src/databricks/sql/__init__.py b/src/databricks/sql/__init__.py index ffc99f4ba..1f7d30044 100644 --- a/src/databricks/sql/__init__.py +++ b/src/databricks/sql/__init__.py @@ -44,9 +44,7 @@ def TimestampFromTicks(ticks): return Timestamp(*time.localtime(ticks)[:6]) -def connect(server_hostname, http_path, experimental_oauth_persistence=None, **kwargs): +def connect(server_hostname, http_path, **kwargs): from .client import Connection - return Connection( - server_hostname, http_path, experimental_oauth_persistence, **kwargs - ) + return Connection(server_hostname, http_path, **kwargs) diff --git a/src/databricks/sql/auth/auth.py b/src/databricks/sql/auth/auth.py index af61f7ac2..bb2525f87 100644 --- a/src/databricks/sql/auth/auth.py +++ b/src/databricks/sql/auth/auth.py @@ -73,9 +73,7 @@ def get_auth_provider(cfg: ClientContext): PYSQL_OAUTH_REDIRECT_PORT_RANGE = list(range(8020, 8025)) -def get_python_sql_connector_auth_provider( - hostname: str, oauth_persistence: OAuthPersistence = None, **kwargs -): +def get_python_sql_connector_auth_provider(hostname: str, **kwargs): cfg = ClientContext( hostname=hostname, auth_type=kwargs.get("auth_type"), @@ -87,6 +85,6 @@ def get_python_sql_connector_auth_provider( oauth_scopes=PYSQL_OAUTH_SCOPES, oauth_client_id=PYSQL_OAUTH_CLIENT_ID, oauth_redirect_port_range=PYSQL_OAUTH_REDIRECT_PORT_RANGE, - oauth_persistence=oauth_persistence, + oauth_persistence=kwargs.get("experimental_oauth_persistence"), ) return get_auth_provider(cfg) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 1d413b6a3..0e07f6354 100644 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -23,7 +23,6 @@ def __init__( self, server_hostname: str, http_path: str, - oauth_persistence: OAuthPersistence = None, http_headers: Optional[List[Tuple[str, str]]] = None, session_configuration: Dict[str, Any] = None, catalog: Optional[str] = None, @@ -33,15 +32,79 @@ def __init__( """ Connect to a Databricks SQL endpoint or a Databricks cluster. - :param server_hostname: Databricks instance host name. - :param http_path: Http path either to a DBSQL endpoint (e.g. /sql/1.0/endpoints/1234567890abcdef) - or to a DBR interactive cluster (e.g. /sql/protocolv1/o/1234567890123456/1234-123456-slid123) - :param access_token: Http Bearer access token, e.g. Databricks Personal Access Token. - :param http_headers: An optional list of (k, v) pairs that will be set as Http headers on every request - :param session_configuration: An optional dictionary of Spark session parameters. Defaults to None. - Execute the SQL command `SET -v` to get a full list of available commands. - :param catalog: An optional initial catalog to use. Requires DBR version 9.0+ - :param schema: An optional initial schema to use. Requires DBR version 9.0+ + Parameters: + :param server_hostname: Databricks instance host name. + :param http_path: Http path either to a DBSQL endpoint (e.g. /sql/1.0/endpoints/1234567890abcdef) + or to a DBR interactive cluster (e.g. /sql/protocolv1/o/1234567890123456/1234-123456-slid123) + :param http_headers: An optional list of (k, v) pairs that will be set as Http headers on every request + :param session_configuration: An optional dictionary of Spark session parameters. Defaults to None. + Execute the SQL command `SET -v` to get a full list of available commands. + :param catalog: An optional initial catalog to use. Requires DBR version 9.0+ + :param schema: An optional initial schema to use. Requires DBR version 9.0+ + + Other Parameters: + access_token: `str`, optional + Http Bearer access token, e.g. Databricks Personal Access Token. + Unless if you use auth_type=`databricks-oauth` you need to pass `access_token. + Examples: + connection = sql.connect( + server_hostname='dbc-12345.staging.cloud.databricks.com', + http_path='sql/protocolv1/o/6789/12abc567', + access_token='dabpi12345678' + ) + + auth_type: `str`, optional + `databricks-oauth` : to use oauth with fine-grained permission scopes, set to `databricks-oauth`. + This is currently in private preview for Databricks accounts on AWS. + This is a beta feature and supports User to Machine OAuth authentication for Databricks on AWS with + any IDP configured. This is only for interactive python applications and open a browser window. + Note this is beta (private preview) + + experimental_oauth_persistence: configures preferred storage for persisting oauth tokens. + This has to be a class implementing `OAuthPersistence`. + When `auth_type` is set to `databricks-oauth` without persisting the oauth token in a persistence storage + the oauth tokens will only be maintained in memory and if the python process restarts the end user + will have to login again. + Note this is beta (private preview) + + For persisting the oauth token in a prod environment you should subclass and implement OAuthPersistence + + + from databricks.sql.experimental.oauth_persistence import OAuthPersistence, OAuthToken + class MyCustomImplementation(OAuthPersistence): + def __init__(self, file_path): + self._file_path = file_path + + def persist(self, token: OAuthToken): + # implement this method to persist token.refresh_token and token.access_token + + def read(self) -> Optional[OAuthToken]: + # implement this method to return an instance of the persisted token + + + connection = sql.connect( + server_hostname='dbc-12345.staging.cloud.databricks.com', + http_path='sql/protocolv1/o/6789/12abc567', + auth_type="databricks-oauth", + experimental_oauth_persistence=MyCustomImplementation() + ) + + For development purpose you can use the existing `DevOnlyFilePersistence` which stores the + raw oauth token in the provided file path. Please note this is only for development and for prod you should provide your + own implementation of OAuthPersistence. + + Examples: + # for development only + from databricks.sql.experimental.oauth_persistence import DevOnlyFilePersistence + + connection = sql.connect( + server_hostname='dbc-12345.staging.cloud.databricks.com', + http_path='sql/protocolv1/o/6789/12abc567', + auth_type="databricks-oauth", + experimental_oauth_persistence=DevOnlyFilePersistence("~/dev-oauth.json") + ) + + """ # Internal arguments in **kwargs: @@ -88,7 +151,7 @@ def __init__( self.disable_pandas = kwargs.get("_disable_pandas", False) auth_provider = get_python_sql_connector_auth_provider( - server_hostname, oauth_persistence, **kwargs + server_hostname, **kwargs ) if not kwargs.get("_user_agent_entry"): From e64df63dd842b57f522ccf9fec7b3c7c37c5fdf1 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Tue, 30 Aug 2022 14:28:32 -0700 Subject: [PATCH 50/57] Update src/databricks/sql/auth/thrift_http_client.py Co-authored-by: Jesse --- src/databricks/sql/auth/thrift_http_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sql/auth/thrift_http_client.py b/src/databricks/sql/auth/thrift_http_client.py index 6d68b620c..7ed35e54a 100644 --- a/src/databricks/sql/auth/thrift_http_client.py +++ b/src/databricks/sql/auth/thrift_http_client.py @@ -1,7 +1,7 @@ import logging from typing import Dict -import thrift.transport.THttpClient +import thrift logger = logging.getLogger(__name__) From ef385e9e28a0e5ef4d99a448fcd1c8bf26e29e86 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Tue, 30 Aug 2022 22:15:11 -0700 Subject: [PATCH 51/57] cleanup Signed-off-by: Moe Derakhshani --- src/databricks/sql/auth/auth.py | 8 +++++++- src/databricks/sql/auth/authenticators.py | 8 +------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/databricks/sql/auth/auth.py b/src/databricks/sql/auth/auth.py index bb2525f87..018bfdb07 100644 --- a/src/databricks/sql/auth/auth.py +++ b/src/databricks/sql/auth/auth.py @@ -73,9 +73,15 @@ def get_auth_provider(cfg: ClientContext): PYSQL_OAUTH_REDIRECT_PORT_RANGE = list(range(8020, 8025)) +def normalize_host_name(hostname: str): + maybe_scheme = "https://" if not hostname.startswith("https://") else "" + maybe_trailing_slash = "/" if not hostname.endswith("/") else "" + return f"{maybe_scheme}{hostname}{maybe_trailing_slash}" + + def get_python_sql_connector_auth_provider(hostname: str, **kwargs): cfg = ClientContext( - hostname=hostname, + hostname=normalize_host_name(hostname), auth_type=kwargs.get("auth_type"), access_token=kwargs.get("access_token"), username=kwargs.get("_username"), diff --git a/src/databricks/sql/auth/authenticators.py b/src/databricks/sql/auth/authenticators.py index 02bf5816a..6e4944e6b 100644 --- a/src/databricks/sql/auth/authenticators.py +++ b/src/databricks/sql/auth/authenticators.py @@ -56,7 +56,7 @@ def __init__( self.oauth_manager = OAuthManager( port_range=redirect_port_range, client_id=client_id ) - self._hostname = self._normalize_host_name(hostname=hostname) + self._hostname = hostname self._scopes_as_str = DatabricksOAuthProvider.SCOPE_DELIM.join(scopes) self._oauth_persistence = oauth_persistence self._client_id = client_id @@ -71,12 +71,6 @@ def add_headers(self, request_headers: Dict[str, str]): self._update_token_if_expired() request_headers["Authorization"] = f"Bearer {self._access_token}" - @staticmethod - def _normalize_host_name(hostname: str): - maybe_scheme = "https://" if not hostname.startswith("https://") else "" - maybe_trailing_slash = "/" if not hostname.endswith("/") else "" - return f"{maybe_scheme}{hostname}{maybe_trailing_slash}" - def _initial_get_token(self): try: if self._access_token is None or self._refresh_token is None: From 27fb3b5ac68faf8551d49df5d443b8cedd5717a1 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Tue, 30 Aug 2022 22:17:16 -0700 Subject: [PATCH 52/57] cleanup Signed-off-by: Moe Derakhshani --- src/databricks/sql/experimental/oauth_persistence.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/databricks/sql/experimental/oauth_persistence.py b/src/databricks/sql/experimental/oauth_persistence.py index 3149ea940..9f5cc6a59 100644 --- a/src/databricks/sql/experimental/oauth_persistence.py +++ b/src/databricks/sql/experimental/oauth_persistence.py @@ -48,7 +48,6 @@ def persist(self, token: OAuthToken): outfile.write(json_object) def read(self) -> Optional[OAuthToken]: - # TODO: validate the try: with open(self._file_path, "r") as infile: json_as_string = infile.read() From 0a6c455c75d12510fbcb2161e94dca8d3c20de44 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Tue, 30 Aug 2022 22:23:56 -0700 Subject: [PATCH 53/57] cleanup Signed-off-by: Moe Derakhshani --- tests/unit/test_oauth_persistence.py | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/tests/unit/test_oauth_persistence.py b/tests/unit/test_oauth_persistence.py index 01c199033..4ddb7050e 100644 --- a/tests/unit/test_oauth_persistence.py +++ b/tests/unit/test_oauth_persistence.py @@ -1,24 +1,3 @@ -# Copyright 2022 Databricks, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"), except -# that the use of services to which certain application programming -# interfaces (each, an "API") connect requires that the user first obtain -# a license for the use of the APIs from Databricks, Inc. ("Databricks"), -# by creating an account at www.databricks.com and agreeing to either (a) -# the Community Edition Terms of Service, (b) the Databricks Terms of -# Service, or (c) another written agreement between Licensee and Databricks -# for the use of the APIs. -# -# You may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. import unittest From 16aa44e2ab70a2f17eed89738ad00efa0d51da6e Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Tue, 6 Sep 2022 13:51:28 -0700 Subject: [PATCH 54/57] moved access_token out of kwargs Signed-off-by: Moe Derakhshani --- src/databricks/sql/__init__.py | 5 ++--- src/databricks/sql/client.py | 21 ++++++++++++--------- src/databricks/sql/thrift_backend.py | 6 ++++++ 3 files changed, 20 insertions(+), 12 deletions(-) diff --git a/src/databricks/sql/__init__.py b/src/databricks/sql/__init__.py index 1f7d30044..ef5c62124 100644 --- a/src/databricks/sql/__init__.py +++ b/src/databricks/sql/__init__.py @@ -44,7 +44,6 @@ def TimestampFromTicks(ticks): return Timestamp(*time.localtime(ticks)[:6]) -def connect(server_hostname, http_path, **kwargs): +def connect(server_hostname, http_path, access_token=None, **kwargs): from .client import Connection - - return Connection(server_hostname, http_path, **kwargs) + return Connection(server_hostname, http_path, access_token, **kwargs) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 0e07f6354..e97f360bb 100644 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -23,6 +23,7 @@ def __init__( self, server_hostname: str, http_path: str, + access_token: Optional[str] = None, http_headers: Optional[List[Tuple[str, str]]] = None, session_configuration: Dict[str, Any] = None, catalog: Optional[str] = None, @@ -36,14 +37,7 @@ def __init__( :param server_hostname: Databricks instance host name. :param http_path: Http path either to a DBSQL endpoint (e.g. /sql/1.0/endpoints/1234567890abcdef) or to a DBR interactive cluster (e.g. /sql/protocolv1/o/1234567890123456/1234-123456-slid123) - :param http_headers: An optional list of (k, v) pairs that will be set as Http headers on every request - :param session_configuration: An optional dictionary of Spark session parameters. Defaults to None. - Execute the SQL command `SET -v` to get a full list of available commands. - :param catalog: An optional initial catalog to use. Requires DBR version 9.0+ - :param schema: An optional initial schema to use. Requires DBR version 9.0+ - - Other Parameters: - access_token: `str`, optional + :param access_token: `str`, optional Http Bearer access token, e.g. Databricks Personal Access Token. Unless if you use auth_type=`databricks-oauth` you need to pass `access_token. Examples: @@ -52,7 +46,13 @@ def __init__( http_path='sql/protocolv1/o/6789/12abc567', access_token='dabpi12345678' ) + :param http_headers: An optional list of (k, v) pairs that will be set as Http headers on every request + :param session_configuration: An optional dictionary of Spark session parameters. Defaults to None. + Execute the SQL command `SET -v` to get a full list of available commands. + :param catalog: An optional initial catalog to use. Requires DBR version 9.0+ + :param schema: An optional initial schema to use. Requires DBR version 9.0+ + Other Parameters: auth_type: `str`, optional `databricks-oauth` : to use oauth with fine-grained permission scopes, set to `databricks-oauth`. This is currently in private preview for Databricks accounts on AWS. @@ -69,7 +69,6 @@ def __init__( For persisting the oauth token in a prod environment you should subclass and implement OAuthPersistence - from databricks.sql.experimental.oauth_persistence import OAuthPersistence, OAuthToken class MyCustomImplementation(OAuthPersistence): def __init__(self, file_path): @@ -145,6 +144,10 @@ def read(self) -> Optional[OAuthToken]: # Databricks runtime will return native Arrow types for timestamps instead of Arrow strings # (True by default) + if access_token: + access_token_kv = {"access_token": access_token} + kwargs = {**kwargs, **access_token_kv} + self.open = False self.host = server_hostname self.port = kwargs.get("_port", 443) diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index ca983cbdc..c3cf75f38 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -217,6 +217,12 @@ def _extract_error_message_from_headers(headers): err_msg = headers[DATABRICKS_ERROR_OR_REDIRECT_HEADER] if DATABRICKS_REASON_HEADER in headers: err_msg += ": " + headers[DATABRICKS_REASON_HEADER] + + if not err_msg: + # if authentication token is invalid we need this branch + if DATABRICKS_REASON_HEADER in headers: + err_msg += ": " + headers[DATABRICKS_REASON_HEADER] + return err_msg def _handle_request_error(self, error_info, attempt, elapsed): From be19297478b01b538ed61ac07feccc42167bacc1 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Tue, 6 Sep 2022 14:00:03 -0700 Subject: [PATCH 55/57] added hostname to the persitence api Signed-off-by: Moe Derakhshani --- src/databricks/sql/auth/authenticators.py | 6 +++--- .../sql/experimental/oauth_persistence.py | 16 ++++++++++++---- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/src/databricks/sql/auth/authenticators.py b/src/databricks/sql/auth/authenticators.py index 6e4944e6b..f006567ae 100644 --- a/src/databricks/sql/auth/authenticators.py +++ b/src/databricks/sql/auth/authenticators.py @@ -75,7 +75,7 @@ def _initial_get_token(self): try: if self._access_token is None or self._refresh_token is None: if self._oauth_persistence: - token = self._oauth_persistence.read() + token = self._oauth_persistence.read(self._hostname) if token: self._access_token = token.access_token self._refresh_token = token.refresh_token @@ -88,7 +88,7 @@ def _initial_get_token(self): ) self._access_token = access_token self._refresh_token = refresh_token - self._oauth_persistence.persist(OAuthToken(access_token, refresh_token)) + self._oauth_persistence.persist(self._hostname, OAuthToken(access_token, refresh_token)) except Exception as e: logging.error(f"unexpected error in oauth initialization", e, exc_info=True) raise e @@ -112,7 +112,7 @@ def _update_token_if_expired(self): if self._oauth_persistence: token = OAuthToken(self._access_token, self._refresh_token) - self._oauth_persistence.persist(token) + self._oauth_persistence.persist(self._hostname, token) except Exception as e: logging.error(f"unexpected error in oauth token update", e, exc_info=True) raise e diff --git a/src/databricks/sql/experimental/oauth_persistence.py b/src/databricks/sql/experimental/oauth_persistence.py index 9f5cc6a59..815c2c1e5 100644 --- a/src/databricks/sql/experimental/oauth_persistence.py +++ b/src/databricks/sql/experimental/oauth_persistence.py @@ -20,10 +20,10 @@ def refresh_token(self) -> str: class OAuthPersistence: - def persist(self, oauth_token: OAuthToken): + def persist(self, hostname: str, oauth_token: OAuthToken): pass - def read(self) -> Optional[OAuthToken]: + def read(self, hostname: str) -> Optional[OAuthToken]: pass @@ -32,13 +32,14 @@ class DevOnlyFilePersistence(OAuthPersistence): def __init__(self, file_path): self._file_path = file_path - def persist(self, token: OAuthToken): + def persist(self, hostname: str, token: OAuthToken): logger.info(f"persisting token in {self._file_path}") # Data to be written dictionary = { "refresh_token": token.refresh_token, "access_token": token.access_token, + "hostname": hostname } # Serializing json @@ -47,12 +48,19 @@ def persist(self, token: OAuthToken): with open(self._file_path, "w") as outfile: outfile.write(json_object) - def read(self) -> Optional[OAuthToken]: + def read(self, hostname: str) -> Optional[OAuthToken]: try: with open(self._file_path, "r") as infile: json_as_string = infile.read() token_as_json = json.loads(json_as_string) + hostname_in_token = token_as_json["hostname"] + if hostname != hostname_in_token: + msg = f"token was persisted for host {hostname_in_token} does not match {hostname} " \ + f"This is a dev only persistence and it only supports a single Databricks hostname." \ + f"\n manually delete {self._file_path} file and restart this process" + logger.error(msg) + raise Exception(msg) return OAuthToken( token_as_json["access_token"], token_as_json["refresh_token"] ) From 367a3eec857bfa1c9ed9a1a588464b70fd3a1da1 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Tue, 13 Sep 2022 20:51:05 -0700 Subject: [PATCH 56/57] responded to review comments Signed-off-by: Moe Derakhshani --- src/databricks/sql/auth/auth.py | 4 +- src/databricks/sql/auth/authenticators.py | 8 +-- src/databricks/sql/client.py | 2 +- src/databricks/sql/thrift_backend.py | 4 +- tests/unit/test_auth.py | 4 +- tests/unit/test_oauth_persistence.py | 2 +- tests/unit/test_thrift_backend.py | 88 +++++++++++------------ 7 files changed, 56 insertions(+), 56 deletions(-) diff --git a/src/databricks/sql/auth/auth.py b/src/databricks/sql/auth/auth.py index 018bfdb07..31198a617 100644 --- a/src/databricks/sql/auth/auth.py +++ b/src/databricks/sql/auth/auth.py @@ -2,7 +2,7 @@ from typing import List from databricks.sql.auth.authenticators import ( - CredentialsProvider, + AuthProvider, AccessTokenAuthProvider, BasicAuthProvider, DatabricksOAuthProvider, @@ -63,7 +63,7 @@ def get_auth_provider(cfg: ClientContext): return BasicAuthProvider(cfg.username, cfg.password) elif cfg.use_cert_as_auth and cfg.tls_client_cert_file: # no op authenticator. authentication is performed using ssl certificate outside of headers - return CredentialsProvider() + return AuthProvider() else: raise RuntimeError("No valid authentication settings!") diff --git a/src/databricks/sql/auth/authenticators.py b/src/databricks/sql/auth/authenticators.py index f006567ae..0498cfbb2 100644 --- a/src/databricks/sql/auth/authenticators.py +++ b/src/databricks/sql/auth/authenticators.py @@ -9,14 +9,14 @@ from databricks.sql.experimental.oauth_persistence import OAuthToken, OAuthPersistence -class CredentialsProvider: +class AuthProvider: def add_headers(self, request_headers: Dict[str, str]): pass # Private API: this is an evolving interface and it will change in the future. # Please must not depend on it in your applications. -class AccessTokenAuthProvider(CredentialsProvider): +class AccessTokenAuthProvider(AuthProvider): def __init__(self, access_token: str): self.__authorization_header_value = "Bearer {}".format(access_token) @@ -26,7 +26,7 @@ def add_headers(self, request_headers: Dict[str, str]): # Private API: this is an evolving interface and it will change in the future. # Please must not depend on it in your applications. -class BasicAuthProvider(CredentialsProvider): +class BasicAuthProvider(AuthProvider): def __init__(self, username: str, password: str): auth_credentials = f"{username}:{password}".encode("UTF-8") auth_credentials_base64 = base64.standard_b64encode(auth_credentials).decode( @@ -41,7 +41,7 @@ def add_headers(self, request_headers: Dict[str, str]): # Private API: this is an evolving interface and it will change in the future. # Please must not depend on it in your applications. -class DatabricksOAuthProvider(CredentialsProvider): +class DatabricksOAuthProvider(AuthProvider): SCOPE_DELIM = " " def __init__( diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index e97f360bb..e3190d457 100644 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -56,7 +56,7 @@ def __init__( auth_type: `str`, optional `databricks-oauth` : to use oauth with fine-grained permission scopes, set to `databricks-oauth`. This is currently in private preview for Databricks accounts on AWS. - This is a beta feature and supports User to Machine OAuth authentication for Databricks on AWS with + This supports User to Machine OAuth authentication for Databricks on AWS with any IDP configured. This is only for interactive python applications and open a browser window. Note this is beta (private preview) diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index c3cf75f38..b5e29f394 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -13,7 +13,7 @@ import thrift.transport.TTransport import databricks.sql.auth.thrift_http_client -from databricks.sql.auth.authenticators import CredentialsProvider +from databricks.sql.auth.authenticators import AuthProvider from databricks.sql.thrift_api.TCLIService import TCLIService, ttypes from databricks.sql import * from databricks.sql.thrift_api.TCLIService.TCLIService import ( @@ -59,7 +59,7 @@ def __init__( port, http_path: str, http_headers, - auth_provider: CredentialsProvider, + auth_provider: AuthProvider, **kwargs, ): # Internal arguments in **kwargs: diff --git a/tests/unit/test_auth.py b/tests/unit/test_auth.py index 22b4336bf..59660f17c 100644 --- a/tests/unit/test_auth.py +++ b/tests/unit/test_auth.py @@ -1,6 +1,6 @@ import unittest -from databricks.sql.auth.auth import AccessTokenAuthProvider, BasicAuthProvider, CredentialsProvider +from databricks.sql.auth.auth import AccessTokenAuthProvider, BasicAuthProvider, AuthProvider from databricks.sql.auth.auth import get_python_sql_connector_auth_provider @@ -29,7 +29,7 @@ def test_basic_auth_provider(self): self.assertEqual(http_request['myKey'], 'myVal') def test_noop_auth_provider(self): - auth = CredentialsProvider() + auth = AuthProvider() http_request = {'myKey': 'myVal'} auth.add_headers(http_request) diff --git a/tests/unit/test_oauth_persistence.py b/tests/unit/test_oauth_persistence.py index 4ddb7050e..6c30888a8 100644 --- a/tests/unit/test_oauth_persistence.py +++ b/tests/unit/test_oauth_persistence.py @@ -1,7 +1,7 @@ import unittest -from databricks.sql.auth.auth import AccessTokenAuthProvider, BasicAuthProvider, CredentialsProvider +from databricks.sql.auth.auth import AccessTokenAuthProvider, BasicAuthProvider, AuthProvider from databricks.sql.auth.auth import get_python_sql_connector_auth_provider from databricks.sql.experimental.oauth_persistence import DevOnlyFilePersistence, OAuthToken import tempfile diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 1120c0525..488de5842 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -10,7 +10,7 @@ import databricks.sql from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql import * -from databricks.sql.auth.authenticators import CredentialsProvider +from databricks.sql.auth.authenticators import AuthProvider from databricks.sql.thrift_backend import ThriftBackend @@ -60,7 +60,7 @@ def test_make_request_checks_thrift_status_code(self): mock_method = Mock() mock_method.__name__ = "method name" mock_method.return_value = mock_response - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) with self.assertRaises(DatabaseError): thrift_backend.make_request(mock_method, Mock()) @@ -68,7 +68,7 @@ def _make_type_desc(self, type): return ttypes.TTypeDesc(types=[ttypes.TTypeEntry(ttypes.TPrimitiveTypeEntry(type=type))]) def _make_fake_thrift_backend(self): - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) thrift_backend._hive_schema_to_arrow_schema = Mock() thrift_backend._hive_schema_to_description = Mock() thrift_backend._create_arrow_table = MagicMock() @@ -140,7 +140,7 @@ def test_okay_protocol_versions_succeed(self, tcli_service_client_cass): @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_headers_are_set(self, t_http_client_class): - ThriftBackend("foo", 123, "bar", [("header", "value")], auth_provider=CredentialsProvider()) + ThriftBackend("foo", 123, "bar", [("header", "value")], auth_provider=AuthProvider()) t_http_client_class.return_value.setCustomHeaders.assert_called_with({"header": "value"}) @patch("databricks.sql.auth.thrift_http_client.THttpClient") @@ -155,7 +155,7 @@ def test_tls_cert_args_are_propagated(self, mock_create_default_context, t_http_ "foo", 123, "bar", [], - auth_provider=CredentialsProvider(), + auth_provider=AuthProvider(), _tls_client_cert_file=mock_cert_file, _tls_client_cert_key_file=mock_cert_key_file, _tls_client_cert_key_password=mock_cert_key_password, @@ -172,7 +172,7 @@ def test_tls_cert_args_are_propagated(self, mock_create_default_context, t_http_ @patch("databricks.sql.auth.thrift_http_client.THttpClient") @patch("databricks.sql.thrift_backend.create_default_context") def test_tls_no_verify_is_respected(self, mock_create_default_context, t_http_client_class): - ThriftBackend("foo", 123, "bar", [], auth_provider=CredentialsProvider(), _tls_no_verify=True) + ThriftBackend("foo", 123, "bar", [], auth_provider=AuthProvider(), _tls_no_verify=True) mock_ssl_context = mock_create_default_context.return_value self.assertFalse(mock_ssl_context.check_hostname) @@ -183,7 +183,7 @@ def test_tls_no_verify_is_respected(self, mock_create_default_context, t_http_cl @patch("databricks.sql.thrift_backend.create_default_context") def test_tls_verify_hostname_is_respected(self, mock_create_default_context, t_http_client_class): - ThriftBackend("foo", 123, "bar", [], auth_provider=CredentialsProvider(), _tls_verify_hostname=False) + ThriftBackend("foo", 123, "bar", [], auth_provider=AuthProvider(), _tls_verify_hostname=False) mock_ssl_context = mock_create_default_context.return_value self.assertFalse(mock_ssl_context.check_hostname) @@ -192,17 +192,17 @@ def test_tls_verify_hostname_is_respected(self, mock_create_default_context, @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_port_and_host_are_respected(self, t_http_client_class): - ThriftBackend("hostname", 123, "path_value", [], auth_provider=CredentialsProvider()) + ThriftBackend("hostname", 123, "path_value", [], auth_provider=AuthProvider()) self.assertEqual(t_http_client_class.call_args[1]["uri_or_host"], "https://hostname:123/path_value") @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_socket_timeout_is_propagated(self, t_http_client_class): - ThriftBackend("hostname", 123, "path_value", [], auth_provider=CredentialsProvider(), _socket_timeout=129) + ThriftBackend("hostname", 123, "path_value", [], auth_provider=AuthProvider(), _socket_timeout=129) self.assertEqual(t_http_client_class.return_value.setTimeout.call_args[0][0], 129 * 1000) - ThriftBackend("hostname", 123, "path_value", [], auth_provider=CredentialsProvider(), _socket_timeout=0) + ThriftBackend("hostname", 123, "path_value", [], auth_provider=AuthProvider(), _socket_timeout=0) self.assertEqual(t_http_client_class.return_value.setTimeout.call_args[0][0], 0) - ThriftBackend("hostname", 123, "path_value", [], auth_provider=CredentialsProvider(), _socket_timeout=None) + ThriftBackend("hostname", 123, "path_value", [], auth_provider=AuthProvider(), _socket_timeout=None) self.assertEqual(t_http_client_class.return_value.setTimeout.call_args[0][0], None) def test_non_primitive_types_raise_error(self): @@ -270,7 +270,7 @@ def test_hive_schema_to_description_preserves_scale_and_precision(self): def test_make_request_checks_status_code(self): error_codes = [ttypes.TStatusCode.ERROR_STATUS, ttypes.TStatusCode.INVALID_HANDLE_STATUS] - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) for code in error_codes: mock_error_response = Mock() @@ -303,7 +303,7 @@ def test_handle_execute_response_checks_operation_state_in_direct_results(self): resultSetMetadata=None, resultSet=None, closeOperation=None)) - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) with self.assertRaises(DatabaseError) as cm: thrift_backend._handle_execute_response(t_execute_resp, Mock()) @@ -331,7 +331,7 @@ def test_handle_execute_response_checks_operation_state_in_polls(self, tcli_serv operationHandle=self.operation_handle) tcli_service_instance.GetOperationStatus.return_value = op_state_resp - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) with self.assertRaises(DatabaseError) as cm: thrift_backend._handle_execute_response(t_execute_resp, Mock()) @@ -356,7 +356,7 @@ def test_get_status_uses_display_message_if_available(self, tcli_service_class): tcli_service_instance.GetOperationStatus.return_value = t_get_operation_status_resp tcli_service_instance.ExecuteStatement.return_value = t_execute_resp - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) with self.assertRaises(DatabaseError) as cm: thrift_backend.execute_command(Mock(), Mock(), 100, 100, Mock()) @@ -386,7 +386,7 @@ def test_direct_results_uses_display_message_if_available(self, tcli_service_cla tcli_service_instance.ExecuteStatement.return_value = t_execute_resp - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) with self.assertRaises(DatabaseError) as cm: thrift_backend.execute_command(Mock(), Mock(), 100, 100, Mock()) @@ -429,7 +429,7 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): for error_resp in [resp_1, resp_2, resp_3, resp_4]: with self.subTest(error_resp=error_resp): - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) with self.assertRaises(DatabaseError) as cm: thrift_backend._handle_execute_response(error_resp, Mock()) @@ -465,7 +465,7 @@ def test_handle_execute_response_can_handle_without_direct_results(self, tcli_se tcli_service_instance.GetOperationStatus.side_effect = [ op_state_1, op_state_2, op_state_3 ] - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) results_message_response = thrift_backend._handle_execute_response( execute_resp, Mock()) self.assertEqual(results_message_response.status, @@ -490,7 +490,7 @@ def test_handle_execute_response_can_handle_with_direct_results(self): directResults=direct_results_message, operationHandle=self.operation_handle) - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) thrift_backend._results_message_to_execute_response = Mock() thrift_backend._handle_execute_response(execute_resp, Mock()) @@ -644,7 +644,7 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class): pyarrow.field("column3", pyarrow.binary()) ]).serialize().to_pybytes() - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) arrow_queue, has_more_results = thrift_backend.fetch_results( op_handle=Mock(), max_rows=1, @@ -660,7 +660,7 @@ def test_execute_statement_calls_client_and_handle_execute_response(self, tcli_s tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.ExecuteStatement.return_value = response - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() @@ -678,7 +678,7 @@ def test_get_catalogs_calls_client_and_handle_execute_response(self, tcli_servic tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetCatalogs.return_value = response - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() @@ -695,7 +695,7 @@ def test_get_schemas_calls_client_and_handle_execute_response(self, tcli_service tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetSchemas.return_value = response - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() @@ -720,7 +720,7 @@ def test_get_tables_calls_client_and_handle_execute_response(self, tcli_service_ tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetTables.return_value = response - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() @@ -749,7 +749,7 @@ def test_get_columns_calls_client_and_handle_execute_response(self, tcli_service tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetColumns.return_value = response - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() @@ -778,14 +778,14 @@ def test_open_session_user_provided_session_id_optional(self, tcli_service_class tcli_service_instance = tcli_service_class.return_value tcli_service_instance.OpenSession.return_value = self.open_session_resp - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) thrift_backend.open_session({}, None, None) self.assertEqual(len(tcli_service_instance.OpenSession.call_args_list), 1) @patch("databricks.sql.thrift_backend.TCLIService.Client") def test_op_handle_respected_in_close_command(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) thrift_backend.close_command(self.operation_handle) self.assertEqual(tcli_service_instance.CloseOperation.call_args[0][0].operationHandle, self.operation_handle) @@ -793,7 +793,7 @@ def test_op_handle_respected_in_close_command(self, tcli_service_class): @patch("databricks.sql.thrift_backend.TCLIService.Client") def test_session_handle_respected_in_close_session(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) thrift_backend.close_session(self.session_handle) self.assertEqual(tcli_service_instance.CloseSession.call_args[0][0].sessionHandle, self.session_handle) @@ -828,7 +828,7 @@ def test_non_arrow_non_column_based_set_triggers_exception(self, tcli_service_cl def test_create_arrow_table_raises_error_for_unsupported_type(self): t_row_set = ttypes.TRowSet() - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) with self.assertRaises(OperationalError): thrift_backend._create_arrow_table(t_row_set, None, Mock()) @@ -836,7 +836,7 @@ def test_create_arrow_table_raises_error_for_unsupported_type(self): @patch.object(ThriftBackend, "_convert_column_based_set_to_arrow_table") def test_create_arrow_table_calls_correct_conversion_method(self, convert_col_mock, convert_arrow_mock): - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) convert_arrow_mock.return_value = (MagicMock(), Mock()) convert_col_mock.return_value = (MagicMock(), Mock()) @@ -1000,7 +1000,7 @@ def test_make_request_will_retry_GetOperationStatus( "foobar", 443, "path", [], - auth_provider=CredentialsProvider(), + auth_provider=AuthProvider(), _retry_stop_after_attempts_count=EXPECTED_RETRIES, _retry_delay_default=1) @@ -1037,7 +1037,7 @@ def test_make_request_wont_retry_if_headers_not_present(self, t_transport_class) mock_method.__name__ = "method name" mock_method.side_effect = Exception("This method fails") - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) with self.assertRaises(OperationalError) as cm: thrift_backend.make_request(mock_method, Mock()) @@ -1053,7 +1053,7 @@ def test_make_request_wont_retry_if_error_code_not_429_or_503(self, t_transport_ mock_method.__name__ = "method name" mock_method.side_effect = Exception("This method fails") - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) with self.assertRaises(OperationalError) as cm: thrift_backend.make_request(mock_method, Mock()) @@ -1075,7 +1075,7 @@ def test_make_request_will_retry_stop_after_attempts_count_if_retryable( "foobar", 443, "path", [], - auth_provider=CredentialsProvider(), + auth_provider=AuthProvider(), _retry_stop_after_attempts_count=14, _retry_delay_max=0, _retry_delay_min=0) @@ -1095,7 +1095,7 @@ def test_make_request_will_read_error_message_headers_if_set(self, t_transport_c mock_method.__name__ = "method name" mock_method.side_effect = Exception("This method fails") - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) error_headers = [[("x-thriftserver-error-message", "thrift server error message")], [("x-databricks-error-or-redirect-message", "databricks error message")], @@ -1177,7 +1177,7 @@ def test_retry_args_passthrough(self, mock_http_client): "_retry_stop_after_attempts_count": 1, "_retry_stop_after_attempts_duration": 100 } - backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider(), **retry_delay_args) + backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), **retry_delay_args) for (arg, val) in retry_delay_args.items(): self.assertEqual(getattr(backend, arg), val) @@ -1192,7 +1192,7 @@ def test_retry_args_bounding(self, mock_http_client): k: v[i][0] for (k, v) in retry_delay_test_args_and_expected_values.items() } - backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider(), **retry_delay_args) + backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), **retry_delay_args) retry_delay_expected_vals = { k: v[i][1] for (k, v) in retry_delay_test_args_and_expected_values.items() @@ -1212,7 +1212,7 @@ def test_configuration_passthrough(self, tcli_client_class): "42": "42" } - backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) + backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) backend.open_session(mock_config, None, None) open_session_req = tcli_client_class.return_value.OpenSession.call_args[0][0] @@ -1223,7 +1223,7 @@ def test_cant_set_timestamp_as_string_to_true(self, tcli_client_class): tcli_service_instance = tcli_client_class.return_value tcli_service_instance.OpenSession.return_value = self.open_session_resp mock_config = {"spark.thriftserver.arrowBasedRowSet.timestampAsString": True} - backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) + backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) with self.assertRaises(databricks.sql.Error) as cm: backend.open_session(mock_config, None, None) @@ -1241,7 +1241,7 @@ def _construct_open_session_with_namespace(self, can_use_multiple_cats, cat, sch def test_initial_namespace_passthrough_to_open_session(self, tcli_client_class): tcli_service_instance = tcli_client_class.return_value - backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) + backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) initial_cat_schem_args = [("cat", None), (None, "schem"), ("cat", "schem")] for cat, schem in initial_cat_schem_args: @@ -1260,7 +1260,7 @@ def test_can_use_multiple_catalogs_is_set_in_open_session_req(self, tcli_client_ tcli_service_instance = tcli_client_class.return_value tcli_service_instance.OpenSession.return_value = self.open_session_resp - backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) + backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) backend.open_session({}, None, None) open_session_req = tcli_client_class.return_value.OpenSession.call_args[0][0] @@ -1270,7 +1270,7 @@ def test_can_use_multiple_catalogs_is_set_in_open_session_req(self, tcli_client_ def test_can_use_multiple_catalogs_is_false_fails_with_initial_catalog(self, tcli_client_class): tcli_service_instance = tcli_client_class.return_value - backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) + backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) # If the initial catalog is set, but server returns canUseMultipleCatalogs=False, we # expect failure. If the initial catalog isn't set, then canUseMultipleCatalogs=False # is fine @@ -1304,7 +1304,7 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): initialNamespace=ttypes.TNamespace(catalogName="cat", schemaName="schem") ) - backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider()) + backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) with self.assertRaises(InvalidServerResponseError) as cm: backend.open_session({}, "cat", "schem") @@ -1328,7 +1328,7 @@ def test_execute_command_sets_complex_type_fields_correctly(self, mock_handle_ex if decimals is not None: complex_arg_types["_use_arrow_native_decimals"] = decimals - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=CredentialsProvider(), **complex_arg_types) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), **complex_arg_types) thrift_backend.execute_command(Mock(), Mock(), 100, 100, Mock()) t_execute_statement_req = tcli_service_instance.ExecuteStatement.call_args[0][0] From 847667361867210cd5526982aaefd7de49e55fa5 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Tue, 13 Sep 2022 21:06:09 -0700 Subject: [PATCH 57/57] fix lint Signed-off-by: Moe Derakhshani --- src/databricks/sql/__init__.py | 1 + src/databricks/sql/auth/authenticators.py | 4 +++- src/databricks/sql/experimental/oauth_persistence.py | 10 ++++++---- tests/unit/test_oauth_persistence.py | 6 +++--- 4 files changed, 13 insertions(+), 8 deletions(-) diff --git a/src/databricks/sql/__init__.py b/src/databricks/sql/__init__.py index ef5c62124..ce1cf471a 100644 --- a/src/databricks/sql/__init__.py +++ b/src/databricks/sql/__init__.py @@ -46,4 +46,5 @@ def TimestampFromTicks(ticks): def connect(server_hostname, http_path, access_token=None, **kwargs): from .client import Connection + return Connection(server_hostname, http_path, access_token, **kwargs) diff --git a/src/databricks/sql/auth/authenticators.py b/src/databricks/sql/auth/authenticators.py index 0498cfbb2..8209931da 100644 --- a/src/databricks/sql/auth/authenticators.py +++ b/src/databricks/sql/auth/authenticators.py @@ -88,7 +88,9 @@ def _initial_get_token(self): ) self._access_token = access_token self._refresh_token = refresh_token - self._oauth_persistence.persist(self._hostname, OAuthToken(access_token, refresh_token)) + self._oauth_persistence.persist( + self._hostname, OAuthToken(access_token, refresh_token) + ) except Exception as e: logging.error(f"unexpected error in oauth initialization", e, exc_info=True) raise e diff --git a/src/databricks/sql/experimental/oauth_persistence.py b/src/databricks/sql/experimental/oauth_persistence.py index 815c2c1e5..bd0066d90 100644 --- a/src/databricks/sql/experimental/oauth_persistence.py +++ b/src/databricks/sql/experimental/oauth_persistence.py @@ -39,7 +39,7 @@ def persist(self, hostname: str, token: OAuthToken): dictionary = { "refresh_token": token.refresh_token, "access_token": token.access_token, - "hostname": hostname + "hostname": hostname, } # Serializing json @@ -56,9 +56,11 @@ def read(self, hostname: str) -> Optional[OAuthToken]: token_as_json = json.loads(json_as_string) hostname_in_token = token_as_json["hostname"] if hostname != hostname_in_token: - msg = f"token was persisted for host {hostname_in_token} does not match {hostname} " \ - f"This is a dev only persistence and it only supports a single Databricks hostname." \ - f"\n manually delete {self._file_path} file and restart this process" + msg = ( + f"token was persisted for host {hostname_in_token} does not match {hostname} " + f"This is a dev only persistence and it only supports a single Databricks hostname." + f"\n manually delete {self._file_path} file and restart this process" + ) logger.error(msg) raise Exception(msg) return OAuthToken( diff --git a/tests/unit/test_oauth_persistence.py b/tests/unit/test_oauth_persistence.py index 6c30888a8..10677c160 100644 --- a/tests/unit/test_oauth_persistence.py +++ b/tests/unit/test_oauth_persistence.py @@ -17,8 +17,8 @@ def test_DevOnlyFilePersistence_read_my_write(self): access_token = "abc#$%%^&^*&*()()_=-/" refresh_token = "#$%%^^&**()+)_gter243]xyz" token = OAuthToken(access_token=access_token, refresh_token=refresh_token) - persistence_manager.persist(token) - new_token = persistence_manager.read() + persistence_manager.persist("https://randomserver", token) + new_token = persistence_manager.read("https://randomserver") self.assertEqual(new_token.access_token, access_token) self.assertEqual(new_token.refresh_token, refresh_token) @@ -27,7 +27,7 @@ def test_DevOnlyFilePersistence_file_does_not_exist(self): with tempfile.TemporaryDirectory() as tempdir: test_json_file_path = os.path.join(tempdir, 'test.json') persistence_manager = DevOnlyFilePersistence(test_json_file_path) - new_token = persistence_manager.read() + new_token = persistence_manager.read("https://randomserver") self.assertEqual(new_token, None)