From 56a854fa89d4afb0ac5f7619ce51934095381441 Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Wed, 7 May 2025 05:22:33 +0000 Subject: [PATCH 01/46] initial commit --- .github/workflows/token-federation-test.yml | 166 ++++++++ poetry.lock | 111 +++++- pyproject.toml | 1 + src/databricks/sql/auth/auth.py | 37 ++ src/databricks/sql/auth/authenticators.py | 6 + src/databricks/sql/auth/token_federation.py | 400 ++++++++++++++++++++ 6 files changed, 710 insertions(+), 11 deletions(-) create mode 100644 .github/workflows/token-federation-test.yml create mode 100644 src/databricks/sql/auth/token_federation.py diff --git a/.github/workflows/token-federation-test.yml b/.github/workflows/token-federation-test.yml new file mode 100644 index 000000000..98ce336f6 --- /dev/null +++ b/.github/workflows/token-federation-test.yml @@ -0,0 +1,166 @@ +name: Token Federation Test + +# This workflow tests token federation functionality with GitHub Actions OIDC tokens +# in the databricks-sql-python connector to ensure CI/CD functionality + +on: + # Manual trigger with required inputs + workflow_dispatch: + inputs: + databricks_host: + description: 'Databricks host URL (e.g., example.cloud.databricks.com)' + required: true + databricks_http_path: + description: 'Databricks HTTP path (e.g., /sql/1.0/warehouses/abc123)' + required: true + identity_federation_client_id: + description: 'Identity federation client ID' + required: true + + # Automatically run on PR that changes token federation files + pull_request: + branches: + - main + + # Run on push to main that affects token federation + push: + paths: + - 'src/databricks/sql/auth/token_federation.py' + - 'src/databricks/sql/auth/auth.py' + - 'examples/token_federation_*.py' + branches: + - main + +permissions: + # Required for GitHub OIDC token + id-token: write + contents: read + +jobs: + test-token-federation: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python 3.9 + uses: actions/setup-python@v5 + with: + python-version: '3.9' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e . + pip install pyarrow + + - name: Get GitHub OIDC token + id: get-id-token + uses: actions/github-script@v7 + with: + script: | + const token = await core.getIDToken('https://github.com') + core.setSecret(token) + core.setOutput('token', token) + + - name: Create test script + run: | + cat > test_github_token_federation.py << 'EOF' + #!/usr/bin/env python3 + + """ + Test script for Databricks SQL token federation with GitHub Actions OIDC tokens. + + This script demonstrates how to use the Databricks SQL connector with token federation + using a GitHub Actions OIDC token. It connects to a Databricks SQL warehouse, + runs a simple query, and shows the connected user. + """ + + import os + import sys + import json + import base64 + from databricks import sql + + def decode_jwt(token): + """Decode and return the claims from a JWT token.""" + try: + parts = token.split(".") + if len(parts) != 3: + raise ValueError("Invalid JWT format") + + payload = parts[1] + padding = '=' * (4 - len(payload) % 4) + payload += padding + + decoded = base64.b64decode(payload) + return json.loads(decoded) + except Exception as e: + print(f"Failed to decode token: {str(e)}") + return None + + def main(): + # Get GitHub OIDC token + github_token = os.environ.get("OIDC_TOKEN") + if not github_token: + print("GitHub OIDC token not available") + sys.exit(1) + + # Get Databricks connection parameters + host = os.environ.get("DATABRICKS_HOST") + http_path = os.environ.get("DATABRICKS_HTTP_PATH") + identity_federation_client_id = os.environ.get("IDENTITY_FEDERATION_CLIENT_ID") + + if not host or not http_path: + print("Missing Databricks connection parameters") + sys.exit(1) + + claims = decode_jwt(github_token) + if claims: + print(f"Token issuer: {claims.get('iss', 'unknown')}") + print(f"Token subject: {claims.get('sub', 'unknown')}") + print(f"Token audience: {claims.get('aud', 'unknown')}") + + try: + # Connect to Databricks using token federation + print(f"Connecting to Databricks at {host}{http_path}") + with sql.connect( + server_hostname=host, + http_path=http_path, + access_token=github_token, + auth_type="token-federation", + identity_federation_client_id=identity_federation_client_id + ) as connection: + print("Connection established successfully") + + # Execute a simple query + cursor = connection.cursor() + cursor.execute("SELECT 1 + 1 as result") + result = cursor.fetchall() + print(f"Query result: {result[0][0]}") + + # Show current user + cursor.execute("SELECT current_user() as user") + result = cursor.fetchall() + print(f"Connected as user: {result[0][0]}") + + print("Token federation test successful!") + return True + except Exception as e: + print(f"Error connecting to Databricks: {str(e)}") + sys.exit(1) + + if __name__ == "__main__": + main() + EOF + chmod +x test_github_token_federation.py + + - name: Test token federation with GitHub OIDC token + env: + DATABRICKS_HOST: ${{ github.event_name == 'workflow_dispatch' && inputs.databricks_host || secrets.DATABRICKS_HOST }} + DATABRICKS_HTTP_PATH: ${{ github.event_name == 'workflow_dispatch' && inputs.databricks_http_path || secrets.DATABRICKS_HTTP_PATH }} + IDENTITY_FEDERATION_CLIENT_ID: ${{ github.event_name == 'workflow_dispatch' && inputs.identity_federation_client_id || secrets.IDENTITY_FEDERATION_CLIENT_ID }} + OIDC_TOKEN: ${{ steps.get-id-token.outputs.token }} + run: | + python test_github_token_federation.py diff --git a/poetry.lock b/poetry.lock index 1bc396c9d..5d6a0891e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. [[package]] name = "astroid" @@ -6,6 +6,7 @@ version = "3.2.4" description = "An abstract syntax tree for Python with inference support." optional = false python-versions = ">=3.8.0" +groups = ["dev"] files = [ {file = "astroid-3.2.4-py3-none-any.whl", hash = "sha256:413658a61eeca6202a59231abb473f932038fbcbf1666587f66d482083413a25"}, {file = "astroid-3.2.4.tar.gz", hash = "sha256:0e14202810b30da1b735827f78f5157be2bbd4a7a59b7707ca0bfc2fb4c0063a"}, @@ -20,6 +21,7 @@ version = "22.12.0" description = "The uncompromising code formatter." optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "black-22.12.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9eedd20838bd5d75b80c9f5487dbcb06836a43833a37846cf1d8c1cc01cef59d"}, {file = "black-22.12.0-cp310-cp310-win_amd64.whl", hash = "sha256:159a46a4947f73387b4d83e87ea006dbb2337eab6c879620a3ba52699b1f4351"}, @@ -55,6 +57,7 @@ version = "2025.1.31" description = "Python package for providing Mozilla's CA Bundle." optional = false python-versions = ">=3.6" +groups = ["main"] files = [ {file = "certifi-2025.1.31-py3-none-any.whl", hash = "sha256:ca78db4565a652026a4db2bcdf68f2fb589ea80d0be70e03929ed730746b84fe"}, {file = "certifi-2025.1.31.tar.gz", hash = "sha256:3d5da6925056f6f18f119200434a4780a94263f10d1c21d032a6f6b2baa20651"}, @@ -66,6 +69,7 @@ version = "3.4.1" description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "charset_normalizer-3.4.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:91b36a978b5ae0ee86c394f5a54d6ef44db1de0815eb43de826d41d21e4af3de"}, {file = "charset_normalizer-3.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7461baadb4dc00fd9e0acbe254e3d7d2112e7f92ced2adc96e54ef6501c5f176"}, @@ -167,6 +171,7 @@ version = "8.1.8" description = "Composable command line interface toolkit" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "click-8.1.8-py3-none-any.whl", hash = "sha256:63c132bbbed01578a06712a2d1f497bb62d9c1c0d329b7903a866228027263b2"}, {file = "click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a"}, @@ -181,6 +186,8 @@ version = "0.4.6" description = "Cross-platform colored terminal text." optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +groups = ["dev"] +markers = "sys_platform == \"win32\" or platform_system == \"Windows\"" files = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, @@ -192,6 +199,7 @@ version = "0.3.9" description = "serialize all of Python" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "dill-0.3.9-py3-none-any.whl", hash = "sha256:468dff3b89520b474c0397703366b7b95eebe6303f108adf9b19da1f702be87a"}, {file = "dill-0.3.9.tar.gz", hash = "sha256:81aa267dddf68cbfe8029c42ca9ec6a4ab3b22371d1c450abc54422577b4512c"}, @@ -207,6 +215,7 @@ version = "2.0.0" description = "An implementation of lxml.xmlfile for the standard library" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "et_xmlfile-2.0.0-py3-none-any.whl", hash = "sha256:7a91720bc756843502c3b7504c77b8fe44217c85c537d85037f0f536151b2caa"}, {file = "et_xmlfile-2.0.0.tar.gz", hash = "sha256:dab3f4764309081ce75662649be815c4c9081e88f0837825f90fd28317d4da54"}, @@ -218,6 +227,8 @@ version = "1.2.2" description = "Backport of PEP 654 (exception groups)" optional = false python-versions = ">=3.7" +groups = ["dev"] +markers = "python_version <= \"3.10\"" files = [ {file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"}, {file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"}, @@ -232,6 +243,7 @@ version = "3.10" description = "Internationalized Domain Names in Applications (IDNA)" optional = false python-versions = ">=3.6" +groups = ["main"] files = [ {file = "idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3"}, {file = "idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9"}, @@ -246,6 +258,7 @@ version = "2.1.0" description = "brain-dead simple config-ini parsing" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760"}, {file = "iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7"}, @@ -257,6 +270,7 @@ version = "5.13.2" description = "A Python utility / library to sort Python imports." optional = false python-versions = ">=3.8.0" +groups = ["dev"] files = [ {file = "isort-5.13.2-py3-none-any.whl", hash = "sha256:8ca5e72a8d85860d5a3fa69b8745237f2939afe12dbf656afbcb47fe72d947a6"}, {file = "isort-5.13.2.tar.gz", hash = "sha256:48fdfcb9face5d58a4f6dde2e72a1fb8dcaf8ab26f95ab49fab84c2ddefb0109"}, @@ -271,6 +285,7 @@ version = "4.3.3" description = "LZ4 Bindings for Python" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "lz4-4.3.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b891880c187e96339474af2a3b2bfb11a8e4732ff5034be919aa9029484cd201"}, {file = "lz4-4.3.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:222a7e35137d7539c9c33bb53fcbb26510c5748779364014235afc62b0ec797f"}, @@ -321,6 +336,7 @@ version = "0.7.0" description = "McCabe checker, plugin for flake8" optional = false python-versions = ">=3.6" +groups = ["dev"] files = [ {file = "mccabe-0.7.0-py2.py3-none-any.whl", hash = "sha256:6c2d30ab6be0e4a46919781807b4f0d834ebdd6c6e3dca0bda5a15f863427b6e"}, {file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"}, @@ -332,6 +348,7 @@ version = "1.14.1" description = "Optional static typing for Python" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "mypy-1.14.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:52686e37cf13d559f668aa398dd7ddf1f92c5d613e4f8cb262be2fb4fedb0fcb"}, {file = "mypy-1.14.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1fb545ca340537d4b45d3eecdb3def05e913299ca72c290326be19b3804b39c0"}, @@ -391,6 +408,7 @@ version = "1.0.0" description = "Type system extensions for programs checked with the mypy type checker." optional = false python-versions = ">=3.5" +groups = ["dev"] files = [ {file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"}, {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, @@ -402,6 +420,8 @@ version = "1.24.4" description = "Fundamental package for array computing in Python" optional = false python-versions = ">=3.8" +groups = ["main", "dev"] +markers = "python_version < \"3.10\"" files = [ {file = "numpy-1.24.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c0bfb52d2169d58c1cdb8cc1f16989101639b34c7d3ce60ed70b19c63eba0b64"}, {file = "numpy-1.24.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ed094d4f0c177b1b8e7aa9cba7d6ceed51c0e569a5318ac0ca9a090680a6a1b1"}, @@ -439,6 +459,8 @@ version = "2.2.4" description = "Fundamental package for array computing in Python" optional = false python-versions = ">=3.10" +groups = ["main", "dev"] +markers = "python_version >= \"3.10\"" files = [ {file = "numpy-2.2.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8146f3550d627252269ac42ae660281d673eb6f8b32f113538e0cc2a9aed42b9"}, {file = "numpy-2.2.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e642d86b8f956098b564a45e6f6ce68a22c2c97a04f5acd3f221f57b8cb850ae"}, @@ -503,6 +525,7 @@ version = "3.2.2" description = "A generic, spec-compliant, thorough implementation of the OAuth request-signing logic" optional = false python-versions = ">=3.6" +groups = ["main"] files = [ {file = "oauthlib-3.2.2-py3-none-any.whl", hash = "sha256:8139f29aac13e25d502680e9e19963e83f16838d48a0d71c287fe40e7067fbca"}, {file = "oauthlib-3.2.2.tar.gz", hash = "sha256:9859c40929662bec5d64f34d01c99e093149682a3f38915dc0655d5a633dd918"}, @@ -519,6 +542,7 @@ version = "3.1.5" description = "A Python library to read/write Excel 2010 xlsx/xlsm files" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "openpyxl-3.1.5-py2.py3-none-any.whl", hash = "sha256:5282c12b107bffeef825f4617dc029afaf41d0ea60823bbb665ef3079dc79de2"}, {file = "openpyxl-3.1.5.tar.gz", hash = "sha256:cf0e3cf56142039133628b5acffe8ef0c12bc902d2aadd3e0fe5878dc08d1050"}, @@ -533,6 +557,7 @@ version = "24.2" description = "Core utilities for Python packages" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759"}, {file = "packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f"}, @@ -544,6 +569,8 @@ version = "2.0.3" description = "Powerful data structures for data analysis, time series, and statistics" optional = false python-versions = ">=3.8" +groups = ["main"] +markers = "python_version < \"3.10\"" files = [ {file = "pandas-2.0.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e4c7c9f27a4185304c7caf96dc7d91bc60bc162221152de697c98eb0b2648dd8"}, {file = "pandas-2.0.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f167beed68918d62bffb6ec64f2e1d8a7d297a038f86d4aed056b9493fca407f"}, @@ -573,11 +600,7 @@ files = [ ] [package.dependencies] -numpy = [ - {version = ">=1.20.3", markers = "python_version < \"3.10\""}, - {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, - {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, -] +numpy = {version = ">=1.20.3", markers = "python_version < \"3.10\""} python-dateutil = ">=2.8.2" pytz = ">=2020.1" tzdata = ">=2022.1" @@ -611,6 +634,8 @@ version = "2.2.3" description = "Powerful data structures for data analysis, time series, and statistics" optional = false python-versions = ">=3.9" +groups = ["main"] +markers = "python_version >= \"3.10\"" files = [ {file = "pandas-2.2.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:1948ddde24197a0f7add2bdc4ca83bf2b1ef84a1bc8ccffd95eda17fd836ecb5"}, {file = "pandas-2.2.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:381175499d3802cde0eabbaf6324cce0c4f5d52ca6f8c377c29ad442f50f6348"}, @@ -657,7 +682,11 @@ files = [ ] [package.dependencies] -numpy = {version = ">=1.26.0", markers = "python_version >= \"3.12\""} +numpy = [ + {version = ">=1.22.4", markers = "python_version < \"3.11\""}, + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, + {version = ">=1.23.2", markers = "python_version == \"3.11\""}, +] python-dateutil = ">=2.8.2" pytz = ">=2020.1" tzdata = ">=2022.7" @@ -693,6 +722,7 @@ version = "0.12.1" description = "Utility library for gitignore style pattern matching of file paths." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08"}, {file = "pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712"}, @@ -704,6 +734,7 @@ version = "4.3.6" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a `user data dir`." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "platformdirs-4.3.6-py3-none-any.whl", hash = "sha256:73e575e1408ab8103900836b97580d5307456908a03e92031bab39e4554cc3fb"}, {file = "platformdirs-4.3.6.tar.gz", hash = "sha256:357fb2acbc885b0419afd3ce3ed34564c13c9b95c89360cd9563f73aa5e2b907"}, @@ -720,6 +751,7 @@ version = "1.5.0" description = "plugin and hook calling mechanisms for python" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669"}, {file = "pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1"}, @@ -735,6 +767,8 @@ version = "17.0.0" description = "Python library for Apache Arrow" optional = true python-versions = ">=3.8" +groups = ["main"] +markers = "python_version < \"3.10\" and extra == \"pyarrow\"" files = [ {file = "pyarrow-17.0.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:a5c8b238d47e48812ee577ee20c9a2779e6a5904f1708ae240f53ecbee7c9f07"}, {file = "pyarrow-17.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:db023dc4c6cae1015de9e198d41250688383c3f9af8f565370ab2b4cb5f62655"}, @@ -786,6 +820,8 @@ version = "19.0.1" description = "Python library for Apache Arrow" optional = true python-versions = ">=3.9" +groups = ["main"] +markers = "python_version >= \"3.10\" and extra == \"pyarrow\"" files = [ {file = "pyarrow-19.0.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:fc28912a2dc924dddc2087679cc8b7263accc71b9ff025a1362b004711661a69"}, {file = "pyarrow-19.0.1-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:fca15aabbe9b8355800d923cc2e82c8ef514af321e18b437c3d782aa884eaeec"}, @@ -834,12 +870,51 @@ files = [ [package.extras] test = ["cffi", "hypothesis", "pandas", "pytest", "pytz"] +[[package]] +name = "pyjwt" +version = "2.9.0" +description = "JSON Web Token implementation in Python" +optional = false +python-versions = ">=3.8" +groups = ["main"] +markers = "python_version < \"3.10\"" +files = [ + {file = "PyJWT-2.9.0-py3-none-any.whl", hash = "sha256:3b02fb0f44517787776cf48f2ae25d8e14f300e6d7545a4315cee571a415e850"}, + {file = "pyjwt-2.9.0.tar.gz", hash = "sha256:7e1e5b56cc735432a7369cbfa0efe50fa113ebecdc04ae6922deba8b84582d0c"}, +] + +[package.extras] +crypto = ["cryptography (>=3.4.0)"] +dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.4.0)", "pre-commit", "pytest (>=6.0.0,<7.0.0)", "sphinx", "sphinx-rtd-theme", "zope.interface"] +docs = ["sphinx", "sphinx-rtd-theme", "zope.interface"] +tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"] + +[[package]] +name = "pyjwt" +version = "2.10.1" +description = "JSON Web Token implementation in Python" +optional = false +python-versions = ">=3.9" +groups = ["main"] +markers = "python_version >= \"3.10\"" +files = [ + {file = "PyJWT-2.10.1-py3-none-any.whl", hash = "sha256:dcdd193e30abefd5debf142f9adfcdd2b58004e644f25406ffaebd50bd98dacb"}, + {file = "pyjwt-2.10.1.tar.gz", hash = "sha256:3cc5772eb20009233caf06e9d8a0577824723b44e6648ee0a2aedb6cf9381953"}, +] + +[package.extras] +crypto = ["cryptography (>=3.4.0)"] +dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.4.0)", "pre-commit", "pytest (>=6.0.0,<7.0.0)", "sphinx", "sphinx-rtd-theme", "zope.interface"] +docs = ["sphinx", "sphinx-rtd-theme", "zope.interface"] +tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"] + [[package]] name = "pylint" version = "3.2.7" description = "python code static checker" optional = false python-versions = ">=3.8.0" +groups = ["dev"] files = [ {file = "pylint-3.2.7-py3-none-any.whl", hash = "sha256:02f4aedeac91be69fb3b4bea997ce580a4ac68ce58b89eaefeaf06749df73f4b"}, {file = "pylint-3.2.7.tar.gz", hash = "sha256:1b7a721b575eaeaa7d39db076b6e7743c993ea44f57979127c517c6c572c803e"}, @@ -851,7 +926,7 @@ colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""} dill = [ {version = ">=0.2", markers = "python_version < \"3.11\""}, {version = ">=0.3.7", markers = "python_version >= \"3.12\""}, - {version = ">=0.3.6", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, + {version = ">=0.3.6", markers = "python_version == \"3.11\""}, ] isort = ">=4.2.5,<5.13.0 || >5.13.0,<6" mccabe = ">=0.6,<0.8" @@ -870,6 +945,7 @@ version = "7.4.4" description = "pytest: simple powerful testing with Python" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "pytest-7.4.4-py3-none-any.whl", hash = "sha256:b090cdf5ed60bf4c45261be03239c2c1c22df034fbffe691abe93cd80cea01d8"}, {file = "pytest-7.4.4.tar.gz", hash = "sha256:2cf0005922c6ace4a3e2ec8b4080eb0d9753fdc93107415332f50ce9e7994280"}, @@ -892,6 +968,7 @@ version = "0.5.2" description = "A py.test plugin that parses environment files before running tests" optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "pytest-dotenv-0.5.2.tar.gz", hash = "sha256:2dc6c3ac6d8764c71c6d2804e902d0ff810fa19692e95fe138aefc9b1aa73732"}, {file = "pytest_dotenv-0.5.2-py3-none-any.whl", hash = "sha256:40a2cece120a213898afaa5407673f6bd924b1fa7eafce6bda0e8abffe2f710f"}, @@ -907,6 +984,7 @@ version = "2.9.0.post0" description = "Extensions to the standard Python datetime module" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" +groups = ["main"] files = [ {file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"}, {file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"}, @@ -921,6 +999,7 @@ version = "1.0.1" description = "Read key-value pairs from a .env file and set them as environment variables" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "python-dotenv-1.0.1.tar.gz", hash = "sha256:e324ee90a023d808f1959c46bcbc04446a10ced277783dc6ee09987c37ec10ca"}, {file = "python_dotenv-1.0.1-py3-none-any.whl", hash = "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a"}, @@ -935,6 +1014,7 @@ version = "2025.2" description = "World timezone definitions, modern and historical" optional = false python-versions = "*" +groups = ["main"] files = [ {file = "pytz-2025.2-py2.py3-none-any.whl", hash = "sha256:5ddf76296dd8c44c26eb8f4b6f35488f3ccbf6fbbd7adee0b7262d43f0ec2f00"}, {file = "pytz-2025.2.tar.gz", hash = "sha256:360b9e3dbb49a209c21ad61809c7fb453643e048b38924c765813546746e81c3"}, @@ -946,6 +1026,7 @@ version = "2.32.3" description = "Python HTTP for Humans." optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6"}, {file = "requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760"}, @@ -967,6 +1048,7 @@ version = "1.17.0" description = "Python 2 and 3 compatibility utilities" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" +groups = ["main"] files = [ {file = "six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274"}, {file = "six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81"}, @@ -978,6 +1060,7 @@ version = "0.20.0" description = "Python bindings for the Apache Thrift RPC system" optional = false python-versions = "*" +groups = ["main"] files = [ {file = "thrift-0.20.0.tar.gz", hash = "sha256:4dd662eadf6b8aebe8a41729527bd69adf6ceaa2a8681cbef64d1273b3e8feba"}, ] @@ -996,6 +1079,8 @@ version = "2.2.1" description = "A lil' TOML parser" optional = false python-versions = ">=3.8" +groups = ["dev"] +markers = "python_version <= \"3.10\"" files = [ {file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"}, {file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"}, @@ -1037,6 +1122,7 @@ version = "0.13.2" description = "Style preserving TOML library" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "tomlkit-0.13.2-py3-none-any.whl", hash = "sha256:7a974427f6e119197f670fbbbeae7bef749a6c14e793db934baefc1b5f03efde"}, {file = "tomlkit-0.13.2.tar.gz", hash = "sha256:fff5fe59a87295b278abd31bec92c15d9bc4a06885ab12bcea52c71119392e79"}, @@ -1048,6 +1134,7 @@ version = "4.13.0" description = "Backported and Experimental Type Hints for Python 3.8+" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "typing_extensions-4.13.0-py3-none-any.whl", hash = "sha256:c8dd92cc0d6425a97c18fbb9d1954e5ff92c1ca881a309c45f06ebc0b79058e5"}, {file = "typing_extensions-4.13.0.tar.gz", hash = "sha256:0a4ac55a5820789d87e297727d229866c9650f6521b64206413c4fbada24d95b"}, @@ -1059,6 +1146,7 @@ version = "2025.2" description = "Provider of IANA time zone data" optional = false python-versions = ">=2" +groups = ["main"] files = [ {file = "tzdata-2025.2-py2.py3-none-any.whl", hash = "sha256:1a403fada01ff9221ca8044d701868fa132215d84beb92242d9acd2147f667a8"}, {file = "tzdata-2025.2.tar.gz", hash = "sha256:b60a638fcc0daffadf82fe0f57e53d06bdec2f36c4df66280ae79bce6bd6f2b9"}, @@ -1070,13 +1158,14 @@ version = "2.2.3" description = "HTTP library with thread-safe connection pooling, file post, and more." optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "urllib3-2.2.3-py3-none-any.whl", hash = "sha256:ca899ca043dcb1bafa3e262d73aa25c465bfb49e0bd9dd5d59f1d0acba2f8fac"}, {file = "urllib3-2.2.3.tar.gz", hash = "sha256:e7d814a81dad81e6caf2ec9fdedb284ecc9c73076b62654547cc64ccdcae26e9"}, ] [package.extras] -brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] +brotli = ["brotli (>=1.0.9) ; platform_python_implementation == \"CPython\"", "brotlicffi (>=0.8.0) ; platform_python_implementation != \"CPython\""] h2 = ["h2 (>=4,<5)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] @@ -1085,6 +1174,6 @@ zstd = ["zstandard (>=0.18.0)"] pyarrow = ["pyarrow", "pyarrow"] [metadata] -lock-version = "2.0" +lock-version = "2.1" python-versions = "^3.8.0" -content-hash = "0bd6a6a019693a69a3da5ae312cea625ea73dfc5832b1e4051c7c7d1e76553d8" +content-hash = "118b7702637d44a7fee4107b471528b14c436bdb01d3618676bc50bbebc6ab65" diff --git a/pyproject.toml b/pyproject.toml index 7b95a5097..d40255a24 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ pyarrow = [ { version = ">=18.0.0", python = ">=3.13", optional=true } ] python-dateutil = "^2.8.0" +PyJWT = ">=2.0.0" [tool.poetry.extras] pyarrow = ["pyarrow"] diff --git a/src/databricks/sql/auth/auth.py b/src/databricks/sql/auth/auth.py index 347934ee4..635563ce0 100755 --- a/src/databricks/sql/auth/auth.py +++ b/src/databricks/sql/auth/auth.py @@ -5,6 +5,7 @@ AuthProvider, AccessTokenAuthProvider, ExternalAuthProvider, + CredentialsProvider, DatabricksOAuthProvider, ) @@ -12,6 +13,7 @@ class AuthType(Enum): DATABRICKS_OAUTH = "databricks-oauth" AZURE_OAUTH = "azure-oauth" + TOKEN_FEDERATION = "token-federation" # other supported types (access_token) can be inferred # we can add more types as needed later @@ -29,6 +31,7 @@ def __init__( tls_client_cert_file: Optional[str] = None, oauth_persistence=None, credentials_provider=None, + identity_federation_client_id: Optional[str] = None, ): self.hostname = hostname self.access_token = access_token @@ -40,11 +43,44 @@ def __init__( self.tls_client_cert_file = tls_client_cert_file self.oauth_persistence = oauth_persistence self.credentials_provider = credentials_provider + self.identity_federation_client_id = identity_federation_client_id def get_auth_provider(cfg: ClientContext): if cfg.credentials_provider: + # If token federation is enabled and credentials provider is provided, + # wrap the credentials provider with DatabricksTokenFederationProvider + if cfg.auth_type == AuthType.TOKEN_FEDERATION.value: + from databricks.sql.auth.token_federation import DatabricksTokenFederationProvider + federation_provider = DatabricksTokenFederationProvider( + cfg.credentials_provider, + cfg.hostname, + cfg.identity_federation_client_id + ) + return ExternalAuthProvider(federation_provider) + + # If access token is provided with token federation, create a SimpleCredentialsProvider + elif cfg.auth_type == AuthType.TOKEN_FEDERATION.value and cfg.access_token: + from databricks.sql.auth.token_federation import create_token_federation_provider + federation_provider = create_token_federation_provider( + cfg.access_token, + cfg.hostname, + cfg.identity_federation_client_id + ) + return ExternalAuthProvider(federation_provider) + return ExternalAuthProvider(cfg.credentials_provider) + + if cfg.auth_type == AuthType.TOKEN_FEDERATION.value and cfg.access_token: + # If only access_token is provided with token federation, use create_token_federation_provider + from databricks.sql.auth.token_federation import create_token_federation_provider + federation_provider = create_token_federation_provider( + cfg.access_token, + cfg.hostname, + cfg.identity_federation_client_id + ) + return ExternalAuthProvider(federation_provider) + if cfg.auth_type in [AuthType.DATABRICKS_OAUTH.value, AuthType.AZURE_OAUTH.value]: assert cfg.oauth_redirect_port_range is not None assert cfg.oauth_client_id is not None @@ -125,5 +161,6 @@ def get_python_sql_connector_auth_provider(hostname: str, **kwargs): else redirect_port_range, oauth_persistence=kwargs.get("experimental_oauth_persistence"), credentials_provider=kwargs.get("credentials_provider"), + identity_federation_client_id=kwargs.get("identity_federation_client_id"), ) return get_auth_provider(cfg) diff --git a/src/databricks/sql/auth/authenticators.py b/src/databricks/sql/auth/authenticators.py index 64eb91bb0..c425f0888 100644 --- a/src/databricks/sql/auth/authenticators.py +++ b/src/databricks/sql/auth/authenticators.py @@ -26,10 +26,16 @@ class CredentialsProvider(abc.ABC): @abc.abstractmethod def auth_type(self) -> str: + """ + Returns the authentication type for this provider + """ ... @abc.abstractmethod def __call__(self, *args, **kwargs) -> HeaderFactory: + """ + Configure and return a HeaderFactory that provides authentication headers + """ ... diff --git a/src/databricks/sql/auth/token_federation.py b/src/databricks/sql/auth/token_federation.py new file mode 100644 index 000000000..c20dd0eb1 --- /dev/null +++ b/src/databricks/sql/auth/token_federation.py @@ -0,0 +1,400 @@ +import base64 +import json +import logging +import urllib.parse +from datetime import datetime, timezone, timedelta +from typing import Dict, Optional, Any, Tuple, List, Union +from urllib.parse import urlparse + +import requests +from requests.exceptions import RequestException + +from databricks.sql.auth.authenticators import CredentialsProvider, HeaderFactory +from databricks.sql.auth.endpoint import get_databricks_oidc_url, get_oauth_endpoints, infer_cloud_from_host + +logger = logging.getLogger(__name__) + +TOKEN_EXCHANGE_PARAMS = { + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "scope": "sql", + "subject_token_type": "urn:ietf:params:oauth:token-type:jwt", + "return_original_token_if_authenticated": "true" +} + +# Special client IDs for different IdPs +AZURE_AD_MULTI_TENANT_APP_ID = "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d" + +# Buffer time in seconds before token expiry to trigger a refresh (5 minutes) +TOKEN_REFRESH_BUFFER_SECONDS = 300 + +class Token: + """Represents an OAuth token with expiry information.""" + + def __init__(self, access_token: str, token_type: str, refresh_token: str = "", expiry: Optional[datetime] = None): + self.access_token = access_token + self.token_type = token_type + self.refresh_token = refresh_token + self.expiry = expiry or datetime.now(tz=timezone.utc) + + def is_expired(self) -> bool: + """Check if the token is expired.""" + return datetime.now(tz=timezone.utc) >= self.expiry + + def needs_refresh(self) -> bool: + """Check if the token needs to be refreshed soon.""" + buffer_time = timedelta(seconds=TOKEN_REFRESH_BUFFER_SECONDS) + return datetime.now(tz=timezone.utc) >= (self.expiry - buffer_time) + + def __str__(self) -> str: + return f"{self.token_type} {self.access_token}" + + +class DatabricksTokenFederationProvider(CredentialsProvider): + """ + Implementation of the Credential Provider that exchanges the third party access token + for a Databricks InHouse Token. This class exchanges the access token if the issued token + is not from the same host as the Databricks host. + """ + + def __init__(self, credentials_provider: CredentialsProvider, hostname: str, + identity_federation_client_id: Optional[str] = None): + """ + Initialize the token federation provider. + + Args: + credentials_provider: The underlying credentials provider + hostname: The Databricks hostname + identity_federation_client_id: Optional client ID for identity federation + """ + self.credentials_provider = credentials_provider + self.hostname = hostname + self.identity_federation_client_id = identity_federation_client_id + self.external_provider_headers = {} + self.token = None + self.token_endpoint = None + self.idp_endpoints = None + self.openid_config = None + self.last_exchanged_token = None + self.last_external_token = None + + def auth_type(self) -> str: + """Return the auth type from the underlying credentials provider.""" + return self.credentials_provider.auth_type() + + def __call__(self, *args, **kwargs) -> HeaderFactory: + """ + Configure and return a HeaderFactory that provides authentication headers. + + This is called by the ExternalAuthProvider to get headers for authentication. + """ + # First call the underlying credentials provider to get its headers + header_factory = self.credentials_provider(*args, **kwargs) + + # Initialize OIDC discovery + self._init_oidc_discovery() + + def get_headers() -> Dict[str, str]: + # Get headers from the underlying provider + self.external_provider_headers = header_factory() + + # Extract the token from the headers + token_info = self._extract_token_info_from_header(self.external_provider_headers) + token_type, access_token = token_info + + try: + # Check if we need to refresh the token + if (self.last_exchanged_token and self.last_external_token == access_token and + self.last_exchanged_token.needs_refresh()): + # The token is approaching expiry, try to refresh + logger.debug("Exchanged token approaching expiry, refreshing...") + return self._refresh_token(access_token, token_type) + + # Parse the JWT to get claims + token_claims = self._parse_jwt_claims(access_token) + + # Check if token needs to be exchanged + if self._is_same_host(token_claims.get("iss", ""), self.hostname): + # Token is from the same host, no need to exchange + return self.external_provider_headers + else: + # Token is from a different host, need to exchange + return self._try_token_exchange_or_fallback(access_token, token_type) + + except Exception as e: + logger.error(f"Failed to process token: {str(e)}") + # Fall back to original headers in case of error + return self.external_provider_headers + + return get_headers + + def _init_oidc_discovery(self): + """Initialize OIDC discovery to find token endpoint.""" + if self.token_endpoint is not None: + return + + try: + # Use the existing OIDC discovery mechanism + use_azure_auth = infer_cloud_from_host(self.hostname) == "azure" + self.idp_endpoints = get_oauth_endpoints(self.hostname, use_azure_auth) + + if self.idp_endpoints: + # Get the OpenID configuration URL + openid_config_url = self.idp_endpoints.get_openid_config_url(self.hostname) + + # Fetch the OpenID configuration + response = requests.get(openid_config_url) + if response.status_code == 200: + self.openid_config = response.json() + # Extract token endpoint from OpenID config + self.token_endpoint = self.openid_config.get("token_endpoint") + logger.info(f"Discovered token endpoint: {self.token_endpoint}") + else: + logger.warning(f"Failed to fetch OpenID configuration from {openid_config_url}: {response.status_code}") + + # Fallback to default token endpoint if discovery fails + if not self.token_endpoint: + self.token_endpoint = f"{self.hostname}oidc/v1/token" + logger.info(f"Using default token endpoint: {self.token_endpoint}") + + except Exception as e: + logger.warning(f"OIDC discovery failed: {str(e)}. Using default token endpoint.") + self.token_endpoint = f"{self.hostname}oidc/v1/token" + + def _extract_token_info_from_header(self, headers: Dict[str, str]) -> Tuple[str, str]: + """Extract token type and token value from authorization header.""" + auth_header = headers.get("Authorization") + if not auth_header: + raise ValueError("No Authorization header found") + + parts = auth_header.split(" ", 1) + if len(parts) != 2: + raise ValueError(f"Invalid Authorization header format: {auth_header}") + + return parts[0], parts[1] + + def _parse_jwt_claims(self, token: str) -> Dict[str, Any]: + """Parse JWT token claims without validation.""" + try: + # Split the token + parts = token.split(".") + if len(parts) != 3: + raise ValueError("Invalid JWT format") + + # Get the payload part (second part) + payload = parts[1] + + # Add padding if needed + padding = '=' * (4 - len(payload) % 4) + payload += padding + + # Decode and parse JSON + decoded = base64.b64decode(payload) + return json.loads(decoded) + except Exception as e: + logger.error(f"Failed to parse JWT: {str(e)}") + raise + + def _detect_idp_from_claims(self, token_claims: Dict[str, Any]) -> str: + """ + Detect the identity provider type from token claims. + + This can be used to adjust token exchange parameters based on the IdP. + """ + issuer = token_claims.get("iss", "") + + if "login.microsoftonline.com" in issuer or "sts.windows.net" in issuer: + return "azure" + elif "token.actions.githubusercontent.com" in issuer: + return "github" + elif "accounts.google.com" in issuer: + return "google" + elif "cognito-idp" in issuer and "amazonaws.com" in issuer: + return "aws" + else: + return "unknown" + + def _is_same_host(self, url1: str, url2: str) -> bool: + """Check if two URLs have the same host.""" + try: + host1 = urlparse(url1).netloc + host2 = urlparse(url2).netloc + # If host1 is empty, it's not a valid URL, so we return False + if not host1: + return False + return host1 == host2 + except Exception as e: + logger.error(f"Failed to parse URLs: {str(e)}") + return False + + def _refresh_token(self, access_token: str, token_type: str) -> Dict[str, str]: + """ + Attempt to refresh an expired token. + + For most OAuth implementations, refreshing involves a new token exchange + with the latest external token. + + Args: + access_token: The original external access token + token_type: The token type (Bearer, etc.) + + Returns: + The headers with the fresh token + """ + try: + logger.info("Refreshing expired token via new token exchange") + # For most federation implementations, refresh is just a new token exchange + token_claims = self._parse_jwt_claims(access_token) + idp_type = self._detect_idp_from_claims(token_claims) + + # Perform a new token exchange + refreshed_token = self._exchange_token(access_token, idp_type) + + # Update the stored token + self.last_exchanged_token = refreshed_token + self.last_external_token = access_token + + # Create new headers with the refreshed token + headers = dict(self.external_provider_headers) + headers["Authorization"] = f"{refreshed_token.token_type} {refreshed_token.access_token}" + return headers + except Exception as e: + logger.error(f"Token refresh failed, falling back to original token: {str(e)}") + # If refresh fails, fall back to the original headers + return self.external_provider_headers + + def _try_token_exchange_or_fallback(self, access_token: str, token_type: str) -> Dict[str, str]: + """Try to exchange the token or fall back to the original token.""" + try: + # Parse the token to get claims for IdP-specific adjustments + token_claims = self._parse_jwt_claims(access_token) + idp_type = self._detect_idp_from_claims(token_claims) + + # Exchange the token + exchanged_token = self._exchange_token(access_token, idp_type) + + # Store the exchanged token for potential refresh later + self.last_exchanged_token = exchanged_token + self.last_external_token = access_token + + # Create new headers with the exchanged token + headers = dict(self.external_provider_headers) + headers["Authorization"] = f"{exchanged_token.token_type} {exchanged_token.access_token}" + return headers + except Exception as e: + logger.error(f"Token exchange failed, falling back to using external token: {str(e)}") + # Fall back to original headers + return self.external_provider_headers + + def _exchange_token(self, access_token: str, idp_type: str = "unknown") -> Token: + """ + Exchange an external token for a Databricks token. + + Args: + access_token: The external token to exchange + idp_type: The detected identity provider type (azure, github, etc.) + + Returns: + A Token object containing the exchanged token + """ + if not self.token_endpoint: + self._init_oidc_discovery() + + # Create request parameters + params = dict(TOKEN_EXCHANGE_PARAMS) + params["subject_token"] = access_token + + # Add client ID if available + if self.identity_federation_client_id: + params["client_id"] = self.identity_federation_client_id + + # Make IdP-specific adjustments + if idp_type == "azure": + # For Azure AD, add special handling if needed + pass + elif idp_type == "github": + # For GitHub Actions, add special handling if needed + pass + + # Set up headers + headers = { + "Accept": "*/*", + "Content-Type": "application/x-www-form-urlencoded" + } + + try: + # Make the token exchange request + response = requests.post(self.token_endpoint, data=params, headers=headers) + response.raise_for_status() + + # Parse the response + resp_data = response.json() + + # Create a token from the response + token = Token( + access_token=resp_data.get("access_token"), + token_type=resp_data.get("token_type", "Bearer"), + refresh_token=resp_data.get("refresh_token", ""), + ) + + # Set expiry time from the response's expires_in field if available + # This is the standard OAuth approach + if "expires_in" in resp_data and resp_data["expires_in"]: + try: + # Calculate expiry by adding expires_in seconds to current time + expires_in_seconds = int(resp_data["expires_in"]) + token.expiry = datetime.now(tz=timezone.utc) + timedelta(seconds=expires_in_seconds) + logger.debug(f"Token expiry set from expires_in: {token.expiry}") + except (ValueError, TypeError) as e: + logger.warning(f"Could not parse expires_in from response: {str(e)}") + + # If expires_in wasn't available, try to parse expiry from the token JWT + if token.expiry == datetime.now(tz=timezone.utc): + try: + token_claims = self._parse_jwt_claims(token.access_token) + exp_time = token_claims.get("exp") + if exp_time: + token.expiry = datetime.fromtimestamp(exp_time, tz=timezone.utc) + logger.debug(f"Token expiry set from JWT exp claim: {token.expiry}") + except Exception as e: + logger.warning(f"Could not parse expiry from token: {str(e)}") + + return token + except RequestException as e: + logger.error(f"Failed to perform token exchange: {str(e)}") + raise + + +class SimpleCredentialsProvider(CredentialsProvider): + """A simple credentials provider that returns fixed headers.""" + + def __init__(self, token: str, token_type: str = "Bearer", auth_type_value: str = "token"): + self.token = token + self.token_type = token_type + self._auth_type = auth_type_value + + def auth_type(self) -> str: + return self._auth_type + + def __call__(self, *args, **kwargs) -> HeaderFactory: + def get_headers() -> Dict[str, str]: + return {"Authorization": f"{self.token_type} {self.token}"} + return get_headers + + +def create_token_federation_provider(token: str, hostname: str, + identity_federation_client_id: Optional[str] = None, + token_type: str = "Bearer") -> DatabricksTokenFederationProvider: + """ + Create a token federation provider using a simple token. + + Args: + token: The token to use + hostname: The Databricks hostname + identity_federation_client_id: Optional client ID for identity federation + token_type: The token type (default: "Bearer") + + Returns: + A DatabricksTokenFederationProvider + """ + provider = SimpleCredentialsProvider(token, token_type) + return DatabricksTokenFederationProvider(provider, hostname, identity_federation_client_id) \ No newline at end of file From aedb3bf60ad46da77dc50c143c3524a498626aa1 Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Wed, 7 May 2025 05:46:14 +0000 Subject: [PATCH 02/46] update vars --- .github/workflows/token-federation-test.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/token-federation-test.yml b/.github/workflows/token-federation-test.yml index 98ce336f6..bdb95753b 100644 --- a/.github/workflows/token-federation-test.yml +++ b/.github/workflows/token-federation-test.yml @@ -108,9 +108,9 @@ jobs: sys.exit(1) # Get Databricks connection parameters - host = os.environ.get("DATABRICKS_HOST") - http_path = os.environ.get("DATABRICKS_HTTP_PATH") - identity_federation_client_id = os.environ.get("IDENTITY_FEDERATION_CLIENT_ID") + host = os.environ.get("DATABRICKS_HOST_FOR_TF") + http_path = os.environ.get("DATABRICKS_HTTP_PATH_FOR_TF") + identity_federation_client_id = os.environ.get("IDENTITY_FEDERATION_CLIENT_ID_FOR_TF") if not host or not http_path: print("Missing Databricks connection parameters") @@ -158,8 +158,8 @@ jobs: - name: Test token federation with GitHub OIDC token env: - DATABRICKS_HOST: ${{ github.event_name == 'workflow_dispatch' && inputs.databricks_host || secrets.DATABRICKS_HOST }} - DATABRICKS_HTTP_PATH: ${{ github.event_name == 'workflow_dispatch' && inputs.databricks_http_path || secrets.DATABRICKS_HTTP_PATH }} + DATABRICKS_HOST_FOR_TF: ${{ github.event_name == 'workflow_dispatch' && inputs.databricks_host || secrets.DATABRICKS_HOST_FOR_TF }} + DATABRICKS_HTTP_PATH_FOR_TF: ${{ github.event_name == 'workflow_dispatch' && inputs.databricks_http_path || secrets.DATABRICKS_HTTP_PATH_FOR_TF }} IDENTITY_FEDERATION_CLIENT_ID: ${{ github.event_name == 'workflow_dispatch' && inputs.identity_federation_client_id || secrets.IDENTITY_FEDERATION_CLIENT_ID }} OIDC_TOKEN: ${{ steps.get-id-token.outputs.token }} run: | From d06672c8af1e535a1b7a4425eb157dc8d216256f Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Wed, 7 May 2025 05:56:18 +0000 Subject: [PATCH 03/46] mod --- .github/workflows/token-federation-test.yml | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/.github/workflows/token-federation-test.yml b/.github/workflows/token-federation-test.yml index bdb95753b..fda0133f7 100644 --- a/.github/workflows/token-federation-test.yml +++ b/.github/workflows/token-federation-test.yml @@ -38,9 +38,16 @@ permissions: jobs: test-token-federation: - runs-on: ubuntu-latest + runs-on: + group: databricks-protected-runner-group + labels: linux-ubuntu-latest steps: + - name: Debug OIDC Claims + uses: github/actions-oidc-debugger@main + with: + audience: '${{ github.server_url }}/${{ github.repository_owner }}' + - name: Checkout code uses: actions/checkout@v4 From 9aff81123bf2f22cbdd3c62d2c8938c1af3ae083 Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Wed, 7 May 2025 06:16:54 +0000 Subject: [PATCH 04/46] debugging patch --- .github/workflows/token-federation-test.yml | 39 +++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/.github/workflows/token-federation-test.yml b/.github/workflows/token-federation-test.yml index fda0133f7..655bf6236 100644 --- a/.github/workflows/token-federation-test.yml +++ b/.github/workflows/token-federation-test.yml @@ -62,6 +62,45 @@ jobs: pip install -e . pip install pyarrow + - name: Create debugging patch script + run: | + cat > patch_for_debugging.py << 'EOF' + #!/usr/bin/env python3 + + def patch_code(): + with open('src/databricks/sql/auth/token_federation.py', 'r') as f: + content = f.read() + + # Add verbose request debugging + modified = content.replace( + 'try:\n # Make the token exchange request', + 'try:\n import urllib.parse\n # Debug full request\n print(f"Token endpoint: {self.token_endpoint}")\n print(f"Request parameters: {urllib.parse.urlencode(params)}")\n print(f"Request headers: {headers}")\n # Make the token exchange request' + ) + + # Add verbose response debugging + modified = modified.replace( + 'response = requests.post(self.token_endpoint, data=params, headers=headers)', + 'response = requests.post(self.token_endpoint, data=params, headers=headers)\n print(f"Response status: {response.status_code}")\n print(f"Response headers: {dict(response.headers)}")\n print(f"Response body: {response.text}")' + ) + + # Improve error handling + modified = modified.replace( + 'except RequestException as e:', + 'except RequestException as e:\n if hasattr(e, "response") and e.response:\n print(f"Error response status: {e.response.status_code}")\n print(f"Error response headers: {dict(e.response.headers)}")\n print(f"Error response text: {e.response.text}")' + ) + + with open('src/databricks/sql/auth/token_federation.py', 'w') as f: + f.write(modified) + + if __name__ == "__main__": + patch_code() + EOF + + chmod +x patch_for_debugging.py + + - name: Apply debugging patches to token_federation.py + run: python patch_for_debugging.py + - name: Get GitHub OIDC token id: get-id-token uses: actions/github-script@v7 From 299b5ae967eb923be4a2d141a2f413ece191d381 Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Wed, 7 May 2025 06:35:50 +0000 Subject: [PATCH 05/46] mod --- .github/workflows/token-federation-test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/token-federation-test.yml b/.github/workflows/token-federation-test.yml index 655bf6236..fc7ee9840 100644 --- a/.github/workflows/token-federation-test.yml +++ b/.github/workflows/token-federation-test.yml @@ -106,7 +106,7 @@ jobs: uses: actions/github-script@v7 with: script: | - const token = await core.getIDToken('https://github.com') + const token = await core.getIDToken('https://github.com/databricks') core.setSecret(token) core.setOutput('token', token) From 10a501686b41ffbddd2b8eb25ec201c47174e6de Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Wed, 7 May 2025 06:52:04 +0000 Subject: [PATCH 06/46] debug --- .github/workflows/token-federation-test.yml | 288 +++++++++++++++++++- 1 file changed, 277 insertions(+), 11 deletions(-) diff --git a/.github/workflows/token-federation-test.yml b/.github/workflows/token-federation-test.yml index fc7ee9840..1ef333816 100644 --- a/.github/workflows/token-federation-test.yml +++ b/.github/workflows/token-federation-test.yml @@ -71,10 +71,16 @@ jobs: with open('src/databricks/sql/auth/token_federation.py', 'r') as f: content = f.read() - # Add verbose request debugging + # Add token debugging modified = content.replace( + 'def _exchange_token(self, token, force_refresh=False):', + 'def _exchange_token(self, token, force_refresh=False):\n # Debug token info\n import jwt\n try:\n decoded = jwt.decode(token, options={"verify_signature": False})\n print(f"Token issuer: {decoded.get(\'iss\')}")\n print(f"Token subject: {decoded.get(\'sub\')}")\n print(f"Token audience: {decoded.get(\'aud\') if isinstance(decoded.get(\'aud\'), str) else decoded.get(\'aud\', [])[0] if decoded.get(\'aud\') else \'\'}")\n except Exception as e:\n print(f"Unable to decode token: {str(e)}")' + ) + + # Add verbose request debugging + modified = modified.replace( 'try:\n # Make the token exchange request', - 'try:\n import urllib.parse\n # Debug full request\n print(f"Token endpoint: {self.token_endpoint}")\n print(f"Request parameters: {urllib.parse.urlencode(params)}")\n print(f"Request headers: {headers}")\n # Make the token exchange request' + 'try:\n import urllib.parse\n # Debug full request\n print(f"Connecting to Databricks at {self.host}")\n print(f"Token endpoint: {self.token_endpoint}")\n print(f"Request parameters: {urllib.parse.urlencode(params)}")\n print(f"Request headers: {headers}")\n # Make the token exchange request' ) # Add verbose response debugging @@ -86,7 +92,7 @@ jobs: # Improve error handling modified = modified.replace( 'except RequestException as e:', - 'except RequestException as e:\n if hasattr(e, "response") and e.response:\n print(f"Error response status: {e.response.status_code}")\n print(f"Error response headers: {dict(e.response.headers)}")\n print(f"Error response text: {e.response.text}")' + 'except RequestException as e:\n print(f"Failed to perform token exchange: {str(e)}")\n if hasattr(e, "response") and e.response:\n print(f"Error response status: {e.response.status_code}")\n print(f"Error response headers: {dict(e.response.headers)}")\n print(f"Error response text: {e.response.text}")' ) with open('src/databricks/sql/auth/token_federation.py', 'w') as f: @@ -98,9 +104,73 @@ jobs: chmod +x patch_for_debugging.py + - name: Install PyJWT for token debugging + run: pip install pyjwt + - name: Apply debugging patches to token_federation.py run: python patch_for_debugging.py + - name: Create audience fix patch script + run: | + cat > patch_for_audience_fix.py << 'EOF' + #!/usr/bin/env python3 + + def patch_code(): + with open('src/databricks/sql/auth/token_federation.py', 'r') as f: + content = f.read() + + # Fix audience handling + modified = content.replace( + 'def _exchange_token(self, token, force_refresh=False):', + '''def _exchange_token(self, token, force_refresh=False): + # Additional handling for different audience formats + import jwt + try: + # Try both standard and alternative audience formats + audience_tried = False + + def try_with_audience(token, audience): + nonlocal audience_tried + if audience_tried: + return None + + audience_tried = True + decoded = jwt.decode(token, options={"verify_signature": False}) + aud = decoded.get("aud") + + # Check if aud is a list and convert to string if needed + if isinstance(aud, list) and len(aud) > 0: + aud = aud[0] + + # Print audience for debugging + print(f"Original token audience: {aud}") + + if aud != audience: + print(f"WARNING: Token audience '{aud}' doesn't match expected audience '{audience}'") + # We won't modify the token as that would invalidate the signature + + return None + + # We're just collecting debugging info, not modifying the token + try_with_audience(token, "https://github.com/databricks") + + except Exception as e: + print(f"Audience debug error: {str(e)}") +''' + ) + + with open('src/databricks/sql/auth/token_federation.py', 'w') as f: + f.write(modified) + + if __name__ == "__main__": + patch_code() + EOF + + chmod +x patch_for_audience_fix.py + + - name: Apply audience fix patches + run: python patch_for_audience_fix.py + - name: Get GitHub OIDC token id: get-id-token uses: actions/github-script@v7 @@ -110,6 +180,106 @@ jobs: core.setSecret(token) core.setOutput('token', token) + - name: Decode and display OIDC token claims + env: + OIDC_TOKEN: ${{ steps.get-id-token.outputs.token }} + run: | + echo "Decoding GitHub OIDC token claims..." + python -c ' + import sys, base64, json + + token = """$OIDC_TOKEN""" + + # Parse the token + try: + header, payload, signature = token.split(".") + + # Add padding if needed + payload_padding = payload + "=" * (-len(payload) % 4) + + # Decode the payload + decoded_payload = base64.b64decode(payload_padding).decode("utf-8") + claims = json.loads(decoded_payload) + + # Print important claims + print("\n=== GITHUB OIDC TOKEN CLAIMS ===") + print(f"Issuer (iss): {claims.get(\"iss\")}") + print(f"Subject (sub): {claims.get(\"sub\")}") + print(f"Audience (aud): {claims.get(\"aud\")}") + print(f"Repository: {claims.get(\"repository\")}") + print(f"Repository owner: {claims.get(\"repository_owner\")}") + print(f"Event name: {claims.get(\"event_name\")}") + print(f"Ref: {claims.get(\"ref\")}") + print(f"Workflow ref: {claims.get(\"workflow_ref\")}") + print("\n=== FULL CLAIMS ===") + print(json.dumps(claims, indent=2)) + print("===========================\n") + except Exception as e: + print(f"Failed to decode token: {str(e)}") + ' + + - name: Debug token exchange with curl + env: + DATABRICKS_HOST: ${{ github.event_name == 'workflow_dispatch' && inputs.databricks_host || secrets.DATABRICKS_HOST_FOR_TF }} + IDENTITY_FEDERATION_CLIENT_ID: ${{ github.event_name == 'workflow_dispatch' && inputs.identity_federation_client_id || secrets.IDENTITY_FEDERATION_CLIENT_ID_FOR_TF }} + OIDC_TOKEN: ${{ steps.get-id-token.outputs.token }} + run: | + echo "Attempting direct token exchange with curl..." + echo "Host: $DATABRICKS_HOST" + echo "Client ID: $IDENTITY_FEDERATION_CLIENT_ID" + + # Debug token claims before making the request + echo "Token claims:" + python3 -c " + import base64, json, sys + token = \"$OIDC_TOKEN\" + parts = token.split('.') + if len(parts) >= 2: + padding = '=' * (4 - len(parts[1]) % 4) + decoded_bytes = base64.b64decode(parts[1] + padding) + decoded_str = decoded_bytes.decode('utf-8') + claims = json.loads(decoded_str) + print(f\"Issuer: {claims.get('iss', 'unknown')}\") + print(f\"Subject: {claims.get('sub', 'unknown')}\") + print(f\"Audience: {claims.get('aud', 'unknown')}\") + else: + print('Invalid token format') + " + + # Create a properly URL-encoded request + echo "Creating token exchange request..." + curl_data=$(cat <&1) + + # Extract and display results + echo "Response:" + echo "$response" + + # Extract HTTP status if possible + status_code=$(echo "$response" | grep -o "< HTTP/[0-9.]* [0-9]*" | grep -o "[0-9]*$" || echo "unknown") + echo "HTTP Status Code: $status_code" + + # Don't fail the workflow if curl fails + exit 0 + - name: Create test script run: | cat > test_github_token_federation.py << 'EOF' @@ -127,7 +297,9 @@ jobs: import sys import json import base64 + import requests from databricks import sql + import time def decode_jwt(token): """Decode and return the claims from a JWT token.""" @@ -137,6 +309,7 @@ jobs: raise ValueError("Invalid JWT format") payload = parts[1] + # Add padding if needed padding = '=' * (4 - len(payload) % 4) payload += padding @@ -146,6 +319,55 @@ jobs: print(f"Failed to decode token: {str(e)}") return None + def test_direct_token_exchange(host, token, client_id, audience=None): + """Directly test token exchange with the Databricks API.""" + try: + url = f"https://{host}/oidc/v1/token" + data = { + "client_id": client_id, + "subject_token": token, + "subject_token_type": "urn:ietf:params:oauth:token-type:jwt", + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "scope": "sql", + "return_original_token_if_authenticated": "true" + } + + headers = { + "Content-Type": "application/x-www-form-urlencoded", + "Accept": "application/json" + } + + print(f"Testing direct token exchange with {url}") + print(f"Request parameters: {data}") + + # Add debugging info + claims = decode_jwt(token) + if claims: + print(f"Token issuer: {claims.get('iss', 'unknown')}") + print(f"Token subject: {claims.get('sub', 'unknown')}") + print(f"Token audience: {claims.get('aud', 'unknown')}") + + # If audience was specified in policy but doesn't match token + if audience and audience != claims.get('aud'): + print(f"WARNING: Expected audience '{audience}' doesn't match token audience '{claims.get('aud')}'") + + response = requests.post(url, data=data, headers=headers) + + print(f"Status code: {response.status_code}") + print(f"Response headers: {dict(response.headers)}") + print(f"Response content: {response.text}") + + if response.status_code == 200: + try: + return json.loads(response.text).get("access_token") + except json.JSONDecodeError: + print("Failed to parse response JSON") + return None + return None + except Exception as e: + print(f"Direct token exchange failed: {str(e)}") + return None + def main(): # Get GitHub OIDC token github_token = os.environ.get("OIDC_TOKEN") @@ -164,20 +386,63 @@ jobs: claims = decode_jwt(github_token) if claims: + print("\n=== GitHub OIDC Token Claims ===") print(f"Token issuer: {claims.get('iss', 'unknown')}") print(f"Token subject: {claims.get('sub', 'unknown')}") print(f"Token audience: {claims.get('aud', 'unknown')}") + print(f"Token expiration: {claims.get('exp', 'unknown')}") + print(f"Repository: {claims.get('repository', 'unknown')}") + print(f"Workflow ref: {claims.get('workflow_ref', 'unknown')}") + print(f"Event name: {claims.get('event_name', 'unknown')}") + print("===============================\n") + + # Try token exchange with several possible audience values + audience_values = [ + "https://github.com/databricks", # Standard audience for GitHub tokens + "https://github.com", # Alternative audience + None # No audience + ] + + # Direct token exchange test + access_token = None + for audience in audience_values: + print(f"\n=== Testing Direct Token Exchange (audience={audience}) ===") + result = test_direct_token_exchange(host, github_token, identity_federation_client_id, audience) + if result: + print("Direct token exchange successful!") + access_token = result + token_claims = decode_jwt(result) + if token_claims: + print(f"Databricks token subject: {token_claims.get('sub', 'unknown')}") + break + print(f"Token exchange failed with audience={audience}") + # Add a small delay between attempts + time.sleep(1) + + if not access_token: + print("All token exchange attempts failed") + print("=====================================\n") + else: + print("=====================================\n") try: # Connect to Databricks using token federation + print(f"\n=== Testing Connection via Connector ===") print(f"Connecting to Databricks at {host}{http_path}") - with sql.connect( - server_hostname=host, - http_path=http_path, - access_token=github_token, - auth_type="token-federation", - identity_federation_client_id=identity_federation_client_id - ) as connection: + print(f"Using client ID: {identity_federation_client_id}") + + connection_params = { + "server_hostname": host, + "http_path": http_path, + "access_token": github_token, + "auth_type": "token-federation", + "identity_federation_client_id": identity_federation_client_id, + } + + print("Connection parameters:") + print(json.dumps({k: v if k != 'access_token' else '***' for k, v in connection_params.items()}, indent=2)) + + with sql.connect(**connection_params) as connection: print("Connection established successfully") # Execute a simple query @@ -195,6 +460,7 @@ jobs: return True except Exception as e: print(f"Error connecting to Databricks: {str(e)}") + print("===================================\n") sys.exit(1) if __name__ == "__main__": @@ -206,7 +472,7 @@ jobs: env: DATABRICKS_HOST_FOR_TF: ${{ github.event_name == 'workflow_dispatch' && inputs.databricks_host || secrets.DATABRICKS_HOST_FOR_TF }} DATABRICKS_HTTP_PATH_FOR_TF: ${{ github.event_name == 'workflow_dispatch' && inputs.databricks_http_path || secrets.DATABRICKS_HTTP_PATH_FOR_TF }} - IDENTITY_FEDERATION_CLIENT_ID: ${{ github.event_name == 'workflow_dispatch' && inputs.identity_federation_client_id || secrets.IDENTITY_FEDERATION_CLIENT_ID }} + IDENTITY_FEDERATION_CLIENT_ID_FOR_TF: ${{ github.event_name == 'workflow_dispatch' && inputs.identity_federation_client_id || secrets.IDENTITY_FEDERATION_CLIENT_ID_FOR_TF }} OIDC_TOKEN: ${{ steps.get-id-token.outputs.token }} run: | python test_github_token_federation.py From 3bb9b3dcaf0b14ee0d3b034dd28cb0a9ec7da8b2 Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Wed, 7 May 2025 06:57:35 +0000 Subject: [PATCH 07/46] debug --- .github/workflows/token-federation-test.yml | 74 ++++++--------------- 1 file changed, 21 insertions(+), 53 deletions(-) diff --git a/.github/workflows/token-federation-test.yml b/.github/workflows/token-federation-test.yml index 1ef333816..8c17afe33 100644 --- a/.github/workflows/token-federation-test.yml +++ b/.github/workflows/token-federation-test.yml @@ -122,41 +122,7 @@ jobs: # Fix audience handling modified = content.replace( 'def _exchange_token(self, token, force_refresh=False):', - '''def _exchange_token(self, token, force_refresh=False): - # Additional handling for different audience formats - import jwt - try: - # Try both standard and alternative audience formats - audience_tried = False - - def try_with_audience(token, audience): - nonlocal audience_tried - if audience_tried: - return None - - audience_tried = True - decoded = jwt.decode(token, options={"verify_signature": False}) - aud = decoded.get("aud") - - # Check if aud is a list and convert to string if needed - if isinstance(aud, list) and len(aud) > 0: - aud = aud[0] - - # Print audience for debugging - print(f"Original token audience: {aud}") - - if aud != audience: - print(f"WARNING: Token audience '{aud}' doesn't match expected audience '{audience}'") - # We won't modify the token as that would invalidate the signature - - return None - - # We're just collecting debugging info, not modifying the token - try_with_audience(token, "https://github.com/databricks") - - except Exception as e: - print(f"Audience debug error: {str(e)}") -''' + 'def _exchange_token(self, token, force_refresh=False):\\n # Additional handling for different audience formats\\n import jwt\\n try:\\n # Try both standard and alternative audience formats\\n audience_tried = False\\n \\n def try_with_audience(token, audience):\\n nonlocal audience_tried\\n if audience_tried:\\n return None\\n \\n audience_tried = True\\n decoded = jwt.decode(token, options={\"verify_signature\": False})\\n aud = decoded.get(\"aud\")\\n \\n # Check if aud is a list and convert to string if needed\\n if isinstance(aud, list) and len(aud) > 0:\\n aud = aud[0]\\n \\n # Print audience for debugging\\n print(f\"Original token audience: {aud}\")\\n \\n if aud != audience:\\n print(f\"WARNING: Token audience \\\'{aud}\\\' doesn\\\'t match expected audience \\\'{audience}\\\'\")\\n # We won\\\'t modify the token as that would invalidate the signature\\n \\n return None\\n \\n # We\\\'re just collecting debugging info, not modifying the token\\n try_with_audience(token, \"https://github.com/databricks\")\\n \\n except Exception as e:\\n print(f\"Audience debug error: {str(e)}\")' ) with open('src/databricks/sql/auth/token_federation.py', 'w') as f: @@ -233,17 +199,17 @@ jobs: python3 -c " import base64, json, sys token = \"$OIDC_TOKEN\" - parts = token.split('.') + parts = token.split(\".\") if len(parts) >= 2: - padding = '=' * (4 - len(parts[1]) % 4) + padding = \"=\" * (4 - len(parts[1]) % 4) decoded_bytes = base64.b64decode(parts[1] + padding) - decoded_str = decoded_bytes.decode('utf-8') + decoded_str = decoded_bytes.decode(\"utf-8\") claims = json.loads(decoded_str) - print(f\"Issuer: {claims.get('iss', 'unknown')}\") - print(f\"Subject: {claims.get('sub', 'unknown')}\") - print(f\"Audience: {claims.get('aud', 'unknown')}\") + print(f\"Token issuer: {claims.get('iss', 'unknown')}\") + print(f\"Token subject: {claims.get('sub', 'unknown')}\") + print(f\"Token audience: {claims.get('aud', 'unknown')}\") else: - print('Invalid token format') + print(\"Invalid token format\") " # Create a properly URL-encoded request @@ -343,13 +309,15 @@ EOF # Add debugging info claims = decode_jwt(token) if claims: - print(f"Token issuer: {claims.get('iss', 'unknown')}") - print(f"Token subject: {claims.get('sub', 'unknown')}") - print(f"Token audience: {claims.get('aud', 'unknown')}") + print(f"Token issuer: {claims.get(\'iss\', \'unknown\')}") + print(f"Token subject: {claims.get(\'sub\', \'unknown\')}") + print(f"Token audience: {claims.get(\'aud\', \'unknown\')}") # If audience was specified in policy but doesn't match token if audience and audience != claims.get('aud'): - print(f"WARNING: Expected audience '{audience}' doesn't match token audience '{claims.get('aud')}'") + print("WARNING: Expected audience and token audience don't match") + print(f"Expected: {audience}") + print(f"Actual: {claims.get('aud')}") response = requests.post(url, data=data, headers=headers) @@ -387,13 +355,13 @@ EOF claims = decode_jwt(github_token) if claims: print("\n=== GitHub OIDC Token Claims ===") - print(f"Token issuer: {claims.get('iss', 'unknown')}") - print(f"Token subject: {claims.get('sub', 'unknown')}") - print(f"Token audience: {claims.get('aud', 'unknown')}") - print(f"Token expiration: {claims.get('exp', 'unknown')}") - print(f"Repository: {claims.get('repository', 'unknown')}") - print(f"Workflow ref: {claims.get('workflow_ref', 'unknown')}") - print(f"Event name: {claims.get('event_name', 'unknown')}") + print(f"Token issuer: {claims.get(\'iss\', \'unknown\')}") + print(f"Token subject: {claims.get(\'sub\', \'unknown\')}") + print(f"Token audience: {claims.get(\'aud\', \'unknown\')}") + print(f"Token expiration: {claims.get(\'exp\', \'unknown\')}") + print(f"Repository: {claims.get(\'repository\', \'unknown\')}") + print(f"Workflow ref: {claims.get(\'workflow_ref\', \'unknown\')}") + print(f"Event name: {claims.get(\'event_name\', \'unknown\')}") print("===============================\n") # Try token exchange with several possible audience values From 708c13bfc23ff0d146acf064dce2a033e47fee17 Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Wed, 7 May 2025 06:58:58 +0000 Subject: [PATCH 08/46] debug --- .github/workflows/token-federation-test.yml | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/.github/workflows/token-federation-test.yml b/.github/workflows/token-federation-test.yml index 8c17afe33..a7432c924 100644 --- a/.github/workflows/token-federation-test.yml +++ b/.github/workflows/token-federation-test.yml @@ -214,15 +214,18 @@ jobs: # Create a properly URL-encoded request echo "Creating token exchange request..." - curl_data=$(cat < Date: Wed, 7 May 2025 10:04:06 +0000 Subject: [PATCH 09/46] fix --- .github/workflows/token-federation-test.yml | 38 ++++++++++----------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/.github/workflows/token-federation-test.yml b/.github/workflows/token-federation-test.yml index a7432c924..a42b0f463 100644 --- a/.github/workflows/token-federation-test.yml +++ b/.github/workflows/token-federation-test.yml @@ -169,14 +169,14 @@ jobs: # Print important claims print("\n=== GITHUB OIDC TOKEN CLAIMS ===") - print(f"Issuer (iss): {claims.get(\"iss\")}") - print(f"Subject (sub): {claims.get(\"sub\")}") - print(f"Audience (aud): {claims.get(\"aud\")}") - print(f"Repository: {claims.get(\"repository\")}") - print(f"Repository owner: {claims.get(\"repository_owner\")}") - print(f"Event name: {claims.get(\"event_name\")}") - print(f"Ref: {claims.get(\"ref\")}") - print(f"Workflow ref: {claims.get(\"workflow_ref\")}") + print(f"Issuer (iss): {claims.get('iss')}") + print(f"Subject (sub): {claims.get('sub')}") + print(f"Audience (aud): {claims.get('aud')}") + print(f"Repository: {claims.get('repository')}") + print(f"Repository owner: {claims.get('repository_owner')}") + print(f"Event name: {claims.get('event_name')}") + print(f"Ref: {claims.get('ref')}") + print(f"Workflow ref: {claims.get('workflow_ref')}") print("\n=== FULL CLAIMS ===") print(json.dumps(claims, indent=2)) print("===========================\n") @@ -312,9 +312,9 @@ jobs: # Add debugging info claims = decode_jwt(token) if claims: - print(f"Token issuer: {claims.get(\'iss\', \'unknown\')}") - print(f"Token subject: {claims.get(\'sub\', \'unknown\')}") - print(f"Token audience: {claims.get(\'aud\', \'unknown\')}") + print(f"Token issuer: {claims.get('iss', 'unknown')}") + print(f"Token subject: {claims.get('sub', 'unknown')}") + print(f"Token audience: {claims.get('aud', 'unknown')}") # If audience was specified in policy but doesn't match token if audience and audience != claims.get('aud'): @@ -358,13 +358,13 @@ jobs: claims = decode_jwt(github_token) if claims: print("\n=== GitHub OIDC Token Claims ===") - print(f"Token issuer: {claims.get(\'iss\', \'unknown\')}") - print(f"Token subject: {claims.get(\'sub\', \'unknown\')}") - print(f"Token audience: {claims.get(\'aud\', \'unknown\')}") - print(f"Token expiration: {claims.get(\'exp\', \'unknown\')}") - print(f"Repository: {claims.get(\'repository\', \'unknown\')}") - print(f"Workflow ref: {claims.get(\'workflow_ref\', \'unknown\')}") - print(f"Event name: {claims.get(\'event_name\', \'unknown\')}") + print(f"Token issuer: {claims.get('iss')}") + print(f"Token subject: {claims.get('sub')}") + print(f"Token audience: {claims.get('aud')}") + print(f"Token expiration: {claims.get('exp', 'unknown')}") + print(f"Repository: {claims.get('repository', 'unknown')}") + print(f"Workflow ref: {claims.get('workflow_ref', 'unknown')}") + print(f"Event name: {claims.get('event_name', 'unknown')}") print("===============================\n") # Try token exchange with several possible audience values @@ -443,7 +443,7 @@ jobs: env: DATABRICKS_HOST_FOR_TF: ${{ github.event_name == 'workflow_dispatch' && inputs.databricks_host || secrets.DATABRICKS_HOST_FOR_TF }} DATABRICKS_HTTP_PATH_FOR_TF: ${{ github.event_name == 'workflow_dispatch' && inputs.databricks_http_path || secrets.DATABRICKS_HTTP_PATH_FOR_TF }} - IDENTITY_FEDERATION_CLIENT_ID_FOR_TF: ${{ github.event_name == 'workflow_dispatch' && inputs.identity_federation_client_id || secrets.IDENTITY_FEDERATION_CLIENT_ID_FOR_TF }} + IDENTITY_FEDERATION_CLIENT_ID: ${{ github.event_name == 'workflow_dispatch' && inputs.identity_federation_client_id || secrets.IDENTITY_FEDERATION_CLIENT_ID }} OIDC_TOKEN: ${{ steps.get-id-token.outputs.token }} run: | python test_github_token_federation.py From 00e015c30de8859486fd53eba6a1319ec92031bf Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Wed, 7 May 2025 10:10:03 +0000 Subject: [PATCH 10/46] fix --- .github/workflows/token-federation-test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/token-federation-test.yml b/.github/workflows/token-federation-test.yml index a42b0f463..dd7acd65b 100644 --- a/.github/workflows/token-federation-test.yml +++ b/.github/workflows/token-federation-test.yml @@ -122,7 +122,7 @@ jobs: # Fix audience handling modified = content.replace( 'def _exchange_token(self, token, force_refresh=False):', - 'def _exchange_token(self, token, force_refresh=False):\\n # Additional handling for different audience formats\\n import jwt\\n try:\\n # Try both standard and alternative audience formats\\n audience_tried = False\\n \\n def try_with_audience(token, audience):\\n nonlocal audience_tried\\n if audience_tried:\\n return None\\n \\n audience_tried = True\\n decoded = jwt.decode(token, options={\"verify_signature\": False})\\n aud = decoded.get(\"aud\")\\n \\n # Check if aud is a list and convert to string if needed\\n if isinstance(aud, list) and len(aud) > 0:\\n aud = aud[0]\\n \\n # Print audience for debugging\\n print(f\"Original token audience: {aud}\")\\n \\n if aud != audience:\\n print(f\"WARNING: Token audience \\\'{aud}\\\' doesn\\\'t match expected audience \\\'{audience}\\\'\")\\n # We won\\\'t modify the token as that would invalidate the signature\\n \\n return None\\n \\n # We\\\'re just collecting debugging info, not modifying the token\\n try_with_audience(token, \"https://github.com/databricks\")\\n \\n except Exception as e:\\n print(f\"Audience debug error: {str(e)}\")' + 'def _exchange_token(self, token, force_refresh=False):\\n # Additional handling for different audience formats\\n import jwt\\n try:\\n # Try both standard and alternative audience formats\\n audience_tried = False\\n \\n def try_with_audience(token, audience):\\n nonlocal audience_tried\\n if audience_tried:\\n return None\\n \\n audience_tried = True\\n decoded = jwt.decode(token, options={\"verify_signature\": False})\\n aud = decoded.get(\"aud\")\\n \\n # Check if aud is a list and convert to string if needed\\n if isinstance(aud, list) and len(aud) > 0:\\n aud = aud[0]\\n \\n # Print audience for debugging\\n print(f\"Original token audience: {aud}\")\\n \\n if aud != audience:\\n print(f\"WARNING: Token audience \'{aud}\' doesn\'t match expected audience \'{audience}\'\")\\n # We won\'t modify the token as that would invalidate the signature\\n \\n return None\\n \\n # We\'re just collecting debugging info, not modifying the token\\n try_with_audience(token, \"https://github.com/databricks\")\\n \\n except Exception as e:\\n print(f\"Audience debug error: {str(e)}\")' ) with open('src/databricks/sql/auth/token_federation.py', 'w') as f: From d538b750a57eef0b6afaf99d84e8ed085ff46c6b Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Wed, 7 May 2025 10:13:39 +0000 Subject: [PATCH 11/46] fix --- .github/workflows/token-federation-test.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/token-federation-test.yml b/.github/workflows/token-federation-test.yml index dd7acd65b..de48e25b8 100644 --- a/.github/workflows/token-federation-test.yml +++ b/.github/workflows/token-federation-test.yml @@ -186,8 +186,8 @@ jobs: - name: Debug token exchange with curl env: - DATABRICKS_HOST: ${{ github.event_name == 'workflow_dispatch' && inputs.databricks_host || secrets.DATABRICKS_HOST_FOR_TF }} - IDENTITY_FEDERATION_CLIENT_ID: ${{ github.event_name == 'workflow_dispatch' && inputs.identity_federation_client_id || secrets.IDENTITY_FEDERATION_CLIENT_ID_FOR_TF }} + DATABRICKS_HOST_FOR_TF: ${{ github.event_name == 'workflow_dispatch' && inputs.databricks_host || secrets.DATABRICKS_HOST_FOR_TF }} + IDENTITY_FEDERATION_CLIENT_ID: ${{ github.event_name == 'workflow_dispatch' && inputs.identity_federation_client_id || secrets.IDENTITY_FEDERATION_CLIENT_ID }} OIDC_TOKEN: ${{ steps.get-id-token.outputs.token }} run: | echo "Attempting direct token exchange with curl..." @@ -232,7 +232,7 @@ jobs: # Make the request with detailed info echo "Sending request..." - response=$(curl -v -s -X POST "https://$DATABRICKS_HOST/oidc/v1/token" \ + response=$(curl -v -s -X POST "https://$DATABRICKS_HOST_FOR_TF/oidc/v1/token" \ --data-raw "$curl_data" \ -H "Content-Type: application/x-www-form-urlencoded" \ -H "Accept: application/json" \ From 4b48ac93401a8b82b6f911b814c9e66ad7f512c8 Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Wed, 7 May 2025 10:25:10 +0000 Subject: [PATCH 12/46] fix --- .github/workflows/token-federation-test.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/token-federation-test.yml b/.github/workflows/token-federation-test.yml index de48e25b8..e302dcf14 100644 --- a/.github/workflows/token-federation-test.yml +++ b/.github/workflows/token-federation-test.yml @@ -191,7 +191,7 @@ jobs: OIDC_TOKEN: ${{ steps.get-id-token.outputs.token }} run: | echo "Attempting direct token exchange with curl..." - echo "Host: $DATABRICKS_HOST" + echo "Host: $DATABRICKS_HOST_FOR_TF" echo "Client ID: $IDENTITY_FEDERATION_CLIENT_ID" # Debug token claims before making the request @@ -227,7 +227,7 @@ jobs: curl_data=$(eval echo "$curl_data") # Print request details (except the token) - echo "Request URL: https://$DATABRICKS_HOST/oidc/v1/token" + echo "Request URL: https://$DATABRICKS_HOST_FOR_TF/oidc/v1/token" echo "Request data: $(echo "$curl_data" | sed 's/subject_token=.*&/subject_token=REDACTED&/')" # Make the request with detailed info @@ -349,7 +349,7 @@ jobs: # Get Databricks connection parameters host = os.environ.get("DATABRICKS_HOST_FOR_TF") http_path = os.environ.get("DATABRICKS_HTTP_PATH_FOR_TF") - identity_federation_client_id = os.environ.get("IDENTITY_FEDERATION_CLIENT_ID_FOR_TF") + identity_federation_client_id = os.environ.get("IDENTITY_FEDERATION_CLIENT_ID") if not host or not http_path: print("Missing Databricks connection parameters") From e8d4a483eea2eb9f3c4fac0e2c900191ada07c3f Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Thu, 8 May 2025 07:57:59 +0000 Subject: [PATCH 13/46] debug --- .github/workflows/token-federation-test.yml | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/.github/workflows/token-federation-test.yml b/.github/workflows/token-federation-test.yml index e302dcf14..a9cdbd1d4 100644 --- a/.github/workflows/token-federation-test.yml +++ b/.github/workflows/token-federation-test.yml @@ -214,17 +214,9 @@ jobs: # Create a properly URL-encoded request echo "Creating token exchange request..." - curl_data=$(cat << 'EOF' - client_id=$IDENTITY_FEDERATION_CLIENT_ID&\ - subject_token=$OIDC_TOKEN&\ - subject_token_type=urn:ietf:params:oauth:token-type:jwt&\ - grant_type=urn:ietf:params:oauth:grant-type:token-exchange&\ - scope=sql - EOF - ) - - # Substitute environment variables in the curl data - curl_data=$(eval echo "$curl_data") + # URL encode the token + encoded_token=$(echo -n "$OIDC_TOKEN" | python3 -c 'import sys, urllib.parse; print(urllib.parse.quote(sys.stdin.read(), safe=""))') + curl_data="client_id=$IDENTITY_FEDERATION_CLIENT_ID&subject_token=$encoded_token&subject_token_type=urn:ietf:params:oauth:token-type:jwt&grant_type=urn:ietf:params:oauth:grant-type:token-exchange&scope=sql" # Print request details (except the token) echo "Request URL: https://$DATABRICKS_HOST_FOR_TF/oidc/v1/token" From 5b74b60f58a3086eea9f6348dd2a2358c8561a68 Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Thu, 8 May 2025 08:02:40 +0000 Subject: [PATCH 14/46] debug --- .github/workflows/token-federation-test.yml | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/.github/workflows/token-federation-test.yml b/.github/workflows/token-federation-test.yml index a9cdbd1d4..10e059b28 100644 --- a/.github/workflows/token-federation-test.yml +++ b/.github/workflows/token-federation-test.yml @@ -214,18 +214,19 @@ jobs: # Create a properly URL-encoded request echo "Creating token exchange request..." - # URL encode the token - encoded_token=$(echo -n "$OIDC_TOKEN" | python3 -c 'import sys, urllib.parse; print(urllib.parse.quote(sys.stdin.read(), safe=""))') - curl_data="client_id=$IDENTITY_FEDERATION_CLIENT_ID&subject_token=$encoded_token&subject_token_type=urn:ietf:params:oauth:token-type:jwt&grant_type=urn:ietf:params:oauth:grant-type:token-exchange&scope=sql" # Print request details (except the token) echo "Request URL: https://$DATABRICKS_HOST_FOR_TF/oidc/v1/token" - echo "Request data: $(echo "$curl_data" | sed 's/subject_token=.*&/subject_token=REDACTED&/')" + echo "Request data: client_id=$IDENTITY_FEDERATION_CLIENT_ID&subject_token=REDACTED&subject_token_type=urn:ietf:params:oauth:token-type:jwt&grant_type=urn:ietf:params:oauth:grant-type:token-exchange&scope=sql" # Make the request with detailed info echo "Sending request..." response=$(curl -v -s -X POST "https://$DATABRICKS_HOST_FOR_TF/oidc/v1/token" \ - --data-raw "$curl_data" \ + --data-urlencode "client_id=$IDENTITY_FEDERATION_CLIENT_ID" \ + --data-urlencode "subject_token=$OIDC_TOKEN" \ + --data-urlencode "subject_token_type=urn:ietf:params:oauth:token-type:jwt" \ + --data-urlencode "grant_type=urn:ietf:params:oauth:grant-type:token-exchange" \ + --data-urlencode "scope=sql" \ -H "Content-Type: application/x-www-form-urlencoded" \ -H "Accept: application/json" \ 2>&1) From edc6027008bbf4bbd9e878f2de6823b23364eac9 Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Thu, 8 May 2025 08:33:15 +0000 Subject: [PATCH 15/46] debug --- .github/workflows/token-federation-test.yml | 38 +++++++++++++++++---- src/databricks/sql/auth/token_federation.py | 17 +++++++-- 2 files changed, 47 insertions(+), 8 deletions(-) diff --git a/.github/workflows/token-federation-test.yml b/.github/workflows/token-federation-test.yml index 10e059b28..84029f60f 100644 --- a/.github/workflows/token-federation-test.yml +++ b/.github/workflows/token-federation-test.yml @@ -214,19 +214,26 @@ jobs: # Create a properly URL-encoded request echo "Creating token exchange request..." + curl_data=$(cat << 'EOF' + client_id=$IDENTITY_FEDERATION_CLIENT_ID&\ + subject_token=$OIDC_TOKEN&\ + subject_token_type=urn:ietf:params:oauth:token-type:jwt&\ + grant_type=urn:ietf:params:oauth:grant-type:token-exchange&\ + scope=sql + EOF + ) + + # Substitute environment variables in the curl data + curl_data=$(eval echo "$curl_data") # Print request details (except the token) echo "Request URL: https://$DATABRICKS_HOST_FOR_TF/oidc/v1/token" - echo "Request data: client_id=$IDENTITY_FEDERATION_CLIENT_ID&subject_token=REDACTED&subject_token_type=urn:ietf:params:oauth:token-type:jwt&grant_type=urn:ietf:params:oauth:grant-type:token-exchange&scope=sql" + echo "Request data: $(echo "$curl_data" | sed 's/subject_token=.*&/subject_token=REDACTED&/')" # Make the request with detailed info echo "Sending request..." response=$(curl -v -s -X POST "https://$DATABRICKS_HOST_FOR_TF/oidc/v1/token" \ - --data-urlencode "client_id=$IDENTITY_FEDERATION_CLIENT_ID" \ - --data-urlencode "subject_token=$OIDC_TOKEN" \ - --data-urlencode "subject_token_type=urn:ietf:params:oauth:token-type:jwt" \ - --data-urlencode "grant_type=urn:ietf:params:oauth:grant-type:token-exchange" \ - --data-urlencode "scope=sql" \ + --data-raw "$curl_data" \ -H "Content-Type: application/x-www-form-urlencoded" \ -H "Accept: application/json" \ 2>&1) @@ -239,6 +246,13 @@ jobs: status_code=$(echo "$response" | grep -o "< HTTP/[0-9.]* [0-9]*" | grep -o "[0-9]*$" || echo "unknown") echo "HTTP Status Code: $status_code" + # Try to extract and pretty-print the JSON response body if present + response_body=$(echo "$response" | sed -n -e '/^{/,/^}/p' || echo "") + if [ ! -z "$response_body" ]; then + echo "Response body (formatted):" + echo "$response_body" | python3 -m json.tool || echo "$response_body" + fi + # Don't fail the workflow if curl fails exit 0 @@ -315,6 +329,18 @@ jobs: print(f"Expected: {audience}") print(f"Actual: {claims.get('aud')}") + # Enable more verbose HTTP debugging + import http.client as http_client + http_client.HTTPConnection.debuglevel = 1 + + # Log requests library debug info + import logging + logging.basicConfig() + logging.getLogger().setLevel(logging.DEBUG) + requests_log = logging.getLogger("requests.packages.urllib3") + requests_log.setLevel(logging.DEBUG) + requests_log.propagate = True + response = requests.post(url, data=data, headers=headers) print(f"Status code: {response.status_code}") diff --git a/src/databricks/sql/auth/token_federation.py b/src/databricks/sql/auth/token_federation.py index c20dd0eb1..45fadcb1b 100644 --- a/src/databricks/sql/auth/token_federation.py +++ b/src/databricks/sql/auth/token_federation.py @@ -153,12 +153,25 @@ def _init_oidc_discovery(self): # Fallback to default token endpoint if discovery fails if not self.token_endpoint: - self.token_endpoint = f"{self.hostname}oidc/v1/token" + # Make sure hostname has proper format with https:// prefix and trailing slash + hostname = self.hostname + if not hostname.startswith('https://'): + hostname = f'https://{hostname}' + if not hostname.endswith('/'): + hostname = f'{hostname}/' + self.token_endpoint = f"{hostname}oidc/v1/token" logger.info(f"Using default token endpoint: {self.token_endpoint}") except Exception as e: logger.warning(f"OIDC discovery failed: {str(e)}. Using default token endpoint.") - self.token_endpoint = f"{self.hostname}oidc/v1/token" + # Make sure hostname has proper format with https:// prefix and trailing slash + hostname = self.hostname + if not hostname.startswith('https://'): + hostname = f'https://{hostname}' + if not hostname.endswith('/'): + hostname = f'{hostname}/' + self.token_endpoint = f"{hostname}oidc/v1/token" + logger.info(f"Using default token endpoint after error: {self.token_endpoint}") def _extract_token_info_from_header(self, headers: Dict[str, str]) -> Tuple[str, str]: """Extract token type and token value from authorization header.""" From 3613cb07a16efe7b7308abdc26c8ffa7c97fe544 Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Thu, 8 May 2025 08:39:53 +0000 Subject: [PATCH 16/46] debug --- src/databricks/sql/auth/token_federation.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/databricks/sql/auth/token_federation.py b/src/databricks/sql/auth/token_federation.py index 45fadcb1b..4d95ec6b3 100644 --- a/src/databricks/sql/auth/token_federation.py +++ b/src/databricks/sql/auth/token_federation.py @@ -81,6 +81,16 @@ def auth_type(self) -> str: """Return the auth type from the underlying credentials provider.""" return self.credentials_provider.auth_type() + @property + def host(self) -> str: + """ + Alias for hostname to maintain compatibility with code expecting a host attribute. + + Returns: + str: The hostname value + """ + return self.hostname + def __call__(self, *args, **kwargs) -> HeaderFactory: """ Configure and return a HeaderFactory that provides authentication headers. From e87b52d32ceb0be2da00f9c40f4b7e089030c3e5 Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Thu, 8 May 2025 08:53:46 +0000 Subject: [PATCH 17/46] readability --- .github/workflows/token-federation-test.yml | 284 +------------------- 1 file changed, 3 insertions(+), 281 deletions(-) diff --git a/.github/workflows/token-federation-test.yml b/.github/workflows/token-federation-test.yml index 84029f60f..3b5fbaf4a 100644 --- a/.github/workflows/token-federation-test.yml +++ b/.github/workflows/token-federation-test.yml @@ -62,81 +62,6 @@ jobs: pip install -e . pip install pyarrow - - name: Create debugging patch script - run: | - cat > patch_for_debugging.py << 'EOF' - #!/usr/bin/env python3 - - def patch_code(): - with open('src/databricks/sql/auth/token_federation.py', 'r') as f: - content = f.read() - - # Add token debugging - modified = content.replace( - 'def _exchange_token(self, token, force_refresh=False):', - 'def _exchange_token(self, token, force_refresh=False):\n # Debug token info\n import jwt\n try:\n decoded = jwt.decode(token, options={"verify_signature": False})\n print(f"Token issuer: {decoded.get(\'iss\')}")\n print(f"Token subject: {decoded.get(\'sub\')}")\n print(f"Token audience: {decoded.get(\'aud\') if isinstance(decoded.get(\'aud\'), str) else decoded.get(\'aud\', [])[0] if decoded.get(\'aud\') else \'\'}")\n except Exception as e:\n print(f"Unable to decode token: {str(e)}")' - ) - - # Add verbose request debugging - modified = modified.replace( - 'try:\n # Make the token exchange request', - 'try:\n import urllib.parse\n # Debug full request\n print(f"Connecting to Databricks at {self.host}")\n print(f"Token endpoint: {self.token_endpoint}")\n print(f"Request parameters: {urllib.parse.urlencode(params)}")\n print(f"Request headers: {headers}")\n # Make the token exchange request' - ) - - # Add verbose response debugging - modified = modified.replace( - 'response = requests.post(self.token_endpoint, data=params, headers=headers)', - 'response = requests.post(self.token_endpoint, data=params, headers=headers)\n print(f"Response status: {response.status_code}")\n print(f"Response headers: {dict(response.headers)}")\n print(f"Response body: {response.text}")' - ) - - # Improve error handling - modified = modified.replace( - 'except RequestException as e:', - 'except RequestException as e:\n print(f"Failed to perform token exchange: {str(e)}")\n if hasattr(e, "response") and e.response:\n print(f"Error response status: {e.response.status_code}")\n print(f"Error response headers: {dict(e.response.headers)}")\n print(f"Error response text: {e.response.text}")' - ) - - with open('src/databricks/sql/auth/token_federation.py', 'w') as f: - f.write(modified) - - if __name__ == "__main__": - patch_code() - EOF - - chmod +x patch_for_debugging.py - - - name: Install PyJWT for token debugging - run: pip install pyjwt - - - name: Apply debugging patches to token_federation.py - run: python patch_for_debugging.py - - - name: Create audience fix patch script - run: | - cat > patch_for_audience_fix.py << 'EOF' - #!/usr/bin/env python3 - - def patch_code(): - with open('src/databricks/sql/auth/token_federation.py', 'r') as f: - content = f.read() - - # Fix audience handling - modified = content.replace( - 'def _exchange_token(self, token, force_refresh=False):', - 'def _exchange_token(self, token, force_refresh=False):\\n # Additional handling for different audience formats\\n import jwt\\n try:\\n # Try both standard and alternative audience formats\\n audience_tried = False\\n \\n def try_with_audience(token, audience):\\n nonlocal audience_tried\\n if audience_tried:\\n return None\\n \\n audience_tried = True\\n decoded = jwt.decode(token, options={\"verify_signature\": False})\\n aud = decoded.get(\"aud\")\\n \\n # Check if aud is a list and convert to string if needed\\n if isinstance(aud, list) and len(aud) > 0:\\n aud = aud[0]\\n \\n # Print audience for debugging\\n print(f\"Original token audience: {aud}\")\\n \\n if aud != audience:\\n print(f\"WARNING: Token audience \'{aud}\' doesn\'t match expected audience \'{audience}\'\")\\n # We won\'t modify the token as that would invalidate the signature\\n \\n return None\\n \\n # We\'re just collecting debugging info, not modifying the token\\n try_with_audience(token, \"https://github.com/databricks\")\\n \\n except Exception as e:\\n print(f\"Audience debug error: {str(e)}\")' - ) - - with open('src/databricks/sql/auth/token_federation.py', 'w') as f: - f.write(modified) - - if __name__ == "__main__": - patch_code() - EOF - - chmod +x patch_for_audience_fix.py - - - name: Apply audience fix patches - run: python patch_for_audience_fix.py - - name: Get GitHub OIDC token id: get-id-token uses: actions/github-script@v7 @@ -146,116 +71,6 @@ jobs: core.setSecret(token) core.setOutput('token', token) - - name: Decode and display OIDC token claims - env: - OIDC_TOKEN: ${{ steps.get-id-token.outputs.token }} - run: | - echo "Decoding GitHub OIDC token claims..." - python -c ' - import sys, base64, json - - token = """$OIDC_TOKEN""" - - # Parse the token - try: - header, payload, signature = token.split(".") - - # Add padding if needed - payload_padding = payload + "=" * (-len(payload) % 4) - - # Decode the payload - decoded_payload = base64.b64decode(payload_padding).decode("utf-8") - claims = json.loads(decoded_payload) - - # Print important claims - print("\n=== GITHUB OIDC TOKEN CLAIMS ===") - print(f"Issuer (iss): {claims.get('iss')}") - print(f"Subject (sub): {claims.get('sub')}") - print(f"Audience (aud): {claims.get('aud')}") - print(f"Repository: {claims.get('repository')}") - print(f"Repository owner: {claims.get('repository_owner')}") - print(f"Event name: {claims.get('event_name')}") - print(f"Ref: {claims.get('ref')}") - print(f"Workflow ref: {claims.get('workflow_ref')}") - print("\n=== FULL CLAIMS ===") - print(json.dumps(claims, indent=2)) - print("===========================\n") - except Exception as e: - print(f"Failed to decode token: {str(e)}") - ' - - - name: Debug token exchange with curl - env: - DATABRICKS_HOST_FOR_TF: ${{ github.event_name == 'workflow_dispatch' && inputs.databricks_host || secrets.DATABRICKS_HOST_FOR_TF }} - IDENTITY_FEDERATION_CLIENT_ID: ${{ github.event_name == 'workflow_dispatch' && inputs.identity_federation_client_id || secrets.IDENTITY_FEDERATION_CLIENT_ID }} - OIDC_TOKEN: ${{ steps.get-id-token.outputs.token }} - run: | - echo "Attempting direct token exchange with curl..." - echo "Host: $DATABRICKS_HOST_FOR_TF" - echo "Client ID: $IDENTITY_FEDERATION_CLIENT_ID" - - # Debug token claims before making the request - echo "Token claims:" - python3 -c " - import base64, json, sys - token = \"$OIDC_TOKEN\" - parts = token.split(\".\") - if len(parts) >= 2: - padding = \"=\" * (4 - len(parts[1]) % 4) - decoded_bytes = base64.b64decode(parts[1] + padding) - decoded_str = decoded_bytes.decode(\"utf-8\") - claims = json.loads(decoded_str) - print(f\"Token issuer: {claims.get('iss', 'unknown')}\") - print(f\"Token subject: {claims.get('sub', 'unknown')}\") - print(f\"Token audience: {claims.get('aud', 'unknown')}\") - else: - print(\"Invalid token format\") - " - - # Create a properly URL-encoded request - echo "Creating token exchange request..." - curl_data=$(cat << 'EOF' - client_id=$IDENTITY_FEDERATION_CLIENT_ID&\ - subject_token=$OIDC_TOKEN&\ - subject_token_type=urn:ietf:params:oauth:token-type:jwt&\ - grant_type=urn:ietf:params:oauth:grant-type:token-exchange&\ - scope=sql - EOF - ) - - # Substitute environment variables in the curl data - curl_data=$(eval echo "$curl_data") - - # Print request details (except the token) - echo "Request URL: https://$DATABRICKS_HOST_FOR_TF/oidc/v1/token" - echo "Request data: $(echo "$curl_data" | sed 's/subject_token=.*&/subject_token=REDACTED&/')" - - # Make the request with detailed info - echo "Sending request..." - response=$(curl -v -s -X POST "https://$DATABRICKS_HOST_FOR_TF/oidc/v1/token" \ - --data-raw "$curl_data" \ - -H "Content-Type: application/x-www-form-urlencoded" \ - -H "Accept: application/json" \ - 2>&1) - - # Extract and display results - echo "Response:" - echo "$response" - - # Extract HTTP status if possible - status_code=$(echo "$response" | grep -o "< HTTP/[0-9.]* [0-9]*" | grep -o "[0-9]*$" || echo "unknown") - echo "HTTP Status Code: $status_code" - - # Try to extract and pretty-print the JSON response body if present - response_body=$(echo "$response" | sed -n -e '/^{/,/^}/p' || echo "") - if [ ! -z "$response_body" ]; then - echo "Response body (formatted):" - echo "$response_body" | python3 -m json.tool || echo "$response_body" - fi - - # Don't fail the workflow if curl fails - exit 0 - - name: Create test script run: | cat > test_github_token_federation.py << 'EOF' @@ -264,7 +79,7 @@ jobs: """ Test script for Databricks SQL token federation with GitHub Actions OIDC tokens. - This script demonstrates how to use the Databricks SQL connector with token federation + This script tests the Databricks SQL connector with token federation using a GitHub Actions OIDC token. It connects to a Databricks SQL warehouse, runs a simple query, and shows the connected user. """ @@ -273,9 +88,7 @@ jobs: import sys import json import base64 - import requests from databricks import sql - import time def decode_jwt(token): """Decode and return the claims from a JWT token.""" @@ -295,69 +108,6 @@ jobs: print(f"Failed to decode token: {str(e)}") return None - def test_direct_token_exchange(host, token, client_id, audience=None): - """Directly test token exchange with the Databricks API.""" - try: - url = f"https://{host}/oidc/v1/token" - data = { - "client_id": client_id, - "subject_token": token, - "subject_token_type": "urn:ietf:params:oauth:token-type:jwt", - "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", - "scope": "sql", - "return_original_token_if_authenticated": "true" - } - - headers = { - "Content-Type": "application/x-www-form-urlencoded", - "Accept": "application/json" - } - - print(f"Testing direct token exchange with {url}") - print(f"Request parameters: {data}") - - # Add debugging info - claims = decode_jwt(token) - if claims: - print(f"Token issuer: {claims.get('iss', 'unknown')}") - print(f"Token subject: {claims.get('sub', 'unknown')}") - print(f"Token audience: {claims.get('aud', 'unknown')}") - - # If audience was specified in policy but doesn't match token - if audience and audience != claims.get('aud'): - print("WARNING: Expected audience and token audience don't match") - print(f"Expected: {audience}") - print(f"Actual: {claims.get('aud')}") - - # Enable more verbose HTTP debugging - import http.client as http_client - http_client.HTTPConnection.debuglevel = 1 - - # Log requests library debug info - import logging - logging.basicConfig() - logging.getLogger().setLevel(logging.DEBUG) - requests_log = logging.getLogger("requests.packages.urllib3") - requests_log.setLevel(logging.DEBUG) - requests_log.propagate = True - - response = requests.post(url, data=data, headers=headers) - - print(f"Status code: {response.status_code}") - print(f"Response headers: {dict(response.headers)}") - print(f"Response content: {response.text}") - - if response.status_code == 200: - try: - return json.loads(response.text).get("access_token") - except json.JSONDecodeError: - print("Failed to parse response JSON") - return None - return None - except Exception as e: - print(f"Direct token exchange failed: {str(e)}") - return None - def main(): # Get GitHub OIDC token github_token = os.environ.get("OIDC_TOKEN") @@ -374,6 +124,7 @@ jobs: print("Missing Databricks connection parameters") sys.exit(1) + # Display token claims for debugging claims = decode_jwt(github_token) if claims: print("\n=== GitHub OIDC Token Claims ===") @@ -386,38 +137,9 @@ jobs: print(f"Event name: {claims.get('event_name', 'unknown')}") print("===============================\n") - # Try token exchange with several possible audience values - audience_values = [ - "https://github.com/databricks", # Standard audience for GitHub tokens - "https://github.com", # Alternative audience - None # No audience - ] - - # Direct token exchange test - access_token = None - for audience in audience_values: - print(f"\n=== Testing Direct Token Exchange (audience={audience}) ===") - result = test_direct_token_exchange(host, github_token, identity_federation_client_id, audience) - if result: - print("Direct token exchange successful!") - access_token = result - token_claims = decode_jwt(result) - if token_claims: - print(f"Databricks token subject: {token_claims.get('sub', 'unknown')}") - break - print(f"Token exchange failed with audience={audience}") - # Add a small delay between attempts - time.sleep(1) - - if not access_token: - print("All token exchange attempts failed") - print("=====================================\n") - else: - print("=====================================\n") - try: # Connect to Databricks using token federation - print(f"\n=== Testing Connection via Connector ===") + print(f"=== Testing Connection via Connector ===") print(f"Connecting to Databricks at {host}{http_path}") print(f"Using client ID: {identity_federation_client_id}") From 929191bdc5972f1a1e752d90a1dbb2b697441ccd Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Thu, 8 May 2025 09:22:31 +0000 Subject: [PATCH 18/46] separate py script --- .github/workflows/token-federation-test.yml | 117 +------------------- tests/token_federation/github_oidc_test.py | 103 +++++++++++++++++ 2 files changed, 105 insertions(+), 115 deletions(-) create mode 100755 tests/token_federation/github_oidc_test.py diff --git a/.github/workflows/token-federation-test.yml b/.github/workflows/token-federation-test.yml index 3b5fbaf4a..353606c77 100644 --- a/.github/workflows/token-federation-test.yml +++ b/.github/workflows/token-federation-test.yml @@ -28,6 +28,7 @@ on: - 'src/databricks/sql/auth/token_federation.py' - 'src/databricks/sql/auth/auth.py' - 'examples/token_federation_*.py' + - 'tests/token_federation/github_oidc_test.py' branches: - main @@ -43,11 +44,6 @@ jobs: labels: linux-ubuntu-latest steps: - - name: Debug OIDC Claims - uses: github/actions-oidc-debugger@main - with: - audience: '${{ github.server_url }}/${{ github.repository_owner }}' - - name: Checkout code uses: actions/checkout@v4 @@ -71,115 +67,6 @@ jobs: core.setSecret(token) core.setOutput('token', token) - - name: Create test script - run: | - cat > test_github_token_federation.py << 'EOF' - #!/usr/bin/env python3 - - """ - Test script for Databricks SQL token federation with GitHub Actions OIDC tokens. - - This script tests the Databricks SQL connector with token federation - using a GitHub Actions OIDC token. It connects to a Databricks SQL warehouse, - runs a simple query, and shows the connected user. - """ - - import os - import sys - import json - import base64 - from databricks import sql - - def decode_jwt(token): - """Decode and return the claims from a JWT token.""" - try: - parts = token.split(".") - if len(parts) != 3: - raise ValueError("Invalid JWT format") - - payload = parts[1] - # Add padding if needed - padding = '=' * (4 - len(payload) % 4) - payload += padding - - decoded = base64.b64decode(payload) - return json.loads(decoded) - except Exception as e: - print(f"Failed to decode token: {str(e)}") - return None - - def main(): - # Get GitHub OIDC token - github_token = os.environ.get("OIDC_TOKEN") - if not github_token: - print("GitHub OIDC token not available") - sys.exit(1) - - # Get Databricks connection parameters - host = os.environ.get("DATABRICKS_HOST_FOR_TF") - http_path = os.environ.get("DATABRICKS_HTTP_PATH_FOR_TF") - identity_federation_client_id = os.environ.get("IDENTITY_FEDERATION_CLIENT_ID") - - if not host or not http_path: - print("Missing Databricks connection parameters") - sys.exit(1) - - # Display token claims for debugging - claims = decode_jwt(github_token) - if claims: - print("\n=== GitHub OIDC Token Claims ===") - print(f"Token issuer: {claims.get('iss')}") - print(f"Token subject: {claims.get('sub')}") - print(f"Token audience: {claims.get('aud')}") - print(f"Token expiration: {claims.get('exp', 'unknown')}") - print(f"Repository: {claims.get('repository', 'unknown')}") - print(f"Workflow ref: {claims.get('workflow_ref', 'unknown')}") - print(f"Event name: {claims.get('event_name', 'unknown')}") - print("===============================\n") - - try: - # Connect to Databricks using token federation - print(f"=== Testing Connection via Connector ===") - print(f"Connecting to Databricks at {host}{http_path}") - print(f"Using client ID: {identity_federation_client_id}") - - connection_params = { - "server_hostname": host, - "http_path": http_path, - "access_token": github_token, - "auth_type": "token-federation", - "identity_federation_client_id": identity_federation_client_id, - } - - print("Connection parameters:") - print(json.dumps({k: v if k != 'access_token' else '***' for k, v in connection_params.items()}, indent=2)) - - with sql.connect(**connection_params) as connection: - print("Connection established successfully") - - # Execute a simple query - cursor = connection.cursor() - cursor.execute("SELECT 1 + 1 as result") - result = cursor.fetchall() - print(f"Query result: {result[0][0]}") - - # Show current user - cursor.execute("SELECT current_user() as user") - result = cursor.fetchall() - print(f"Connected as user: {result[0][0]}") - - print("Token federation test successful!") - return True - except Exception as e: - print(f"Error connecting to Databricks: {str(e)}") - print("===================================\n") - sys.exit(1) - - if __name__ == "__main__": - main() - EOF - chmod +x test_github_token_federation.py - - name: Test token federation with GitHub OIDC token env: DATABRICKS_HOST_FOR_TF: ${{ github.event_name == 'workflow_dispatch' && inputs.databricks_host || secrets.DATABRICKS_HOST_FOR_TF }} @@ -187,4 +74,4 @@ jobs: IDENTITY_FEDERATION_CLIENT_ID: ${{ github.event_name == 'workflow_dispatch' && inputs.identity_federation_client_id || secrets.IDENTITY_FEDERATION_CLIENT_ID }} OIDC_TOKEN: ${{ steps.get-id-token.outputs.token }} run: | - python test_github_token_federation.py + python tests/token_federation/github_oidc_test.py diff --git a/tests/token_federation/github_oidc_test.py b/tests/token_federation/github_oidc_test.py new file mode 100755 index 000000000..e413d42f0 --- /dev/null +++ b/tests/token_federation/github_oidc_test.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python3 + +""" +Test script for Databricks SQL token federation with GitHub Actions OIDC tokens. + +This script tests the Databricks SQL connector with token federation +using a GitHub Actions OIDC token. It connects to a Databricks SQL warehouse, +runs a simple query, and shows the connected user. +""" + +import os +import sys +import json +import base64 +from databricks import sql + + +def decode_jwt(token): + """Decode and return the claims from a JWT token.""" + try: + parts = token.split(".") + if len(parts) != 3: + raise ValueError("Invalid JWT format") + + payload = parts[1] + # Add padding if needed + padding = '=' * (4 - len(payload) % 4) + payload += padding + + decoded = base64.b64decode(payload) + return json.loads(decoded) + except Exception as e: + print(f"Failed to decode token: {str(e)}") + return None + + +def main(): + # Get GitHub OIDC token + github_token = os.environ.get("OIDC_TOKEN") + if not github_token: + print("GitHub OIDC token not available") + sys.exit(1) + + # Get Databricks connection parameters + host = os.environ.get("DATABRICKS_HOST_FOR_TF") + http_path = os.environ.get("DATABRICKS_HTTP_PATH_FOR_TF") + identity_federation_client_id = os.environ.get("IDENTITY_FEDERATION_CLIENT_ID") + + if not host or not http_path: + print("Missing Databricks connection parameters") + sys.exit(1) + + # Display token claims for debugging + claims = decode_jwt(github_token) + if claims: + print("\n=== GitHub OIDC Token Claims ===") + print(f"Token issuer: {claims.get('iss')}") + print(f"Token subject: {claims.get('sub')}") + print(f"Token audience: {claims.get('aud')}") + print(f"Token expiration: {claims.get('exp', 'unknown')}") + print(f"Repository: {claims.get('repository', 'unknown')}") + print(f"Workflow ref: {claims.get('workflow_ref', 'unknown')}") + print(f"Event name: {claims.get('event_name', 'unknown')}") + print("===============================\n") + + try: + # Connect to Databricks using token federation + print(f"=== Testing Connection via Connector ===") + print(f"Connecting to Databricks at {host}{http_path}") + print(f"Using client ID: {identity_federation_client_id}") + + connection_params = { + "server_hostname": host, + "http_path": http_path, + "access_token": github_token, + "auth_type": "token-federation", + "identity_federation_client_id": identity_federation_client_id, + } + + with sql.connect(**connection_params) as connection: + print("Connection established successfully") + + # Execute a simple query + cursor = connection.cursor() + cursor.execute("SELECT 1 + 1 as result") + result = cursor.fetchall() + print(f"Query result: {result[0][0]}") + + # Show current user + cursor.execute("SELECT current_user() as user") + result = cursor.fetchall() + print(f"Connected as user: {result[0][0]}") + + print("Token federation test successful!") + return True + except Exception as e: + print(f"Error connecting to Databricks: {str(e)}") + print("===================================\n") + sys.exit(1) + + +if __name__ == "__main__": + main() \ No newline at end of file From 82d0be25daf48beb412f0ae65a53fbd790b3592f Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Thu, 8 May 2025 09:27:54 +0000 Subject: [PATCH 19/46] addresses codecheck errors --- src/databricks/sql/auth/token_federation.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/databricks/sql/auth/token_federation.py b/src/databricks/sql/auth/token_federation.py index 4d95ec6b3..0f4688211 100644 --- a/src/databricks/sql/auth/token_federation.py +++ b/src/databricks/sql/auth/token_federation.py @@ -69,13 +69,13 @@ def __init__(self, credentials_provider: CredentialsProvider, hostname: str, self.credentials_provider = credentials_provider self.hostname = hostname self.identity_federation_client_id = identity_federation_client_id - self.external_provider_headers = {} + self.external_provider_headers: Dict[str, str] = {} self.token = None - self.token_endpoint = None + self.token_endpoint: Optional[str] = None self.idp_endpoints = None self.openid_config = None - self.last_exchanged_token = None - self.last_external_token = None + self.last_exchanged_token: Optional[Token] = None + self.last_external_token: Optional[str] = None def auth_type(self) -> str: """Return the auth type from the underlying credentials provider.""" @@ -322,6 +322,10 @@ def _exchange_token(self, access_token: str, idp_type: str = "unknown") -> Token if not self.token_endpoint: self._init_oidc_discovery() + # Ensure token_endpoint is set + if not self.token_endpoint: + raise ValueError("Token endpoint could not be determined") + # Create request parameters params = dict(TOKEN_EXCHANGE_PARAMS) params["subject_token"] = access_token From 1e6075044a6378a01bd62cb6572bf4718a877245 Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Thu, 8 May 2025 15:01:41 +0000 Subject: [PATCH 20/46] adds unit test --- tests/unit/test_token_federation.py | 138 ++++++++++++++++++++++++++++ 1 file changed, 138 insertions(+) create mode 100644 tests/unit/test_token_federation.py diff --git a/tests/unit/test_token_federation.py b/tests/unit/test_token_federation.py new file mode 100644 index 000000000..9ade2a5bd --- /dev/null +++ b/tests/unit/test_token_federation.py @@ -0,0 +1,138 @@ +#!/usr/bin/env python3 + +""" +Unit tests for token federation functionality in the Databricks SQL connector. +""" + +import unittest +from unittest.mock import patch, MagicMock +import json +from datetime import datetime, timezone, timedelta + +from databricks.sql.auth.token_federation import ( + Token, + DatabricksTokenFederationProvider, + SimpleCredentialsProvider, + create_token_federation_provider +) + + +class TestToken(unittest.TestCase): + """Tests for the Token class.""" + + def test_token_initialization(self): + """Test Token initialization.""" + token = Token("access_token_value", "Bearer", "refresh_token_value") + self.assertEqual(token.access_token, "access_token_value") + self.assertEqual(token.token_type, "Bearer") + self.assertEqual(token.refresh_token, "refresh_token_value") + + def test_token_is_expired(self): + """Test Token is_expired method.""" + # Token with expiry in the past + past = datetime.now(tz=timezone.utc) - timedelta(hours=1) + token = Token("access_token", "Bearer", expiry=past) + self.assertTrue(token.is_expired()) + + # Token with expiry in the future + future = datetime.now(tz=timezone.utc) + timedelta(hours=1) + token = Token("access_token", "Bearer", expiry=future) + self.assertFalse(token.is_expired()) + + def test_token_needs_refresh(self): + """Test Token needs_refresh method.""" + # Token with expiry in the past + past = datetime.now(tz=timezone.utc) - timedelta(hours=1) + token = Token("access_token", "Bearer", expiry=past) + self.assertTrue(token.needs_refresh()) + + # Token with expiry in the near future (within refresh buffer) + near_future = datetime.now(tz=timezone.utc) + timedelta(minutes=4) + token = Token("access_token", "Bearer", expiry=near_future) + self.assertTrue(token.needs_refresh()) + + # Token with expiry far in the future + far_future = datetime.now(tz=timezone.utc) + timedelta(hours=1) + token = Token("access_token", "Bearer", expiry=far_future) + self.assertFalse(token.needs_refresh()) + + +class TestSimpleCredentialsProvider(unittest.TestCase): + """Tests for the SimpleCredentialsProvider class.""" + + def test_simple_credentials_provider(self): + """Test SimpleCredentialsProvider.""" + provider = SimpleCredentialsProvider("token_value", "Bearer", "custom_auth_type") + self.assertEqual(provider.auth_type(), "custom_auth_type") + + header_factory = provider() + headers = header_factory() + self.assertEqual(headers, {"Authorization": "Bearer token_value"}) + + +class TestTokenFederationProvider(unittest.TestCase): + """Tests for the DatabricksTokenFederationProvider class.""" + + def test_host_property(self): + """Test the host property of DatabricksTokenFederationProvider.""" + creds_provider = SimpleCredentialsProvider("token") + federation_provider = DatabricksTokenFederationProvider( + creds_provider, "example.com", "client_id" + ) + self.assertEqual(federation_provider.host, "example.com") + self.assertEqual(federation_provider.hostname, "example.com") + + @patch('databricks.sql.auth.token_federation.requests.get') + @patch('databricks.sql.auth.token_federation.get_oauth_endpoints') + def test_init_oidc_discovery(self, mock_get_endpoints, mock_requests_get): + """Test _init_oidc_discovery method.""" + # Mock the get_oauth_endpoints function + mock_endpoints = MagicMock() + mock_endpoints.get_openid_config_url.return_value = "https://example.com/openid-config" + mock_get_endpoints.return_value = mock_endpoints + + # Mock the requests.get response + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"token_endpoint": "https://example.com/token"} + mock_requests_get.return_value = mock_response + + # Create the provider + creds_provider = SimpleCredentialsProvider("token") + federation_provider = DatabricksTokenFederationProvider( + creds_provider, "example.com", "client_id" + ) + + # Call the method + federation_provider._init_oidc_discovery() + + # Check if the token endpoint was set correctly + self.assertEqual(federation_provider.token_endpoint, "https://example.com/token") + + # Test fallback when discovery fails + mock_requests_get.side_effect = Exception("Connection error") + federation_provider.token_endpoint = None + federation_provider._init_oidc_discovery() + self.assertEqual(federation_provider.token_endpoint, "https://example.com/oidc/v1/token") + + +class TestTokenFederationFactory(unittest.TestCase): + """Tests for the token federation factory function.""" + + def test_create_token_federation_provider(self): + """Test create_token_federation_provider function.""" + provider = create_token_federation_provider( + "token_value", "example.com", "client_id", "Bearer" + ) + + self.assertIsInstance(provider, DatabricksTokenFederationProvider) + self.assertEqual(provider.hostname, "example.com") + self.assertEqual(provider.identity_federation_client_id, "client_id") + + # Test that the underlying credentials provider was set up correctly + self.assertEqual(provider.credentials_provider.token, "token_value") + self.assertEqual(provider.credentials_provider.token_type, "Bearer") + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From de484119fbecc2b5f708f623f3f3b8e2e1bc323f Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Thu, 8 May 2025 15:12:11 +0000 Subject: [PATCH 21/46] Fix: Apply Black formatting to auth and token_federation modules --- src/databricks/sql/auth/auth.py | 33 ++- src/databricks/sql/auth/token_federation.py | 267 ++++++++++++-------- 2 files changed, 179 insertions(+), 121 deletions(-) diff --git a/src/databricks/sql/auth/auth.py b/src/databricks/sql/auth/auth.py index 635563ce0..060c3bfa3 100755 --- a/src/databricks/sql/auth/auth.py +++ b/src/databricks/sql/auth/auth.py @@ -51,36 +51,41 @@ def get_auth_provider(cfg: ClientContext): # If token federation is enabled and credentials provider is provided, # wrap the credentials provider with DatabricksTokenFederationProvider if cfg.auth_type == AuthType.TOKEN_FEDERATION.value: - from databricks.sql.auth.token_federation import DatabricksTokenFederationProvider + from databricks.sql.auth.token_federation import ( + DatabricksTokenFederationProvider, + ) + federation_provider = DatabricksTokenFederationProvider( cfg.credentials_provider, cfg.hostname, - cfg.identity_federation_client_id + cfg.identity_federation_client_id, ) return ExternalAuthProvider(federation_provider) - + # If access token is provided with token federation, create a SimpleCredentialsProvider elif cfg.auth_type == AuthType.TOKEN_FEDERATION.value and cfg.access_token: - from databricks.sql.auth.token_federation import create_token_federation_provider + from databricks.sql.auth.token_federation import ( + create_token_federation_provider, + ) + federation_provider = create_token_federation_provider( - cfg.access_token, - cfg.hostname, - cfg.identity_federation_client_id + cfg.access_token, cfg.hostname, cfg.identity_federation_client_id ) return ExternalAuthProvider(federation_provider) - + return ExternalAuthProvider(cfg.credentials_provider) - + if cfg.auth_type == AuthType.TOKEN_FEDERATION.value and cfg.access_token: # If only access_token is provided with token federation, use create_token_federation_provider - from databricks.sql.auth.token_federation import create_token_federation_provider + from databricks.sql.auth.token_federation import ( + create_token_federation_provider, + ) + federation_provider = create_token_federation_provider( - cfg.access_token, - cfg.hostname, - cfg.identity_federation_client_id + cfg.access_token, cfg.hostname, cfg.identity_federation_client_id ) return ExternalAuthProvider(federation_provider) - + if cfg.auth_type in [AuthType.DATABRICKS_OAUTH.value, AuthType.AZURE_OAUTH.value]: assert cfg.oauth_redirect_port_range is not None assert cfg.oauth_client_id is not None diff --git a/src/databricks/sql/auth/token_federation.py b/src/databricks/sql/auth/token_federation.py index 0f4688211..7a45aadf8 100644 --- a/src/databricks/sql/auth/token_federation.py +++ b/src/databricks/sql/auth/token_federation.py @@ -10,7 +10,11 @@ from requests.exceptions import RequestException from databricks.sql.auth.authenticators import CredentialsProvider, HeaderFactory -from databricks.sql.auth.endpoint import get_databricks_oidc_url, get_oauth_endpoints, infer_cloud_from_host +from databricks.sql.auth.endpoint import ( + get_databricks_oidc_url, + get_oauth_endpoints, + infer_cloud_from_host, +) logger = logging.getLogger(__name__) @@ -18,7 +22,7 @@ "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", "scope": "sql", "subject_token_type": "urn:ietf:params:oauth:token-type:jwt", - "return_original_token_if_authenticated": "true" + "return_original_token_if_authenticated": "true", } # Special client IDs for different IdPs @@ -27,24 +31,31 @@ # Buffer time in seconds before token expiry to trigger a refresh (5 minutes) TOKEN_REFRESH_BUFFER_SECONDS = 300 + class Token: """Represents an OAuth token with expiry information.""" - - def __init__(self, access_token: str, token_type: str, refresh_token: str = "", expiry: Optional[datetime] = None): + + def __init__( + self, + access_token: str, + token_type: str, + refresh_token: str = "", + expiry: Optional[datetime] = None, + ): self.access_token = access_token self.token_type = token_type self.refresh_token = refresh_token self.expiry = expiry or datetime.now(tz=timezone.utc) - + def is_expired(self) -> bool: """Check if the token is expired.""" return datetime.now(tz=timezone.utc) >= self.expiry - + def needs_refresh(self) -> bool: """Check if the token needs to be refreshed soon.""" buffer_time = timedelta(seconds=TOKEN_REFRESH_BUFFER_SECONDS) return datetime.now(tz=timezone.utc) >= (self.expiry - buffer_time) - + def __str__(self) -> str: return f"{self.token_type} {self.access_token}" @@ -55,12 +66,16 @@ class DatabricksTokenFederationProvider(CredentialsProvider): for a Databricks InHouse Token. This class exchanges the access token if the issued token is not from the same host as the Databricks host. """ - - def __init__(self, credentials_provider: CredentialsProvider, hostname: str, - identity_federation_client_id: Optional[str] = None): + + def __init__( + self, + credentials_provider: CredentialsProvider, + hostname: str, + identity_federation_client_id: Optional[str] = None, + ): """ Initialize the token federation provider. - + Args: credentials_provider: The underlying credentials provider hostname: The Databricks hostname @@ -76,81 +91,90 @@ def __init__(self, credentials_provider: CredentialsProvider, hostname: str, self.openid_config = None self.last_exchanged_token: Optional[Token] = None self.last_external_token: Optional[str] = None - + def auth_type(self) -> str: """Return the auth type from the underlying credentials provider.""" return self.credentials_provider.auth_type() - + @property def host(self) -> str: """ Alias for hostname to maintain compatibility with code expecting a host attribute. - + Returns: str: The hostname value """ return self.hostname - + def __call__(self, *args, **kwargs) -> HeaderFactory: """ Configure and return a HeaderFactory that provides authentication headers. - + This is called by the ExternalAuthProvider to get headers for authentication. """ # First call the underlying credentials provider to get its headers header_factory = self.credentials_provider(*args, **kwargs) - + # Initialize OIDC discovery self._init_oidc_discovery() - + def get_headers() -> Dict[str, str]: # Get headers from the underlying provider self.external_provider_headers = header_factory() - + # Extract the token from the headers - token_info = self._extract_token_info_from_header(self.external_provider_headers) + token_info = self._extract_token_info_from_header( + self.external_provider_headers + ) token_type, access_token = token_info - + try: # Check if we need to refresh the token - if (self.last_exchanged_token and self.last_external_token == access_token and - self.last_exchanged_token.needs_refresh()): + if ( + self.last_exchanged_token + and self.last_external_token == access_token + and self.last_exchanged_token.needs_refresh() + ): # The token is approaching expiry, try to refresh logger.debug("Exchanged token approaching expiry, refreshing...") return self._refresh_token(access_token, token_type) - + # Parse the JWT to get claims token_claims = self._parse_jwt_claims(access_token) - + # Check if token needs to be exchanged if self._is_same_host(token_claims.get("iss", ""), self.hostname): # Token is from the same host, no need to exchange return self.external_provider_headers else: # Token is from a different host, need to exchange - return self._try_token_exchange_or_fallback(access_token, token_type) - + return self._try_token_exchange_or_fallback( + access_token, token_type + ) + except Exception as e: logger.error(f"Failed to process token: {str(e)}") # Fall back to original headers in case of error return self.external_provider_headers - + return get_headers - + def _init_oidc_discovery(self): """Initialize OIDC discovery to find token endpoint.""" if self.token_endpoint is not None: return - + try: # Use the existing OIDC discovery mechanism use_azure_auth = infer_cloud_from_host(self.hostname) == "azure" self.idp_endpoints = get_oauth_endpoints(self.hostname, use_azure_auth) - + if self.idp_endpoints: # Get the OpenID configuration URL - openid_config_url = self.idp_endpoints.get_openid_config_url(self.hostname) - + openid_config_url = self.idp_endpoints.get_openid_config_url( + self.hostname + ) + # Fetch the OpenID configuration response = requests.get(openid_config_url) if response.status_code == 200: @@ -159,42 +183,50 @@ def _init_oidc_discovery(self): self.token_endpoint = self.openid_config.get("token_endpoint") logger.info(f"Discovered token endpoint: {self.token_endpoint}") else: - logger.warning(f"Failed to fetch OpenID configuration from {openid_config_url}: {response.status_code}") - + logger.warning( + f"Failed to fetch OpenID configuration from {openid_config_url}: {response.status_code}" + ) + # Fallback to default token endpoint if discovery fails if not self.token_endpoint: # Make sure hostname has proper format with https:// prefix and trailing slash hostname = self.hostname - if not hostname.startswith('https://'): - hostname = f'https://{hostname}' - if not hostname.endswith('/'): - hostname = f'{hostname}/' + if not hostname.startswith("https://"): + hostname = f"https://{hostname}" + if not hostname.endswith("/"): + hostname = f"{hostname}/" self.token_endpoint = f"{hostname}oidc/v1/token" logger.info(f"Using default token endpoint: {self.token_endpoint}") - + except Exception as e: - logger.warning(f"OIDC discovery failed: {str(e)}. Using default token endpoint.") + logger.warning( + f"OIDC discovery failed: {str(e)}. Using default token endpoint." + ) # Make sure hostname has proper format with https:// prefix and trailing slash hostname = self.hostname - if not hostname.startswith('https://'): - hostname = f'https://{hostname}' - if not hostname.endswith('/'): - hostname = f'{hostname}/' + if not hostname.startswith("https://"): + hostname = f"https://{hostname}" + if not hostname.endswith("/"): + hostname = f"{hostname}/" self.token_endpoint = f"{hostname}oidc/v1/token" - logger.info(f"Using default token endpoint after error: {self.token_endpoint}") - - def _extract_token_info_from_header(self, headers: Dict[str, str]) -> Tuple[str, str]: + logger.info( + f"Using default token endpoint after error: {self.token_endpoint}" + ) + + def _extract_token_info_from_header( + self, headers: Dict[str, str] + ) -> Tuple[str, str]: """Extract token type and token value from authorization header.""" auth_header = headers.get("Authorization") if not auth_header: raise ValueError("No Authorization header found") - + parts = auth_header.split(" ", 1) if len(parts) != 2: raise ValueError(f"Invalid Authorization header format: {auth_header}") - + return parts[0], parts[1] - + def _parse_jwt_claims(self, token: str) -> Dict[str, Any]: """Parse JWT token claims without validation.""" try: @@ -202,29 +234,29 @@ def _parse_jwt_claims(self, token: str) -> Dict[str, Any]: parts = token.split(".") if len(parts) != 3: raise ValueError("Invalid JWT format") - + # Get the payload part (second part) payload = parts[1] - + # Add padding if needed - padding = '=' * (4 - len(payload) % 4) + padding = "=" * (4 - len(payload) % 4) payload += padding - + # Decode and parse JSON decoded = base64.b64decode(payload) return json.loads(decoded) except Exception as e: logger.error(f"Failed to parse JWT: {str(e)}") raise - + def _detect_idp_from_claims(self, token_claims: Dict[str, Any]) -> str: """ Detect the identity provider type from token claims. - + This can be used to adjust token exchange parameters based on the IdP. """ issuer = token_claims.get("iss", "") - + if "login.microsoftonline.com" in issuer or "sts.windows.net" in issuer: return "azure" elif "token.actions.githubusercontent.com" in issuer: @@ -235,7 +267,7 @@ def _detect_idp_from_claims(self, token_claims: Dict[str, Any]) -> str: return "aws" else: return "unknown" - + def _is_same_host(self, url1: str, url2: str) -> bool: """Check if two URLs have the same host.""" try: @@ -248,18 +280,18 @@ def _is_same_host(self, url1: str, url2: str) -> bool: except Exception as e: logger.error(f"Failed to parse URLs: {str(e)}") return False - + def _refresh_token(self, access_token: str, token_type: str) -> Dict[str, str]: """ Attempt to refresh an expired token. - + For most OAuth implementations, refreshing involves a new token exchange with the latest external token. - + Args: access_token: The original external access token token_type: The token type (Bearer, etc.) - + Returns: The headers with the fresh token """ @@ -268,72 +300,82 @@ def _refresh_token(self, access_token: str, token_type: str) -> Dict[str, str]: # For most federation implementations, refresh is just a new token exchange token_claims = self._parse_jwt_claims(access_token) idp_type = self._detect_idp_from_claims(token_claims) - + # Perform a new token exchange refreshed_token = self._exchange_token(access_token, idp_type) - - # Update the stored token + + # Update the stored token self.last_exchanged_token = refreshed_token self.last_external_token = access_token - + # Create new headers with the refreshed token headers = dict(self.external_provider_headers) - headers["Authorization"] = f"{refreshed_token.token_type} {refreshed_token.access_token}" + headers[ + "Authorization" + ] = f"{refreshed_token.token_type} {refreshed_token.access_token}" return headers except Exception as e: - logger.error(f"Token refresh failed, falling back to original token: {str(e)}") + logger.error( + f"Token refresh failed, falling back to original token: {str(e)}" + ) # If refresh fails, fall back to the original headers return self.external_provider_headers - - def _try_token_exchange_or_fallback(self, access_token: str, token_type: str) -> Dict[str, str]: + + def _try_token_exchange_or_fallback( + self, access_token: str, token_type: str + ) -> Dict[str, str]: """Try to exchange the token or fall back to the original token.""" try: # Parse the token to get claims for IdP-specific adjustments token_claims = self._parse_jwt_claims(access_token) idp_type = self._detect_idp_from_claims(token_claims) - + # Exchange the token exchanged_token = self._exchange_token(access_token, idp_type) - + # Store the exchanged token for potential refresh later self.last_exchanged_token = exchanged_token self.last_external_token = access_token - + # Create new headers with the exchanged token headers = dict(self.external_provider_headers) - headers["Authorization"] = f"{exchanged_token.token_type} {exchanged_token.access_token}" + headers[ + "Authorization" + ] = f"{exchanged_token.token_type} {exchanged_token.access_token}" return headers except Exception as e: - logger.error(f"Token exchange failed, falling back to using external token: {str(e)}") + logger.error( + f"Token exchange failed, falling back to using external token: {str(e)}" + ) # Fall back to original headers return self.external_provider_headers - + def _exchange_token(self, access_token: str, idp_type: str = "unknown") -> Token: """ Exchange an external token for a Databricks token. - + Args: access_token: The external token to exchange idp_type: The detected identity provider type (azure, github, etc.) - + Returns: A Token object containing the exchanged token """ if not self.token_endpoint: self._init_oidc_discovery() - + # Ensure token_endpoint is set if not self.token_endpoint: raise ValueError("Token endpoint could not be determined") - + # Create request parameters params = dict(TOKEN_EXCHANGE_PARAMS) params["subject_token"] = access_token - + # Add client ID if available if self.identity_federation_client_id: params["client_id"] = self.identity_federation_client_id - + # Make IdP-specific adjustments if idp_type == "azure": # For Azure AD, add special handling if needed @@ -341,39 +383,40 @@ def _exchange_token(self, access_token: str, idp_type: str = "unknown") -> Token elif idp_type == "github": # For GitHub Actions, add special handling if needed pass - + # Set up headers - headers = { - "Accept": "*/*", - "Content-Type": "application/x-www-form-urlencoded" - } - + headers = {"Accept": "*/*", "Content-Type": "application/x-www-form-urlencoded"} + try: # Make the token exchange request response = requests.post(self.token_endpoint, data=params, headers=headers) response.raise_for_status() - + # Parse the response resp_data = response.json() - + # Create a token from the response token = Token( access_token=resp_data.get("access_token"), token_type=resp_data.get("token_type", "Bearer"), refresh_token=resp_data.get("refresh_token", ""), ) - + # Set expiry time from the response's expires_in field if available # This is the standard OAuth approach if "expires_in" in resp_data and resp_data["expires_in"]: try: # Calculate expiry by adding expires_in seconds to current time expires_in_seconds = int(resp_data["expires_in"]) - token.expiry = datetime.now(tz=timezone.utc) + timedelta(seconds=expires_in_seconds) + token.expiry = datetime.now(tz=timezone.utc) + timedelta( + seconds=expires_in_seconds + ) logger.debug(f"Token expiry set from expires_in: {token.expiry}") except (ValueError, TypeError) as e: - logger.warning(f"Could not parse expires_in from response: {str(e)}") - + logger.warning( + f"Could not parse expires_in from response: {str(e)}" + ) + # If expires_in wasn't available, try to parse expiry from the token JWT if token.expiry == datetime.now(tz=timezone.utc): try: @@ -381,10 +424,12 @@ def _exchange_token(self, access_token: str, idp_type: str = "unknown") -> Token exp_time = token_claims.get("exp") if exp_time: token.expiry = datetime.fromtimestamp(exp_time, tz=timezone.utc) - logger.debug(f"Token expiry set from JWT exp claim: {token.expiry}") + logger.debug( + f"Token expiry set from JWT exp claim: {token.expiry}" + ) except Exception as e: logger.warning(f"Could not parse expiry from token: {str(e)}") - + return token except RequestException as e: logger.error(f"Failed to perform token exchange: {str(e)}") @@ -393,35 +438,43 @@ def _exchange_token(self, access_token: str, idp_type: str = "unknown") -> Token class SimpleCredentialsProvider(CredentialsProvider): """A simple credentials provider that returns fixed headers.""" - - def __init__(self, token: str, token_type: str = "Bearer", auth_type_value: str = "token"): + + def __init__( + self, token: str, token_type: str = "Bearer", auth_type_value: str = "token" + ): self.token = token self.token_type = token_type self._auth_type = auth_type_value - + def auth_type(self) -> str: return self._auth_type - + def __call__(self, *args, **kwargs) -> HeaderFactory: def get_headers() -> Dict[str, str]: return {"Authorization": f"{self.token_type} {self.token}"} + return get_headers -def create_token_federation_provider(token: str, hostname: str, - identity_federation_client_id: Optional[str] = None, - token_type: str = "Bearer") -> DatabricksTokenFederationProvider: +def create_token_federation_provider( + token: str, + hostname: str, + identity_federation_client_id: Optional[str] = None, + token_type: str = "Bearer", +) -> DatabricksTokenFederationProvider: """ Create a token federation provider using a simple token. - + Args: token: The token to use hostname: The Databricks hostname identity_federation_client_id: Optional client ID for identity federation token_type: The token type (default: "Bearer") - + Returns: A DatabricksTokenFederationProvider """ provider = SimpleCredentialsProvider(token, token_type) - return DatabricksTokenFederationProvider(provider, hostname, identity_federation_client_id) \ No newline at end of file + return DatabricksTokenFederationProvider( + provider, hostname, identity_federation_client_id + ) From d54ba9384dad2f9c97fd880a041a02fb8d21a6ed Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Thu, 8 May 2025 16:45:12 +0000 Subject: [PATCH 22/46] Enhance token federation refresh to get fresh external tokens --- src/databricks/sql/auth/token_federation.py | 45 +++++++-- tests/unit/test_token_federation.py | 57 +++++++++++ tests/unit/test_token_federation_jdbc.py | 105 ++++++++++++++++++++ 3 files changed, 196 insertions(+), 11 deletions(-) create mode 100644 tests/unit/test_token_federation_jdbc.py diff --git a/src/databricks/sql/auth/token_federation.py b/src/databricks/sql/auth/token_federation.py index 7a45aadf8..f9ea18b4c 100644 --- a/src/databricks/sql/auth/token_federation.py +++ b/src/databricks/sql/auth/token_federation.py @@ -283,33 +283,56 @@ def _is_same_host(self, url1: str, url2: str) -> bool: def _refresh_token(self, access_token: str, token_type: str) -> Dict[str, str]: """ - Attempt to refresh an expired token. + Attempt to refresh an expired token by first getting a fresh external token + and then exchanging it for a new Databricks token. - For most OAuth implementations, refreshing involves a new token exchange - with the latest external token. + This implementation follows the JDBC driver approach by first requesting + a fresh token from the underlying credentials provider before performing + the token exchange. Args: - access_token: The original external access token + access_token: The original external access token (will be replaced) token_type: The token type (Bearer, etc.) Returns: The headers with the fresh token """ try: - logger.info("Refreshing expired token via new token exchange") - # For most federation implementations, refresh is just a new token exchange - token_claims = self._parse_jwt_claims(access_token) + logger.info("Refreshing expired token by getting a new external token") + + # ENHANCEMENT: Get a fresh token from the underlying credentials provider + # instead of reusing the same access_token + fresh_headers = self.credentials_provider()() + + # Extract the fresh token from the headers + auth_header = fresh_headers.get("Authorization", "") + if not auth_header: + logger.error("No Authorization header in fresh headers") + return self.external_provider_headers + + parts = auth_header.split(" ", 1) + if len(parts) != 2: + logger.error(f"Invalid Authorization header format: {auth_header}") + return self.external_provider_headers + + fresh_token_type = parts[0] + fresh_access_token = parts[1] + + logger.debug("Got fresh external token") + + # Now process the fresh token + token_claims = self._parse_jwt_claims(fresh_access_token) idp_type = self._detect_idp_from_claims(token_claims) - # Perform a new token exchange - refreshed_token = self._exchange_token(access_token, idp_type) + # Perform a new token exchange with the fresh token + refreshed_token = self._exchange_token(fresh_access_token, idp_type) # Update the stored token self.last_exchanged_token = refreshed_token - self.last_external_token = access_token + self.last_external_token = fresh_access_token # Create new headers with the refreshed token - headers = dict(self.external_provider_headers) + headers = dict(fresh_headers) # Use the fresh headers as base headers[ "Authorization" ] = f"{refreshed_token.token_type} {refreshed_token.access_token}" diff --git a/tests/unit/test_token_federation.py b/tests/unit/test_token_federation.py index 9ade2a5bd..f04915e2f 100644 --- a/tests/unit/test_token_federation.py +++ b/tests/unit/test_token_federation.py @@ -114,6 +114,63 @@ def test_init_oidc_discovery(self, mock_get_endpoints, mock_requests_get): federation_provider.token_endpoint = None federation_provider._init_oidc_discovery() self.assertEqual(federation_provider.token_endpoint, "https://example.com/oidc/v1/token") + + @patch('databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._parse_jwt_claims') + @patch('databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._exchange_token') + @patch('databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._is_same_host') + def test_token_refresh(self, mock_is_same_host, mock_exchange_token, mock_parse_jwt): + """Test token refresh functionality for approaching expiry.""" + # Set up mocks + mock_parse_jwt.return_value = {"iss": "https://login.microsoftonline.com/tenant"} + mock_is_same_host.return_value = False + + # Create a simple credentials provider that returns a fixed token + external_token = "test_token" + creds_provider = SimpleCredentialsProvider(external_token) + + # Set up the token federation provider + federation_provider = DatabricksTokenFederationProvider( + creds_provider, "example.com", "client_id" + ) + + # Mock the token exchange to return a known token + future_time = datetime.now(tz=timezone.utc) + timedelta(hours=1) + mock_exchange_token.return_value = Token( + "exchanged_token_1", "Bearer", expiry=future_time + ) + + # First call to get initial headers and token - this should trigger an exchange + headers_factory = federation_provider() + headers = headers_factory() + + # Verify the exchange happened + mock_exchange_token.assert_called_with(external_token, "azure") + self.assertEqual(headers["Authorization"], "Bearer exchanged_token_1") + + # Reset the mocks to track the next call + mock_exchange_token.reset_mock() + + # Now simulate an approaching expiry + near_expiry = datetime.now(tz=timezone.utc) + timedelta(minutes=4) + federation_provider.last_exchanged_token = Token( + "exchanged_token_1", "Bearer", expiry=near_expiry + ) + federation_provider.last_external_token = external_token + + # Set up the mock to return a different token for the refresh + mock_exchange_token.return_value = Token( + "exchanged_token_2", "Bearer", expiry=future_time + ) + + # Make a second call which should trigger refresh + headers = headers_factory() + + # Verify the token was exchanged with the SAME external token (current implementation) + # This is different from the JDBC driver approach which gets a fresh token + mock_exchange_token.assert_called_once_with(external_token, "azure") + + # Verify the headers contain the new token + self.assertEqual(headers["Authorization"], "Bearer exchanged_token_2") class TestTokenFederationFactory(unittest.TestCase): diff --git a/tests/unit/test_token_federation_jdbc.py b/tests/unit/test_token_federation_jdbc.py new file mode 100644 index 000000000..2c53e456b --- /dev/null +++ b/tests/unit/test_token_federation_jdbc.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python3 + +""" +Unit tests for the JDBC-style token refresh in Databricks SQL connector. + +This test verifies that the token federation implementation follows the JDBC driver's approach +of getting a fresh external token before exchanging it for a Databricks token during refresh. +""" + +import unittest +from unittest.mock import patch, MagicMock +from datetime import datetime, timezone, timedelta + +from databricks.sql.auth.token_federation import ( + DatabricksTokenFederationProvider, + Token +) + + +class RefreshingCredentialsProvider: + """ + A credentials provider that returns different tokens on each call. + This simulates providers like Azure AD that can refresh their tokens. + """ + + def __init__(self): + self.call_count = 0 + + def auth_type(self): + return "bearer" + + def __call__(self, *args, **kwargs): + def get_headers(): + self.call_count += 1 + # Return a different token each time to simulate fresh tokens + return {"Authorization": f"Bearer fresh_token_{self.call_count}"} + return get_headers + + +class TestJdbcStyleTokenRefresh(unittest.TestCase): + """Tests for the JDBC-style token refresh implementation.""" + + @patch('databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._parse_jwt_claims') + @patch('databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._exchange_token') + @patch('databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._is_same_host') + def test_refresh_gets_fresh_token(self, mock_is_same_host, mock_exchange_token, mock_parse_jwt): + """Test that token refresh first gets a fresh external token.""" + # Set up mocks + mock_parse_jwt.return_value = {"iss": "https://login.microsoftonline.com/tenant"} + mock_is_same_host.return_value = False + + # Create a credentials provider that returns different tokens on each call + refreshing_provider = RefreshingCredentialsProvider() + + # Set up the token federation provider + federation_provider = DatabricksTokenFederationProvider( + refreshing_provider, "example.com", "client_id" + ) + + # Set up mock for token exchange + future_time = datetime.now(tz=timezone.utc) + timedelta(hours=1) + mock_exchange_token.return_value = Token( + "exchanged_token_1", "Bearer", expiry=future_time + ) + + # First call to get initial headers and token + headers_factory = federation_provider() + headers = headers_factory() + + # Verify the first exchange happened + mock_exchange_token.assert_called_with("fresh_token_1", "azure") + self.assertEqual(headers["Authorization"], "Bearer exchanged_token_1") + self.assertEqual(refreshing_provider.call_count, 1) + + # Reset the mock to track the next call + mock_exchange_token.reset_mock() + + # Now simulate an approaching expiry + near_expiry = datetime.now(tz=timezone.utc) + timedelta(minutes=4) + federation_provider.last_exchanged_token = Token( + "exchanged_token_1", "Bearer", expiry=near_expiry + ) + federation_provider.last_external_token = "fresh_token_1" + + # Set up the mock to return a different token for the refresh + mock_exchange_token.return_value = Token( + "exchanged_token_2", "Bearer", expiry=future_time + ) + + # Make a second call which should trigger refresh + headers = headers_factory() + + # With JDBC-style implementation: + # 1. Should call credentials provider again to get fresh token + self.assertEqual(refreshing_provider.call_count, 2) + + # 2. Should exchange the FRESH token (fresh_token_2), not the stored one (fresh_token_1) + mock_exchange_token.assert_called_once_with("fresh_token_2", "azure") + + # 3. Should return headers with the new Databricks token + self.assertEqual(headers["Authorization"], "Bearer exchanged_token_2") + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From aa2d1b907172735f2b1257b22d1a545b057342a5 Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Fri, 9 May 2025 05:51:04 +0000 Subject: [PATCH 23/46] refresh --- tests/unit/test_token_federation.py | 28 +++--- tests/unit/test_token_federation_jdbc.py | 105 ----------------------- 2 files changed, 18 insertions(+), 115 deletions(-) delete mode 100644 tests/unit/test_token_federation_jdbc.py diff --git a/tests/unit/test_token_federation.py b/tests/unit/test_token_federation.py index f04915e2f..f92c4e1ee 100644 --- a/tests/unit/test_token_federation.py +++ b/tests/unit/test_token_federation.py @@ -124,13 +124,21 @@ def test_token_refresh(self, mock_is_same_host, mock_exchange_token, mock_parse_ mock_parse_jwt.return_value = {"iss": "https://login.microsoftonline.com/tenant"} mock_is_same_host.return_value = False - # Create a simple credentials provider that returns a fixed token - external_token = "test_token" - creds_provider = SimpleCredentialsProvider(external_token) + # Create a mock credentials provider that can return different tokens + mock_creds_provider = MagicMock() + # Initial token factory + initial_header_factory = MagicMock() + initial_header_factory.return_value = {"Authorization": "Bearer initial_token"} + # Fresh token factory for refresh + fresh_header_factory = MagicMock() + fresh_header_factory.return_value = {"Authorization": "Bearer fresh_token"} + + # Configure the mock to return different header factories on consecutive calls + mock_creds_provider.side_effect = [initial_header_factory, fresh_header_factory] # Set up the token federation provider federation_provider = DatabricksTokenFederationProvider( - creds_provider, "example.com", "client_id" + mock_creds_provider, "example.com", "client_id" ) # Mock the token exchange to return a known token @@ -143,8 +151,8 @@ def test_token_refresh(self, mock_is_same_host, mock_exchange_token, mock_parse_ headers_factory = federation_provider() headers = headers_factory() - # Verify the exchange happened - mock_exchange_token.assert_called_with(external_token, "azure") + # Verify the exchange happened with the initial token + mock_exchange_token.assert_called_with("initial_token", "azure") self.assertEqual(headers["Authorization"], "Bearer exchanged_token_1") # Reset the mocks to track the next call @@ -155,7 +163,7 @@ def test_token_refresh(self, mock_is_same_host, mock_exchange_token, mock_parse_ federation_provider.last_exchanged_token = Token( "exchanged_token_1", "Bearer", expiry=near_expiry ) - federation_provider.last_external_token = external_token + federation_provider.last_external_token = "initial_token" # Set up the mock to return a different token for the refresh mock_exchange_token.return_value = Token( @@ -165,9 +173,9 @@ def test_token_refresh(self, mock_is_same_host, mock_exchange_token, mock_parse_ # Make a second call which should trigger refresh headers = headers_factory() - # Verify the token was exchanged with the SAME external token (current implementation) - # This is different from the JDBC driver approach which gets a fresh token - mock_exchange_token.assert_called_once_with(external_token, "azure") + # Verify a fresh token was requested from the credentials provider + # and the exchange was performed with the fresh token + mock_exchange_token.assert_called_once_with("fresh_token", "azure") # Verify the headers contain the new token self.assertEqual(headers["Authorization"], "Bearer exchanged_token_2") diff --git a/tests/unit/test_token_federation_jdbc.py b/tests/unit/test_token_federation_jdbc.py deleted file mode 100644 index 2c53e456b..000000000 --- a/tests/unit/test_token_federation_jdbc.py +++ /dev/null @@ -1,105 +0,0 @@ -#!/usr/bin/env python3 - -""" -Unit tests for the JDBC-style token refresh in Databricks SQL connector. - -This test verifies that the token federation implementation follows the JDBC driver's approach -of getting a fresh external token before exchanging it for a Databricks token during refresh. -""" - -import unittest -from unittest.mock import patch, MagicMock -from datetime import datetime, timezone, timedelta - -from databricks.sql.auth.token_federation import ( - DatabricksTokenFederationProvider, - Token -) - - -class RefreshingCredentialsProvider: - """ - A credentials provider that returns different tokens on each call. - This simulates providers like Azure AD that can refresh their tokens. - """ - - def __init__(self): - self.call_count = 0 - - def auth_type(self): - return "bearer" - - def __call__(self, *args, **kwargs): - def get_headers(): - self.call_count += 1 - # Return a different token each time to simulate fresh tokens - return {"Authorization": f"Bearer fresh_token_{self.call_count}"} - return get_headers - - -class TestJdbcStyleTokenRefresh(unittest.TestCase): - """Tests for the JDBC-style token refresh implementation.""" - - @patch('databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._parse_jwt_claims') - @patch('databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._exchange_token') - @patch('databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._is_same_host') - def test_refresh_gets_fresh_token(self, mock_is_same_host, mock_exchange_token, mock_parse_jwt): - """Test that token refresh first gets a fresh external token.""" - # Set up mocks - mock_parse_jwt.return_value = {"iss": "https://login.microsoftonline.com/tenant"} - mock_is_same_host.return_value = False - - # Create a credentials provider that returns different tokens on each call - refreshing_provider = RefreshingCredentialsProvider() - - # Set up the token federation provider - federation_provider = DatabricksTokenFederationProvider( - refreshing_provider, "example.com", "client_id" - ) - - # Set up mock for token exchange - future_time = datetime.now(tz=timezone.utc) + timedelta(hours=1) - mock_exchange_token.return_value = Token( - "exchanged_token_1", "Bearer", expiry=future_time - ) - - # First call to get initial headers and token - headers_factory = federation_provider() - headers = headers_factory() - - # Verify the first exchange happened - mock_exchange_token.assert_called_with("fresh_token_1", "azure") - self.assertEqual(headers["Authorization"], "Bearer exchanged_token_1") - self.assertEqual(refreshing_provider.call_count, 1) - - # Reset the mock to track the next call - mock_exchange_token.reset_mock() - - # Now simulate an approaching expiry - near_expiry = datetime.now(tz=timezone.utc) + timedelta(minutes=4) - federation_provider.last_exchanged_token = Token( - "exchanged_token_1", "Bearer", expiry=near_expiry - ) - federation_provider.last_external_token = "fresh_token_1" - - # Set up the mock to return a different token for the refresh - mock_exchange_token.return_value = Token( - "exchanged_token_2", "Bearer", expiry=future_time - ) - - # Make a second call which should trigger refresh - headers = headers_factory() - - # With JDBC-style implementation: - # 1. Should call credentials provider again to get fresh token - self.assertEqual(refreshing_provider.call_count, 2) - - # 2. Should exchange the FRESH token (fresh_token_2), not the stored one (fresh_token_1) - mock_exchange_token.assert_called_once_with("fresh_token_2", "azure") - - # 3. Should return headers with the new Databricks token - self.assertEqual(headers["Authorization"], "Bearer exchanged_token_2") - - -if __name__ == "__main__": - unittest.main() \ No newline at end of file From 34413f371c3121f79f5e771a45416d5b58d2a551 Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Fri, 9 May 2025 06:08:15 +0000 Subject: [PATCH 24/46] fmt --- src/databricks/sql/auth/token_federation.py | 95 ++++++--------------- 1 file changed, 27 insertions(+), 68 deletions(-) diff --git a/src/databricks/sql/auth/token_federation.py b/src/databricks/sql/auth/token_federation.py index f9ea18b4c..8ff613fd9 100644 --- a/src/databricks/sql/auth/token_federation.py +++ b/src/databricks/sql/auth/token_federation.py @@ -1,9 +1,8 @@ import base64 import json import logging -import urllib.parse from datetime import datetime, timezone, timedelta -from typing import Dict, Optional, Any, Tuple, List, Union +from typing import Dict, Optional, Any, Tuple from urllib.parse import urlparse import requests @@ -11,7 +10,6 @@ from databricks.sql.auth.authenticators import CredentialsProvider, HeaderFactory from databricks.sql.auth.endpoint import ( - get_databricks_oidc_url, get_oauth_endpoints, infer_cloud_from_host, ) @@ -25,11 +23,7 @@ "return_original_token_if_authenticated": "true", } -# Special client IDs for different IdPs -AZURE_AD_MULTI_TENANT_APP_ID = "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d" - -# Buffer time in seconds before token expiry to trigger a refresh (5 minutes) -TOKEN_REFRESH_BUFFER_SECONDS = 300 +TOKEN_REFRESH_BUFFER_SECONDS = 10 class Token: @@ -85,7 +79,6 @@ def __init__( self.hostname = hostname self.identity_federation_client_id = identity_federation_client_id self.external_provider_headers: Dict[str, str] = {} - self.token = None self.token_endpoint: Optional[str] = None self.idp_endpoints = None self.openid_config = None @@ -123,9 +116,7 @@ def get_headers() -> Dict[str, str]: self.external_provider_headers = header_factory() # Extract the token from the headers - token_info = self._extract_token_info_from_header( - self.external_provider_headers - ) + token_info = self._extract_token_info_from_header(self.external_provider_headers) token_type, access_token = token_info try: @@ -148,10 +139,7 @@ def get_headers() -> Dict[str, str]: return self.external_provider_headers else: # Token is from a different host, need to exchange - return self._try_token_exchange_or_fallback( - access_token, token_type - ) - + return self._try_token_exchange_or_fallback(access_token, token_type) except Exception as e: logger.error(f"Failed to process token: {str(e)}") # Fall back to original headers in case of error @@ -171,10 +159,8 @@ def _init_oidc_discovery(self): if self.idp_endpoints: # Get the OpenID configuration URL - openid_config_url = self.idp_endpoints.get_openid_config_url( - self.hostname - ) - + openid_config_url = self.idp_endpoints.get_openid_config_url(self.hostname) + # Fetch the OpenID configuration response = requests.get(openid_config_url) if response.status_code == 200: @@ -189,33 +175,26 @@ def _init_oidc_discovery(self): # Fallback to default token endpoint if discovery fails if not self.token_endpoint: - # Make sure hostname has proper format with https:// prefix and trailing slash - hostname = self.hostname - if not hostname.startswith("https://"): - hostname = f"https://{hostname}" - if not hostname.endswith("/"): - hostname = f"{hostname}/" + hostname = self._format_hostname(self.hostname) self.token_endpoint = f"{hostname}oidc/v1/token" logger.info(f"Using default token endpoint: {self.token_endpoint}") - except Exception as e: logger.warning( f"OIDC discovery failed: {str(e)}. Using default token endpoint." ) - # Make sure hostname has proper format with https:// prefix and trailing slash - hostname = self.hostname - if not hostname.startswith("https://"): - hostname = f"https://{hostname}" - if not hostname.endswith("/"): - hostname = f"{hostname}/" + hostname = self._format_hostname(self.hostname) self.token_endpoint = f"{hostname}oidc/v1/token" - logger.info( - f"Using default token endpoint after error: {self.token_endpoint}" - ) + logger.info(f"Using default token endpoint after error: {self.token_endpoint}") - def _extract_token_info_from_header( - self, headers: Dict[str, str] - ) -> Tuple[str, str]: + def _format_hostname(self, hostname: str) -> str: + """Format hostname to ensure it has proper https:// prefix and trailing slash.""" + if not hostname.startswith("https://"): + hostname = f"https://{hostname}" + if not hostname.endswith("/"): + hostname = f"{hostname}/" + return hostname + + def _extract_token_info_from_header(self, headers: Dict[str, str]) -> Tuple[str, str]: """Extract token type and token value from authorization header.""" auth_header = headers.get("Authorization") if not auth_header: @@ -286,10 +265,6 @@ def _refresh_token(self, access_token: str, token_type: str) -> Dict[str, str]: Attempt to refresh an expired token by first getting a fresh external token and then exchanging it for a new Databricks token. - This implementation follows the JDBC driver approach by first requesting - a fresh token from the underlying credentials provider before performing - the token exchange. - Args: access_token: The original external access token (will be replaced) token_type: The token type (Bearer, etc.) @@ -300,7 +275,7 @@ def _refresh_token(self, access_token: str, token_type: str) -> Dict[str, str]: try: logger.info("Refreshing expired token by getting a new external token") - # ENHANCEMENT: Get a fresh token from the underlying credentials provider + # Get a fresh token from the underlying credentials provider # instead of reusing the same access_token fresh_headers = self.credentials_provider()() @@ -333,20 +308,14 @@ def _refresh_token(self, access_token: str, token_type: str) -> Dict[str, str]: # Create new headers with the refreshed token headers = dict(fresh_headers) # Use the fresh headers as base - headers[ - "Authorization" - ] = f"{refreshed_token.token_type} {refreshed_token.access_token}" + headers["Authorization"] = f"{refreshed_token.token_type} {refreshed_token.access_token}" return headers except Exception as e: - logger.error( - f"Token refresh failed, falling back to original token: {str(e)}" - ) + logger.error(f"Token refresh failed, falling back to original token: {str(e)}") # If refresh fails, fall back to the original headers return self.external_provider_headers - def _try_token_exchange_or_fallback( - self, access_token: str, token_type: str - ) -> Dict[str, str]: + def _try_token_exchange_or_fallback(self, access_token: str, token_type: str) -> Dict[str, str]: """Try to exchange the token or fall back to the original token.""" try: # Parse the token to get claims for IdP-specific adjustments @@ -362,14 +331,10 @@ def _try_token_exchange_or_fallback( # Create new headers with the exchanged token headers = dict(self.external_provider_headers) - headers[ - "Authorization" - ] = f"{exchanged_token.token_type} {exchanged_token.access_token}" + headers["Authorization"] = f"{exchanged_token.token_type} {exchanged_token.access_token}" return headers except Exception as e: - logger.error( - f"Token exchange failed, falling back to using external token: {str(e)}" - ) + logger.error(f"Token exchange failed, falling back to using external token: {str(e)}") # Fall back to original headers return self.external_provider_headers @@ -431,14 +396,10 @@ def _exchange_token(self, access_token: str, idp_type: str = "unknown") -> Token try: # Calculate expiry by adding expires_in seconds to current time expires_in_seconds = int(resp_data["expires_in"]) - token.expiry = datetime.now(tz=timezone.utc) + timedelta( - seconds=expires_in_seconds - ) + token.expiry = datetime.now(tz=timezone.utc) + timedelta(seconds=expires_in_seconds) logger.debug(f"Token expiry set from expires_in: {token.expiry}") except (ValueError, TypeError) as e: - logger.warning( - f"Could not parse expires_in from response: {str(e)}" - ) + logger.warning(f"Could not parse expires_in from response: {str(e)}") # If expires_in wasn't available, try to parse expiry from the token JWT if token.expiry == datetime.now(tz=timezone.utc): @@ -447,9 +408,7 @@ def _exchange_token(self, access_token: str, idp_type: str = "unknown") -> Token exp_time = token_claims.get("exp") if exp_time: token.expiry = datetime.fromtimestamp(exp_time, tz=timezone.utc) - logger.debug( - f"Token expiry set from JWT exp claim: {token.expiry}" - ) + logger.debug(f"Token expiry set from JWT exp claim: {token.expiry}") except Exception as e: logger.warning(f"Could not parse expiry from token: {str(e)}") From a93dd4b049c874e968f1cafdfc2131d55ec56db5 Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Fri, 9 May 2025 06:19:58 +0000 Subject: [PATCH 25/46] clean up --- .github/workflows/token-federation-test.yml | 35 ++--- tests/token_federation/github_oidc_test.py | 141 ++++++++++++++------ tests/unit/test_token_federation.py | 33 +++-- 3 files changed, 143 insertions(+), 66 deletions(-) diff --git a/.github/workflows/token-federation-test.yml b/.github/workflows/token-federation-test.yml index 353606c77..74b936089 100644 --- a/.github/workflows/token-federation-test.yml +++ b/.github/workflows/token-federation-test.yml @@ -1,8 +1,6 @@ name: Token Federation Test -# This workflow tests token federation functionality with GitHub Actions OIDC tokens -# in the databricks-sql-python connector to ensure CI/CD functionality - +# Tests token federation functionality with GitHub Actions OIDC tokens on: # Manual trigger with required inputs workflow_dispatch: @@ -17,31 +15,34 @@ on: description: 'Identity federation client ID' required: true - # Automatically run on PR that changes token federation files + # Run on PRs that might affect token federation pull_request: - branches: - - main + branches: [main] + paths: + - 'src/databricks/sql/auth/**' + - 'examples/token_federation_*.py' + - 'tests/token_federation/**' + - '.github/workflows/token-federation-test.yml' # Run on push to main that affects token federation push: + branches: [main] paths: - - 'src/databricks/sql/auth/token_federation.py' - - 'src/databricks/sql/auth/auth.py' + - 'src/databricks/sql/auth/**' - 'examples/token_federation_*.py' - - 'tests/token_federation/github_oidc_test.py' - branches: - - main + - 'tests/token_federation/**' + - '.github/workflows/token-federation-test.yml' permissions: - # Required for GitHub OIDC token - id-token: write + id-token: write # Required for GitHub OIDC token contents: read jobs: test-token-federation: + name: Test Token Federation runs-on: - group: databricks-protected-runner-group - labels: linux-ubuntu-latest + group: databricks-protected-runner-group + labels: linux-ubuntu-latest steps: - name: Checkout code @@ -51,6 +52,7 @@ jobs: uses: actions/setup-python@v5 with: python-version: '3.9' + cache: 'pip' - name: Install dependencies run: | @@ -73,5 +75,4 @@ jobs: DATABRICKS_HTTP_PATH_FOR_TF: ${{ github.event_name == 'workflow_dispatch' && inputs.databricks_http_path || secrets.DATABRICKS_HTTP_PATH_FOR_TF }} IDENTITY_FEDERATION_CLIENT_ID: ${{ github.event_name == 'workflow_dispatch' && inputs.identity_federation_client_id || secrets.IDENTITY_FEDERATION_CLIENT_ID }} OIDC_TOKEN: ${{ steps.get-id-token.outputs.token }} - run: | - python tests/token_federation/github_oidc_test.py + run: python tests/token_federation/github_oidc_test.py diff --git a/tests/token_federation/github_oidc_test.py b/tests/token_federation/github_oidc_test.py index e413d42f0..79fc40b34 100755 --- a/tests/token_federation/github_oidc_test.py +++ b/tests/token_federation/github_oidc_test.py @@ -12,11 +12,27 @@ import sys import json import base64 +import logging from databricks import sql +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + + def decode_jwt(token): - """Decode and return the claims from a JWT token.""" + """ + Decode and return the claims from a JWT token. + + Args: + token: The JWT token string + + Returns: + dict: The decoded token claims or None if decoding fails + """ try: parts = token.split(".") if len(parts) != 3: @@ -30,72 +46,121 @@ def decode_jwt(token): decoded = base64.b64decode(payload) return json.loads(decoded) except Exception as e: - print(f"Failed to decode token: {str(e)}") + logger.error(f"Failed to decode token: {str(e)}") return None -def main(): - # Get GitHub OIDC token +def get_environment_variables(): + """ + Get required environment variables for the test. + + Returns: + tuple: (github_token, host, http_path, identity_federation_client_id) + + Raises: + SystemExit: If any required environment variable is missing + """ github_token = os.environ.get("OIDC_TOKEN") if not github_token: - print("GitHub OIDC token not available") + logger.error("GitHub OIDC token not available") sys.exit(1) - # Get Databricks connection parameters host = os.environ.get("DATABRICKS_HOST_FOR_TF") http_path = os.environ.get("DATABRICKS_HTTP_PATH_FOR_TF") identity_federation_client_id = os.environ.get("IDENTITY_FEDERATION_CLIENT_ID") if not host or not http_path: - print("Missing Databricks connection parameters") + logger.error("Missing Databricks connection parameters") sys.exit(1) - # Display token claims for debugging - claims = decode_jwt(github_token) - if claims: - print("\n=== GitHub OIDC Token Claims ===") - print(f"Token issuer: {claims.get('iss')}") - print(f"Token subject: {claims.get('sub')}") - print(f"Token audience: {claims.get('aud')}") - print(f"Token expiration: {claims.get('exp', 'unknown')}") - print(f"Repository: {claims.get('repository', 'unknown')}") - print(f"Workflow ref: {claims.get('workflow_ref', 'unknown')}") - print(f"Event name: {claims.get('event_name', 'unknown')}") - print("===============================\n") - - try: - # Connect to Databricks using token federation - print(f"=== Testing Connection via Connector ===") - print(f"Connecting to Databricks at {host}{http_path}") - print(f"Using client ID: {identity_federation_client_id}") + return github_token, host, http_path, identity_federation_client_id + + +def display_token_info(claims): + """Display token claims for debugging.""" + if not claims: + logger.warning("No token claims available to display") + return - connection_params = { - "server_hostname": host, - "http_path": http_path, - "access_token": github_token, - "auth_type": "token-federation", - "identity_federation_client_id": identity_federation_client_id, - } + logger.info("=== GitHub OIDC Token Claims ===") + logger.info(f"Token issuer: {claims.get('iss')}") + logger.info(f"Token subject: {claims.get('sub')}") + logger.info(f"Token audience: {claims.get('aud')}") + logger.info(f"Token expiration: {claims.get('exp', 'unknown')}") + logger.info(f"Repository: {claims.get('repository', 'unknown')}") + logger.info(f"Workflow ref: {claims.get('workflow_ref', 'unknown')}") + logger.info(f"Event name: {claims.get('event_name', 'unknown')}") + logger.info("===============================") + + +def test_databricks_connection(host, http_path, github_token, identity_federation_client_id): + """ + Test connection to Databricks using token federation. + + Args: + host: Databricks host + http_path: Databricks HTTP path + github_token: GitHub OIDC token + identity_federation_client_id: Identity federation client ID + Returns: + bool: True if the test is successful, False otherwise + """ + logger.info("=== Testing Connection via Connector ===") + logger.info(f"Connecting to Databricks at {host}{http_path}") + logger.info(f"Using client ID: {identity_federation_client_id}") + + connection_params = { + "server_hostname": host, + "http_path": http_path, + "access_token": github_token, + "auth_type": "token-federation", + "identity_federation_client_id": identity_federation_client_id, + } + + try: with sql.connect(**connection_params) as connection: - print("Connection established successfully") + logger.info("Connection established successfully") # Execute a simple query cursor = connection.cursor() cursor.execute("SELECT 1 + 1 as result") result = cursor.fetchall() - print(f"Query result: {result[0][0]}") + logger.info(f"Query result: {result[0][0]}") # Show current user cursor.execute("SELECT current_user() as user") result = cursor.fetchall() - print(f"Connected as user: {result[0][0]}") + logger.info(f"Connected as user: {result[0][0]}") - print("Token federation test successful!") + logger.info("Token federation test successful!") return True except Exception as e: - print(f"Error connecting to Databricks: {str(e)}") - print("===================================\n") + logger.error(f"Error connecting to Databricks: {str(e)}") + return False + + +def main(): + """Main entry point for the test script.""" + try: + # Get environment variables + github_token, host, http_path, identity_federation_client_id = get_environment_variables() + + # Display token claims + claims = decode_jwt(github_token) + display_token_info(claims) + + # Test Databricks connection + success = test_databricks_connection( + host, http_path, github_token, identity_federation_client_id + ) + + if not success: + logger.error("Token federation test failed") + sys.exit(1) + + except Exception as e: + logger.error(f"Unexpected error: {str(e)}") sys.exit(1) diff --git a/tests/unit/test_token_federation.py b/tests/unit/test_token_federation.py index f92c4e1ee..2a0ad6fbc 100644 --- a/tests/unit/test_token_federation.py +++ b/tests/unit/test_token_federation.py @@ -13,7 +13,8 @@ Token, DatabricksTokenFederationProvider, SimpleCredentialsProvider, - create_token_federation_provider + create_token_federation_provider, + TOKEN_REFRESH_BUFFER_SECONDS ) @@ -47,12 +48,12 @@ def test_token_needs_refresh(self): self.assertTrue(token.needs_refresh()) # Token with expiry in the near future (within refresh buffer) - near_future = datetime.now(tz=timezone.utc) + timedelta(minutes=4) + near_future = datetime.now(tz=timezone.utc) + timedelta(seconds=TOKEN_REFRESH_BUFFER_SECONDS - 60) token = Token("access_token", "Bearer", expiry=near_future) self.assertTrue(token.needs_refresh()) # Token with expiry far in the future - far_future = datetime.now(tz=timezone.utc) + timedelta(hours=1) + far_future = datetime.now(tz=timezone.utc) + timedelta(seconds=TOKEN_REFRESH_BUFFER_SECONDS + 60) token = Token("access_token", "Bearer", expiry=far_future) self.assertFalse(token.needs_refresh()) @@ -118,22 +119,30 @@ def test_init_oidc_discovery(self, mock_get_endpoints, mock_requests_get): @patch('databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._parse_jwt_claims') @patch('databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._exchange_token') @patch('databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._is_same_host') - def test_token_refresh(self, mock_is_same_host, mock_exchange_token, mock_parse_jwt): + @patch('databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._detect_idp_from_claims') + def test_token_refresh(self, mock_detect_idp, mock_is_same_host, mock_exchange_token, mock_parse_jwt): """Test token refresh functionality for approaching expiry.""" # Set up mocks mock_parse_jwt.return_value = {"iss": "https://login.microsoftonline.com/tenant"} mock_is_same_host.return_value = False + mock_detect_idp.return_value = "azure" - # Create a mock credentials provider that can return different tokens + # Create mock credentials provider that can return different tokens for different calls mock_creds_provider = MagicMock() - # Initial token factory + + # First call returns initial_token, second call returns fresh_token + initial_headers = {"Authorization": "Bearer initial_token"} + fresh_headers = {"Authorization": "Bearer fresh_token"} + + # Set up initial header factory initial_header_factory = MagicMock() - initial_header_factory.return_value = {"Authorization": "Bearer initial_token"} - # Fresh token factory for refresh + initial_header_factory.return_value = initial_headers + + # Set up fresh header factory for second call fresh_header_factory = MagicMock() - fresh_header_factory.return_value = {"Authorization": "Bearer fresh_token"} + fresh_header_factory.return_value = fresh_headers - # Configure the mock to return different header factories on consecutive calls + # Configure the mock to return factories mock_creds_provider.side_effect = [initial_header_factory, fresh_header_factory] # Set up the token federation provider @@ -157,9 +166,11 @@ def test_token_refresh(self, mock_is_same_host, mock_exchange_token, mock_parse_ # Reset the mocks to track the next call mock_exchange_token.reset_mock() + mock_creds_provider.reset_mock() + mock_creds_provider.return_value = fresh_header_factory # Now simulate an approaching expiry - near_expiry = datetime.now(tz=timezone.utc) + timedelta(minutes=4) + near_expiry = datetime.now(tz=timezone.utc) + timedelta(seconds=TOKEN_REFRESH_BUFFER_SECONDS - 60) federation_provider.last_exchanged_token = Token( "exchanged_token_1", "Bearer", expiry=near_expiry ) From 76df22ee274bbf4726c954347559f9ef95d88694 Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Fri, 9 May 2025 06:36:53 +0000 Subject: [PATCH 26/46] update and add todo for future work --- poetry.lock | 2 +- pyproject.toml | 2 +- src/databricks/sql/auth/auth.py | 10 ++++ src/databricks/sql/auth/token_federation.py | 54 +++++++++++++++------ 4 files changed, 52 insertions(+), 16 deletions(-) diff --git a/poetry.lock b/poetry.lock index 5d6a0891e..678804586 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1176,4 +1176,4 @@ pyarrow = ["pyarrow", "pyarrow"] [metadata] lock-version = "2.1" python-versions = "^3.8.0" -content-hash = "118b7702637d44a7fee4107b471528b14c436bdb01d3618676bc50bbebc6ab65" +content-hash = "aa36901ed7501adeeba5384352904ba06a34d298e400e926201e0fd57f6b6678" diff --git a/pyproject.toml b/pyproject.toml index d40255a24..7d326b2bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ PyJWT = ">=2.0.0" [tool.poetry.extras] pyarrow = ["pyarrow"] -[tool.poetry.dev-dependencies] +[tool.poetry.group.dev.dependencies] pytest = "^7.1.2" mypy = "^1.10.1" pylint = ">=2.12.0" diff --git a/src/databricks/sql/auth/auth.py b/src/databricks/sql/auth/auth.py index 060c3bfa3..6a1e89fe5 100755 --- a/src/databricks/sql/auth/auth.py +++ b/src/databricks/sql/auth/auth.py @@ -13,6 +13,8 @@ class AuthType(Enum): DATABRICKS_OAUTH = "databricks-oauth" AZURE_OAUTH = "azure-oauth" + # TODO: Token federation should be a feature that works with different auth types, + # not an auth type itself. This will be refactored in a future release. TOKEN_FEDERATION = "token-federation" # other supported types (access_token) can be inferred # we can add more types as needed later @@ -47,6 +49,10 @@ def __init__( def get_auth_provider(cfg: ClientContext): + # TODO: In a future refactoring, token federation should be a feature that wraps + # any auth provider, not a separate auth type. The code below treats it as an auth type + # for backward compatibility, but this approach will be revised. + if cfg.credentials_provider: # If token federation is enabled and credentials provider is provided, # wrap the credentials provider with DatabricksTokenFederationProvider @@ -153,6 +159,10 @@ def get_python_sql_connector_auth_provider(hostname: str, **kwargs): "Please use OAuth or access token instead." ) + # TODO: Future refactoring needed: + # - Add a use_token_federation flag that can be combined with any auth type + # - Remove TOKEN_FEDERATION as an auth_type and properly handle the underlying auth type + # - Maintain backward compatibility during transition cfg = ClientContext( hostname=normalize_host_name(hostname), auth_type=auth_type, diff --git a/src/databricks/sql/auth/token_federation.py b/src/databricks/sql/auth/token_federation.py index 8ff613fd9..a0035e680 100644 --- a/src/databricks/sql/auth/token_federation.py +++ b/src/databricks/sql/auth/token_federation.py @@ -116,7 +116,9 @@ def get_headers() -> Dict[str, str]: self.external_provider_headers = header_factory() # Extract the token from the headers - token_info = self._extract_token_info_from_header(self.external_provider_headers) + token_info = self._extract_token_info_from_header( + self.external_provider_headers + ) token_type, access_token = token_info try: @@ -139,7 +141,9 @@ def get_headers() -> Dict[str, str]: return self.external_provider_headers else: # Token is from a different host, need to exchange - return self._try_token_exchange_or_fallback(access_token, token_type) + return self._try_token_exchange_or_fallback( + access_token, token_type + ) except Exception as e: logger.error(f"Failed to process token: {str(e)}") # Fall back to original headers in case of error @@ -159,8 +163,10 @@ def _init_oidc_discovery(self): if self.idp_endpoints: # Get the OpenID configuration URL - openid_config_url = self.idp_endpoints.get_openid_config_url(self.hostname) - + openid_config_url = self.idp_endpoints.get_openid_config_url( + self.hostname + ) + # Fetch the OpenID configuration response = requests.get(openid_config_url) if response.status_code == 200: @@ -184,7 +190,9 @@ def _init_oidc_discovery(self): ) hostname = self._format_hostname(self.hostname) self.token_endpoint = f"{hostname}oidc/v1/token" - logger.info(f"Using default token endpoint after error: {self.token_endpoint}") + logger.info( + f"Using default token endpoint after error: {self.token_endpoint}" + ) def _format_hostname(self, hostname: str) -> str: """Format hostname to ensure it has proper https:// prefix and trailing slash.""" @@ -194,7 +202,9 @@ def _format_hostname(self, hostname: str) -> str: hostname = f"{hostname}/" return hostname - def _extract_token_info_from_header(self, headers: Dict[str, str]) -> Tuple[str, str]: + def _extract_token_info_from_header( + self, headers: Dict[str, str] + ) -> Tuple[str, str]: """Extract token type and token value from authorization header.""" auth_header = headers.get("Authorization") if not auth_header: @@ -308,14 +318,20 @@ def _refresh_token(self, access_token: str, token_type: str) -> Dict[str, str]: # Create new headers with the refreshed token headers = dict(fresh_headers) # Use the fresh headers as base - headers["Authorization"] = f"{refreshed_token.token_type} {refreshed_token.access_token}" + headers[ + "Authorization" + ] = f"{refreshed_token.token_type} {refreshed_token.access_token}" return headers except Exception as e: - logger.error(f"Token refresh failed, falling back to original token: {str(e)}") + logger.error( + f"Token refresh failed, falling back to original token: {str(e)}" + ) # If refresh fails, fall back to the original headers return self.external_provider_headers - def _try_token_exchange_or_fallback(self, access_token: str, token_type: str) -> Dict[str, str]: + def _try_token_exchange_or_fallback( + self, access_token: str, token_type: str + ) -> Dict[str, str]: """Try to exchange the token or fall back to the original token.""" try: # Parse the token to get claims for IdP-specific adjustments @@ -331,10 +347,14 @@ def _try_token_exchange_or_fallback(self, access_token: str, token_type: str) -> # Create new headers with the exchanged token headers = dict(self.external_provider_headers) - headers["Authorization"] = f"{exchanged_token.token_type} {exchanged_token.access_token}" + headers[ + "Authorization" + ] = f"{exchanged_token.token_type} {exchanged_token.access_token}" return headers except Exception as e: - logger.error(f"Token exchange failed, falling back to using external token: {str(e)}") + logger.error( + f"Token exchange failed, falling back to using external token: {str(e)}" + ) # Fall back to original headers return self.external_provider_headers @@ -396,10 +416,14 @@ def _exchange_token(self, access_token: str, idp_type: str = "unknown") -> Token try: # Calculate expiry by adding expires_in seconds to current time expires_in_seconds = int(resp_data["expires_in"]) - token.expiry = datetime.now(tz=timezone.utc) + timedelta(seconds=expires_in_seconds) + token.expiry = datetime.now(tz=timezone.utc) + timedelta( + seconds=expires_in_seconds + ) logger.debug(f"Token expiry set from expires_in: {token.expiry}") except (ValueError, TypeError) as e: - logger.warning(f"Could not parse expires_in from response: {str(e)}") + logger.warning( + f"Could not parse expires_in from response: {str(e)}" + ) # If expires_in wasn't available, try to parse expiry from the token JWT if token.expiry == datetime.now(tz=timezone.utc): @@ -408,7 +432,9 @@ def _exchange_token(self, access_token: str, idp_type: str = "unknown") -> Token exp_time = token_claims.get("exp") if exp_time: token.expiry = datetime.fromtimestamp(exp_time, tz=timezone.utc) - logger.debug(f"Token expiry set from JWT exp claim: {token.expiry}") + logger.debug( + f"Token expiry set from JWT exp claim: {token.expiry}" + ) except Exception as e: logger.warning(f"Could not parse expiry from token: {str(e)}") From c37cd0190c20f376967aec30ac1f796be7e3373f Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Fri, 9 May 2025 06:48:43 +0000 Subject: [PATCH 27/46] refactoring --- src/databricks/sql/auth/auth.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/src/databricks/sql/auth/auth.py b/src/databricks/sql/auth/auth.py index 6a1e89fe5..47a43db13 100755 --- a/src/databricks/sql/auth/auth.py +++ b/src/databricks/sql/auth/auth.py @@ -14,7 +14,7 @@ class AuthType(Enum): DATABRICKS_OAUTH = "databricks-oauth" AZURE_OAUTH = "azure-oauth" # TODO: Token federation should be a feature that works with different auth types, - # not an auth type itself. This will be refactored in a future release. + # not an auth type itself. This will be refactored in a future change. TOKEN_FEDERATION = "token-federation" # other supported types (access_token) can be inferred # we can add more types as needed later @@ -68,19 +68,10 @@ def get_auth_provider(cfg: ClientContext): ) return ExternalAuthProvider(federation_provider) - # If access token is provided with token federation, create a SimpleCredentialsProvider - elif cfg.auth_type == AuthType.TOKEN_FEDERATION.value and cfg.access_token: - from databricks.sql.auth.token_federation import ( - create_token_federation_provider, - ) - - federation_provider = create_token_federation_provider( - cfg.access_token, cfg.hostname, cfg.identity_federation_client_id - ) - return ExternalAuthProvider(federation_provider) - + # If not token federation, just use the credentials provider directly return ExternalAuthProvider(cfg.credentials_provider) + # If we don't have a credentials provider but have token federation auth type with access token if cfg.auth_type == AuthType.TOKEN_FEDERATION.value and cfg.access_token: # If only access_token is provided with token federation, use create_token_federation_provider from databricks.sql.auth.token_federation import ( From f2d45162a860ec5ce2dc485931f9922d54856301 Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Fri, 9 May 2025 07:02:42 +0000 Subject: [PATCH 28/46] update test --- tests/unit/test_token_federation.py | 33 +++++++++++++---------------- 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/tests/unit/test_token_federation.py b/tests/unit/test_token_federation.py index 2a0ad6fbc..126b7d888 100644 --- a/tests/unit/test_token_federation.py +++ b/tests/unit/test_token_federation.py @@ -41,19 +41,19 @@ def test_token_is_expired(self): self.assertFalse(token.is_expired()) def test_token_needs_refresh(self): - """Test Token needs_refresh method.""" + """Test Token needs_refresh method using actual TOKEN_REFRESH_BUFFER_SECONDS.""" # Token with expiry in the past past = datetime.now(tz=timezone.utc) - timedelta(hours=1) token = Token("access_token", "Bearer", expiry=past) self.assertTrue(token.needs_refresh()) # Token with expiry in the near future (within refresh buffer) - near_future = datetime.now(tz=timezone.utc) + timedelta(seconds=TOKEN_REFRESH_BUFFER_SECONDS - 60) + near_future = datetime.now(tz=timezone.utc) + timedelta(seconds=TOKEN_REFRESH_BUFFER_SECONDS - 1) token = Token("access_token", "Bearer", expiry=near_future) self.assertTrue(token.needs_refresh()) # Token with expiry far in the future - far_future = datetime.now(tz=timezone.utc) + timedelta(seconds=TOKEN_REFRESH_BUFFER_SECONDS + 60) + far_future = datetime.now(tz=timezone.utc) + timedelta(seconds=TOKEN_REFRESH_BUFFER_SECONDS + 10) token = Token("access_token", "Bearer", expiry=far_future) self.assertFalse(token.needs_refresh()) @@ -127,23 +127,19 @@ def test_token_refresh(self, mock_detect_idp, mock_is_same_host, mock_exchange_t mock_is_same_host.return_value = False mock_detect_idp.return_value = "azure" - # Create mock credentials provider that can return different tokens for different calls - mock_creds_provider = MagicMock() - - # First call returns initial_token, second call returns fresh_token + # Create the initial header factory initial_headers = {"Authorization": "Bearer initial_token"} - fresh_headers = {"Authorization": "Bearer fresh_token"} - - # Set up initial header factory initial_header_factory = MagicMock() initial_header_factory.return_value = initial_headers - # Set up fresh header factory for second call + # Create the fresh header factory for later use + fresh_headers = {"Authorization": "Bearer fresh_token"} fresh_header_factory = MagicMock() fresh_header_factory.return_value = fresh_headers - # Configure the mock to return factories - mock_creds_provider.side_effect = [initial_header_factory, fresh_header_factory] + # Create the credentials provider that will return the header factory + mock_creds_provider = MagicMock() + mock_creds_provider.return_value = initial_header_factory # Set up the token federation provider federation_provider = DatabricksTokenFederationProvider( @@ -166,16 +162,18 @@ def test_token_refresh(self, mock_detect_idp, mock_is_same_host, mock_exchange_t # Reset the mocks to track the next call mock_exchange_token.reset_mock() - mock_creds_provider.reset_mock() - mock_creds_provider.return_value = fresh_header_factory # Now simulate an approaching expiry - near_expiry = datetime.now(tz=timezone.utc) + timedelta(seconds=TOKEN_REFRESH_BUFFER_SECONDS - 60) + near_expiry = datetime.now(tz=timezone.utc) + timedelta(seconds=TOKEN_REFRESH_BUFFER_SECONDS - 1) federation_provider.last_exchanged_token = Token( "exchanged_token_1", "Bearer", expiry=near_expiry ) federation_provider.last_external_token = "initial_token" + # For the refresh call, we need the credentials provider to return a fresh token + # Update the mock to return fresh_header_factory for the second call + mock_creds_provider.return_value = fresh_header_factory + # Set up the mock to return a different token for the refresh mock_exchange_token.return_value = Token( "exchanged_token_2", "Bearer", expiry=future_time @@ -184,8 +182,7 @@ def test_token_refresh(self, mock_detect_idp, mock_is_same_host, mock_exchange_t # Make a second call which should trigger refresh headers = headers_factory() - # Verify a fresh token was requested from the credentials provider - # and the exchange was performed with the fresh token + # Verify the exchange was performed with the fresh token mock_exchange_token.assert_called_once_with("fresh_token", "azure") # Verify the headers contain the new token From aeeca66dfe4f6d39be1469fae78a2cfdf26636a6 Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Fri, 9 May 2025 07:04:46 +0000 Subject: [PATCH 29/46] fmt --- src/databricks/sql/auth/auth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sql/auth/auth.py b/src/databricks/sql/auth/auth.py index 47a43db13..c679879f2 100755 --- a/src/databricks/sql/auth/auth.py +++ b/src/databricks/sql/auth/auth.py @@ -52,7 +52,7 @@ def get_auth_provider(cfg: ClientContext): # TODO: In a future refactoring, token federation should be a feature that wraps # any auth provider, not a separate auth type. The code below treats it as an auth type # for backward compatibility, but this approach will be revised. - + if cfg.credentials_provider: # If token federation is enabled and credentials provider is provided, # wrap the credentials provider with DatabricksTokenFederationProvider From ae286499a68909c6496030d69a97f4814947c017 Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Fri, 9 May 2025 07:10:54 +0000 Subject: [PATCH 30/46] remove idp detection --- src/databricks/sql/auth/token_federation.py | 76 +++++------- tests/unit/test_token_federation.py | 130 ++++++++++++-------- 2 files changed, 108 insertions(+), 98 deletions(-) diff --git a/src/databricks/sql/auth/token_federation.py b/src/databricks/sql/auth/token_federation.py index a0035e680..7f3f147de 100644 --- a/src/databricks/sql/auth/token_federation.py +++ b/src/databricks/sql/auth/token_federation.py @@ -39,7 +39,15 @@ def __init__( self.access_token = access_token self.token_type = token_type self.refresh_token = refresh_token - self.expiry = expiry or datetime.now(tz=timezone.utc) + + # Ensure expiry is timezone-aware + if expiry is None: + self.expiry = datetime.now(tz=timezone.utc) + elif expiry.tzinfo is None: + # Convert naive datetime to aware datetime + self.expiry = expiry.replace(tzinfo=timezone.utc) + else: + self.expiry = expiry def is_expired(self) -> bool: """Check if the token is expired.""" @@ -129,7 +137,9 @@ def get_headers() -> Dict[str, str]: and self.last_exchanged_token.needs_refresh() ): # The token is approaching expiry, try to refresh - logger.debug("Exchanged token approaching expiry, refreshing...") + logger.info( + "Exchanged token approaching expiry, refreshing with fresh external token..." + ) return self._refresh_token(access_token, token_type) # Parse the JWT to get claims @@ -138,14 +148,16 @@ def get_headers() -> Dict[str, str]: # Check if token needs to be exchanged if self._is_same_host(token_claims.get("iss", ""), self.hostname): # Token is from the same host, no need to exchange + logger.debug("Token from same host, no exchange needed") return self.external_provider_headers else: # Token is from a different host, need to exchange + logger.debug("Token from different host, attempting exchange") return self._try_token_exchange_or_fallback( access_token, token_type ) except Exception as e: - logger.error(f"Failed to process token: {str(e)}") + logger.error(f"Error processing token: {str(e)}") # Fall back to original headers in case of error return self.external_provider_headers @@ -238,25 +250,6 @@ def _parse_jwt_claims(self, token: str) -> Dict[str, Any]: logger.error(f"Failed to parse JWT: {str(e)}") raise - def _detect_idp_from_claims(self, token_claims: Dict[str, Any]) -> str: - """ - Detect the identity provider type from token claims. - - This can be used to adjust token exchange parameters based on the IdP. - """ - issuer = token_claims.get("iss", "") - - if "login.microsoftonline.com" in issuer or "sts.windows.net" in issuer: - return "azure" - elif "token.actions.githubusercontent.com" in issuer: - return "github" - elif "accounts.google.com" in issuer: - return "google" - elif "cognito-idp" in issuer and "amazonaws.com" in issuer: - return "aws" - else: - return "unknown" - def _is_same_host(self, url1: str, url2: str) -> bool: """Check if two URLs have the same host.""" try: @@ -283,7 +276,9 @@ def _refresh_token(self, access_token: str, token_type: str) -> Dict[str, str]: The headers with the fresh token """ try: - logger.info("Refreshing expired token by getting a new external token") + logger.info( + "Refreshing token using proactive approach (getting fresh external token first)" + ) # Get a fresh token from the underlying credentials provider # instead of reusing the same access_token @@ -303,14 +298,14 @@ def _refresh_token(self, access_token: str, token_type: str) -> Dict[str, str]: fresh_token_type = parts[0] fresh_access_token = parts[1] - logger.debug("Got fresh external token") - - # Now process the fresh token - token_claims = self._parse_jwt_claims(fresh_access_token) - idp_type = self._detect_idp_from_claims(token_claims) + # Check if we got the same token back + if fresh_access_token == access_token: + logger.warning( + "Credentials provider returned the same token during refresh" + ) # Perform a new token exchange with the fresh token - refreshed_token = self._exchange_token(fresh_access_token, idp_type) + refreshed_token = self._exchange_token(fresh_access_token) # Update the stored token self.last_exchanged_token = refreshed_token @@ -321,6 +316,10 @@ def _refresh_token(self, access_token: str, token_type: str) -> Dict[str, str]: headers[ "Authorization" ] = f"{refreshed_token.token_type} {refreshed_token.access_token}" + + logger.info( + f"Successfully refreshed token, new expiry: {refreshed_token.expiry}" + ) return headers except Exception as e: logger.error( @@ -334,12 +333,8 @@ def _try_token_exchange_or_fallback( ) -> Dict[str, str]: """Try to exchange the token or fall back to the original token.""" try: - # Parse the token to get claims for IdP-specific adjustments - token_claims = self._parse_jwt_claims(access_token) - idp_type = self._detect_idp_from_claims(token_claims) - # Exchange the token - exchanged_token = self._exchange_token(access_token, idp_type) + exchanged_token = self._exchange_token(access_token) # Store the exchanged token for potential refresh later self.last_exchanged_token = exchanged_token @@ -358,13 +353,12 @@ def _try_token_exchange_or_fallback( # Fall back to original headers return self.external_provider_headers - def _exchange_token(self, access_token: str, idp_type: str = "unknown") -> Token: + def _exchange_token(self, access_token: str) -> Token: """ Exchange an external token for a Databricks token. Args: access_token: The external token to exchange - idp_type: The detected identity provider type (azure, github, etc.) Returns: A Token object containing the exchanged token @@ -384,14 +378,6 @@ def _exchange_token(self, access_token: str, idp_type: str = "unknown") -> Token if self.identity_federation_client_id: params["client_id"] = self.identity_federation_client_id - # Make IdP-specific adjustments - if idp_type == "azure": - # For Azure AD, add special handling if needed - pass - elif idp_type == "github": - # For GitHub Actions, add special handling if needed - pass - # Set up headers headers = {"Accept": "*/*", "Content-Type": "application/x-www-form-urlencoded"} @@ -441,7 +427,7 @@ def _exchange_token(self, access_token: str, idp_type: str = "unknown") -> Token return token except RequestException as e: logger.error(f"Failed to perform token exchange: {str(e)}") - raise + raise ValueError(f"Request error during token exchange: {str(e)}") class SimpleCredentialsProvider(CredentialsProvider): diff --git a/tests/unit/test_token_federation.py b/tests/unit/test_token_federation.py index 126b7d888..78ffc9e2a 100644 --- a/tests/unit/test_token_federation.py +++ b/tests/unit/test_token_federation.py @@ -14,7 +14,7 @@ DatabricksTokenFederationProvider, SimpleCredentialsProvider, create_token_federation_provider, - TOKEN_REFRESH_BUFFER_SECONDS + TOKEN_REFRESH_BUFFER_SECONDS, ) @@ -27,45 +27,51 @@ def test_token_initialization(self): self.assertEqual(token.access_token, "access_token_value") self.assertEqual(token.token_type, "Bearer") self.assertEqual(token.refresh_token, "refresh_token_value") - + def test_token_is_expired(self): """Test Token is_expired method.""" # Token with expiry in the past past = datetime.now(tz=timezone.utc) - timedelta(hours=1) token = Token("access_token", "Bearer", expiry=past) self.assertTrue(token.is_expired()) - + # Token with expiry in the future future = datetime.now(tz=timezone.utc) + timedelta(hours=1) token = Token("access_token", "Bearer", expiry=future) self.assertFalse(token.is_expired()) - + def test_token_needs_refresh(self): """Test Token needs_refresh method using actual TOKEN_REFRESH_BUFFER_SECONDS.""" # Token with expiry in the past past = datetime.now(tz=timezone.utc) - timedelta(hours=1) token = Token("access_token", "Bearer", expiry=past) self.assertTrue(token.needs_refresh()) - + # Token with expiry in the near future (within refresh buffer) - near_future = datetime.now(tz=timezone.utc) + timedelta(seconds=TOKEN_REFRESH_BUFFER_SECONDS - 1) + near_future = datetime.now(tz=timezone.utc) + timedelta( + seconds=TOKEN_REFRESH_BUFFER_SECONDS - 1 + ) token = Token("access_token", "Bearer", expiry=near_future) self.assertTrue(token.needs_refresh()) - + # Token with expiry far in the future - far_future = datetime.now(tz=timezone.utc) + timedelta(seconds=TOKEN_REFRESH_BUFFER_SECONDS + 10) + far_future = datetime.now(tz=timezone.utc) + timedelta( + seconds=TOKEN_REFRESH_BUFFER_SECONDS + 10 + ) token = Token("access_token", "Bearer", expiry=far_future) self.assertFalse(token.needs_refresh()) class TestSimpleCredentialsProvider(unittest.TestCase): """Tests for the SimpleCredentialsProvider class.""" - + def test_simple_credentials_provider(self): """Test SimpleCredentialsProvider.""" - provider = SimpleCredentialsProvider("token_value", "Bearer", "custom_auth_type") + provider = SimpleCredentialsProvider( + "token_value", "Bearer", "custom_auth_type" + ) self.assertEqual(provider.auth_type(), "custom_auth_type") - + header_factory = provider() headers = header_factory() self.assertEqual(headers, {"Authorization": "Bearer token_value"}) @@ -73,7 +79,7 @@ def test_simple_credentials_provider(self): class TestTokenFederationProvider(unittest.TestCase): """Tests for the DatabricksTokenFederationProvider class.""" - + def test_host_property(self): """Test the host property of DatabricksTokenFederationProvider.""" creds_provider = SimpleCredentialsProvider("token") @@ -82,130 +88,148 @@ def test_host_property(self): ) self.assertEqual(federation_provider.host, "example.com") self.assertEqual(federation_provider.hostname, "example.com") - - @patch('databricks.sql.auth.token_federation.requests.get') - @patch('databricks.sql.auth.token_federation.get_oauth_endpoints') + + @patch("databricks.sql.auth.token_federation.requests.get") + @patch("databricks.sql.auth.token_federation.get_oauth_endpoints") def test_init_oidc_discovery(self, mock_get_endpoints, mock_requests_get): """Test _init_oidc_discovery method.""" # Mock the get_oauth_endpoints function mock_endpoints = MagicMock() - mock_endpoints.get_openid_config_url.return_value = "https://example.com/openid-config" + mock_endpoints.get_openid_config_url.return_value = ( + "https://example.com/openid-config" + ) mock_get_endpoints.return_value = mock_endpoints - + # Mock the requests.get response mock_response = MagicMock() mock_response.status_code = 200 - mock_response.json.return_value = {"token_endpoint": "https://example.com/token"} + mock_response.json.return_value = { + "token_endpoint": "https://example.com/token" + } mock_requests_get.return_value = mock_response - + # Create the provider creds_provider = SimpleCredentialsProvider("token") federation_provider = DatabricksTokenFederationProvider( creds_provider, "example.com", "client_id" ) - + # Call the method federation_provider._init_oidc_discovery() - + # Check if the token endpoint was set correctly - self.assertEqual(federation_provider.token_endpoint, "https://example.com/token") - + self.assertEqual( + federation_provider.token_endpoint, "https://example.com/token" + ) + # Test fallback when discovery fails mock_requests_get.side_effect = Exception("Connection error") federation_provider.token_endpoint = None federation_provider._init_oidc_discovery() - self.assertEqual(federation_provider.token_endpoint, "https://example.com/oidc/v1/token") - - @patch('databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._parse_jwt_claims') - @patch('databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._exchange_token') - @patch('databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._is_same_host') - @patch('databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._detect_idp_from_claims') - def test_token_refresh(self, mock_detect_idp, mock_is_same_host, mock_exchange_token, mock_parse_jwt): + self.assertEqual( + federation_provider.token_endpoint, "https://example.com/oidc/v1/token" + ) + + @patch( + "databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._parse_jwt_claims" + ) + @patch( + "databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._exchange_token" + ) + @patch( + "databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._is_same_host" + ) + def test_token_refresh( + self, mock_is_same_host, mock_exchange_token, mock_parse_jwt + ): """Test token refresh functionality for approaching expiry.""" # Set up mocks - mock_parse_jwt.return_value = {"iss": "https://login.microsoftonline.com/tenant"} + mock_parse_jwt.return_value = { + "iss": "https://login.microsoftonline.com/tenant" + } mock_is_same_host.return_value = False - mock_detect_idp.return_value = "azure" - + # Create the initial header factory initial_headers = {"Authorization": "Bearer initial_token"} initial_header_factory = MagicMock() initial_header_factory.return_value = initial_headers - + # Create the fresh header factory for later use fresh_headers = {"Authorization": "Bearer fresh_token"} fresh_header_factory = MagicMock() fresh_header_factory.return_value = fresh_headers - + # Create the credentials provider that will return the header factory mock_creds_provider = MagicMock() mock_creds_provider.return_value = initial_header_factory - + # Set up the token federation provider federation_provider = DatabricksTokenFederationProvider( mock_creds_provider, "example.com", "client_id" ) - + # Mock the token exchange to return a known token future_time = datetime.now(tz=timezone.utc) + timedelta(hours=1) mock_exchange_token.return_value = Token( "exchanged_token_1", "Bearer", expiry=future_time ) - + # First call to get initial headers and token - this should trigger an exchange headers_factory = federation_provider() headers = headers_factory() - + # Verify the exchange happened with the initial token - mock_exchange_token.assert_called_with("initial_token", "azure") + mock_exchange_token.assert_called_with("initial_token") self.assertEqual(headers["Authorization"], "Bearer exchanged_token_1") - + # Reset the mocks to track the next call mock_exchange_token.reset_mock() - + # Now simulate an approaching expiry - near_expiry = datetime.now(tz=timezone.utc) + timedelta(seconds=TOKEN_REFRESH_BUFFER_SECONDS - 1) + near_expiry = datetime.now(tz=timezone.utc) + timedelta( + seconds=TOKEN_REFRESH_BUFFER_SECONDS - 1 + ) federation_provider.last_exchanged_token = Token( "exchanged_token_1", "Bearer", expiry=near_expiry ) federation_provider.last_external_token = "initial_token" - + # For the refresh call, we need the credentials provider to return a fresh token # Update the mock to return fresh_header_factory for the second call mock_creds_provider.return_value = fresh_header_factory - + # Set up the mock to return a different token for the refresh mock_exchange_token.return_value = Token( "exchanged_token_2", "Bearer", expiry=future_time ) - + # Make a second call which should trigger refresh headers = headers_factory() - + # Verify the exchange was performed with the fresh token - mock_exchange_token.assert_called_once_with("fresh_token", "azure") - + mock_exchange_token.assert_called_once_with("fresh_token") + # Verify the headers contain the new token self.assertEqual(headers["Authorization"], "Bearer exchanged_token_2") class TestTokenFederationFactory(unittest.TestCase): """Tests for the token federation factory function.""" - + def test_create_token_federation_provider(self): """Test create_token_federation_provider function.""" provider = create_token_federation_provider( "token_value", "example.com", "client_id", "Bearer" ) - + self.assertIsInstance(provider, DatabricksTokenFederationProvider) self.assertEqual(provider.hostname, "example.com") self.assertEqual(provider.identity_federation_client_id, "client_id") - + # Test that the underlying credentials provider was set up correctly self.assertEqual(provider.credentials_provider.token, "token_value") self.assertEqual(provider.credentials_provider.token_type, "Bearer") if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() From 541e82fdd9b8e4e9814b67e4a48915faadca785b Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Sun, 11 May 2025 06:36:49 +0000 Subject: [PATCH 31/46] fmt --- src/databricks/sql/auth/auth.py | 62 ++- src/databricks/sql/auth/token_federation.py | 324 ++++++--------- tests/token_federation/github_oidc_test.py | 37 +- tests/unit/test_token_federation.py | 436 ++++++++++---------- 4 files changed, 421 insertions(+), 438 deletions(-) diff --git a/src/databricks/sql/auth/auth.py b/src/databricks/sql/auth/auth.py index c679879f2..f1f5543f7 100755 --- a/src/databricks/sql/auth/auth.py +++ b/src/databricks/sql/auth/auth.py @@ -15,6 +15,7 @@ class AuthType(Enum): AZURE_OAUTH = "azure-oauth" # TODO: Token federation should be a feature that works with different auth types, # not an auth type itself. This will be refactored in a future change. + # We will add a use_token_federation flag that can be used with any auth type. TOKEN_FEDERATION = "token-federation" # other supported types (access_token) can be inferred # we can add more types as needed later @@ -49,10 +50,28 @@ def __init__( def get_auth_provider(cfg: ClientContext): - # TODO: In a future refactoring, token federation should be a feature that wraps - # any auth provider, not a separate auth type. The code below treats it as an auth type - # for backward compatibility, but this approach will be revised. - + """ + Get an appropriate auth provider based on the provided configuration. + + Token Federation Support: + ----------------------- + Currently, token federation is implemented as a separate auth type, but the goal is to + refactor it as a feature that can work with any auth type. The current implementation + is maintained for backward compatibility while the refactoring is planned. + + Future refactoring will introduce a `use_token_federation` flag that can be combined + with any auth type to enable token federation. + + Args: + cfg: The client context containing configuration parameters + + Returns: + An appropriate AuthProvider instance + + Raises: + RuntimeError: If no valid authentication settings are provided + """ + # If credentials_provider is explicitly provided if cfg.credentials_provider: # If token federation is enabled and credentials provider is provided, # wrap the credentials provider with DatabricksTokenFederationProvider @@ -73,13 +92,15 @@ def get_auth_provider(cfg: ClientContext): # If we don't have a credentials provider but have token federation auth type with access token if cfg.auth_type == AuthType.TOKEN_FEDERATION.value and cfg.access_token: - # If only access_token is provided with token federation, use create_token_federation_provider + # Create a simple credentials provider and wrap it with token federation provider from databricks.sql.auth.token_federation import ( - create_token_federation_provider, + DatabricksTokenFederationProvider, + SimpleCredentialsProvider, ) - federation_provider = create_token_federation_provider( - cfg.access_token, cfg.hostname, cfg.identity_federation_client_id + simple_provider = SimpleCredentialsProvider(cfg.access_token) + federation_provider = DatabricksTokenFederationProvider( + simple_provider, cfg.hostname, cfg.identity_federation_client_id ) return ExternalAuthProvider(federation_provider) @@ -140,6 +161,27 @@ def get_client_id_and_redirect_port(use_azure_auth: bool): def get_python_sql_connector_auth_provider(hostname: str, **kwargs): + """ + Get an auth provider for the Python SQL connector. + + This function is the main entry point for authentication in the SQL connector. + It processes the parameters and creates an appropriate auth provider. + + TODO: Future refactoring needed: + 1. Add a use_token_federation flag that can be combined with any auth type + 2. Remove TOKEN_FEDERATION as an auth_type while maintaining backward compatibility + 3. Create a token federation wrapper that can wrap any existing auth provider + + Args: + hostname: The Databricks server hostname + **kwargs: Additional configuration parameters + + Returns: + An appropriate AuthProvider instance + + Raises: + ValueError: If username/password authentication is attempted (no longer supported) + """ auth_type = kwargs.get("auth_type") (client_id, redirect_port_range) = get_client_id_and_redirect_port( auth_type == AuthType.AZURE_OAUTH.value @@ -150,10 +192,6 @@ def get_python_sql_connector_auth_provider(hostname: str, **kwargs): "Please use OAuth or access token instead." ) - # TODO: Future refactoring needed: - # - Add a use_token_federation flag that can be combined with any auth type - # - Remove TOKEN_FEDERATION as an auth_type and properly handle the underlying auth type - # - Maintain backward compatibility during transition cfg = ClientContext( hostname=normalize_host_name(hostname), auth_type=auth_type, diff --git a/src/databricks/sql/auth/token_federation.py b/src/databricks/sql/auth/token_federation.py index 7f3f147de..e92f9ccb5 100644 --- a/src/databricks/sql/auth/token_federation.py +++ b/src/databricks/sql/auth/token_federation.py @@ -64,8 +64,8 @@ def __str__(self) -> str: class DatabricksTokenFederationProvider(CredentialsProvider): """ - Implementation of the Credential Provider that exchanges the third party access token - for a Databricks InHouse Token. This class exchanges the access token if the issued token + Implementation of the Credential Provider that exchanges a third party access token + for a Databricks token. It exchanges the token only if the issued token is not from the same host as the Databricks host. """ @@ -88,8 +88,6 @@ def __init__( self.identity_federation_client_id = identity_federation_client_id self.external_provider_headers: Dict[str, str] = {} self.token_endpoint: Optional[str] = None - self.idp_endpoints = None - self.openid_config = None self.last_exchanged_token: Optional[Token] = None self.last_external_token: Optional[str] = None @@ -120,16 +118,15 @@ def __call__(self, *args, **kwargs) -> HeaderFactory: self._init_oidc_discovery() def get_headers() -> Dict[str, str]: - # Get headers from the underlying provider - self.external_provider_headers = header_factory() + try: + # Get headers from the underlying provider + self.external_provider_headers = header_factory() - # Extract the token from the headers - token_info = self._extract_token_info_from_header( - self.external_provider_headers - ) - token_type, access_token = token_info + # Extract the token from the headers + token_type, access_token = self._extract_token_info_from_header( + self.external_provider_headers + ) - try: # Check if we need to refresh the token if ( self.last_exchanged_token @@ -171,40 +168,35 @@ def _init_oidc_discovery(self): try: # Use the existing OIDC discovery mechanism use_azure_auth = infer_cloud_from_host(self.hostname) == "azure" - self.idp_endpoints = get_oauth_endpoints(self.hostname, use_azure_auth) + idp_endpoints = get_oauth_endpoints(self.hostname, use_azure_auth) - if self.idp_endpoints: + if idp_endpoints: # Get the OpenID configuration URL - openid_config_url = self.idp_endpoints.get_openid_config_url( + openid_config_url = idp_endpoints.get_openid_config_url( self.hostname ) # Fetch the OpenID configuration response = requests.get(openid_config_url) if response.status_code == 200: - self.openid_config = response.json() + openid_config = response.json() # Extract token endpoint from OpenID config - self.token_endpoint = self.openid_config.get("token_endpoint") + self.token_endpoint = openid_config.get("token_endpoint") logger.info(f"Discovered token endpoint: {self.token_endpoint}") else: logger.warning( f"Failed to fetch OpenID configuration from {openid_config_url}: {response.status_code}" ) - - # Fallback to default token endpoint if discovery fails - if not self.token_endpoint: - hostname = self._format_hostname(self.hostname) - self.token_endpoint = f"{hostname}oidc/v1/token" - logger.info(f"Using default token endpoint: {self.token_endpoint}") except Exception as e: logger.warning( f"OIDC discovery failed: {str(e)}. Using default token endpoint." ) + + # Fallback to default token endpoint if discovery fails + if not self.token_endpoint: hostname = self._format_hostname(self.hostname) self.token_endpoint = f"{hostname}oidc/v1/token" - logger.info( - f"Using default token endpoint after error: {self.token_endpoint}" - ) + logger.info(f"Using default token endpoint: {self.token_endpoint}") def _format_hostname(self, hostname: str) -> str: """Format hostname to ensure it has proper https:// prefix and trailing slash.""" @@ -248,111 +240,107 @@ def _parse_jwt_claims(self, token: str) -> Dict[str, Any]: return json.loads(decoded) except Exception as e: logger.error(f"Failed to parse JWT: {str(e)}") - raise + return {} def _is_same_host(self, url1: str, url2: str) -> bool: - """Check if two URLs have the same host.""" + """ + Check if two URLs have the same host. + + Args: + url1: First URL + url2: Second URL + + Returns: + bool: True if the hosts match, False otherwise + """ try: - host1 = urlparse(url1).netloc - host2 = urlparse(url2).netloc - # If host1 is empty, it's not a valid URL, so we return False - if not host1: - return False - return host1 == host2 + # Parse the URLs + parsed1 = urlparse(url1) + parsed2 = urlparse(url2) + + # Compare the hostnames + return parsed1.netloc.lower() == parsed2.netloc.lower() except Exception as e: - logger.error(f"Failed to parse URLs: {str(e)}") + logger.warning(f"Error comparing hosts: {str(e)}") return False def _refresh_token(self, access_token: str, token_type: str) -> Dict[str, str]: """ - Attempt to refresh an expired token by first getting a fresh external token - and then exchanging it for a new Databricks token. + Refresh the exchanged token by getting a fresh external token. Args: - access_token: The original external access token (will be replaced) - token_type: The token type (Bearer, etc.) + access_token: The external access token + token_type: The token type (usually "Bearer") Returns: - The headers with the fresh token + Dict[str, str]: Headers with the refreshed token """ try: - logger.info( - "Refreshing token using proactive approach (getting fresh external token first)" - ) - - # Get a fresh token from the underlying credentials provider - # instead of reusing the same access_token - fresh_headers = self.credentials_provider()() - - # Extract the fresh token from the headers - auth_header = fresh_headers.get("Authorization", "") - if not auth_header: - logger.error("No Authorization header in fresh headers") - return self.external_provider_headers - - parts = auth_header.split(" ", 1) - if len(parts) != 2: - logger.error(f"Invalid Authorization header format: {auth_header}") - return self.external_provider_headers - - fresh_token_type = parts[0] - fresh_access_token = parts[1] - - # Check if we got the same token back - if fresh_access_token == access_token: - logger.warning( - "Credentials provider returned the same token during refresh" - ) - - # Perform a new token exchange with the fresh token - refreshed_token = self._exchange_token(fresh_access_token) - - # Update the stored token - self.last_exchanged_token = refreshed_token - self.last_external_token = fresh_access_token - - # Create new headers with the refreshed token - headers = dict(fresh_headers) # Use the fresh headers as base - headers[ - "Authorization" - ] = f"{refreshed_token.token_type} {refreshed_token.access_token}" + # Exchange the token for a new one + exchanged_token = self._exchange_token(access_token) + self.last_exchanged_token = exchanged_token + self.last_external_token = access_token - logger.info( - f"Successfully refreshed token, new expiry: {refreshed_token.expiry}" - ) - return headers + # Update the headers with the new token + return {"Authorization": f"{exchanged_token.token_type} {exchanged_token.access_token}"} except Exception as e: - logger.error( - f"Token refresh failed, falling back to original token: {str(e)}" - ) - # If refresh fails, fall back to the original headers + logger.error(f"Token refresh failed: {str(e)}, falling back to original token") return self.external_provider_headers def _try_token_exchange_or_fallback( self, access_token: str, token_type: str ) -> Dict[str, str]: - """Try to exchange the token or fall back to the original token.""" + """ + Attempt to exchange the token or fall back to the original token if exchange fails. + + Args: + access_token: The external access token + token_type: The token type (usually "Bearer") + + Returns: + Dict[str, str]: Headers with either the exchanged token or the original token + """ try: - # Exchange the token exchanged_token = self._exchange_token(access_token) - - # Store the exchanged token for potential refresh later self.last_exchanged_token = exchanged_token self.last_external_token = access_token - # Create new headers with the exchanged token - headers = dict(self.external_provider_headers) - headers[ - "Authorization" - ] = f"{exchanged_token.token_type} {exchanged_token.access_token}" - return headers + return {"Authorization": f"{exchanged_token.token_type} {exchanged_token.access_token}"} except Exception as e: - logger.error( - f"Token exchange failed, falling back to using external token: {str(e)}" - ) - # Fall back to original headers + logger.warning(f"Token exchange failed: {str(e)}, falling back to original token") return self.external_provider_headers + def _send_token_exchange_request(self, token_exchange_data: Dict[str, str]) -> Dict[str, Any]: + """ + Send the token exchange request to the token endpoint. + + Args: + token_exchange_data: The data to send in the request + + Returns: + Dict[str, Any]: The parsed JSON response + + Raises: + Exception: If the request fails + """ + if not self.token_endpoint: + raise ValueError("Token endpoint not initialized") + + headers = {"Accept": "*/*", "Content-Type": "application/x-www-form-urlencoded"} + + response = requests.post( + self.token_endpoint, + data=token_exchange_data, + headers=headers + ) + + if response.status_code != 200: + raise ValueError( + f"Token exchange failed with status code {response.status_code}: {response.text}" + ) + + return response.json() + def _exchange_token(self, access_token: str) -> Token: """ Exchange an external token for a Databricks token. @@ -361,114 +349,74 @@ def _exchange_token(self, access_token: str) -> Token: access_token: The external token to exchange Returns: - A Token object containing the exchanged token - """ - if not self.token_endpoint: - self._init_oidc_discovery() - - # Ensure token_endpoint is set - if not self.token_endpoint: - raise ValueError("Token endpoint could not be determined") + Token: The exchanged token with expiry information - # Create request parameters - params = dict(TOKEN_EXCHANGE_PARAMS) - params["subject_token"] = access_token + Raises: + Exception: If token exchange fails + """ + # Prepare the request data + token_exchange_data = dict(TOKEN_EXCHANGE_PARAMS) + token_exchange_data["subject_token"] = access_token - # Add client ID if available + # Add client_id if provided if self.identity_federation_client_id: - params["client_id"] = self.identity_federation_client_id - - # Set up headers - headers = {"Accept": "*/*", "Content-Type": "application/x-www-form-urlencoded"} + token_exchange_data["client_id"] = self.identity_federation_client_id try: - # Make the token exchange request - response = requests.post(self.token_endpoint, data=params, headers=headers) - response.raise_for_status() - - # Parse the response - resp_data = response.json() - - # Create a token from the response - token = Token( - access_token=resp_data.get("access_token"), - token_type=resp_data.get("token_type", "Bearer"), - refresh_token=resp_data.get("refresh_token", ""), - ) - - # Set expiry time from the response's expires_in field if available - # This is the standard OAuth approach + # Send the token exchange request + resp_data = self._send_token_exchange_request(token_exchange_data) + + # Extract token information + new_access_token = resp_data.get("access_token") + if not new_access_token: + raise ValueError("No access token in exchange response") + + token_type = resp_data.get("token_type", "Bearer") + refresh_token = resp_data.get("refresh_token", "") + + # Parse expiry time from token claims if possible + expiry = datetime.now(tz=timezone.utc) + + # First try to get expiry from the response's expires_in field if "expires_in" in resp_data and resp_data["expires_in"]: try: - # Calculate expiry by adding expires_in seconds to current time - expires_in_seconds = int(resp_data["expires_in"]) - token.expiry = datetime.now(tz=timezone.utc) + timedelta( - seconds=expires_in_seconds - ) - logger.debug(f"Token expiry set from expires_in: {token.expiry}") + expires_in = int(resp_data["expires_in"]) + expiry = datetime.now(tz=timezone.utc) + timedelta(seconds=expires_in) except (ValueError, TypeError) as e: - logger.warning( - f"Could not parse expires_in from response: {str(e)}" - ) - - # If expires_in wasn't available, try to parse expiry from the token JWT - if token.expiry == datetime.now(tz=timezone.utc): - try: - token_claims = self._parse_jwt_claims(token.access_token) - exp_time = token_claims.get("exp") - if exp_time: - token.expiry = datetime.fromtimestamp(exp_time, tz=timezone.utc) - logger.debug( - f"Token expiry set from JWT exp claim: {token.expiry}" - ) - except Exception as e: - logger.warning(f"Could not parse expiry from token: {str(e)}") - - return token - except RequestException as e: - logger.error(f"Failed to perform token exchange: {str(e)}") - raise ValueError(f"Request error during token exchange: {str(e)}") + logger.warning(f"Invalid expires_in value: {str(e)}") + + # If that didn't work, try to parse JWT claims for expiry + if expiry == datetime.now(tz=timezone.utc): + token_claims = self._parse_jwt_claims(new_access_token) + if "exp" in token_claims: + try: + exp_timestamp = int(token_claims["exp"]) + expiry = datetime.fromtimestamp(exp_timestamp, tz=timezone.utc) + except (ValueError, TypeError) as e: + logger.warning(f"Invalid exp claim in token: {str(e)}") + + return Token(new_access_token, token_type, refresh_token, expiry) + + except Exception as e: + logger.error(f"Token exchange failed: {str(e)}") + raise class SimpleCredentialsProvider(CredentialsProvider): - """A simple credentials provider that returns fixed headers.""" + """A simple credentials provider that returns a fixed token.""" def __init__( self, token: str, token_type: str = "Bearer", auth_type_value: str = "token" ): self.token = token self.token_type = token_type - self._auth_type = auth_type_value + self.auth_type_value = auth_type_value def auth_type(self) -> str: - return self._auth_type + return self.auth_type_value def __call__(self, *args, **kwargs) -> HeaderFactory: def get_headers() -> Dict[str, str]: return {"Authorization": f"{self.token_type} {self.token}"} return get_headers - - -def create_token_federation_provider( - token: str, - hostname: str, - identity_federation_client_id: Optional[str] = None, - token_type: str = "Bearer", -) -> DatabricksTokenFederationProvider: - """ - Create a token federation provider using a simple token. - - Args: - token: The token to use - hostname: The Databricks hostname - identity_federation_client_id: Optional client ID for identity federation - token_type: The token type (default: "Bearer") - - Returns: - A DatabricksTokenFederationProvider - """ - provider = SimpleCredentialsProvider(token, token_type) - return DatabricksTokenFederationProvider( - provider, hostname, identity_federation_client_id - ) diff --git a/tests/token_federation/github_oidc_test.py b/tests/token_federation/github_oidc_test.py index 79fc40b34..e1c65d632 100755 --- a/tests/token_federation/github_oidc_test.py +++ b/tests/token_federation/github_oidc_test.py @@ -14,6 +14,8 @@ import base64 import logging from databricks import sql +import jwt + logging.basicConfig( @@ -34,20 +36,10 @@ def decode_jwt(token): dict: The decoded token claims or None if decoding fails """ try: - parts = token.split(".") - if len(parts) != 3: - raise ValueError("Invalid JWT format") - - payload = parts[1] - # Add padding if needed - padding = '=' * (4 - len(payload) % 4) - payload += padding - - decoded = base64.b64decode(payload) - return json.loads(decoded) + return jwt.decode(token, options={"verify_signature": False}) except Exception as e: - logger.error(f"Failed to decode token: {str(e)}") - return None + logger.error(f"Failed to decode token with PyJWT: {str(e)}") + return {} def get_environment_variables(): @@ -56,23 +48,12 @@ def get_environment_variables(): Returns: tuple: (github_token, host, http_path, identity_federation_client_id) - - Raises: - SystemExit: If any required environment variable is missing """ github_token = os.environ.get("OIDC_TOKEN") - if not github_token: - logger.error("GitHub OIDC token not available") - sys.exit(1) - host = os.environ.get("DATABRICKS_HOST_FOR_TF") http_path = os.environ.get("DATABRICKS_HTTP_PATH_FOR_TF") identity_federation_client_id = os.environ.get("IDENTITY_FEDERATION_CLIENT_ID") - if not host or not http_path: - logger.error("Missing Databricks connection parameters") - sys.exit(1) - return github_token, host, http_path, identity_federation_client_id @@ -146,6 +127,14 @@ def main(): # Get environment variables github_token, host, http_path, identity_federation_client_id = get_environment_variables() + if not github_token: + logger.error("Missing GitHub OIDC token (OIDC_TOKEN)") + sys.exit(1) + + if not host or not http_path: + logger.error("Missing Databricks connection parameters (DATABRICKS_HOST_FOR_TF, DATABRICKS_HTTP_PATH_FOR_TF)") + sys.exit(1) + # Display token claims claims = decode_jwt(github_token) display_token_info(claims) diff --git a/tests/unit/test_token_federation.py b/tests/unit/test_token_federation.py index 78ffc9e2a..1ba550a6b 100644 --- a/tests/unit/test_token_federation.py +++ b/tests/unit/test_token_federation.py @@ -4,8 +4,8 @@ Unit tests for token federation functionality in the Databricks SQL connector. """ -import unittest -from unittest.mock import patch, MagicMock +import pytest +from unittest.mock import MagicMock, patch import json from datetime import datetime, timezone, timedelta @@ -18,218 +18,226 @@ ) -class TestToken(unittest.TestCase): - """Tests for the Token class.""" - - def test_token_initialization(self): - """Test Token initialization.""" - token = Token("access_token_value", "Bearer", "refresh_token_value") - self.assertEqual(token.access_token, "access_token_value") - self.assertEqual(token.token_type, "Bearer") - self.assertEqual(token.refresh_token, "refresh_token_value") - - def test_token_is_expired(self): - """Test Token is_expired method.""" - # Token with expiry in the past - past = datetime.now(tz=timezone.utc) - timedelta(hours=1) - token = Token("access_token", "Bearer", expiry=past) - self.assertTrue(token.is_expired()) - - # Token with expiry in the future - future = datetime.now(tz=timezone.utc) + timedelta(hours=1) - token = Token("access_token", "Bearer", expiry=future) - self.assertFalse(token.is_expired()) - - def test_token_needs_refresh(self): - """Test Token needs_refresh method using actual TOKEN_REFRESH_BUFFER_SECONDS.""" - # Token with expiry in the past - past = datetime.now(tz=timezone.utc) - timedelta(hours=1) - token = Token("access_token", "Bearer", expiry=past) - self.assertTrue(token.needs_refresh()) - - # Token with expiry in the near future (within refresh buffer) - near_future = datetime.now(tz=timezone.utc) + timedelta( - seconds=TOKEN_REFRESH_BUFFER_SECONDS - 1 - ) - token = Token("access_token", "Bearer", expiry=near_future) - self.assertTrue(token.needs_refresh()) - - # Token with expiry far in the future - far_future = datetime.now(tz=timezone.utc) + timedelta( - seconds=TOKEN_REFRESH_BUFFER_SECONDS + 10 - ) - token = Token("access_token", "Bearer", expiry=far_future) - self.assertFalse(token.needs_refresh()) - - -class TestSimpleCredentialsProvider(unittest.TestCase): - """Tests for the SimpleCredentialsProvider class.""" - - def test_simple_credentials_provider(self): - """Test SimpleCredentialsProvider.""" - provider = SimpleCredentialsProvider( - "token_value", "Bearer", "custom_auth_type" - ) - self.assertEqual(provider.auth_type(), "custom_auth_type") - - header_factory = provider() - headers = header_factory() - self.assertEqual(headers, {"Authorization": "Bearer token_value"}) - - -class TestTokenFederationProvider(unittest.TestCase): - """Tests for the DatabricksTokenFederationProvider class.""" - - def test_host_property(self): - """Test the host property of DatabricksTokenFederationProvider.""" - creds_provider = SimpleCredentialsProvider("token") - federation_provider = DatabricksTokenFederationProvider( - creds_provider, "example.com", "client_id" - ) - self.assertEqual(federation_provider.host, "example.com") - self.assertEqual(federation_provider.hostname, "example.com") - - @patch("databricks.sql.auth.token_federation.requests.get") - @patch("databricks.sql.auth.token_federation.get_oauth_endpoints") - def test_init_oidc_discovery(self, mock_get_endpoints, mock_requests_get): - """Test _init_oidc_discovery method.""" - # Mock the get_oauth_endpoints function - mock_endpoints = MagicMock() - mock_endpoints.get_openid_config_url.return_value = ( - "https://example.com/openid-config" - ) - mock_get_endpoints.return_value = mock_endpoints - - # Mock the requests.get response - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "token_endpoint": "https://example.com/token" - } - mock_requests_get.return_value = mock_response - - # Create the provider - creds_provider = SimpleCredentialsProvider("token") - federation_provider = DatabricksTokenFederationProvider( - creds_provider, "example.com", "client_id" - ) - - # Call the method - federation_provider._init_oidc_discovery() - - # Check if the token endpoint was set correctly - self.assertEqual( - federation_provider.token_endpoint, "https://example.com/token" - ) - - # Test fallback when discovery fails - mock_requests_get.side_effect = Exception("Connection error") - federation_provider.token_endpoint = None - federation_provider._init_oidc_discovery() - self.assertEqual( - federation_provider.token_endpoint, "https://example.com/oidc/v1/token" - ) - - @patch( - "databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._parse_jwt_claims" +# Tests for Token class +def test_token_initialization(): + """Test Token initialization.""" + token = Token("access_token_value", "Bearer", "refresh_token_value") + assert token.access_token == "access_token_value" + assert token.token_type == "Bearer" + assert token.refresh_token == "refresh_token_value" + + +def test_token_is_expired(): + """Test Token is_expired method.""" + # Token with expiry in the past + past = datetime.now(tz=timezone.utc) - timedelta(hours=1) + token = Token("access_token", "Bearer", expiry=past) + assert token.is_expired() + + # Token with expiry in the future + future = datetime.now(tz=timezone.utc) + timedelta(hours=1) + token = Token("access_token", "Bearer", expiry=future) + assert not token.is_expired() + + +def test_token_needs_refresh(): + """Test Token needs_refresh method using actual TOKEN_REFRESH_BUFFER_SECONDS.""" + # Token with expiry in the past + past = datetime.now(tz=timezone.utc) - timedelta(hours=1) + token = Token("access_token", "Bearer", expiry=past) + assert token.needs_refresh() + + # Token with expiry in the near future (within refresh buffer) + near_future = datetime.now(tz=timezone.utc) + timedelta( + seconds=TOKEN_REFRESH_BUFFER_SECONDS - 1 + ) + token = Token("access_token", "Bearer", expiry=near_future) + assert token.needs_refresh() + + # Token with expiry far in the future + far_future = datetime.now(tz=timezone.utc) + timedelta( + seconds=TOKEN_REFRESH_BUFFER_SECONDS + 10 + ) + token = Token("access_token", "Bearer", expiry=far_future) + assert not token.needs_refresh() + + +# Tests for SimpleCredentialsProvider +def test_simple_credentials_provider(): + """Test SimpleCredentialsProvider.""" + provider = SimpleCredentialsProvider( + "token_value", "Bearer", "custom_auth_type" + ) + assert provider.auth_type() == "custom_auth_type" + + header_factory = provider() + headers = header_factory() + assert headers == {"Authorization": "Bearer token_value"} + + +# Tests for DatabricksTokenFederationProvider +def test_host_property(): + """Test the host property of DatabricksTokenFederationProvider.""" + creds_provider = SimpleCredentialsProvider("token") + federation_provider = DatabricksTokenFederationProvider( + creds_provider, "example.com", "client_id" + ) + assert federation_provider.host == "example.com" + assert federation_provider.hostname == "example.com" + + +@pytest.fixture +def mock_request_get(): + with patch("databricks.sql.auth.token_federation.requests.get") as mock: + yield mock + + +@pytest.fixture +def mock_get_oauth_endpoints(): + with patch("databricks.sql.auth.token_federation.get_oauth_endpoints") as mock: + yield mock + + +def test_init_oidc_discovery(mock_request_get, mock_get_oauth_endpoints): + """Test _init_oidc_discovery method.""" + # Mock the get_oauth_endpoints function + mock_endpoints = MagicMock() + mock_endpoints.get_openid_config_url.return_value = ( + "https://example.com/openid-config" + ) + mock_get_oauth_endpoints.return_value = mock_endpoints + + # Mock the requests.get response + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "token_endpoint": "https://example.com/token" + } + mock_request_get.return_value = mock_response + + # Create the provider + creds_provider = SimpleCredentialsProvider("token") + federation_provider = DatabricksTokenFederationProvider( + creds_provider, "example.com", "client_id" + ) + + # Call the method + federation_provider._init_oidc_discovery() + + # Check if the token endpoint was set correctly + assert federation_provider.token_endpoint == "https://example.com/token" + + # Test fallback when discovery fails + mock_request_get.side_effect = Exception("Connection error") + federation_provider.token_endpoint = None + federation_provider._init_oidc_discovery() + assert federation_provider.token_endpoint == "https://example.com/oidc/v1/token" + + +@pytest.fixture +def mock_parse_jwt_claims(): + with patch("databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._parse_jwt_claims") as mock: + yield mock + + +@pytest.fixture +def mock_exchange_token(): + with patch("databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._exchange_token") as mock: + yield mock + + +@pytest.fixture +def mock_is_same_host(): + with patch("databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._is_same_host") as mock: + yield mock + + +def test_token_refresh(mock_parse_jwt_claims, mock_exchange_token, mock_is_same_host): + """Test token refresh functionality for approaching expiry.""" + # Set up mocks + mock_parse_jwt_claims.return_value = { + "iss": "https://login.microsoftonline.com/tenant" + } + mock_is_same_host.return_value = False + + # Create the initial header factory + initial_headers = {"Authorization": "Bearer initial_token"} + initial_header_factory = MagicMock() + initial_header_factory.return_value = initial_headers + + # Create the fresh header factory for later use + fresh_headers = {"Authorization": "Bearer fresh_token"} + fresh_header_factory = MagicMock() + fresh_header_factory.return_value = fresh_headers + + # Create the credentials provider that will return the header factory + mock_creds_provider = MagicMock() + mock_creds_provider.return_value = initial_header_factory + + # Set up the token federation provider + federation_provider = DatabricksTokenFederationProvider( + mock_creds_provider, "example.com", "client_id" + ) + + # Mock the token exchange to return a known token + future_time = datetime.now(tz=timezone.utc) + timedelta(hours=1) + mock_exchange_token.return_value = Token( + "exchanged_token_1", "Bearer", expiry=future_time + ) + + # First call to get initial headers and token - this should trigger an exchange + headers_factory = federation_provider() + headers = headers_factory() + + # Verify the exchange happened with the initial token + mock_exchange_token.assert_called_with("initial_token") + assert headers["Authorization"] == "Bearer exchanged_token_1" + + # Reset the mocks to track the next call + mock_exchange_token.reset_mock() + + # Now simulate an approaching expiry + near_expiry = datetime.now(tz=timezone.utc) + timedelta( + seconds=TOKEN_REFRESH_BUFFER_SECONDS - 1 ) - @patch( - "databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._exchange_token" + federation_provider.last_exchanged_token = Token( + "exchanged_token_1", "Bearer", expiry=near_expiry + ) + federation_provider.last_external_token = "initial_token" + + # For the refresh call, we need the credentials provider to return a fresh token + # Update the mock to return fresh_header_factory for the second call + mock_creds_provider.return_value = fresh_header_factory + + # Set up the mock to return a different token for the refresh + mock_exchange_token.return_value = Token( + "exchanged_token_2", "Bearer", expiry=future_time ) - @patch( - "databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._is_same_host" + + # Make a second call which should trigger refresh + headers = headers_factory() + + # Verify the exchange was performed with the fresh token + mock_exchange_token.assert_called_once_with("fresh_token") + + # Verify the headers contain the new token + assert headers["Authorization"] == "Bearer exchanged_token_2" + + +def test_create_token_federation_provider(): + """Test creation of a federation provider with a simple token provider.""" + # Create a simple provider + simple_provider = SimpleCredentialsProvider("token_value", "Bearer") + + # Create a federation provider with the simple provider + federation_provider = DatabricksTokenFederationProvider( + simple_provider, "example.com", "client_id" ) - def test_token_refresh( - self, mock_is_same_host, mock_exchange_token, mock_parse_jwt - ): - """Test token refresh functionality for approaching expiry.""" - # Set up mocks - mock_parse_jwt.return_value = { - "iss": "https://login.microsoftonline.com/tenant" - } - mock_is_same_host.return_value = False - - # Create the initial header factory - initial_headers = {"Authorization": "Bearer initial_token"} - initial_header_factory = MagicMock() - initial_header_factory.return_value = initial_headers - - # Create the fresh header factory for later use - fresh_headers = {"Authorization": "Bearer fresh_token"} - fresh_header_factory = MagicMock() - fresh_header_factory.return_value = fresh_headers - - # Create the credentials provider that will return the header factory - mock_creds_provider = MagicMock() - mock_creds_provider.return_value = initial_header_factory - - # Set up the token federation provider - federation_provider = DatabricksTokenFederationProvider( - mock_creds_provider, "example.com", "client_id" - ) - - # Mock the token exchange to return a known token - future_time = datetime.now(tz=timezone.utc) + timedelta(hours=1) - mock_exchange_token.return_value = Token( - "exchanged_token_1", "Bearer", expiry=future_time - ) - - # First call to get initial headers and token - this should trigger an exchange - headers_factory = federation_provider() - headers = headers_factory() - - # Verify the exchange happened with the initial token - mock_exchange_token.assert_called_with("initial_token") - self.assertEqual(headers["Authorization"], "Bearer exchanged_token_1") - - # Reset the mocks to track the next call - mock_exchange_token.reset_mock() - - # Now simulate an approaching expiry - near_expiry = datetime.now(tz=timezone.utc) + timedelta( - seconds=TOKEN_REFRESH_BUFFER_SECONDS - 1 - ) - federation_provider.last_exchanged_token = Token( - "exchanged_token_1", "Bearer", expiry=near_expiry - ) - federation_provider.last_external_token = "initial_token" - - # For the refresh call, we need the credentials provider to return a fresh token - # Update the mock to return fresh_header_factory for the second call - mock_creds_provider.return_value = fresh_header_factory - - # Set up the mock to return a different token for the refresh - mock_exchange_token.return_value = Token( - "exchanged_token_2", "Bearer", expiry=future_time - ) - - # Make a second call which should trigger refresh - headers = headers_factory() - - # Verify the exchange was performed with the fresh token - mock_exchange_token.assert_called_once_with("fresh_token") - - # Verify the headers contain the new token - self.assertEqual(headers["Authorization"], "Bearer exchanged_token_2") - - -class TestTokenFederationFactory(unittest.TestCase): - """Tests for the token federation factory function.""" - - def test_create_token_federation_provider(self): - """Test create_token_federation_provider function.""" - provider = create_token_federation_provider( - "token_value", "example.com", "client_id", "Bearer" - ) - - self.assertIsInstance(provider, DatabricksTokenFederationProvider) - self.assertEqual(provider.hostname, "example.com") - self.assertEqual(provider.identity_federation_client_id, "client_id") - - # Test that the underlying credentials provider was set up correctly - self.assertEqual(provider.credentials_provider.token, "token_value") - self.assertEqual(provider.credentials_provider.token_type, "Bearer") - - -if __name__ == "__main__": - unittest.main() + + assert isinstance(federation_provider, DatabricksTokenFederationProvider) + assert federation_provider.hostname == "example.com" + assert federation_provider.identity_federation_client_id == "client_id" + + # Test that the underlying credentials provider was set up correctly + assert federation_provider.credentials_provider.token == "token_value" + assert federation_provider.credentials_provider.token_type == "Bearer" From 49eab2ad7c2d3e08f41eaa73e77af35d94c8e4ea Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Sun, 11 May 2025 08:17:55 +0000 Subject: [PATCH 32/46] fmt --- src/databricks/sql/auth/token_federation.py | 68 +++++++++++-------- tests/token_federation/github_oidc_test.py | 72 +++++++++++++++------ tests/unit/test_token_federation.py | 13 ++-- 3 files changed, 100 insertions(+), 53 deletions(-) diff --git a/src/databricks/sql/auth/token_federation.py b/src/databricks/sql/auth/token_federation.py index e92f9ccb5..61d5033d7 100644 --- a/src/databricks/sql/auth/token_federation.py +++ b/src/databricks/sql/auth/token_federation.py @@ -150,9 +150,7 @@ def get_headers() -> Dict[str, str]: else: # Token is from a different host, need to exchange logger.debug("Token from different host, attempting exchange") - return self._try_token_exchange_or_fallback( - access_token, token_type - ) + return self._try_token_exchange_or_fallback(access_token, token_type) except Exception as e: logger.error(f"Error processing token: {str(e)}") # Fall back to original headers in case of error @@ -172,9 +170,7 @@ def _init_oidc_discovery(self): if idp_endpoints: # Get the OpenID configuration URL - openid_config_url = idp_endpoints.get_openid_config_url( - self.hostname - ) + openid_config_url = idp_endpoints.get_openid_config_url(self.hostname) # Fetch the OpenID configuration response = requests.get(openid_config_url) @@ -185,7 +181,8 @@ def _init_oidc_discovery(self): logger.info(f"Discovered token endpoint: {self.token_endpoint}") else: logger.warning( - f"Failed to fetch OpenID configuration from {openid_config_url}: {response.status_code}" + f"Failed to fetch OpenID configuration from {openid_config_url}: " + f"{response.status_code}" ) except Exception as e: logger.warning( @@ -282,9 +279,15 @@ def _refresh_token(self, access_token: str, token_type: str) -> Dict[str, str]: self.last_external_token = access_token # Update the headers with the new token - return {"Authorization": f"{exchanged_token.token_type} {exchanged_token.access_token}"} + return { + "Authorization": ( + f"{exchanged_token.token_type} {exchanged_token.access_token}" + ) + } except Exception as e: - logger.error(f"Token refresh failed: {str(e)}, falling back to original token") + logger.error( + f"Token refresh failed: {str(e)}, falling back to original token" + ) return self.external_provider_headers def _try_token_exchange_or_fallback( @@ -305,12 +308,20 @@ def _try_token_exchange_or_fallback( self.last_exchanged_token = exchanged_token self.last_external_token = access_token - return {"Authorization": f"{exchanged_token.token_type} {exchanged_token.access_token}"} + return { + "Authorization": ( + f"{exchanged_token.token_type} {exchanged_token.access_token}" + ) + } except Exception as e: - logger.warning(f"Token exchange failed: {str(e)}, falling back to original token") + logger.warning( + f"Token exchange failed: {str(e)}, falling back to original token" + ) return self.external_provider_headers - def _send_token_exchange_request(self, token_exchange_data: Dict[str, str]) -> Dict[str, Any]: + def _send_token_exchange_request( + self, token_exchange_data: Dict[str, str] + ) -> Dict[str, Any]: """ Send the token exchange request to the token endpoint. @@ -325,20 +336,19 @@ def _send_token_exchange_request(self, token_exchange_data: Dict[str, str]) -> D """ if not self.token_endpoint: raise ValueError("Token endpoint not initialized") - + headers = {"Accept": "*/*", "Content-Type": "application/x-www-form-urlencoded"} - + response = requests.post( - self.token_endpoint, - data=token_exchange_data, - headers=headers + self.token_endpoint, data=token_exchange_data, headers=headers ) - + if response.status_code != 200: raise ValueError( - f"Token exchange failed with status code {response.status_code}: {response.text}" + f"Token exchange failed with status code {response.status_code}: " + f"{response.text}" ) - + return response.json() def _exchange_token(self, access_token: str) -> Token: @@ -365,26 +375,28 @@ def _exchange_token(self, access_token: str) -> Token: try: # Send the token exchange request resp_data = self._send_token_exchange_request(token_exchange_data) - + # Extract token information new_access_token = resp_data.get("access_token") if not new_access_token: raise ValueError("No access token in exchange response") - + token_type = resp_data.get("token_type", "Bearer") refresh_token = resp_data.get("refresh_token", "") - + # Parse expiry time from token claims if possible expiry = datetime.now(tz=timezone.utc) - + # First try to get expiry from the response's expires_in field if "expires_in" in resp_data and resp_data["expires_in"]: try: expires_in = int(resp_data["expires_in"]) - expiry = datetime.now(tz=timezone.utc) + timedelta(seconds=expires_in) + expiry = datetime.now(tz=timezone.utc) + timedelta( + seconds=expires_in + ) except (ValueError, TypeError) as e: logger.warning(f"Invalid expires_in value: {str(e)}") - + # If that didn't work, try to parse JWT claims for expiry if expiry == datetime.now(tz=timezone.utc): token_claims = self._parse_jwt_claims(new_access_token) @@ -394,9 +406,9 @@ def _exchange_token(self, access_token: str) -> Token: expiry = datetime.fromtimestamp(exp_timestamp, tz=timezone.utc) except (ValueError, TypeError) as e: logger.warning(f"Invalid exp claim in token: {str(e)}") - + return Token(new_access_token, token_type, refresh_token, expiry) - + except Exception as e: logger.error(f"Token exchange failed: {str(e)}") raise diff --git a/tests/token_federation/github_oidc_test.py b/tests/token_federation/github_oidc_test.py index e1c65d632..71c510c34 100755 --- a/tests/token_federation/github_oidc_test.py +++ b/tests/token_federation/github_oidc_test.py @@ -14,13 +14,17 @@ import base64 import logging from databricks import sql -import jwt +try: + import jwt + + HAS_JWT_LIBRARY = True +except ImportError: + HAS_JWT_LIBRARY = False logging.basicConfig( - level=logging.INFO, - format="%(asctime)s - %(levelname)s - %(message)s" + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) @@ -35,10 +39,29 @@ def decode_jwt(token): Returns: dict: The decoded token claims or None if decoding fails """ + if HAS_JWT_LIBRARY: + try: + # Using PyJWT library (preferred method) + # Note: we're not verifying the signature as this is just for debugging + return jwt.decode(token, options={"verify_signature": False}) + except Exception as e: + logger.error(f"Failed to decode token with PyJWT: {str(e)}") + + # Fallback to manual decoding try: - return jwt.decode(token, options={"verify_signature": False}) + parts = token.split(".") + if len(parts) != 3: + raise ValueError("Invalid JWT format") + + payload = parts[1] + # Add padding if needed + padding = "=" * (4 - len(payload) % 4) + payload += padding + + decoded = base64.b64decode(payload) + return json.loads(decoded) except Exception as e: - logger.error(f"Failed to decode token with PyJWT: {str(e)}") + logger.error(f"Failed to decode token: {str(e)}") return {} @@ -53,7 +76,7 @@ def get_environment_variables(): host = os.environ.get("DATABRICKS_HOST_FOR_TF") http_path = os.environ.get("DATABRICKS_HTTP_PATH_FOR_TF") identity_federation_client_id = os.environ.get("IDENTITY_FEDERATION_CLIENT_ID") - + return github_token, host, http_path, identity_federation_client_id @@ -62,7 +85,7 @@ def display_token_info(claims): if not claims: logger.warning("No token claims available to display") return - + logger.info("=== GitHub OIDC Token Claims ===") logger.info(f"Token issuer: {claims.get('iss')}") logger.info(f"Token subject: {claims.get('sub')}") @@ -74,7 +97,9 @@ def display_token_info(claims): logger.info("===============================") -def test_databricks_connection(host, http_path, github_token, identity_federation_client_id): +def test_databricks_connection( + host, http_path, github_token, identity_federation_client_id +): """ Test connection to Databricks using token federation. @@ -90,7 +115,7 @@ def test_databricks_connection(host, http_path, github_token, identity_federatio logger.info("=== Testing Connection via Connector ===") logger.info(f"Connecting to Databricks at {host}{http_path}") logger.info(f"Using client ID: {identity_federation_client_id}") - + connection_params = { "server_hostname": host, "http_path": http_path, @@ -98,22 +123,22 @@ def test_databricks_connection(host, http_path, github_token, identity_federatio "auth_type": "token-federation", "identity_federation_client_id": identity_federation_client_id, } - + try: with sql.connect(**connection_params) as connection: logger.info("Connection established successfully") - + # Execute a simple query cursor = connection.cursor() cursor.execute("SELECT 1 + 1 as result") result = cursor.fetchall() logger.info(f"Query result: {result[0][0]}") - + # Show current user cursor.execute("SELECT current_user() as user") result = cursor.fetchall() logger.info(f"Connected as user: {result[0][0]}") - + logger.info("Token federation test successful!") return True except Exception as e: @@ -125,29 +150,34 @@ def main(): """Main entry point for the test script.""" try: # Get environment variables - github_token, host, http_path, identity_federation_client_id = get_environment_variables() - + github_token, host, http_path, identity_federation_client_id = ( + get_environment_variables() + ) + if not github_token: logger.error("Missing GitHub OIDC token (OIDC_TOKEN)") sys.exit(1) - + if not host or not http_path: - logger.error("Missing Databricks connection parameters (DATABRICKS_HOST_FOR_TF, DATABRICKS_HTTP_PATH_FOR_TF)") + logger.error( + "Missing Databricks connection parameters " + "(DATABRICKS_HOST_FOR_TF, DATABRICKS_HTTP_PATH_FOR_TF)" + ) sys.exit(1) - + # Display token claims claims = decode_jwt(github_token) display_token_info(claims) - + # Test Databricks connection success = test_databricks_connection( host, http_path, github_token, identity_federation_client_id ) - + if not success: logger.error("Token federation test failed") sys.exit(1) - + except Exception as e: logger.error(f"Unexpected error: {str(e)}") sys.exit(1) diff --git a/tests/unit/test_token_federation.py b/tests/unit/test_token_federation.py index 1ba550a6b..d1664c55d 100644 --- a/tests/unit/test_token_federation.py +++ b/tests/unit/test_token_federation.py @@ -13,7 +13,6 @@ Token, DatabricksTokenFederationProvider, SimpleCredentialsProvider, - create_token_federation_provider, TOKEN_REFRESH_BUFFER_SECONDS, ) @@ -136,19 +135,25 @@ def test_init_oidc_discovery(mock_request_get, mock_get_oauth_endpoints): @pytest.fixture def mock_parse_jwt_claims(): - with patch("databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._parse_jwt_claims") as mock: + with patch( + "databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._parse_jwt_claims" + ) as mock: yield mock @pytest.fixture def mock_exchange_token(): - with patch("databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._exchange_token") as mock: + with patch( + "databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._exchange_token" + ) as mock: yield mock @pytest.fixture def mock_is_same_host(): - with patch("databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._is_same_host") as mock: + with patch( + "databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._is_same_host" + ) as mock: yield mock From e6733cbef26727d6c7b8adfd36a5757d9d5b30b7 Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Sun, 11 May 2025 08:21:40 +0000 Subject: [PATCH 33/46] Apply black formatting to auth files --- src/databricks/sql/auth/auth.py | 20 ++++++++++---------- src/databricks/sql/auth/token_federation.py | 10 ++++++---- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/src/databricks/sql/auth/auth.py b/src/databricks/sql/auth/auth.py index f1f5543f7..3931356d0 100755 --- a/src/databricks/sql/auth/auth.py +++ b/src/databricks/sql/auth/auth.py @@ -52,22 +52,22 @@ def __init__( def get_auth_provider(cfg: ClientContext): """ Get an appropriate auth provider based on the provided configuration. - + Token Federation Support: ----------------------- Currently, token federation is implemented as a separate auth type, but the goal is to refactor it as a feature that can work with any auth type. The current implementation is maintained for backward compatibility while the refactoring is planned. - + Future refactoring will introduce a `use_token_federation` flag that can be combined with any auth type to enable token federation. - + Args: cfg: The client context containing configuration parameters - + Returns: An appropriate AuthProvider instance - + Raises: RuntimeError: If no valid authentication settings are provided """ @@ -163,22 +163,22 @@ def get_client_id_and_redirect_port(use_azure_auth: bool): def get_python_sql_connector_auth_provider(hostname: str, **kwargs): """ Get an auth provider for the Python SQL connector. - + This function is the main entry point for authentication in the SQL connector. It processes the parameters and creates an appropriate auth provider. - + TODO: Future refactoring needed: 1. Add a use_token_federation flag that can be combined with any auth type 2. Remove TOKEN_FEDERATION as an auth_type while maintaining backward compatibility 3. Create a token federation wrapper that can wrap any existing auth provider - + Args: hostname: The Databricks server hostname **kwargs: Additional configuration parameters - + Returns: An appropriate AuthProvider instance - + Raises: ValueError: If username/password authentication is attempted (no longer supported) """ diff --git a/src/databricks/sql/auth/token_federation.py b/src/databricks/sql/auth/token_federation.py index 61d5033d7..2b21c1836 100644 --- a/src/databricks/sql/auth/token_federation.py +++ b/src/databricks/sql/auth/token_federation.py @@ -150,7 +150,9 @@ def get_headers() -> Dict[str, str]: else: # Token is from a different host, need to exchange logger.debug("Token from different host, attempting exchange") - return self._try_token_exchange_or_fallback(access_token, token_type) + return self._try_token_exchange_or_fallback( + access_token, token_type + ) except Exception as e: logger.error(f"Error processing token: {str(e)}") # Fall back to original headers in case of error @@ -324,13 +326,13 @@ def _send_token_exchange_request( ) -> Dict[str, Any]: """ Send the token exchange request to the token endpoint. - + Args: token_exchange_data: The data to send in the request - + Returns: Dict[str, Any]: The parsed JSON response - + Raises: Exception: If the request fails """ From 29f95f2a69adddc9603b5311134321bf88d5724b Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Sun, 11 May 2025 12:01:26 +0000 Subject: [PATCH 34/46] Fix token refresh to use fresh token from provider --- src/databricks/sql/auth/token_federation.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/databricks/sql/auth/token_federation.py b/src/databricks/sql/auth/token_federation.py index 2b21c1836..b2c878ce0 100644 --- a/src/databricks/sql/auth/token_federation.py +++ b/src/databricks/sql/auth/token_federation.py @@ -275,10 +275,18 @@ def _refresh_token(self, access_token: str, token_type: str) -> Dict[str, str]: Dict[str, str]: Headers with the refreshed token """ try: - # Exchange the token for a new one - exchanged_token = self._exchange_token(access_token) + # Get a fresh token from the underlying provider + fresh_headers = self.credentials_provider()() + + # Extract the fresh token from the headers + fresh_token_type, fresh_access_token = self._extract_token_info_from_header( + fresh_headers + ) + + # Exchange the fresh token for a new Databricks token + exchanged_token = self._exchange_token(fresh_access_token) self.last_exchanged_token = exchanged_token - self.last_external_token = access_token + self.last_external_token = fresh_access_token # Update the headers with the new token return { From 2e12935be7c9f1020c4292993d4b5cbe57f569d5 Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Mon, 12 May 2025 05:27:22 +0000 Subject: [PATCH 35/46] general improvements --- src/databricks/sql/auth/oidc_utils.py | 58 ++ src/databricks/sql/auth/token.py | 65 +++ src/databricks/sql/auth/token_federation.py | 467 +++++++-------- tests/token_federation/github_oidc_test.py | 88 ++- tests/unit/test_token_federation.py | 604 ++++++++++++-------- 5 files changed, 751 insertions(+), 531 deletions(-) create mode 100644 src/databricks/sql/auth/oidc_utils.py create mode 100644 src/databricks/sql/auth/token.py diff --git a/src/databricks/sql/auth/oidc_utils.py b/src/databricks/sql/auth/oidc_utils.py new file mode 100644 index 000000000..b0421cf7f --- /dev/null +++ b/src/databricks/sql/auth/oidc_utils.py @@ -0,0 +1,58 @@ +import logging +import requests +from typing import Optional + +from databricks.sql.auth.endpoint import ( + get_oauth_endpoints, + infer_cloud_from_host, +) + +logger = logging.getLogger(__name__) + + +class OIDCDiscoveryUtil: + """ + Utility class for OIDC discovery operations. + + This class handles discovery of OIDC endpoints through standard + discovery mechanisms, with fallback to default endpoints if needed. + """ + + # Standard token endpoint path for Databricks workspaces + DEFAULT_TOKEN_PATH = "oidc/v1/token" + + @staticmethod + def discover_token_endpoint(hostname: str) -> str: + """ + Get the token endpoint for the given Databricks hostname. + + For Databricks workspaces, the token endpoint is always at host/oidc/v1/token. + + Args: + hostname: The hostname to get token endpoint for + + Returns: + str: The token endpoint URL + """ + # Format the hostname and return the standard endpoint + hostname = OIDCDiscoveryUtil.format_hostname(hostname) + token_endpoint = f"{hostname}{OIDCDiscoveryUtil.DEFAULT_TOKEN_PATH}" + logger.info(f"Using token endpoint: {token_endpoint}") + return token_endpoint + + @staticmethod + def format_hostname(hostname: str) -> str: + """ + Format hostname to ensure it has proper https:// prefix and trailing slash. + + Args: + hostname: The hostname to format + + Returns: + str: The formatted hostname + """ + if not hostname.startswith("https://"): + hostname = f"https://{hostname}" + if not hostname.endswith("/"): + hostname = f"{hostname}/" + return hostname diff --git a/src/databricks/sql/auth/token.py b/src/databricks/sql/auth/token.py new file mode 100644 index 000000000..5abd1e028 --- /dev/null +++ b/src/databricks/sql/auth/token.py @@ -0,0 +1,65 @@ +""" +Token class for authentication tokens with expiry handling. +""" + +from datetime import datetime, timezone, timedelta +from typing import Optional + + +class Token: + """ + Represents an OAuth token with expiry information. + + This class handles token state including expiry calculation. + """ + + # Minimum time buffer before expiry to consider a token still valid (in seconds) + MIN_VALIDITY_BUFFER = 10 + + def __init__( + self, + access_token: str, + token_type: str, + refresh_token: str = "", + expiry: Optional[datetime] = None, + ): + """ + Initialize a Token object. + + Args: + access_token: The access token string + token_type: The token type (usually "Bearer") + refresh_token: Optional refresh token + expiry: Token expiry datetime, must be provided + + Raises: + ValueError: If no expiry is provided + """ + self.access_token = access_token + self.token_type = token_type + self.refresh_token = refresh_token + + # Ensure we have an expiry time + if expiry is None: + raise ValueError("Token expiry must be provided") + + # Ensure expiry is timezone-aware + if expiry.tzinfo is None: + # Convert naive datetime to aware datetime + self.expiry = expiry.replace(tzinfo=timezone.utc) + else: + self.expiry = expiry + + def is_valid(self) -> bool: + """ + Check if the token is valid (has at least MIN_VALIDITY_BUFFER seconds before expiry). + + Returns: + bool: True if the token is valid, False otherwise + """ + buffer = timedelta(seconds=self.MIN_VALIDITY_BUFFER) + return datetime.now(tz=timezone.utc) + buffer < self.expiry + + def __str__(self) -> str: + """Return the token as a string in the format used for Authorization headers.""" + return f"{self.token_type} {self.access_token}" diff --git a/src/databricks/sql/auth/token_federation.py b/src/databricks/sql/auth/token_federation.py index b2c878ce0..ebce7d546 100644 --- a/src/databricks/sql/auth/token_federation.py +++ b/src/databricks/sql/auth/token_federation.py @@ -6,16 +6,16 @@ from urllib.parse import urlparse import requests +import jwt from requests.exceptions import RequestException from databricks.sql.auth.authenticators import CredentialsProvider, HeaderFactory -from databricks.sql.auth.endpoint import ( - get_oauth_endpoints, - infer_cloud_from_host, -) +from databricks.sql.auth.oidc_utils import OIDCDiscoveryUtil +from databricks.sql.auth.token import Token logger = logging.getLogger(__name__) +# Token exchange constants TOKEN_EXCHANGE_PARAMS = { "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", "scope": "sql", @@ -23,52 +23,23 @@ "return_original_token_if_authenticated": "true", } -TOKEN_REFRESH_BUFFER_SECONDS = 10 - - -class Token: - """Represents an OAuth token with expiry information.""" - - def __init__( - self, - access_token: str, - token_type: str, - refresh_token: str = "", - expiry: Optional[datetime] = None, - ): - self.access_token = access_token - self.token_type = token_type - self.refresh_token = refresh_token - - # Ensure expiry is timezone-aware - if expiry is None: - self.expiry = datetime.now(tz=timezone.utc) - elif expiry.tzinfo is None: - # Convert naive datetime to aware datetime - self.expiry = expiry.replace(tzinfo=timezone.utc) - else: - self.expiry = expiry - - def is_expired(self) -> bool: - """Check if the token is expired.""" - return datetime.now(tz=timezone.utc) >= self.expiry - - def needs_refresh(self) -> bool: - """Check if the token needs to be refreshed soon.""" - buffer_time = timedelta(seconds=TOKEN_REFRESH_BUFFER_SECONDS) - return datetime.now(tz=timezone.utc) >= (self.expiry - buffer_time) - - def __str__(self) -> str: - return f"{self.token_type} {self.access_token}" - class DatabricksTokenFederationProvider(CredentialsProvider): """ Implementation of the Credential Provider that exchanges a third party access token - for a Databricks token. It exchanges the token only if the issued token - is not from the same host as the Databricks host. + for a Databricks token. + + This provider wraps an existing credentials provider and handles token exchange when + the token is from a different host than the Databricks host. It also manages token + refresh when tokens are expired. """ + # HTTP request configuration + EXCHANGE_HEADERS = { + "Accept": "*/*", + "Content-Type": "application/x-www-form-urlencoded", + } + def __init__( self, credentials_provider: CredentialsProvider, @@ -86,10 +57,11 @@ def __init__( self.credentials_provider = credentials_provider self.hostname = hostname self.identity_federation_client_id = identity_federation_client_id - self.external_provider_headers: Dict[str, str] = {} self.token_endpoint: Optional[str] = None - self.last_exchanged_token: Optional[Token] = None - self.last_external_token: Optional[str] = None + + # Store the current token information + self.current_token: Optional[Token] = None + self.external_headers: Optional[Dict[str, str]] = None def auth_type(self) -> str: """Return the auth type from the underlying credentials provider.""" @@ -99,116 +71,41 @@ def auth_type(self) -> str: def host(self) -> str: """ Alias for hostname to maintain compatibility with code expecting a host attribute. - - Returns: - str: The hostname value """ return self.hostname def __call__(self, *args, **kwargs) -> HeaderFactory: """ Configure and return a HeaderFactory that provides authentication headers. - This is called by the ExternalAuthProvider to get headers for authentication. """ # First call the underlying credentials provider to get its headers header_factory = self.credentials_provider(*args, **kwargs) - # Initialize OIDC discovery - self._init_oidc_discovery() - - def get_headers() -> Dict[str, str]: - try: - # Get headers from the underlying provider - self.external_provider_headers = header_factory() - - # Extract the token from the headers - token_type, access_token = self._extract_token_info_from_header( - self.external_provider_headers - ) - - # Check if we need to refresh the token - if ( - self.last_exchanged_token - and self.last_external_token == access_token - and self.last_exchanged_token.needs_refresh() - ): - # The token is approaching expiry, try to refresh - logger.info( - "Exchanged token approaching expiry, refreshing with fresh external token..." - ) - return self._refresh_token(access_token, token_type) - - # Parse the JWT to get claims - token_claims = self._parse_jwt_claims(access_token) - - # Check if token needs to be exchanged - if self._is_same_host(token_claims.get("iss", ""), self.hostname): - # Token is from the same host, no need to exchange - logger.debug("Token from same host, no exchange needed") - return self.external_provider_headers - else: - # Token is from a different host, need to exchange - logger.debug("Token from different host, attempting exchange") - return self._try_token_exchange_or_fallback( - access_token, token_type - ) - except Exception as e: - logger.error(f"Error processing token: {str(e)}") - # Fall back to original headers in case of error - return self.external_provider_headers - - return get_headers - - def _init_oidc_discovery(self): - """Initialize OIDC discovery to find token endpoint.""" - if self.token_endpoint is not None: - return - - try: - # Use the existing OIDC discovery mechanism - use_azure_auth = infer_cloud_from_host(self.hostname) == "azure" - idp_endpoints = get_oauth_endpoints(self.hostname, use_azure_auth) - - if idp_endpoints: - # Get the OpenID configuration URL - openid_config_url = idp_endpoints.get_openid_config_url(self.hostname) - - # Fetch the OpenID configuration - response = requests.get(openid_config_url) - if response.status_code == 200: - openid_config = response.json() - # Extract token endpoint from OpenID config - self.token_endpoint = openid_config.get("token_endpoint") - logger.info(f"Discovered token endpoint: {self.token_endpoint}") - else: - logger.warning( - f"Failed to fetch OpenID configuration from {openid_config_url}: " - f"{response.status_code}" - ) - except Exception as e: - logger.warning( - f"OIDC discovery failed: {str(e)}. Using default token endpoint." + # Get the standard token endpoint if not already set + if self.token_endpoint is None: + self.token_endpoint = OIDCDiscoveryUtil.discover_token_endpoint( + self.hostname ) - # Fallback to default token endpoint if discovery fails - if not self.token_endpoint: - hostname = self._format_hostname(self.hostname) - self.token_endpoint = f"{hostname}oidc/v1/token" - logger.info(f"Using default token endpoint: {self.token_endpoint}") - - def _format_hostname(self, hostname: str) -> str: - """Format hostname to ensure it has proper https:// prefix and trailing slash.""" - if not hostname.startswith("https://"): - hostname = f"https://{hostname}" - if not hostname.endswith("/"): - hostname = f"{hostname}/" - return hostname + # Return a function that will get authentication headers + return self.get_auth_headers def _extract_token_info_from_header( self, headers: Dict[str, str] ) -> Tuple[str, str]: - """Extract token type and token value from authorization header.""" + """ + Extract token type and token value from authorization header. + + Args: + headers: Headers dictionary + + Returns: + Tuple[str, str]: Token type and token value + + Raises: + ValueError: If no authorization header is found or it has invalid format + """ auth_header = headers.get("Authorization") if not auth_header: raise ValueError("No Authorization header found") @@ -220,27 +117,45 @@ def _extract_token_info_from_header( return parts[0], parts[1] def _parse_jwt_claims(self, token: str) -> Dict[str, Any]: - """Parse JWT token claims without validation.""" - try: - # Split the token - parts = token.split(".") - if len(parts) != 3: - raise ValueError("Invalid JWT format") - - # Get the payload part (second part) - payload = parts[1] + """ + Parse JWT token claims without validation. - # Add padding if needed - padding = "=" * (4 - len(payload) % 4) - payload += padding + Args: + token: JWT token string - # Decode and parse JSON - decoded = base64.b64decode(payload) - return json.loads(decoded) + Returns: + Dict[str, Any]: Parsed JWT claims + """ + try: + return jwt.decode(token, options={"verify_signature": False}) except Exception as e: logger.error(f"Failed to parse JWT: {str(e)}") return {} + def _get_expiry_from_jwt(self, token: str) -> Optional[datetime]: + """ + Extract expiry datetime from JWT token. + + Args: + token: JWT token string + + Returns: + Optional[datetime]: Expiry datetime if found in token, None otherwise + """ + claims = self._parse_jwt_claims(token) + + # Look for standard JWT expiry claim ("exp") + if "exp" in claims: + try: + # JWT expiry is in seconds since epoch + expiry_timestamp = int(claims["exp"]) + # Convert to datetime + return datetime.fromtimestamp(expiry_timestamp, tz=timezone.utc) + except (ValueError, TypeError) as e: + logger.warning(f"Invalid JWT expiry value: {e}") + + return None + def _is_same_host(self, url1: str, url2: str) -> bool: """ Check if two URLs have the same host. @@ -250,9 +165,15 @@ def _is_same_host(self, url1: str, url2: str) -> bool: url2: Second URL Returns: - bool: True if the hosts match, False otherwise + bool: True if hosts are the same, False otherwise """ try: + # Add protocol if missing to ensure proper parsing + if not url1.startswith(("http://", "https://")): + url1 = f"https://{url1}" + if not url2.startswith(("http://", "https://")): + url2 = f"https://{url2}" + # Parse the URLs parsed1 = urlparse(url1) parsed2 = urlparse(url2) @@ -263,71 +184,94 @@ def _is_same_host(self, url1: str, url2: str) -> bool: logger.warning(f"Error comparing hosts: {str(e)}") return False - def _refresh_token(self, access_token: str, token_type: str) -> Dict[str, str]: + def refresh_token(self) -> Token: """ - Refresh the exchanged token by getting a fresh external token. + Refresh the token and return the new Token object. - Args: - access_token: The external access token - token_type: The token type (usually "Bearer") + This method gets a fresh token from the credentials provider, + exchanges it if necessary, and returns the new Token object. Returns: - Dict[str, str]: Headers with the refreshed token + Token: The new refreshed token + + Raises: + ValueError: If token refresh fails """ - try: - # Get a fresh token from the underlying provider - fresh_headers = self.credentials_provider()() + # Get fresh headers from the credentials provider + header_factory = self.credentials_provider() + self.external_headers = header_factory() - # Extract the fresh token from the headers - fresh_token_type, fresh_access_token = self._extract_token_info_from_header( - fresh_headers - ) + # Extract the new token info + token_type, access_token = self._extract_token_info_from_header( + self.external_headers + ) - # Exchange the fresh token for a new Databricks token - exchanged_token = self._exchange_token(fresh_access_token) - self.last_exchanged_token = exchanged_token - self.last_external_token = fresh_access_token - - # Update the headers with the new token - return { - "Authorization": ( - f"{exchanged_token.token_type} {exchanged_token.access_token}" - ) - } - except Exception as e: - logger.error( - f"Token refresh failed: {str(e)}, falling back to original token" - ) - return self.external_provider_headers + # Check if we need to exchange the token + token_claims = self._parse_jwt_claims(access_token) + + # Create new token based on whether it's from the same host or not + if self._is_same_host(token_claims.get("iss", ""), self.hostname): + # Token is from the same host, no need to exchange + logger.debug("Token from same host, creating token without exchange") + + expiry = self._get_expiry_from_jwt(access_token) + if expiry is None: + raise ValueError("Could not determine token expiry from JWT") + + new_token = Token(access_token, token_type, "", expiry) + else: + # Token is from a different host, need to exchange + logger.debug("Token from different host, exchanging token") + new_token = self._exchange_token(access_token) + + # Store the token + self.current_token = new_token - def _try_token_exchange_or_fallback( - self, access_token: str, token_type: str - ) -> Dict[str, str]: + return new_token + + def get_current_token(self) -> Token: """ - Attempt to exchange the token or fall back to the original token if exchange fails. + Get the current token, refreshing if necessary. - Args: - access_token: The external access token - token_type: The token type (usually "Bearer") + This method checks if the current token is valid and not expired. + If it is valid, it returns the current token. + If it is expired or doesn't exist, it refreshes the token. + + Returns: + Token: The current valid token + + Raises: + ValueError: If unable to get a valid token + """ + # Return current token if it exists and is valid + if self.current_token is not None and self.current_token.is_valid(): + return self.current_token + + # Token doesn't exist or is expired, get a fresh one + return self.refresh_token() + + def get_auth_headers(self) -> Dict[str, str]: + """ + Get authorization headers using the current token. + + This method gets the current token and returns it formatted + as authorization headers. Returns: - Dict[str, str]: Headers with either the exchanged token or the original token + Dict[str, str]: Authorization headers """ try: - exchanged_token = self._exchange_token(access_token) - self.last_exchanged_token = exchanged_token - self.last_external_token = access_token - - return { - "Authorization": ( - f"{exchanged_token.token_type} {exchanged_token.access_token}" - ) - } + token = self.get_current_token() + return {"Authorization": f"{token.token_type} {token.access_token}"} except Exception as e: - logger.warning( - f"Token exchange failed: {str(e)}, falling back to original token" - ) - return self.external_provider_headers + logger.error(f"Error getting auth headers: {str(e)}") + + # Fall back to external headers if available + if self.external_headers: + return self.external_headers + + # Return empty dict as a last resort + return {} def _send_token_exchange_request( self, token_exchange_data: Dict[str, str] @@ -336,21 +280,19 @@ def _send_token_exchange_request( Send the token exchange request to the token endpoint. Args: - token_exchange_data: The data to send in the request + token_exchange_data: Token exchange request data Returns: - Dict[str, Any]: The parsed JSON response + Dict[str, Any]: Token exchange response Raises: - Exception: If the request fails + ValueError: If token exchange fails """ if not self.token_endpoint: raise ValueError("Token endpoint not initialized") - headers = {"Accept": "*/*", "Content-Type": "application/x-www-form-urlencoded"} - response = requests.post( - self.token_endpoint, data=token_exchange_data, headers=headers + self.token_endpoint, data=token_exchange_data, headers=self.EXCHANGE_HEADERS ) if response.status_code != 200: @@ -366,13 +308,13 @@ def _exchange_token(self, access_token: str) -> Token: Exchange an external token for a Databricks token. Args: - access_token: The external token to exchange + access_token: External token to exchange Returns: - Token: The exchanged token with expiry information + Token: Exchanged token Raises: - Exception: If token exchange fails + ValueError: If token exchange fails """ # Prepare the request data token_exchange_data = dict(TOKEN_EXCHANGE_PARAMS) @@ -382,46 +324,55 @@ def _exchange_token(self, access_token: str) -> Token: if self.identity_federation_client_id: token_exchange_data["client_id"] = self.identity_federation_client_id - try: - # Send the token exchange request - resp_data = self._send_token_exchange_request(token_exchange_data) - - # Extract token information - new_access_token = resp_data.get("access_token") - if not new_access_token: - raise ValueError("No access token in exchange response") - - token_type = resp_data.get("token_type", "Bearer") - refresh_token = resp_data.get("refresh_token", "") - - # Parse expiry time from token claims if possible - expiry = datetime.now(tz=timezone.utc) - - # First try to get expiry from the response's expires_in field - if "expires_in" in resp_data and resp_data["expires_in"]: - try: - expires_in = int(resp_data["expires_in"]) - expiry = datetime.now(tz=timezone.utc) + timedelta( - seconds=expires_in - ) - except (ValueError, TypeError) as e: - logger.warning(f"Invalid expires_in value: {str(e)}") - - # If that didn't work, try to parse JWT claims for expiry - if expiry == datetime.now(tz=timezone.utc): - token_claims = self._parse_jwt_claims(new_access_token) - if "exp" in token_claims: - try: - exp_timestamp = int(token_claims["exp"]) - expiry = datetime.fromtimestamp(exp_timestamp, tz=timezone.utc) - except (ValueError, TypeError) as e: - logger.warning(f"Invalid exp claim in token: {str(e)}") - - return Token(new_access_token, token_type, refresh_token, expiry) + # Send the token exchange request + resp_data = self._send_token_exchange_request(token_exchange_data) - except Exception as e: - logger.error(f"Token exchange failed: {str(e)}") - raise + # Extract token information + new_access_token = resp_data.get("access_token") + if not new_access_token: + raise ValueError("No access token in exchange response") + + token_type = resp_data.get("token_type", "Bearer") + refresh_token = resp_data.get("refresh_token", "") + + # Determine token expiry - first try from JWT claims + expiry = self._get_expiry_from_jwt(new_access_token) + + # If JWT expiry not available, use expires_in from response + if expiry is None: + expiry = self._get_expiry_from_response(resp_data) + + # If we still don't have an expiry, we can't proceed + if expiry is None: + raise ValueError( + "Unable to determine token expiry from response or JWT claims" + ) + + return Token(new_access_token, token_type, refresh_token, expiry) + + def _get_expiry_from_response( + self, resp_data: Dict[str, Any] + ) -> Optional[datetime]: + """ + Extract expiry datetime from response data. + + Args: + resp_data: Response data from token exchange + + Returns: + Optional[datetime]: Expiry datetime if found in response, None otherwise + """ + if "expires_in" not in resp_data or not resp_data["expires_in"]: + return None + + try: + expires_in = int(resp_data["expires_in"]) + expiry = datetime.now(tz=timezone.utc) + timedelta(seconds=expires_in) + logger.debug(f"Using expiry from expires_in: {expiry}") + return expiry + except (ValueError, TypeError) as e: + logger.warning(f"Invalid expires_in value: {str(e)}") + return None class SimpleCredentialsProvider(CredentialsProvider): @@ -430,14 +381,22 @@ class SimpleCredentialsProvider(CredentialsProvider): def __init__( self, token: str, token_type: str = "Bearer", auth_type_value: str = "token" ): + """ + Initialize a SimpleCredentialsProvider. + """ self.token = token self.token_type = token_type self.auth_type_value = auth_type_value def auth_type(self) -> str: + """Return the auth type value.""" return self.auth_type_value def __call__(self, *args, **kwargs) -> HeaderFactory: + """ + Return a HeaderFactory that provides a fixed token. + """ + def get_headers() -> Dict[str, str]: return {"Authorization": f"{self.token_type} {self.token}"} diff --git a/tests/token_federation/github_oidc_test.py b/tests/token_federation/github_oidc_test.py index 71c510c34..74f8f97e4 100755 --- a/tests/token_federation/github_oidc_test.py +++ b/tests/token_federation/github_oidc_test.py @@ -10,18 +10,10 @@ import os import sys -import json -import base64 import logging +import jwt from databricks import sql -try: - import jwt - - HAS_JWT_LIBRARY = True -except ImportError: - HAS_JWT_LIBRARY = False - logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" @@ -32,34 +24,16 @@ def decode_jwt(token): """ Decode and return the claims from a JWT token. - + Args: token: The JWT token string - + Returns: - dict: The decoded token claims or None if decoding fails + dict: The decoded token claims or empty dict if decoding fails """ - if HAS_JWT_LIBRARY: - try: - # Using PyJWT library (preferred method) - # Note: we're not verifying the signature as this is just for debugging - return jwt.decode(token, options={"verify_signature": False}) - except Exception as e: - logger.error(f"Failed to decode token with PyJWT: {str(e)}") - - # Fallback to manual decoding try: - parts = token.split(".") - if len(parts) != 3: - raise ValueError("Invalid JWT format") - - payload = parts[1] - # Add padding if needed - padding = "=" * (4 - len(payload) % 4) - payload += padding - - decoded = base64.b64decode(payload) - return json.loads(decoded) + # Using PyJWT library to decode token without verification + return jwt.decode(token, options={"verify_signature": False}) except Exception as e: logger.error(f"Failed to decode token: {str(e)}") return {} @@ -68,7 +42,7 @@ def decode_jwt(token): def get_environment_variables(): """ Get required environment variables for the test. - + Returns: tuple: (github_token, host, http_path, identity_federation_client_id) """ @@ -77,11 +51,24 @@ def get_environment_variables(): http_path = os.environ.get("DATABRICKS_HTTP_PATH_FOR_TF") identity_federation_client_id = os.environ.get("IDENTITY_FEDERATION_CLIENT_ID") + # Validate required environment variables + if not github_token: + raise ValueError("OIDC_TOKEN environment variable is required") + if not host: + raise ValueError("DATABRICKS_HOST_FOR_TF environment variable is required") + if not http_path: + raise ValueError("DATABRICKS_HTTP_PATH_FOR_TF environment variable is required") + return github_token, host, http_path, identity_federation_client_id def display_token_info(claims): - """Display token claims for debugging.""" + """ + Display token claims for debugging. + + Args: + claims: Dictionary containing JWT token claims + """ if not claims: logger.warning("No token claims available to display") return @@ -102,13 +89,13 @@ def test_databricks_connection( ): """ Test connection to Databricks using token federation. - + Args: host: Databricks host http_path: Databricks HTTP path github_token: GitHub OIDC token identity_federation_client_id: Identity federation client ID - + Returns: bool: True if the test is successful, False otherwise """ @@ -121,9 +108,14 @@ def test_databricks_connection( "http_path": http_path, "access_token": github_token, "auth_type": "token-federation", - "identity_federation_client_id": identity_federation_client_id, } + # Add identity federation client ID if provided + if identity_federation_client_id: + connection_params[ + "identity_federation_client_id" + ] = identity_federation_client_id + try: with sql.connect(**connection_params) as connection: logger.info("Connection established successfully") @@ -150,20 +142,12 @@ def main(): """Main entry point for the test script.""" try: # Get environment variables - github_token, host, http_path, identity_federation_client_id = ( - get_environment_variables() - ) - - if not github_token: - logger.error("Missing GitHub OIDC token (OIDC_TOKEN)") - sys.exit(1) - - if not host or not http_path: - logger.error( - "Missing Databricks connection parameters " - "(DATABRICKS_HOST_FOR_TF, DATABRICKS_HTTP_PATH_FOR_TF)" - ) - sys.exit(1) + ( + github_token, + host, + http_path, + identity_federation_client_id, + ) = get_environment_variables() # Display token claims claims = decode_jwt(github_token) @@ -184,4 +168,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/tests/unit/test_token_federation.py b/tests/unit/test_token_federation.py index d1664c55d..8fa3fa30f 100644 --- a/tests/unit/test_token_federation.py +++ b/tests/unit/test_token_federation.py @@ -6,243 +6,397 @@ import pytest from unittest.mock import MagicMock, patch -import json from datetime import datetime, timezone, timedelta +import jwt +from databricks.sql.auth.token import Token from databricks.sql.auth.token_federation import ( - Token, DatabricksTokenFederationProvider, SimpleCredentialsProvider, - TOKEN_REFRESH_BUFFER_SECONDS, ) +from databricks.sql.auth.oidc_utils import OIDCDiscoveryUtil # Tests for Token class -def test_token_initialization(): - """Test Token initialization.""" - token = Token("access_token_value", "Bearer", "refresh_token_value") - assert token.access_token == "access_token_value" - assert token.token_type == "Bearer" - assert token.refresh_token == "refresh_token_value" - - -def test_token_is_expired(): - """Test Token is_expired method.""" - # Token with expiry in the past - past = datetime.now(tz=timezone.utc) - timedelta(hours=1) - token = Token("access_token", "Bearer", expiry=past) - assert token.is_expired() - - # Token with expiry in the future - future = datetime.now(tz=timezone.utc) + timedelta(hours=1) - token = Token("access_token", "Bearer", expiry=future) - assert not token.is_expired() - - -def test_token_needs_refresh(): - """Test Token needs_refresh method using actual TOKEN_REFRESH_BUFFER_SECONDS.""" - # Token with expiry in the past - past = datetime.now(tz=timezone.utc) - timedelta(hours=1) - token = Token("access_token", "Bearer", expiry=past) - assert token.needs_refresh() - - # Token with expiry in the near future (within refresh buffer) - near_future = datetime.now(tz=timezone.utc) + timedelta( - seconds=TOKEN_REFRESH_BUFFER_SECONDS - 1 - ) - token = Token("access_token", "Bearer", expiry=near_future) - assert token.needs_refresh() - - # Token with expiry far in the future - far_future = datetime.now(tz=timezone.utc) + timedelta( - seconds=TOKEN_REFRESH_BUFFER_SECONDS + 10 - ) - token = Token("access_token", "Bearer", expiry=far_future) - assert not token.needs_refresh() +class TestToken: + """Tests for the Token class.""" + + def test_token_initialization_and_properties(self): + """Test Token initialization, properties and methods.""" + # Test with minimum required parameters plus expiry + future = datetime.now(tz=timezone.utc) + timedelta(hours=1) + token = Token("access_token_value", "Bearer", expiry=future) + assert token.access_token == "access_token_value" + assert token.token_type == "Bearer" + assert token.refresh_token == "" + assert token.expiry == future + assert token.is_valid() + + # Test expired token + past = datetime.now(tz=timezone.utc) - timedelta(hours=1) + expired_token = Token("expired", "Bearer", expiry=past) + assert not expired_token.is_valid() + + # Test almost expired token (will expire within buffer) + almost_expired = datetime.now(tz=timezone.utc) + timedelta( + seconds=5 + ) # Less than MIN_VALIDITY_BUFFER + almost_token = Token("almost", "Bearer", expiry=almost_expired) + assert not almost_token.is_valid() # Not valid due to buffer + + # Test string representation + assert str(token) == "Bearer access_token_value" # Tests for SimpleCredentialsProvider -def test_simple_credentials_provider(): - """Test SimpleCredentialsProvider.""" - provider = SimpleCredentialsProvider( - "token_value", "Bearer", "custom_auth_type" - ) - assert provider.auth_type() == "custom_auth_type" - - header_factory = provider() - headers = header_factory() - assert headers == {"Authorization": "Bearer token_value"} +class TestSimpleCredentialsProvider: + """Tests for the SimpleCredentialsProvider class.""" + + def test_provider_initialization(self): + """Test initialization and methods of SimpleCredentialsProvider.""" + provider = SimpleCredentialsProvider("token1", "Bearer", "token") + assert provider.auth_type() == "token" + + # Test header factory + header_factory = provider() + headers = header_factory() + assert headers == {"Authorization": "Bearer token1"} + + +# Tests for OIDCDiscoveryUtil +class TestOIDCDiscoveryUtil: + """Tests for the OIDCDiscoveryUtil class.""" + + def test_discover_token_endpoint(self): + """Test token endpoint creation for Databricks workspaces.""" + # Test with different hostname formats + # Without protocol and without trailing slash + token_endpoint = OIDCDiscoveryUtil.discover_token_endpoint("databricks.com") + assert token_endpoint == "https://databricks.com/oidc/v1/token" + + # With protocol but without trailing slash + token_endpoint = OIDCDiscoveryUtil.discover_token_endpoint( + "https://databricks.com" + ) + assert token_endpoint == "https://databricks.com/oidc/v1/token" + + # With protocol and trailing slash + token_endpoint = OIDCDiscoveryUtil.discover_token_endpoint( + "https://databricks.com/" + ) + assert token_endpoint == "https://databricks.com/oidc/v1/token" + + def test_format_hostname(self): + """Test hostname formatting.""" + # Without protocol and without trailing slash + assert ( + OIDCDiscoveryUtil.format_hostname("databricks.com") + == "https://databricks.com/" + ) + + # With protocol but without trailing slash + assert ( + OIDCDiscoveryUtil.format_hostname("https://databricks.com") + == "https://databricks.com/" + ) + + # With protocol and trailing slash + assert ( + OIDCDiscoveryUtil.format_hostname("https://databricks.com/") + == "https://databricks.com/" + ) # Tests for DatabricksTokenFederationProvider -def test_host_property(): - """Test the host property of DatabricksTokenFederationProvider.""" - creds_provider = SimpleCredentialsProvider("token") - federation_provider = DatabricksTokenFederationProvider( - creds_provider, "example.com", "client_id" - ) - assert federation_provider.host == "example.com" - assert federation_provider.hostname == "example.com" - - -@pytest.fixture -def mock_request_get(): - with patch("databricks.sql.auth.token_federation.requests.get") as mock: - yield mock - - -@pytest.fixture -def mock_get_oauth_endpoints(): - with patch("databricks.sql.auth.token_federation.get_oauth_endpoints") as mock: - yield mock - - -def test_init_oidc_discovery(mock_request_get, mock_get_oauth_endpoints): - """Test _init_oidc_discovery method.""" - # Mock the get_oauth_endpoints function - mock_endpoints = MagicMock() - mock_endpoints.get_openid_config_url.return_value = ( - "https://example.com/openid-config" - ) - mock_get_oauth_endpoints.return_value = mock_endpoints - - # Mock the requests.get response - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "token_endpoint": "https://example.com/token" - } - mock_request_get.return_value = mock_response - - # Create the provider - creds_provider = SimpleCredentialsProvider("token") - federation_provider = DatabricksTokenFederationProvider( - creds_provider, "example.com", "client_id" - ) - - # Call the method - federation_provider._init_oidc_discovery() - - # Check if the token endpoint was set correctly - assert federation_provider.token_endpoint == "https://example.com/token" - - # Test fallback when discovery fails - mock_request_get.side_effect = Exception("Connection error") - federation_provider.token_endpoint = None - federation_provider._init_oidc_discovery() - assert federation_provider.token_endpoint == "https://example.com/oidc/v1/token" - - -@pytest.fixture -def mock_parse_jwt_claims(): - with patch( - "databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._parse_jwt_claims" - ) as mock: - yield mock - - -@pytest.fixture -def mock_exchange_token(): - with patch( - "databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._exchange_token" - ) as mock: - yield mock - - -@pytest.fixture -def mock_is_same_host(): - with patch( - "databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._is_same_host" - ) as mock: - yield mock - - -def test_token_refresh(mock_parse_jwt_claims, mock_exchange_token, mock_is_same_host): - """Test token refresh functionality for approaching expiry.""" - # Set up mocks - mock_parse_jwt_claims.return_value = { - "iss": "https://login.microsoftonline.com/tenant" - } - mock_is_same_host.return_value = False - - # Create the initial header factory - initial_headers = {"Authorization": "Bearer initial_token"} - initial_header_factory = MagicMock() - initial_header_factory.return_value = initial_headers - - # Create the fresh header factory for later use - fresh_headers = {"Authorization": "Bearer fresh_token"} - fresh_header_factory = MagicMock() - fresh_header_factory.return_value = fresh_headers - - # Create the credentials provider that will return the header factory - mock_creds_provider = MagicMock() - mock_creds_provider.return_value = initial_header_factory - - # Set up the token federation provider - federation_provider = DatabricksTokenFederationProvider( - mock_creds_provider, "example.com", "client_id" - ) - - # Mock the token exchange to return a known token - future_time = datetime.now(tz=timezone.utc) + timedelta(hours=1) - mock_exchange_token.return_value = Token( - "exchanged_token_1", "Bearer", expiry=future_time - ) - - # First call to get initial headers and token - this should trigger an exchange - headers_factory = federation_provider() - headers = headers_factory() - - # Verify the exchange happened with the initial token - mock_exchange_token.assert_called_with("initial_token") - assert headers["Authorization"] == "Bearer exchanged_token_1" - - # Reset the mocks to track the next call - mock_exchange_token.reset_mock() - - # Now simulate an approaching expiry - near_expiry = datetime.now(tz=timezone.utc) + timedelta( - seconds=TOKEN_REFRESH_BUFFER_SECONDS - 1 - ) - federation_provider.last_exchanged_token = Token( - "exchanged_token_1", "Bearer", expiry=near_expiry - ) - federation_provider.last_external_token = "initial_token" - - # For the refresh call, we need the credentials provider to return a fresh token - # Update the mock to return fresh_header_factory for the second call - mock_creds_provider.return_value = fresh_header_factory - - # Set up the mock to return a different token for the refresh - mock_exchange_token.return_value = Token( - "exchanged_token_2", "Bearer", expiry=future_time - ) - - # Make a second call which should trigger refresh - headers = headers_factory() - - # Verify the exchange was performed with the fresh token - mock_exchange_token.assert_called_once_with("fresh_token") - - # Verify the headers contain the new token - assert headers["Authorization"] == "Bearer exchanged_token_2" - - -def test_create_token_federation_provider(): - """Test creation of a federation provider with a simple token provider.""" - # Create a simple provider - simple_provider = SimpleCredentialsProvider("token_value", "Bearer") - - # Create a federation provider with the simple provider - federation_provider = DatabricksTokenFederationProvider( - simple_provider, "example.com", "client_id" - ) - - assert isinstance(federation_provider, DatabricksTokenFederationProvider) - assert federation_provider.hostname == "example.com" - assert federation_provider.identity_federation_client_id == "client_id" - - # Test that the underlying credentials provider was set up correctly - assert federation_provider.credentials_provider.token == "token_value" - assert federation_provider.credentials_provider.token_type == "Bearer" +class TestDatabricksTokenFederationProvider: + """Tests for the DatabricksTokenFederationProvider class.""" + + @pytest.fixture + def mock_credentials_provider(self): + """Fixture for a mock credentials provider.""" + provider = MagicMock() + provider.auth_type.return_value = "mock_auth_type" + header_factory = MagicMock() + header_factory.return_value = {"Authorization": "Bearer mock_token"} + provider.return_value = header_factory + return provider + + @pytest.fixture + def federation_provider(self, mock_credentials_provider): + """Fixture for a token federation provider.""" + return DatabricksTokenFederationProvider( + mock_credentials_provider, "databricks.com", "client_id" + ) + + @pytest.fixture + def mock_discover_token_endpoint(self): + """Fixture for mocking OIDCDiscoveryUtil.discover_token_endpoint.""" + with patch( + "databricks.sql.auth.oidc_utils.OIDCDiscoveryUtil.discover_token_endpoint" + ) as mock: + mock.return_value = "https://databricks.com/token" + yield mock + + @pytest.fixture + def mock_parse_jwt_claims(self): + """Fixture for mocking _parse_jwt_claims.""" + with patch( + "databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._parse_jwt_claims" + ) as mock: + yield mock + + @pytest.fixture + def mock_exchange_token(self): + """Fixture for mocking _exchange_token.""" + with patch( + "databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._exchange_token" + ) as mock: + yield mock + + @pytest.fixture + def mock_is_same_host(self): + """Fixture for mocking _is_same_host.""" + with patch( + "databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._is_same_host" + ) as mock: + yield mock + + @pytest.fixture + def mock_request_post(self): + """Fixture for mocking requests.post.""" + with patch("databricks.sql.auth.token_federation.requests.post") as mock: + yield mock + + def test_host_and_auth_type(self, federation_provider): + """Test the host property and auth_type of DatabricksTokenFederationProvider.""" + assert federation_provider.host == "databricks.com" + assert federation_provider.hostname == "databricks.com" + assert federation_provider.auth_type() == "mock_auth_type" + + def test_is_same_host(self, federation_provider): + """Test the _is_same_host method with various URL combinations.""" + # Same host + assert federation_provider._is_same_host( + "https://databricks.com", "https://databricks.com" + ) + # Different host + assert not federation_provider._is_same_host( + "https://databricks.com", "https://different.com" + ) + # Same host with paths + assert federation_provider._is_same_host( + "https://databricks.com/path", "https://databricks.com/other" + ) + # Missing protocol + assert federation_provider._is_same_host( + "databricks.com", "https://databricks.com" + ) + + def test_extract_token_info_from_header(self, federation_provider): + """Test _extract_token_info_from_header with valid and invalid headers.""" + # Valid headers + assert federation_provider._extract_token_info_from_header( + {"Authorization": "Bearer token"} + ) == ("Bearer", "token") + + assert federation_provider._extract_token_info_from_header( + {"Authorization": "CustomType token"} + ) == ("CustomType", "token") + + # Invalid headers + with pytest.raises(ValueError): + federation_provider._extract_token_info_from_header({}) + + with pytest.raises(ValueError): + federation_provider._extract_token_info_from_header({"Authorization": ""}) + + with pytest.raises(ValueError): + federation_provider._extract_token_info_from_header( + {"Authorization": "Bearer"} + ) + + def test_token_reuse( + self, + federation_provider, + mock_exchange_token, + ): + """Test token reuse when token is still valid.""" + # Set up the initial token + future_time = datetime.now(tz=timezone.utc) + timedelta(hours=1) + initial_token = Token("exchanged_token", "Bearer", expiry=future_time) + federation_provider.current_token = initial_token + federation_provider.external_headers = { + "Authorization": "Bearer external_token" + } + + # Get headers and verify the token is reused without calling exchange + headers = federation_provider.get_auth_headers() + assert headers["Authorization"] == "Bearer exchanged_token" + # Verify exchange was not called + mock_exchange_token.assert_not_called() + + def test_refresh_token_method( + self, + federation_provider, + mock_parse_jwt_claims, + mock_exchange_token, + mock_is_same_host, + mock_discover_token_endpoint, + ): + """Test the refactored refresh_token method for both exchange and non-exchange cases.""" + # CASE 1: Token from different host (needs exchange) + # Set up mocks + mock_parse_jwt_claims.return_value = { + "iss": "https://login.microsoftonline.com/tenant" + } + mock_is_same_host.return_value = False + + # Set up headers that the credentials provider will return + headers = {"Authorization": "Bearer test_token"} + header_factory = MagicMock() + header_factory.return_value = headers + + # Configure the credentials provider + mock_creds_provider = MagicMock() + mock_creds_provider.return_value = header_factory + federation_provider.credentials_provider = mock_creds_provider + + # Configure the mock token exchange + future_time = datetime.now(tz=timezone.utc) + timedelta(hours=1) + mock_exchange_token.return_value = Token( + "exchanged_token", "Bearer", expiry=future_time + ) + + # Call the refresh_token method + token = federation_provider.refresh_token() + + # Verify the token was exchanged + mock_exchange_token.assert_called_with("test_token") + assert token.access_token == "exchanged_token" + assert token == federation_provider.current_token + + # CASE 2: Token from same host (no exchange needed) + mock_is_same_host.return_value = True + mock_exchange_token.reset_mock() + + # Mock the JWT expiry extraction + expiry_time = datetime.now(tz=timezone.utc) + timedelta(hours=2) + with patch( + "databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._get_expiry_from_jwt", + return_value=expiry_time, + ): + # Call refresh_token again + token = federation_provider.refresh_token() + + # Verify no exchange was performed + mock_exchange_token.assert_not_called() + # Verify token was created directly + assert token.access_token == "test_token" + assert token.expiry == expiry_time + + def test_call_method_returns_auth_headers_directly( + self, + federation_provider, + mock_discover_token_endpoint, + ): + """Test that __call__ directly returns the get_auth_headers method.""" + # Mock get_auth_headers to verify it's called directly + with patch.object( + federation_provider, + "get_auth_headers", + return_value={"Authorization": "Bearer test_auth"}, + ) as mock_get_auth: + # Get the header factory from __call__ + result = federation_provider() + + # In our refactored implementation, __call__ returns get_auth_headers directly + assert result is federation_provider.get_auth_headers + + # Now call the result and verify it returns what get_auth_headers returns + headers = result() + assert headers == {"Authorization": "Bearer test_auth"} + mock_get_auth.assert_called_once() + + def test_get_expiry_from_jwt(self, federation_provider): + """Test extracting expiry from JWT token.""" + # Create a JWT token with expiry + expiry_timestamp = int( + (datetime.now(tz=timezone.utc) + timedelta(hours=1)).timestamp() + ) + payload = { + "exp": expiry_timestamp, + "iat": int(datetime.now(tz=timezone.utc).timestamp()), + "sub": "test-subject", + } + + # Create JWT token + token = jwt.encode(payload, "secret", algorithm="HS256") + + # Test the method + expiry = federation_provider._get_expiry_from_jwt(token) + + # Verify the expiry is extracted correctly + assert expiry is not None + assert isinstance(expiry, datetime) + assert expiry.tzinfo is not None # Should be timezone-aware + assert ( + abs( + ( + expiry - datetime.fromtimestamp(expiry_timestamp, tz=timezone.utc) + ).total_seconds() + ) + < 1 + ) # Allow for small rounding differences + + # Test with invalid token + expiry = federation_provider._get_expiry_from_jwt("invalid-token") + assert expiry is None + + # Test with token missing expiry + payload = {"sub": "test-subject"} + token_without_exp = jwt.encode(payload, "secret", algorithm="HS256") + expiry = federation_provider._get_expiry_from_jwt(token_without_exp) + assert expiry is None + + def test_exchange_token( + self, federation_provider, mock_request_post, mock_discover_token_endpoint + ): + """Test the _exchange_token method with success and failure cases.""" + # SUCCESS CASE + # Mock the response data + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "access_token": "new_token", + "token_type": "Bearer", + "refresh_token": "refresh_value", + "expires_in": 3600, + } + mock_request_post.return_value = mock_response + + # Set the token endpoint + federation_provider.token_endpoint = "https://databricks.com/token" + + # Call the method + token = federation_provider._exchange_token("original_token") + + # Verify the token was created correctly + assert token.access_token == "new_token" + assert token.token_type == "Bearer" + assert token.refresh_token == "refresh_value" + # Expiry should be around 1 hour in the future + assert token.expiry > datetime.now(tz=timezone.utc) + assert token.expiry < datetime.now(tz=timezone.utc) + timedelta(seconds=3601) + + # FAILURE CASE + # Mock the response data for failure + mock_response = MagicMock() + mock_response.status_code = 401 + mock_response.text = "Unauthorized" + mock_request_post.return_value = mock_response + + # Call the method and expect an exception + with pytest.raises( + ValueError, match="Token exchange failed with status code 401" + ): + federation_provider._exchange_token("original_token") From e9de21a85d32f87bc912c81f5d09f18e6ba0514d Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Mon, 12 May 2025 05:30:31 +0000 Subject: [PATCH 36/46] minor --- tests/token_federation/github_oidc_test.py | 2 -- tests/unit/test_token_federation.py | 10 ++-------- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/tests/token_federation/github_oidc_test.py b/tests/token_federation/github_oidc_test.py index 74f8f97e4..10bd86868 100755 --- a/tests/token_federation/github_oidc_test.py +++ b/tests/token_federation/github_oidc_test.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python3 - """ Test script for Databricks SQL token federation with GitHub Actions OIDC tokens. diff --git a/tests/unit/test_token_federation.py b/tests/unit/test_token_federation.py index 8fa3fa30f..8dc49b1db 100644 --- a/tests/unit/test_token_federation.py +++ b/tests/unit/test_token_federation.py @@ -1,9 +1,3 @@ -#!/usr/bin/env python3 - -""" -Unit tests for token federation functionality in the Databricks SQL connector. -""" - import pytest from unittest.mock import MagicMock, patch from datetime import datetime, timezone, timedelta @@ -134,7 +128,7 @@ def mock_discover_token_endpoint(self): with patch( "databricks.sql.auth.oidc_utils.OIDCDiscoveryUtil.discover_token_endpoint" ) as mock: - mock.return_value = "https://databricks.com/token" + mock.return_value = "https://databricks.com/oidc/v1/token" yield mock @pytest.fixture @@ -375,7 +369,7 @@ def test_exchange_token( mock_request_post.return_value = mock_response # Set the token endpoint - federation_provider.token_endpoint = "https://databricks.com/token" + federation_provider.token_endpoint = "https://databricks.com/oidc/v1/token" # Call the method token = federation_provider._exchange_token("original_token") From efb91492f7d57582a524132bef4044b70b020af8 Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Mon, 12 May 2025 05:58:57 +0000 Subject: [PATCH 37/46] test improvements --- tests/unit/test_token_federation.py | 578 ++++++++++++++-------------- 1 file changed, 285 insertions(+), 293 deletions(-) diff --git a/tests/unit/test_token_federation.py b/tests/unit/test_token_federation.py index 8dc49b1db..53656b218 100644 --- a/tests/unit/test_token_federation.py +++ b/tests/unit/test_token_federation.py @@ -11,103 +11,109 @@ from databricks.sql.auth.oidc_utils import OIDCDiscoveryUtil -# Tests for Token class +@pytest.fixture +def future_time(): + """Fixture providing a future time for token expiry.""" + return datetime.now(tz=timezone.utc) + timedelta(hours=1) + + +@pytest.fixture +def valid_token(future_time): + """Fixture providing a valid token.""" + return Token("access_token_value", "Bearer", expiry=future_time) + + class TestToken: """Tests for the Token class.""" - def test_token_initialization_and_properties(self): - """Test Token initialization, properties and methods.""" - # Test with minimum required parameters plus expiry - future = datetime.now(tz=timezone.utc) + timedelta(hours=1) - token = Token("access_token_value", "Bearer", expiry=future) + def test_valid_token_properties(self, future_time): + """Test that a valid token has the expected properties.""" + # Create token with future expiry + token = Token("access_token_value", "Bearer", expiry=future_time) + + # Verify properties assert token.access_token == "access_token_value" assert token.token_type == "Bearer" assert token.refresh_token == "" - assert token.expiry == future + assert token.expiry == future_time assert token.is_valid() + assert str(token) == "Bearer access_token_value" + + def test_expired_token_is_invalid(self): + """Test that an expired token is recognized as invalid.""" + past_time = datetime.now(tz=timezone.utc) - timedelta(hours=1) + token = Token("expired", "Bearer", expiry=past_time) - # Test expired token - past = datetime.now(tz=timezone.utc) - timedelta(hours=1) - expired_token = Token("expired", "Bearer", expiry=past) - assert not expired_token.is_valid() + assert not token.is_valid() - # Test almost expired token (will expire within buffer) + def test_almost_expired_token_is_invalid(self): + """Test that a token about to expire is recognized as invalid.""" almost_expired = datetime.now(tz=timezone.utc) + timedelta( seconds=5 ) # Less than MIN_VALIDITY_BUFFER - almost_token = Token("almost", "Bearer", expiry=almost_expired) - assert not almost_token.is_valid() # Not valid due to buffer + token = Token("almost", "Bearer", expiry=almost_expired) - # Test string representation - assert str(token) == "Bearer access_token_value" + assert not token.is_valid() -# Tests for SimpleCredentialsProvider class TestSimpleCredentialsProvider: """Tests for the SimpleCredentialsProvider class.""" - def test_provider_initialization(self): - """Test initialization and methods of SimpleCredentialsProvider.""" + def test_provider_initialization_and_headers(self): + """Test SimpleCredentialsProvider initialization and header generation.""" provider = SimpleCredentialsProvider("token1", "Bearer", "token") + + # Check auth type assert provider.auth_type() == "token" - # Test header factory - header_factory = provider() - headers = header_factory() + # Check header generation + headers = provider()() assert headers == {"Authorization": "Bearer token1"} -# Tests for OIDCDiscoveryUtil class TestOIDCDiscoveryUtil: """Tests for the OIDCDiscoveryUtil class.""" - def test_discover_token_endpoint(self): - """Test token endpoint creation for Databricks workspaces.""" - # Test with different hostname formats - # Without protocol and without trailing slash - token_endpoint = OIDCDiscoveryUtil.discover_token_endpoint("databricks.com") - assert token_endpoint == "https://databricks.com/oidc/v1/token" - - # With protocol but without trailing slash - token_endpoint = OIDCDiscoveryUtil.discover_token_endpoint( - "https://databricks.com" - ) - assert token_endpoint == "https://databricks.com/oidc/v1/token" - - # With protocol and trailing slash - token_endpoint = OIDCDiscoveryUtil.discover_token_endpoint( - "https://databricks.com/" - ) - assert token_endpoint == "https://databricks.com/oidc/v1/token" - - def test_format_hostname(self): - """Test hostname formatting.""" - # Without protocol and without trailing slash - assert ( - OIDCDiscoveryUtil.format_hostname("databricks.com") - == "https://databricks.com/" - ) + @pytest.mark.parametrize( + "hostname,expected", + [ + # Without protocol and without trailing slash + ("databricks.com", "https://databricks.com/oidc/v1/token"), + # With protocol but without trailing slash + ("https://databricks.com", "https://databricks.com/oidc/v1/token"), + # With protocol and trailing slash + ("https://databricks.com/", "https://databricks.com/oidc/v1/token"), + ], + ) + def test_discover_token_endpoint(self, hostname, expected): + """Test token endpoint creation for various hostname formats.""" + token_endpoint = OIDCDiscoveryUtil.discover_token_endpoint(hostname) + assert token_endpoint == expected + + @pytest.mark.parametrize( + "hostname,expected", + [ + # Without protocol and without trailing slash + ("databricks.com", "https://databricks.com/"), + # With protocol but without trailing slash + ("https://databricks.com", "https://databricks.com/"), + # With protocol and trailing slash + ("https://databricks.com/", "https://databricks.com/"), + ], + ) + def test_format_hostname(self, hostname, expected): + """Test hostname formatting with various input formats.""" + formatted = OIDCDiscoveryUtil.format_hostname(hostname) + assert formatted == expected - # With protocol but without trailing slash - assert ( - OIDCDiscoveryUtil.format_hostname("https://databricks.com") - == "https://databricks.com/" - ) - # With protocol and trailing slash - assert ( - OIDCDiscoveryUtil.format_hostname("https://databricks.com/") - == "https://databricks.com/" - ) - - -# Tests for DatabricksTokenFederationProvider class TestDatabricksTokenFederationProvider: """Tests for the DatabricksTokenFederationProvider class.""" + # ==== Fixtures ==== @pytest.fixture def mock_credentials_provider(self): - """Fixture for a mock credentials provider.""" + """Fixture providing a mock credentials provider.""" provider = MagicMock() provider.auth_type.return_value = "mock_auth_type" header_factory = MagicMock() @@ -117,280 +123,266 @@ def mock_credentials_provider(self): @pytest.fixture def federation_provider(self, mock_credentials_provider): - """Fixture for a token federation provider.""" - return DatabricksTokenFederationProvider( + """Fixture providing a token federation provider with mocked dependencies.""" + provider = DatabricksTokenFederationProvider( mock_credentials_provider, "databricks.com", "client_id" ) + # Initialize token endpoint to avoid discovery during tests + provider.token_endpoint = "https://databricks.com/oidc/v1/token" + return provider @pytest.fixture - def mock_discover_token_endpoint(self): - """Fixture for mocking OIDCDiscoveryUtil.discover_token_endpoint.""" - with patch( - "databricks.sql.auth.oidc_utils.OIDCDiscoveryUtil.discover_token_endpoint" - ) as mock: - mock.return_value = "https://databricks.com/oidc/v1/token" - yield mock - - @pytest.fixture - def mock_parse_jwt_claims(self): - """Fixture for mocking _parse_jwt_claims.""" - with patch( - "databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._parse_jwt_claims" - ) as mock: - yield mock - - @pytest.fixture - def mock_exchange_token(self): - """Fixture for mocking _exchange_token.""" - with patch( - "databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._exchange_token" - ) as mock: - yield mock - - @pytest.fixture - def mock_is_same_host(self): - """Fixture for mocking _is_same_host.""" + def mock_dependencies(self): + """Mock all external dependencies of the federation provider.""" with patch( - "databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._is_same_host" - ) as mock: - yield mock - - @pytest.fixture - def mock_request_post(self): - """Fixture for mocking requests.post.""" - with patch("databricks.sql.auth.token_federation.requests.post") as mock: - yield mock - - def test_host_and_auth_type(self, federation_provider): - """Test the host property and auth_type of DatabricksTokenFederationProvider.""" + "databricks.sql.auth.oidc_utils.OIDCDiscoveryUtil.discover_token_endpoint", + return_value="https://databricks.com/oidc/v1/token", + ) as mock_discover: + with patch( + "databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._parse_jwt_claims" + ) as mock_parse_jwt: + with patch( + "databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._exchange_token" + ) as mock_exchange: + with patch( + "databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._is_same_host" + ) as mock_is_same_host: + with patch( + "databricks.sql.auth.token_federation.requests.post" + ) as mock_post: + yield { + "discover": mock_discover, + "parse_jwt": mock_parse_jwt, + "exchange": mock_exchange, + "is_same_host": mock_is_same_host, + "post": mock_post, + } + + # ==== Basic functionality tests ==== + def test_provider_initialization(self, federation_provider): + """Test basic provider initialization and properties.""" assert federation_provider.host == "databricks.com" assert federation_provider.hostname == "databricks.com" assert federation_provider.auth_type() == "mock_auth_type" - def test_is_same_host(self, federation_provider): - """Test the _is_same_host method with various URL combinations.""" - # Same host - assert federation_provider._is_same_host( - "https://databricks.com", "https://databricks.com" - ) - # Different host - assert not federation_provider._is_same_host( - "https://databricks.com", "https://different.com" - ) - # Same host with paths - assert federation_provider._is_same_host( - "https://databricks.com/path", "https://databricks.com/other" - ) - # Missing protocol - assert federation_provider._is_same_host( - "databricks.com", "https://databricks.com" - ) - - def test_extract_token_info_from_header(self, federation_provider): - """Test _extract_token_info_from_header with valid and invalid headers.""" - # Valid headers - assert federation_provider._extract_token_info_from_header( - {"Authorization": "Bearer token"} - ) == ("Bearer", "token") + # ==== Utility method tests ==== + @pytest.mark.parametrize( + "url1,url2,expected", + [ + # Same host with same protocol + ("https://databricks.com", "https://databricks.com", True), + # Different hosts + ("https://databricks.com", "https://different.com", False), + # Same host with different paths + ("https://databricks.com/path", "https://databricks.com/other", True), + # Same host with missing protocol + ("databricks.com", "https://databricks.com", True), + ], + ) + def test_is_same_host(self, federation_provider, url1, url2, expected): + """Test host comparison logic with various URL formats.""" + assert federation_provider._is_same_host(url1, url2) is expected + + @pytest.mark.parametrize( + "headers,expected_result,should_raise", + [ + # Valid Bearer token + ({"Authorization": "Bearer token"}, ("Bearer", "token"), False), + # Valid custom token type + ({"Authorization": "CustomType token"}, ("CustomType", "token"), False), + # Missing Authorization header + ({}, None, True), + # Empty Authorization header + ({"Authorization": ""}, None, True), + # Malformed Authorization header + ({"Authorization": "Bearer"}, None, True), + ], + ) + def test_extract_token_info( + self, federation_provider, headers, expected_result, should_raise + ): + """Test token extraction from headers with various formats.""" + if should_raise: + with pytest.raises(ValueError): + federation_provider._extract_token_info_from_header(headers) + else: + result = federation_provider._extract_token_info_from_header(headers) + assert result == expected_result - assert federation_provider._extract_token_info_from_header( - {"Authorization": "CustomType token"} - ) == ("CustomType", "token") + def test_get_expiry_from_jwt(self, federation_provider): + """Test JWT token expiry extraction.""" + # Create a valid JWT token with expiry + expiry_timestamp = int( + (datetime.now(tz=timezone.utc) + timedelta(hours=1)).timestamp() + ) + valid_payload = { + "exp": expiry_timestamp, + "iat": int(datetime.now(tz=timezone.utc).timestamp()), + "sub": "test-subject", + } + valid_token = jwt.encode(valid_payload, "secret", algorithm="HS256") - # Invalid headers - with pytest.raises(ValueError): - federation_provider._extract_token_info_from_header({}) + # Test with valid token + expiry = federation_provider._get_expiry_from_jwt(valid_token) + assert expiry is not None + assert isinstance(expiry, datetime) + assert expiry.tzinfo is not None # Should be timezone-aware + # Allow for small rounding differences + assert ( + abs( + ( + expiry - datetime.fromtimestamp(expiry_timestamp, tz=timezone.utc) + ).total_seconds() + ) + < 1 + ) - with pytest.raises(ValueError): - federation_provider._extract_token_info_from_header({"Authorization": ""}) + # Test with invalid token format + assert federation_provider._get_expiry_from_jwt("invalid-token") is None - with pytest.raises(ValueError): - federation_provider._extract_token_info_from_header( - {"Authorization": "Bearer"} + # Test with token missing expiry claim + token_without_exp = jwt.encode( + {"sub": "test-subject"}, "secret", algorithm="HS256" + ) + assert federation_provider._get_expiry_from_jwt(token_without_exp) is None + + # ==== Core functionality tests ==== + def test_token_reuse_when_valid(self, federation_provider, future_time): + """Test that a valid token is reused without exchange.""" + # Prepare mock for exchange function + with patch.object(federation_provider, "_exchange_token") as mock_exchange: + # Set up a valid token + federation_provider.current_token = Token( + "existing_token", "Bearer", expiry=future_time ) + federation_provider.external_headers = { + "Authorization": "Bearer external_token" + } - def test_token_reuse( - self, - federation_provider, - mock_exchange_token, - ): - """Test token reuse when token is still valid.""" - # Set up the initial token - future_time = datetime.now(tz=timezone.utc) + timedelta(hours=1) - initial_token = Token("exchanged_token", "Bearer", expiry=future_time) - federation_provider.current_token = initial_token - federation_provider.external_headers = { - "Authorization": "Bearer external_token" - } + # Get headers + headers = federation_provider.get_auth_headers() - # Get headers and verify the token is reused without calling exchange - headers = federation_provider.get_auth_headers() - assert headers["Authorization"] == "Bearer exchanged_token" - # Verify exchange was not called - mock_exchange_token.assert_not_called() - - def test_refresh_token_method( - self, - federation_provider, - mock_parse_jwt_claims, - mock_exchange_token, - mock_is_same_host, - mock_discover_token_endpoint, + # Verify token was reused without exchange + assert headers["Authorization"] == "Bearer existing_token" + mock_exchange.assert_not_called() + + def test_token_exchange_from_different_host( + self, federation_provider, mock_dependencies ): - """Test the refactored refresh_token method for both exchange and non-exchange cases.""" - # CASE 1: Token from different host (needs exchange) - # Set up mocks - mock_parse_jwt_claims.return_value = { + """Test token exchange when token is from a different host.""" + # Configure mocks for token from different host + mock_dependencies["parse_jwt"].return_value = { "iss": "https://login.microsoftonline.com/tenant" } - mock_is_same_host.return_value = False - - # Set up headers that the credentials provider will return - headers = {"Authorization": "Bearer test_token"} - header_factory = MagicMock() - header_factory.return_value = headers + mock_dependencies["is_same_host"].return_value = False - # Configure the credentials provider - mock_creds_provider = MagicMock() - mock_creds_provider.return_value = header_factory - federation_provider.credentials_provider = mock_creds_provider + # Configure credentials provider + headers = {"Authorization": "Bearer external_token"} + header_factory = MagicMock(return_value=headers) + mock_creds = MagicMock(return_value=header_factory) + federation_provider.credentials_provider = mock_creds - # Configure the mock token exchange + # Configure mock token exchange future_time = datetime.now(tz=timezone.utc) + timedelta(hours=1) - mock_exchange_token.return_value = Token( - "exchanged_token", "Bearer", expiry=future_time - ) + exchanged_token = Token("databricks_token", "Bearer", expiry=future_time) + mock_dependencies["exchange"].return_value = exchanged_token - # Call the refresh_token method + # Call refresh_token token = federation_provider.refresh_token() - # Verify the token was exchanged - mock_exchange_token.assert_called_with("test_token") - assert token.access_token == "exchanged_token" - assert token == federation_provider.current_token + # Verify token was exchanged + mock_dependencies["exchange"].assert_called_with("external_token") + assert token.access_token == "databricks_token" + assert federation_provider.current_token == token - # CASE 2: Token from same host (no exchange needed) - mock_is_same_host.return_value = True - mock_exchange_token.reset_mock() + def test_token_from_same_host(self, federation_provider, mock_dependencies): + """Test handling of token from the same host (no exchange needed).""" + # Configure mocks for token from same host + mock_dependencies["parse_jwt"].return_value = {"iss": "https://databricks.com"} + mock_dependencies["is_same_host"].return_value = True - # Mock the JWT expiry extraction + # Configure credentials provider + headers = {"Authorization": "Bearer databricks_token"} + header_factory = MagicMock(return_value=headers) + mock_creds = MagicMock(return_value=header_factory) + federation_provider.credentials_provider = mock_creds + + # Mock JWT expiry extraction expiry_time = datetime.now(tz=timezone.utc) + timedelta(hours=2) - with patch( - "databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._get_expiry_from_jwt", - return_value=expiry_time, + with patch.object( + federation_provider, "_get_expiry_from_jwt", return_value=expiry_time ): - # Call refresh_token again + # Call refresh_token token = federation_provider.refresh_token() # Verify no exchange was performed - mock_exchange_token.assert_not_called() - # Verify token was created directly - assert token.access_token == "test_token" + mock_dependencies["exchange"].assert_not_called() + assert token.access_token == "databricks_token" assert token.expiry == expiry_time - def test_call_method_returns_auth_headers_directly( - self, - federation_provider, - mock_discover_token_endpoint, + def test_call_returns_auth_headers_function( + self, federation_provider, mock_dependencies ): - """Test that __call__ directly returns the get_auth_headers method.""" - # Mock get_auth_headers to verify it's called directly + """Test that __call__ returns the get_auth_headers method directly.""" with patch.object( federation_provider, "get_auth_headers", - return_value={"Authorization": "Bearer test_auth"}, + return_value={"Authorization": "Bearer test_token"}, ) as mock_get_auth: # Get the header factory from __call__ result = federation_provider() - # In our refactored implementation, __call__ returns get_auth_headers directly + # Verify it's the get_auth_headers method assert result is federation_provider.get_auth_headers - # Now call the result and verify it returns what get_auth_headers returns + # Call the result and verify it returns headers headers = result() - assert headers == {"Authorization": "Bearer test_auth"} + assert headers == {"Authorization": "Bearer test_token"} mock_get_auth.assert_called_once() - def test_get_expiry_from_jwt(self, federation_provider): - """Test extracting expiry from JWT token.""" - # Create a JWT token with expiry - expiry_timestamp = int( - (datetime.now(tz=timezone.utc) + timedelta(hours=1)).timestamp() - ) - payload = { - "exp": expiry_timestamp, - "iat": int(datetime.now(tz=timezone.utc).timestamp()), - "sub": "test-subject", - } - - # Create JWT token - token = jwt.encode(payload, "secret", algorithm="HS256") - - # Test the method - expiry = federation_provider._get_expiry_from_jwt(token) - - # Verify the expiry is extracted correctly - assert expiry is not None - assert isinstance(expiry, datetime) - assert expiry.tzinfo is not None # Should be timezone-aware - assert ( - abs( - ( - expiry - datetime.fromtimestamp(expiry_timestamp, tz=timezone.utc) - ).total_seconds() - ) - < 1 - ) # Allow for small rounding differences - - # Test with invalid token - expiry = federation_provider._get_expiry_from_jwt("invalid-token") - assert expiry is None - - # Test with token missing expiry - payload = {"sub": "test-subject"} - token_without_exp = jwt.encode(payload, "secret", algorithm="HS256") - expiry = federation_provider._get_expiry_from_jwt(token_without_exp) - assert expiry is None - - def test_exchange_token( - self, federation_provider, mock_request_post, mock_discover_token_endpoint - ): - """Test the _exchange_token method with success and failure cases.""" - # SUCCESS CASE - # Mock the response data - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "access_token": "new_token", - "token_type": "Bearer", - "refresh_token": "refresh_value", - "expires_in": 3600, - } - mock_request_post.return_value = mock_response - - # Set the token endpoint - federation_provider.token_endpoint = "https://databricks.com/oidc/v1/token" - - # Call the method - token = federation_provider._exchange_token("original_token") - - # Verify the token was created correctly - assert token.access_token == "new_token" - assert token.token_type == "Bearer" - assert token.refresh_token == "refresh_value" - # Expiry should be around 1 hour in the future - assert token.expiry > datetime.now(tz=timezone.utc) - assert token.expiry < datetime.now(tz=timezone.utc) + timedelta(seconds=3601) - - # FAILURE CASE - # Mock the response data for failure - mock_response = MagicMock() - mock_response.status_code = 401 - mock_response.text = "Unauthorized" - mock_request_post.return_value = mock_response - - # Call the method and expect an exception - with pytest.raises( - ValueError, match="Token exchange failed with status code 401" - ): - federation_provider._exchange_token("original_token") + def test_token_exchange_success(self, federation_provider): + """Test successful token exchange.""" + # Mock successful response + with patch("databricks.sql.auth.token_federation.requests.post") as mock_post: + # Configure mock response + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "access_token": "new_token", + "token_type": "Bearer", + "refresh_token": "refresh_value", + "expires_in": 3600, + } + mock_post.return_value = mock_response + + # Patch the _get_expiry_from_jwt method to return None (forcing use of expires_in) + with patch.object( + federation_provider, "_get_expiry_from_jwt", return_value=None + ): + # Call the exchange method + token = federation_provider._exchange_token("original_token") + + # Verify token properties + assert token.access_token == "new_token" + assert token.token_type == "Bearer" + assert token.refresh_token == "refresh_value" + + # Verify expiry time (should be ~1 hour in future) + now = datetime.now(tz=timezone.utc) + assert token.expiry > now + assert token.expiry < now + timedelta(seconds=3601) + + def test_token_exchange_failure(self, federation_provider): + """Test token exchange failure handling.""" + # Mock error response + with patch("databricks.sql.auth.token_federation.requests.post") as mock_post: + mock_response = MagicMock() + mock_response.status_code = 401 + mock_response.text = "Unauthorized" + mock_post.return_value = mock_response + + # Call the method and expect an exception + with pytest.raises( + ValueError, match="Token exchange failed with status code 401" + ): + federation_provider._exchange_token("original_token") From 7ab406805bfdfaa42d9b5361237238cfb4b77e73 Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Mon, 12 May 2025 06:04:49 +0000 Subject: [PATCH 38/46] Refactor token exchange parameters to be instance-specific in DatabricksTokenFederationProvider --- src/databricks/sql/auth/token_federation.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/databricks/sql/auth/token_federation.py b/src/databricks/sql/auth/token_federation.py index ebce7d546..00f037add 100644 --- a/src/databricks/sql/auth/token_federation.py +++ b/src/databricks/sql/auth/token_federation.py @@ -15,14 +15,6 @@ logger = logging.getLogger(__name__) -# Token exchange constants -TOKEN_EXCHANGE_PARAMS = { - "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", - "scope": "sql", - "subject_token_type": "urn:ietf:params:oauth:token-type:jwt", - "return_original_token_if_authenticated": "true", -} - class DatabricksTokenFederationProvider(CredentialsProvider): """ @@ -40,6 +32,14 @@ class DatabricksTokenFederationProvider(CredentialsProvider): "Content-Type": "application/x-www-form-urlencoded", } + # Token exchange parameters + TOKEN_EXCHANGE_PARAMS = { + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "scope": "sql", + "subject_token_type": "urn:ietf:params:oauth:token-type:jwt", + "return_original_token_if_authenticated": "true", + } + def __init__( self, credentials_provider: CredentialsProvider, @@ -317,7 +317,7 @@ def _exchange_token(self, access_token: str) -> Token: ValueError: If token exchange fails """ # Prepare the request data - token_exchange_data = dict(TOKEN_EXCHANGE_PARAMS) + token_exchange_data = dict(self.TOKEN_EXCHANGE_PARAMS) token_exchange_data["subject_token"] = access_token # Add client_id if provided From 9fc4c0c3637e99a33aabd415d4d113d7d98c73ed Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Mon, 12 May 2025 06:09:52 +0000 Subject: [PATCH 39/46] Refactor token expiry handling in DatabricksTokenFederationProvider and enhance unit tests for accurate expiry verification --- src/databricks/sql/auth/token_federation.py | 36 ++------------------- tests/unit/test_token_federation.py | 21 ++++++++---- 2 files changed, 16 insertions(+), 41 deletions(-) diff --git a/src/databricks/sql/auth/token_federation.py b/src/databricks/sql/auth/token_federation.py index 00f037add..7c2ed9b29 100644 --- a/src/databricks/sql/auth/token_federation.py +++ b/src/databricks/sql/auth/token_federation.py @@ -335,45 +335,13 @@ def _exchange_token(self, access_token: str) -> Token: token_type = resp_data.get("token_type", "Bearer") refresh_token = resp_data.get("refresh_token", "") - # Determine token expiry - first try from JWT claims + # Extract expiry from JWT claims expiry = self._get_expiry_from_jwt(new_access_token) - - # If JWT expiry not available, use expires_in from response if expiry is None: - expiry = self._get_expiry_from_response(resp_data) - - # If we still don't have an expiry, we can't proceed - if expiry is None: - raise ValueError( - "Unable to determine token expiry from response or JWT claims" - ) + raise ValueError("Unable to determine token expiry from JWT claims") return Token(new_access_token, token_type, refresh_token, expiry) - def _get_expiry_from_response( - self, resp_data: Dict[str, Any] - ) -> Optional[datetime]: - """ - Extract expiry datetime from response data. - - Args: - resp_data: Response data from token exchange - - Returns: - Optional[datetime]: Expiry datetime if found in response, None otherwise - """ - if "expires_in" not in resp_data or not resp_data["expires_in"]: - return None - - try: - expires_in = int(resp_data["expires_in"]) - expiry = datetime.now(tz=timezone.utc) + timedelta(seconds=expires_in) - logger.debug(f"Using expiry from expires_in: {expiry}") - return expiry - except (ValueError, TypeError) as e: - logger.warning(f"Invalid expires_in value: {str(e)}") - return None - class SimpleCredentialsProvider(CredentialsProvider): """A simple credentials provider that returns a fixed token.""" diff --git a/tests/unit/test_token_federation.py b/tests/unit/test_token_federation.py index 53656b218..e4344fd51 100644 --- a/tests/unit/test_token_federation.py +++ b/tests/unit/test_token_federation.py @@ -344,6 +344,11 @@ def test_token_exchange_success(self, federation_provider): """Test successful token exchange.""" # Mock successful response with patch("databricks.sql.auth.token_federation.requests.post") as mock_post: + # Create a token with a valid expiry + expiry_timestamp = int( + (datetime.now(tz=timezone.utc) + timedelta(hours=1)).timestamp() + ) + # Configure mock response mock_response = MagicMock() mock_response.status_code = 200 @@ -351,13 +356,14 @@ def test_token_exchange_success(self, federation_provider): "access_token": "new_token", "token_type": "Bearer", "refresh_token": "refresh_value", - "expires_in": 3600, } mock_post.return_value = mock_response - # Patch the _get_expiry_from_jwt method to return None (forcing use of expires_in) + # Mock JWT expiry extraction to return a valid expiry with patch.object( - federation_provider, "_get_expiry_from_jwt", return_value=None + federation_provider, + "_get_expiry_from_jwt", + return_value=datetime.fromtimestamp(expiry_timestamp, tz=timezone.utc), ): # Call the exchange method token = federation_provider._exchange_token("original_token") @@ -367,10 +373,11 @@ def test_token_exchange_success(self, federation_provider): assert token.token_type == "Bearer" assert token.refresh_token == "refresh_value" - # Verify expiry time (should be ~1 hour in future) - now = datetime.now(tz=timezone.utc) - assert token.expiry > now - assert token.expiry < now + timedelta(seconds=3601) + # Verify expiry time is correctly set + expiry_datetime = datetime.fromtimestamp( + expiry_timestamp, tz=timezone.utc + ) + assert token.expiry == expiry_datetime def test_token_exchange_failure(self, federation_provider): """Test token exchange failure handling.""" From 85d0cd9b4b8f88d49f6be5b414dbb096ce8add21 Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Wed, 21 May 2025 11:53:51 +0000 Subject: [PATCH 40/46] addresses comments --- src/databricks/sql/auth/oidc_utils.py | 17 ++++ src/databricks/sql/auth/token_federation.py | 98 ++++++++------------- tests/unit/test_token_federation.py | 12 ++- 3 files changed, 60 insertions(+), 67 deletions(-) diff --git a/src/databricks/sql/auth/oidc_utils.py b/src/databricks/sql/auth/oidc_utils.py index b0421cf7f..74c37591d 100644 --- a/src/databricks/sql/auth/oidc_utils.py +++ b/src/databricks/sql/auth/oidc_utils.py @@ -1,6 +1,7 @@ import logging import requests from typing import Optional +from urllib.parse import urlparse from databricks.sql.auth.endpoint import ( get_oauth_endpoints, @@ -56,3 +57,19 @@ def format_hostname(hostname: str) -> str: if not hostname.endswith("/"): hostname = f"{hostname}/" return hostname + + +def is_same_host(url1: str, url2: str) -> bool: + """ + Check if two URLs have the same host. + """ + try: + if not url1.startswith(("http://", "https://")): + url1 = f"https://{url1}" + if not url2.startswith(("http://", "https://")): + url2 = f"https://{url2}" + parsed1 = urlparse(url1) + parsed2 = urlparse(url2) + return parsed1.netloc.lower() == parsed2.netloc.lower() + except Exception: + return False diff --git a/src/databricks/sql/auth/token_federation.py b/src/databricks/sql/auth/token_federation.py index 7c2ed9b29..458954755 100644 --- a/src/databricks/sql/auth/token_federation.py +++ b/src/databricks/sql/auth/token_federation.py @@ -10,7 +10,7 @@ from requests.exceptions import RequestException from databricks.sql.auth.authenticators import CredentialsProvider, HeaderFactory -from databricks.sql.auth.oidc_utils import OIDCDiscoveryUtil +from databricks.sql.auth.oidc_utils import OIDCDiscoveryUtil, is_same_host from databricks.sql.auth.token import Token logger = logging.getLogger(__name__) @@ -79,15 +79,6 @@ def __call__(self, *args, **kwargs) -> HeaderFactory: Configure and return a HeaderFactory that provides authentication headers. This is called by the ExternalAuthProvider to get headers for authentication. """ - # First call the underlying credentials provider to get its headers - header_factory = self.credentials_provider(*args, **kwargs) - - # Get the standard token endpoint if not already set - if self.token_endpoint is None: - self.token_endpoint = OIDCDiscoveryUtil.discover_token_endpoint( - self.hostname - ) - # Return a function that will get authentication headers return self.get_auth_headers @@ -156,34 +147,6 @@ def _get_expiry_from_jwt(self, token: str) -> Optional[datetime]: return None - def _is_same_host(self, url1: str, url2: str) -> bool: - """ - Check if two URLs have the same host. - - Args: - url1: First URL - url2: Second URL - - Returns: - bool: True if hosts are the same, False otherwise - """ - try: - # Add protocol if missing to ensure proper parsing - if not url1.startswith(("http://", "https://")): - url1 = f"https://{url1}" - if not url2.startswith(("http://", "https://")): - url2 = f"https://{url2}" - - # Parse the URLs - parsed1 = urlparse(url1) - parsed2 = urlparse(url2) - - # Compare the hostnames - return parsed1.netloc.lower() == parsed2.netloc.lower() - except Exception as e: - logger.warning(f"Error comparing hosts: {str(e)}") - return False - def refresh_token(self) -> Token: """ Refresh the token and return the new Token object. @@ -210,24 +173,34 @@ def refresh_token(self) -> Token: token_claims = self._parse_jwt_claims(access_token) # Create new token based on whether it's from the same host or not - if self._is_same_host(token_claims.get("iss", ""), self.hostname): + if is_same_host(token_claims.get("iss", ""), self.hostname): # Token is from the same host, no need to exchange logger.debug("Token from same host, creating token without exchange") - expiry = self._get_expiry_from_jwt(access_token) if expiry is None: raise ValueError("Could not determine token expiry from JWT") - new_token = Token(access_token, token_type, "", expiry) + self.current_token = new_token + return new_token else: # Token is from a different host, need to exchange logger.debug("Token from different host, exchanging token") - new_token = self._exchange_token(access_token) - - # Store the token - self.current_token = new_token - - return new_token + try: + new_token = self._exchange_token(access_token) + self.current_token = new_token + return new_token + except Exception as e: + logger.error( + f"Token exchange failed: {e}. Using external token as fallback." + ) + expiry = self._get_expiry_from_jwt(access_token) + if expiry is None: + raise ValueError( + "Could not determine token expiry from JWT (after exchange failure)" + ) + fallback_token = Token(access_token, token_type, "", expiry) + self.current_token = fallback_token + return fallback_token def get_current_token(self) -> Token: """ @@ -254,24 +227,19 @@ def get_auth_headers(self) -> Dict[str, str]: """ Get authorization headers using the current token. - This method gets the current token and returns it formatted - as authorization headers. - Returns: - Dict[str, str]: Authorization headers + Dict[str, str]: Authorization headers (may include extra headers from provider) """ try: token = self.get_current_token() - return {"Authorization": f"{token.token_type} {token.access_token}"} + # Always get the latest headers from the credentials provider + header_factory = self.credentials_provider() + headers = dict(header_factory()) if header_factory else {} + headers["Authorization"] = f"{token.token_type} {token.access_token}" + return headers except Exception as e: logger.error(f"Error getting auth headers: {str(e)}") - - # Fall back to external headers if available - if self.external_headers: - return self.external_headers - - # Return empty dict as a last resort - return {} + return dict(self.external_headers) if self.external_headers else {} def _send_token_exchange_request( self, token_exchange_data: Dict[str, str] @@ -286,7 +254,7 @@ def _send_token_exchange_request( Dict[str, Any]: Token exchange response Raises: - ValueError: If token exchange fails + requests.HTTPError: If token exchange fails """ if not self.token_endpoint: raise ValueError("Token endpoint not initialized") @@ -296,9 +264,9 @@ def _send_token_exchange_request( ) if response.status_code != 200: - raise ValueError( - f"Token exchange failed with status code {response.status_code}: " - f"{response.text}" + raise requests.HTTPError( + f"Token exchange failed with status code {response.status_code}: {response.text}", + response=response, ) return response.json() @@ -316,6 +284,10 @@ def _exchange_token(self, access_token: str) -> Token: Raises: ValueError: If token exchange fails """ + if self.token_endpoint is None: + self.token_endpoint = OIDCDiscoveryUtil.discover_token_endpoint( + self.hostname + ) # Prepare the request data token_exchange_data = dict(self.TOKEN_EXCHANGE_PARAMS) token_exchange_data["subject_token"] = access_token diff --git a/tests/unit/test_token_federation.py b/tests/unit/test_token_federation.py index e4344fd51..2bb57645c 100644 --- a/tests/unit/test_token_federation.py +++ b/tests/unit/test_token_federation.py @@ -145,7 +145,7 @@ def mock_dependencies(self): "databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._exchange_token" ) as mock_exchange: with patch( - "databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._is_same_host" + "databricks.sql.auth.oidc_utils.is_same_host" ) as mock_is_same_host: with patch( "databricks.sql.auth.token_federation.requests.post" @@ -179,9 +179,11 @@ def test_provider_initialization(self, federation_provider): ("databricks.com", "https://databricks.com", True), ], ) - def test_is_same_host(self, federation_provider, url1, url2, expected): + def test_is_same_host(self, url1, url2, expected): """Test host comparison logic with various URL formats.""" - assert federation_provider._is_same_host(url1, url2) is expected + from databricks.sql.auth.oidc_utils import is_same_host + + assert is_same_host(url1, url2) is expected @pytest.mark.parametrize( "headers,expected_result,should_raise", @@ -389,7 +391,9 @@ def test_token_exchange_failure(self, federation_provider): mock_post.return_value = mock_response # Call the method and expect an exception + import requests + with pytest.raises( - ValueError, match="Token exchange failed with status code 401" + requests.HTTPError, match="Token exchange failed with status code 401" ): federation_provider._exchange_token("original_token") From 504056940f8318765a9e4a31e4192d8c4bc341ab Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Thu, 22 May 2025 06:02:48 +0000 Subject: [PATCH 41/46] initial commit --- src/databricks/sql/auth/auth.py | 58 ++++++++++----------------------- 1 file changed, 18 insertions(+), 40 deletions(-) diff --git a/src/databricks/sql/auth/auth.py b/src/databricks/sql/auth/auth.py index 3931356d0..348d3b698 100755 --- a/src/databricks/sql/auth/auth.py +++ b/src/databricks/sql/auth/auth.py @@ -35,6 +35,7 @@ def __init__( oauth_persistence=None, credentials_provider=None, identity_federation_client_id: Optional[str] = None, + use_token_federation: bool = False, ): self.hostname = hostname self.access_token = access_token @@ -47,6 +48,7 @@ def __init__( self.oauth_persistence = oauth_persistence self.credentials_provider = credentials_provider self.identity_federation_client_id = identity_federation_client_id + self.use_token_federation = use_token_federation def get_auth_provider(cfg: ClientContext): @@ -71,45 +73,16 @@ def get_auth_provider(cfg: ClientContext): Raises: RuntimeError: If no valid authentication settings are provided """ - # If credentials_provider is explicitly provided + from databricks.sql.auth.token_federation import DatabricksTokenFederationProvider if cfg.credentials_provider: - # If token federation is enabled and credentials provider is provided, - # wrap the credentials provider with DatabricksTokenFederationProvider - if cfg.auth_type == AuthType.TOKEN_FEDERATION.value: - from databricks.sql.auth.token_federation import ( - DatabricksTokenFederationProvider, - ) - - federation_provider = DatabricksTokenFederationProvider( - cfg.credentials_provider, - cfg.hostname, - cfg.identity_federation_client_id, - ) - return ExternalAuthProvider(federation_provider) - - # If not token federation, just use the credentials provider directly - return ExternalAuthProvider(cfg.credentials_provider) - - # If we don't have a credentials provider but have token federation auth type with access token - if cfg.auth_type == AuthType.TOKEN_FEDERATION.value and cfg.access_token: - # Create a simple credentials provider and wrap it with token federation provider - from databricks.sql.auth.token_federation import ( - DatabricksTokenFederationProvider, - SimpleCredentialsProvider, - ) - - simple_provider = SimpleCredentialsProvider(cfg.access_token) - federation_provider = DatabricksTokenFederationProvider( - simple_provider, cfg.hostname, cfg.identity_federation_client_id - ) - return ExternalAuthProvider(federation_provider) - - if cfg.auth_type in [AuthType.DATABRICKS_OAUTH.value, AuthType.AZURE_OAUTH.value]: + base_provider = ExternalAuthProvider(cfg.credentials_provider) + elif cfg.access_token is not None: + base_provider = AccessTokenAuthProvider(cfg.access_token) + elif cfg.auth_type in [AuthType.DATABRICKS_OAUTH.value, AuthType.AZURE_OAUTH.value]: assert cfg.oauth_redirect_port_range is not None assert cfg.oauth_client_id is not None assert cfg.oauth_scopes is not None - - return DatabricksOAuthProvider( + base_provider = DatabricksOAuthProvider( cfg.hostname, cfg.oauth_persistence, cfg.oauth_redirect_port_range, @@ -117,18 +90,15 @@ def get_auth_provider(cfg: ClientContext): cfg.oauth_scopes, cfg.auth_type, ) - elif cfg.access_token is not None: - return AccessTokenAuthProvider(cfg.access_token) elif cfg.use_cert_as_auth and cfg.tls_client_cert_file: - # no op authenticator. authentication is performed using ssl certificate outside of headers - return AuthProvider() + base_provider = AuthProvider() else: if ( cfg.oauth_redirect_port_range is not None and cfg.oauth_client_id is not None and cfg.oauth_scopes is not None ): - return DatabricksOAuthProvider( + base_provider = DatabricksOAuthProvider( cfg.hostname, cfg.oauth_persistence, cfg.oauth_redirect_port_range, @@ -138,6 +108,13 @@ def get_auth_provider(cfg: ClientContext): else: raise RuntimeError("No valid authentication settings!") + if getattr(cfg, "use_token_federation", False): + base_provider = DatabricksTokenFederationProvider( + base_provider, cfg.hostname, cfg.identity_federation_client_id + ) + + return base_provider + PYSQL_OAUTH_SCOPES = ["sql", "offline_access"] PYSQL_OAUTH_CLIENT_ID = "databricks-sql-python" @@ -206,5 +183,6 @@ def get_python_sql_connector_auth_provider(hostname: str, **kwargs): oauth_persistence=kwargs.get("experimental_oauth_persistence"), credentials_provider=kwargs.get("credentials_provider"), identity_federation_client_id=kwargs.get("identity_federation_client_id"), + use_token_federation=kwargs.get("use_token_federation", False), ) return get_auth_provider(cfg) From 22a46817514ba08ef28e6706d2aaf8aaf1052817 Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Thu, 22 May 2025 06:15:03 +0000 Subject: [PATCH 42/46] change github test to adapt --- tests/token_federation/github_oidc_test.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/token_federation/github_oidc_test.py b/tests/token_federation/github_oidc_test.py index 10bd86868..7202f616d 100755 --- a/tests/token_federation/github_oidc_test.py +++ b/tests/token_federation/github_oidc_test.py @@ -105,10 +105,9 @@ def test_databricks_connection( "server_hostname": host, "http_path": http_path, "access_token": github_token, - "auth_type": "token-federation", + "use_token_federation": True, } - # Add identity federation client ID if provided if identity_federation_client_id: connection_params[ "identity_federation_client_id" From f1346b0a383b33c9cd390f3e36e293e0002bcb46 Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Thu, 22 May 2025 06:18:12 +0000 Subject: [PATCH 43/46] implement add headers to tf provider --- src/databricks/sql/auth/token_federation.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/databricks/sql/auth/token_federation.py b/src/databricks/sql/auth/token_federation.py index 458954755..d40ab62ed 100644 --- a/src/databricks/sql/auth/token_federation.py +++ b/src/databricks/sql/auth/token_federation.py @@ -314,6 +314,14 @@ def _exchange_token(self, access_token: str) -> Token: return Token(new_access_token, token_type, refresh_token, expiry) + def add_headers(self, request_headers: Dict[str, str]): + """ + Add authentication headers to the request. + """ + headers = self.get_auth_headers() + for k, v in headers.items(): + request_headers[k] = v + class SimpleCredentialsProvider(CredentialsProvider): """A simple credentials provider that returns a fixed token.""" From 4c5bce1a2185256e0227ce33de26835b9ed4078f Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Thu, 22 May 2025 06:26:40 +0000 Subject: [PATCH 44/46] Enhance authentication providers by implementing CredentialsProvider interface, adding auth_type and __call__ methods for AccessTokenAuthProvider, DatabricksOAuthProvider, and ExternalAuthProvider. --- src/databricks/sql/auth/authenticators.py | 29 ++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/src/databricks/sql/auth/authenticators.py b/src/databricks/sql/auth/authenticators.py index c425f0888..1baf1f8c9 100644 --- a/src/databricks/sql/auth/authenticators.py +++ b/src/databricks/sql/auth/authenticators.py @@ -41,17 +41,25 @@ def __call__(self, *args, **kwargs) -> HeaderFactory: # Private API: this is an evolving interface and it will change in the future. # Please must not depend on it in your applications. -class AccessTokenAuthProvider(AuthProvider): +class AccessTokenAuthProvider(AuthProvider, CredentialsProvider): def __init__(self, access_token: str): self.__authorization_header_value = "Bearer {}".format(access_token) def add_headers(self, request_headers: Dict[str, str]): request_headers["Authorization"] = self.__authorization_header_value + def auth_type(self) -> str: + return "access-token" + + def __call__(self, *args, **kwargs) -> HeaderFactory: + def get_headers(): + return {"Authorization": self.__authorization_header_value} + return get_headers + # Private API: this is an evolving interface and it will change in the future. # Please must not depend on it in your applications. -class DatabricksOAuthProvider(AuthProvider): +class DatabricksOAuthProvider(AuthProvider, CredentialsProvider): SCOPE_DELIM = " " def __init__( @@ -93,6 +101,15 @@ def add_headers(self, request_headers: Dict[str, str]): self._update_token_if_expired() request_headers["Authorization"] = f"Bearer {self._access_token}" + def auth_type(self) -> str: + return "databricks-oauth" + + def __call__(self, *args, **kwargs) -> HeaderFactory: + def get_headers(): + self._update_token_if_expired() + return {"Authorization": f"Bearer {self._access_token}"} + return get_headers + def _initial_get_token(self): try: if self._access_token is None or self._refresh_token is None: @@ -144,7 +161,7 @@ def _update_token_if_expired(self): raise e -class ExternalAuthProvider(AuthProvider): +class ExternalAuthProvider(AuthProvider, CredentialsProvider): def __init__(self, credentials_provider: CredentialsProvider) -> None: self._header_factory = credentials_provider() @@ -152,3 +169,9 @@ def add_headers(self, request_headers: Dict[str, str]): headers = self._header_factory() for k, v in headers.items(): request_headers[k] = v + + def auth_type(self) -> str: + return "external-auth" + + def __call__(self, *args, **kwargs) -> HeaderFactory: + return self._header_factory From bafef75008f5117c45649db0d1eab956ba02e8d9 Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Wed, 28 May 2025 08:20:35 +0000 Subject: [PATCH 45/46] Add Databricks SQL Token Federation examples and enhance authentication with ClientCredentialsProvider - Introduced a new script for demonstrating various token federation flows in Databricks SQL. - Implemented ClientCredentialsProvider for machine-to-machine authentication, supporting Azure and Databricks service principal flows. - Refactored token federation handling to allow integration with existing authentication methods. - Updated the DatabricksTokenFederationProvider to improve token exchange logic and error handling. --- examples/token_federation_examples.py | 109 +++++++++++++++ src/databricks/sql/auth/auth.py | 115 +++++++++++----- src/databricks/sql/auth/authenticators.py | 141 ++++++++++++++++---- src/databricks/sql/auth/token_federation.py | 97 +++++--------- tests/unit/test_token_federation.py | 18 +-- 5 files changed, 341 insertions(+), 139 deletions(-) create mode 100644 examples/token_federation_examples.py diff --git a/examples/token_federation_examples.py b/examples/token_federation_examples.py new file mode 100644 index 000000000..44cc4a4ff --- /dev/null +++ b/examples/token_federation_examples.py @@ -0,0 +1,109 @@ +""" +Databricks SQL Token Federation Examples + +This script token federation flows: +1. U2M + Account-wide federation +2. U2M + Workflow-level federation +3. M2M + Account-wide federation +4. M2M + Workflow-level federation +5. Access Token + Workflow-level federation +6. Access Token + Account-wide federation + +Token Federation Documentation: +------------------------------ +For detailed setup instructions, refer to the official Databricks documentation: + +- General Token Federation Overview: + https://docs.databricks.com/aws/en/dev-tools/auth/oauth-federation.html + +- Token Exchange Process: + https://docs.databricks.com/aws/en/dev-tools/auth/oauth-federation-howto.html + +- Azure OAuth Token Federation: + https://learn.microsoft.com/en-us/azure/databricks/dev-tools/auth/oauth-federation + +Environment variables required: +- DATABRICKS_HOST: Databricks workspace hostname +- DATABRICKS_HTTP_PATH: HTTP path for the SQL warehouse +- AZURE_TENANT_ID: Azure tenant ID +- AZURE_CLIENT_ID: Azure client ID for service principal +- AZURE_CLIENT_SECRET: Azure client secret +- DATABRICKS_SERVICE_PRINCIPAL_ID: Databricks service principal ID for workflow federation +""" + +import os +from databricks import sql + +def run_query(connection, description): + cursor = connection.cursor() + cursor.execute("SELECT 1+1 AS result") + result = cursor.fetchall() + print(f"Query result: {result[0][0]}") + + cursor.close() + +def demonstrate_m2m_federation(env_vars, use_workflow_federation=False): + """Demonstrate M2M (service principal) token federation""" + + connection_params = { + "server_hostname": env_vars["DATABRICKS_HOST"], + "http_path": env_vars["DATABRICKS_HTTP_PATH"], + "auth_type": "client-credentials", + "oauth_client_id": env_vars["AZURE_CLIENT_ID"], + "client_secret": env_vars["AZURE_CLIENT_SECRET"], + "tenant_id": env_vars["AZURE_TENANT_ID"], + "use_token_federation": True + } + + if use_workflow_federation and env_vars["DATABRICKS_SERVICE_PRINCIPAL_ID"]: + connection_params["identity_federation_client_id"] = env_vars["DATABRICKS_SERVICE_PRINCIPAL_ID"] + description = "M2M + Workflow-level Federation" + else: + description = "M2M + Account-wide Federation" + + with sql.connect(**connection_params) as connection: + run_query(connection, description) + + +def demonstrate_u2m_federation(env_vars, use_workflow_federation=False): + """Demonstrate U2M (interactive) token federation""" + + connection_params = { + "server_hostname": env_vars["DATABRICKS_HOST"], + "http_path": env_vars["DATABRICKS_HTTP_PATH"], + "auth_type": "databricks-oauth", # Will open browser for interactive auth + "use_token_federation": True + } + + if use_workflow_federation and env_vars["DATABRICKS_SERVICE_PRINCIPAL_ID"]: + connection_params["identity_federation_client_id"] = env_vars["DATABRICKS_SERVICE_PRINCIPAL_ID"] + description = "U2M + Workflow-level Federation (Interactive)" + else: + description = "U2M + Account-wide Federation (Interactive)" + + # This will open a browser for interactive auth + with sql.connect(**connection_params) as connection: + run_query(connection, description) + +def demonstrate_access_token_federation(env_vars): + """Demonstrate access token token federation""" + + access_token = os.environ.get("ACCESS_TOKEN") # This is to demonstrate a token obtained from an identity provider + + connection_params = { + "server_hostname": env_vars["DATABRICKS_HOST"], + "http_path": env_vars["DATABRICKS_HTTP_PATH"], + "access_token": access_token, + "use_token_federation": True + } + + # Add workflow federation if available + if env_vars["DATABRICKS_SERVICE_PRINCIPAL_ID"]: + connection_params["identity_federation_client_id"] = env_vars["DATABRICKS_SERVICE_PRINCIPAL_ID"] + description = "Access Token + Workflow-level Federation" + else: + description = "Access Token + Account-wide Federation" + + with sql.connect(**connection_params) as connection: + run_query(connection, description) + diff --git a/src/databricks/sql/auth/auth.py b/src/databricks/sql/auth/auth.py index 348d3b698..d600f8f59 100755 --- a/src/databricks/sql/auth/auth.py +++ b/src/databricks/sql/auth/auth.py @@ -5,21 +5,15 @@ AuthProvider, AccessTokenAuthProvider, ExternalAuthProvider, - CredentialsProvider, DatabricksOAuthProvider, + ClientCredentialsProvider, ) class AuthType(Enum): DATABRICKS_OAUTH = "databricks-oauth" AZURE_OAUTH = "azure-oauth" - # TODO: Token federation should be a feature that works with different auth types, - # not an auth type itself. This will be refactored in a future change. - # We will add a use_token_federation flag that can be used with any auth type. - TOKEN_FEDERATION = "token-federation" - # other supported types (access_token) can be inferred - # we can add more types as needed later - + CLIENT_CREDENTIALS = "client-credentials" class ClientContext: def __init__( @@ -34,8 +28,10 @@ def __init__( tls_client_cert_file: Optional[str] = None, oauth_persistence=None, credentials_provider=None, - identity_federation_client_id: Optional[str] = None, + oauth_client_secret: Optional[str] = None, + tenant_id: Optional[str] = None, use_token_federation: bool = False, + identity_federation_client_id: Optional[str] = None, ): self.hostname = hostname self.access_token = access_token @@ -49,20 +45,52 @@ def __init__( self.credentials_provider = credentials_provider self.identity_federation_client_id = identity_federation_client_id self.use_token_federation = use_token_federation + self.oauth_client_secret = oauth_client_secret + self.tenant_id = tenant_id + +def _create_azure_client_credentials_provider(cfg: ClientContext) -> ClientCredentialsProvider: + """Create an Azure client credentials provider.""" + if not cfg.oauth_client_id or not cfg.oauth_client_secret or not cfg.tenant_id: + raise ValueError("Azure client credentials flow requires oauth_client_id, oauth_client_secret, and tenant_id") + + token_endpoint = "https://login.microsoftonline.com/{}/oauth2/v2.0/token".format(cfg.tenant_id) + return ClientCredentialsProvider( + client_id=cfg.oauth_client_id, + client_secret=cfg.oauth_client_secret, + token_endpoint=token_endpoint, + auth_type_value="azure-client-credentials" + ) + + +def _create_databricks_client_credentials_provider(cfg: ClientContext) -> ClientCredentialsProvider: + """Create a Databricks client credentials provider for service principals.""" + if not cfg.oauth_client_id or not cfg.oauth_client_secret: + raise ValueError("Databricks client credentials flow requires oauth_client_id and oauth_client_secret") + + token_endpoint = "{}oidc/v1/token".format(cfg.hostname) + return ClientCredentialsProvider( + client_id=cfg.oauth_client_id, + client_secret=cfg.oauth_client_secret, + token_endpoint=token_endpoint, + auth_type_value="client-credentials" + ) def get_auth_provider(cfg: ClientContext): """ Get an appropriate auth provider based on the provided configuration. + OAuth Flow Support: + This function supports multiple OAuth flows: + 1. Interactive OAuth (databricks-oauth, azure-oauth) - for user authentication + 2. Client Credentials (client-credentials) - for machine-to-machine authentication + 3. Token Federation - implemented as a feature flag that wraps any auth type + Token Federation Support: ----------------------- - Currently, token federation is implemented as a separate auth type, but the goal is to - refactor it as a feature that can work with any auth type. The current implementation - is maintained for backward compatibility while the refactoring is planned. - - Future refactoring will introduce a `use_token_federation` flag that can be combined - with any auth type to enable token federation. + Token federation is implemented as a feature flag (`use_token_federation=True`) that + can be combined with any auth type. When enabled, it wraps the base auth provider + in a DatabricksTokenFederationProvider for token exchange functionality. Args: cfg: The client context containing configuration parameters @@ -74,21 +102,31 @@ def get_auth_provider(cfg: ClientContext): RuntimeError: If no valid authentication settings are provided """ from databricks.sql.auth.token_federation import DatabricksTokenFederationProvider + + base_provider = None + if cfg.credentials_provider: base_provider = ExternalAuthProvider(cfg.credentials_provider) elif cfg.access_token is not None: base_provider = AccessTokenAuthProvider(cfg.access_token) + elif cfg.auth_type == AuthType.CLIENT_CREDENTIALS.value: + if cfg.tenant_id: + # Azure client credentials flow + base_provider = _create_azure_client_credentials_provider(cfg) + else: + # Databricks service principal client credentials flow + base_provider = _create_databricks_client_credentials_provider(cfg) elif cfg.auth_type in [AuthType.DATABRICKS_OAUTH.value, AuthType.AZURE_OAUTH.value]: assert cfg.oauth_redirect_port_range is not None assert cfg.oauth_client_id is not None assert cfg.oauth_scopes is not None base_provider = DatabricksOAuthProvider( - cfg.hostname, - cfg.oauth_persistence, - cfg.oauth_redirect_port_range, - cfg.oauth_client_id, - cfg.oauth_scopes, - cfg.auth_type, + hostname=cfg.hostname, + oauth_persistence=cfg.oauth_persistence, + redirect_port_range=cfg.oauth_redirect_port_range, + client_id=cfg.oauth_client_id, + scopes=cfg.oauth_scopes, + auth_type=cfg.auth_type, ) elif cfg.use_cert_as_auth and cfg.tls_client_cert_file: base_provider = AuthProvider() @@ -99,11 +137,11 @@ def get_auth_provider(cfg: ClientContext): and cfg.oauth_scopes is not None ): base_provider = DatabricksOAuthProvider( - cfg.hostname, - cfg.oauth_persistence, - cfg.oauth_redirect_port_range, - cfg.oauth_client_id, - cfg.oauth_scopes, + hostname=cfg.hostname, + oauth_persistence=cfg.oauth_persistence, + redirect_port_range=cfg.oauth_redirect_port_range, + client_id=cfg.oauth_client_id, + scopes=cfg.oauth_scopes, ) else: raise RuntimeError("No valid authentication settings!") @@ -126,7 +164,7 @@ def get_auth_provider(cfg: ClientContext): def normalize_host_name(hostname: str): maybe_scheme = "https://" if not hostname.startswith("https://") else "" maybe_trailing_slash = "/" if not hostname.endswith("/") else "" - return f"{maybe_scheme}{hostname}{maybe_trailing_slash}" + return "{}{}{}".format(maybe_scheme, hostname, maybe_trailing_slash) def get_client_id_and_redirect_port(use_azure_auth: bool): @@ -144,14 +182,25 @@ def get_python_sql_connector_auth_provider(hostname: str, **kwargs): This function is the main entry point for authentication in the SQL connector. It processes the parameters and creates an appropriate auth provider. - TODO: Future refactoring needed: - 1. Add a use_token_federation flag that can be combined with any auth type - 2. Remove TOKEN_FEDERATION as an auth_type while maintaining backward compatibility - 3. Create a token federation wrapper that can wrap any existing auth provider + Supported Authentication Methods: + -------------------------------- + 1. Access Token: Provide 'access_token' parameter + 2. Interactive OAuth: Set 'auth_type' to 'databricks-oauth' or 'azure-oauth' + 3. Client Credentials: Set 'auth_type' to 'client-credentials' with client_id, client_secret, tenant_id + 4. External Provider: Provide 'credentials_provider' parameter + 5. Token Federation: Set 'use_token_federation=True' with any of the above Args: hostname: The Databricks server hostname - **kwargs: Additional configuration parameters + **kwargs: Additional configuration parameters including: + - auth_type: Authentication type + - access_token: Static access token + - oauth_client_id: OAuth client ID + - oauth_client_secret: OAuth client secret + - tenant_id: Azure AD tenant ID (for Azure flows) + - credentials_provider: External credentials provider + - use_token_federation: Enable token federation + - identity_federation_client_id: Federation client ID Returns: An appropriate AuthProvider instance @@ -182,6 +231,8 @@ def get_python_sql_connector_auth_provider(hostname: str, **kwargs): else redirect_port_range, oauth_persistence=kwargs.get("experimental_oauth_persistence"), credentials_provider=kwargs.get("credentials_provider"), + oauth_client_secret=kwargs.get("oauth_client_secret"), + tenant_id=kwargs.get("tenant_id"), identity_federation_client_id=kwargs.get("identity_federation_client_id"), use_token_federation=kwargs.get("use_token_federation", False), ) diff --git a/src/databricks/sql/auth/authenticators.py b/src/databricks/sql/auth/authenticators.py index 1baf1f8c9..5f0d4ea01 100644 --- a/src/databricks/sql/auth/authenticators.py +++ b/src/databricks/sql/auth/authenticators.py @@ -1,7 +1,9 @@ import abc import base64 import logging -from typing import Callable, Dict, List +import time +from typing import Callable, Dict, List, Optional +import requests from databricks.sql.auth.oauth import OAuthManager from databricks.sql.auth.endpoint import get_oauth_endpoints, infer_cloud_from_host @@ -9,7 +11,7 @@ # Private API: this is an evolving interface and it will change in the future. # Please must not depend on it in your applications. from databricks.sql.experimental.oauth_persistence import OAuthToken, OAuthPersistence - +from databricks.sql.auth.endpoint import AzureOAuthEndpointCollection, InHouseOAuthEndpointCollection class AuthProvider: def add_headers(self, request_headers: Dict[str, str]): @@ -56,7 +58,6 @@ def get_headers(): return {"Authorization": self.__authorization_header_value} return get_headers - # Private API: this is an evolving interface and it will change in the future. # Please must not depend on it in your applications. class DatabricksOAuthProvider(AuthProvider, CredentialsProvider): @@ -71,43 +72,41 @@ def __init__( scopes: List[str], auth_type: str = "databricks-oauth", ): - try: - idp_endpoint = get_oauth_endpoints(hostname, auth_type == "azure-oauth") - if not idp_endpoint: - raise NotImplementedError( + self._hostname = hostname + self._oauth_persistence = oauth_persistence + self._client_id = client_id + self._auth_type = auth_type + self._access_token = None + self._refresh_token = None + + idp_endpoint = get_oauth_endpoints(hostname, auth_type == "azure-oauth") + if not idp_endpoint: + raise NotImplementedError( f"OAuth is not supported for host ${hostname}" ) - # Convert to the corresponding scopes in the corresponding IdP - cloud_scopes = idp_endpoint.get_scopes_mapping(scopes) + + cloud_scopes = idp_endpoint.get_scopes_mapping(scopes) + self._scopes_as_str = self.SCOPE_DELIM.join(cloud_scopes) - self.oauth_manager = OAuthManager( - port_range=redirect_port_range, - client_id=client_id, - idp_endpoint=idp_endpoint, - ) - self._hostname = hostname - self._scopes_as_str = DatabricksOAuthProvider.SCOPE_DELIM.join(cloud_scopes) - self._oauth_persistence = oauth_persistence - self._client_id = client_id - self._access_token = None - self._refresh_token = None - self._initial_get_token() - except Exception as e: - logging.error(f"unexpected error", e, exc_info=True) - raise e + self.oauth_manager = OAuthManager( + idp_endpoint=idp_endpoint, + client_id=client_id, + port_range=redirect_port_range, + ) + self._initial_get_token() def add_headers(self, request_headers: Dict[str, str]): self._update_token_if_expired() - request_headers["Authorization"] = f"Bearer {self._access_token}" + request_headers["Authorization"] = "Bearer {}".format(self._access_token) def auth_type(self) -> str: - return "databricks-oauth" + return self._auth_type def __call__(self, *args, **kwargs) -> HeaderFactory: def get_headers(): self._update_token_if_expired() - return {"Authorization": f"Bearer {self._access_token}"} + return {"Authorization": "Bearer {}".format(self._access_token)} return get_headers def _initial_get_token(self): @@ -161,8 +160,96 @@ def _update_token_if_expired(self): raise e +class ClientCredentialsProvider(CredentialsProvider, AuthProvider): + """Provider for OAuth client credentials flow (machine-to-machine authentication).""" + + AZURE_DATABRICKS_SCOPE = "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d/.default" + + def __init__( + self, + client_id: str, + client_secret: str, + token_endpoint: str, + auth_type_value: str = "client-credentials" + ): + """ + Initialize a ClientCredentialsProvider. + + Args: + client_id: OAuth client ID + client_secret: OAuth client secret + token_endpoint: OAuth token endpoint URL + auth_type_value: Auth type identifier + """ + self.client_id = client_id + self.client_secret = client_secret + self.token_endpoint = token_endpoint + self.auth_type_value = auth_type_value + + self._cached_token = None + self._token_expires_at = None + + + def auth_type(self) -> str: + return self.auth_type_value + + def __call__(self, *args, **kwargs) -> HeaderFactory: + def get_headers() -> Dict[str, str]: + token = self._get_access_token() + return {"Authorization": "Bearer {}".format(token)} + return get_headers + + def add_headers(self, request_headers: Dict[str, str]): + token = self._get_access_token() + request_headers["Authorization"] = "Bearer {}".format(token) + + def _get_access_token(self) -> str: + """Get a valid access token using client credentials flow, with caching.""" + # Check if we have a valid cached token (with 40 second buffer since azure doesn't respect a token with less than 30s expiry) + if (self._cached_token and self._token_expires_at and + time.time() < self._token_expires_at - 40): + return self._cached_token + + # Get new token using client credentials flow + token_data = self._request_token() + + self._cached_token = token_data['access_token'] + # expires_in is in seconds, convert to absolute time + self._token_expires_at = time.time() + token_data.get('expires_in', 3600) + + return self._cached_token + + def _request_token(self) -> dict: + """Request a new token using OAuth client credentials flow.""" + data = { + 'grant_type': 'client_credentials', + 'client_id': self.client_id, + 'client_secret': self.client_secret, + 'scope': self.AZURE_DATABRICKS_SCOPE, + } + + headers = {'Content-Type': 'application/x-www-form-urlencoded'} + + try: + response = requests.post(self.token_endpoint, data=data, headers=headers) + response.raise_for_status() + + token_data = response.json() + + if 'access_token' not in token_data: + raise ValueError("No access_token in response: {}".format(token_data)) + + return token_data + + except requests.exceptions.RequestException as e: + raise RuntimeError("Token request failed: {}".format(e)) from e + except ValueError as e: + raise RuntimeError("Invalid token response: {}".format(e)) from e + + class ExternalAuthProvider(AuthProvider, CredentialsProvider): def __init__(self, credentials_provider: CredentialsProvider) -> None: + self._credentials_provider = credentials_provider self._header_factory = credentials_provider() def add_headers(self, request_headers: Dict[str, str]): diff --git a/src/databricks/sql/auth/token_federation.py b/src/databricks/sql/auth/token_federation.py index d40ab62ed..75202a5dc 100644 --- a/src/databricks/sql/auth/token_federation.py +++ b/src/databricks/sql/auth/token_federation.py @@ -18,21 +18,21 @@ class DatabricksTokenFederationProvider(CredentialsProvider): """ - Implementation of the Credential Provider that exchanges a third party access token - for a Databricks token. - - This provider wraps an existing credentials provider and handles token exchange when - the token is from a different host than the Databricks host. It also manages token - refresh when tokens are expired. + Token federation provider that exchanges external tokens for Databricks tokens. + + This implementation follows the JDBC pattern: + 1. Try token exchange without HTTP Basic authentication (per RFC 8693) + 2. Fall back to using external token directly if exchange fails + 3. Compare token issuer with Databricks host to determine if exchange is needed """ - # HTTP request configuration + # HTTP request configuration (no authentication) EXCHANGE_HEADERS = { "Accept": "*/*", "Content-Type": "application/x-www-form-urlencoded", } - # Token exchange parameters + # Token exchange parameters following RFC 8693 TOKEN_EXCHANGE_PARAMS = { "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", "scope": "sql", @@ -118,9 +118,9 @@ def _parse_jwt_claims(self, token: str) -> Dict[str, Any]: Dict[str, Any]: Parsed JWT claims """ try: - return jwt.decode(token, options={"verify_signature": False}) + return jwt.decode(token, options={"verify_signature": False, "verify_aud": False}) except Exception as e: - logger.error(f"Failed to parse JWT: {str(e)}") + logger.debug("Failed to parse JWT: %s", str(e)) return {} def _get_expiry_from_jwt(self, token: str) -> Optional[datetime]: @@ -138,14 +138,11 @@ def _get_expiry_from_jwt(self, token: str) -> Optional[datetime]: # Look for standard JWT expiry claim ("exp") if "exp" in claims: try: - # JWT expiry is in seconds since epoch expiry_timestamp = int(claims["exp"]) - # Convert to datetime return datetime.fromtimestamp(expiry_timestamp, tz=timezone.utc) except (ValueError, TypeError) as e: - logger.warning(f"Invalid JWT expiry value: {e}") + logger.warning("Invalid JWT expiry value: %s", e) - return None def refresh_token(self) -> Token: """ @@ -177,27 +174,18 @@ def refresh_token(self) -> Token: # Token is from the same host, no need to exchange logger.debug("Token from same host, creating token without exchange") expiry = self._get_expiry_from_jwt(access_token) - if expiry is None: - raise ValueError("Could not determine token expiry from JWT") new_token = Token(access_token, token_type, "", expiry) self.current_token = new_token return new_token else: - # Token is from a different host, need to exchange - logger.debug("Token from different host, exchanging token") + logger.debug("Token from different host, attempting token exchange") try: new_token = self._exchange_token(access_token) self.current_token = new_token return new_token except Exception as e: - logger.error( - f"Token exchange failed: {e}. Using external token as fallback." - ) + logger.debug("Token exchange failed: %s. Using external token as fallback.", e) expiry = self._get_expiry_from_jwt(access_token) - if expiry is None: - raise ValueError( - "Could not determine token expiry from JWT (after exchange failure)" - ) fallback_token = Token(access_token, token_type, "", expiry) self.current_token = fallback_token return fallback_token @@ -235,10 +223,9 @@ def get_auth_headers(self) -> Dict[str, str]: # Always get the latest headers from the credentials provider header_factory = self.credentials_provider() headers = dict(header_factory()) if header_factory else {} - headers["Authorization"] = f"{token.token_type} {token.access_token}" + headers["Authorization"] = "{} {}".format(token.token_type, token.access_token) return headers except Exception as e: - logger.error(f"Error getting auth headers: {str(e)}") return dict(self.external_headers) if self.external_headers else {} def _send_token_exchange_request( @@ -246,6 +233,9 @@ def _send_token_exchange_request( ) -> Dict[str, Any]: """ Send the token exchange request to the token endpoint. + + For M2M flows, this should include HTTP Basic authentication using client credentials. + For U2M flows, token exchange is validated purely based on the JWT token and federation policies. Args: token_exchange_data: Token exchange request data @@ -259,13 +249,24 @@ def _send_token_exchange_request( if not self.token_endpoint: raise ValueError("Token endpoint not initialized") + auth = None + if hasattr(self.credentials_provider, 'client_id') and hasattr(self.credentials_provider, 'client_secret'): + client_id = self.credentials_provider.client_id + client_secret = self.credentials_provider.client_secret + auth = (client_id, client_secret) + else: + logger.debug("No client credentials available, sending request without authentication") + response = requests.post( - self.token_endpoint, data=token_exchange_data, headers=self.EXCHANGE_HEADERS + self.token_endpoint, + data=token_exchange_data, + headers=self.EXCHANGE_HEADERS, + auth=auth ) if response.status_code != 200: raise requests.HTTPError( - f"Token exchange failed with status code {response.status_code}: {response.text}", + "Token exchange failed with status code {}: {}".format(response.status_code, response.text), response=response, ) @@ -288,15 +289,15 @@ def _exchange_token(self, access_token: str) -> Token: self.token_endpoint = OIDCDiscoveryUtil.discover_token_endpoint( self.hostname ) - # Prepare the request data + + # Prepare the request data according to RFC 8693 token_exchange_data = dict(self.TOKEN_EXCHANGE_PARAMS) token_exchange_data["subject_token"] = access_token - # Add client_id if provided + # Add client_id if provided for federation policy identification if self.identity_federation_client_id: token_exchange_data["client_id"] = self.identity_federation_client_id - # Send the token exchange request resp_data = self._send_token_exchange_request(token_exchange_data) # Extract token information @@ -309,8 +310,6 @@ def _exchange_token(self, access_token: str) -> Token: # Extract expiry from JWT claims expiry = self._get_expiry_from_jwt(new_access_token) - if expiry is None: - raise ValueError("Unable to determine token expiry from JWT claims") return Token(new_access_token, token_type, refresh_token, expiry) @@ -320,32 +319,4 @@ def add_headers(self, request_headers: Dict[str, str]): """ headers = self.get_auth_headers() for k, v in headers.items(): - request_headers[k] = v - - -class SimpleCredentialsProvider(CredentialsProvider): - """A simple credentials provider that returns a fixed token.""" - - def __init__( - self, token: str, token_type: str = "Bearer", auth_type_value: str = "token" - ): - """ - Initialize a SimpleCredentialsProvider. - """ - self.token = token - self.token_type = token_type - self.auth_type_value = auth_type_value - - def auth_type(self) -> str: - """Return the auth type value.""" - return self.auth_type_value - - def __call__(self, *args, **kwargs) -> HeaderFactory: - """ - Return a HeaderFactory that provides a fixed token. - """ - - def get_headers() -> Dict[str, str]: - return {"Authorization": f"{self.token_type} {self.token}"} - - return get_headers + request_headers[k] = v \ No newline at end of file diff --git a/tests/unit/test_token_federation.py b/tests/unit/test_token_federation.py index 2bb57645c..4a77aa98a 100644 --- a/tests/unit/test_token_federation.py +++ b/tests/unit/test_token_federation.py @@ -5,8 +5,7 @@ from databricks.sql.auth.token import Token from databricks.sql.auth.token_federation import ( - DatabricksTokenFederationProvider, - SimpleCredentialsProvider, + DatabricksTokenFederationProvider ) from databricks.sql.auth.oidc_utils import OIDCDiscoveryUtil @@ -56,21 +55,6 @@ def test_almost_expired_token_is_invalid(self): assert not token.is_valid() -class TestSimpleCredentialsProvider: - """Tests for the SimpleCredentialsProvider class.""" - - def test_provider_initialization_and_headers(self): - """Test SimpleCredentialsProvider initialization and header generation.""" - provider = SimpleCredentialsProvider("token1", "Bearer", "token") - - # Check auth type - assert provider.auth_type() == "token" - - # Check header generation - headers = provider()() - assert headers == {"Authorization": "Bearer token1"} - - class TestOIDCDiscoveryUtil: """Tests for the OIDCDiscoveryUtil class.""" From 19dc0b1948118048b91d0e9afeeaac52c9744137 Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Wed, 28 May 2025 08:25:28 +0000 Subject: [PATCH 46/46] formatted --- src/databricks/sql/auth/auth.py | 34 +++++++---- src/databricks/sql/auth/authenticators.py | 67 ++++++++++++--------- src/databricks/sql/auth/token_federation.py | 41 ++++++++----- tests/unit/test_token_federation.py | 4 +- 4 files changed, 87 insertions(+), 59 deletions(-) diff --git a/src/databricks/sql/auth/auth.py b/src/databricks/sql/auth/auth.py index d600f8f59..b151f3ca0 100755 --- a/src/databricks/sql/auth/auth.py +++ b/src/databricks/sql/auth/auth.py @@ -15,6 +15,7 @@ class AuthType(Enum): AZURE_OAUTH = "azure-oauth" CLIENT_CREDENTIALS = "client-credentials" + class ClientContext: def __init__( self, @@ -48,31 +49,42 @@ def __init__( self.oauth_client_secret = oauth_client_secret self.tenant_id = tenant_id -def _create_azure_client_credentials_provider(cfg: ClientContext) -> ClientCredentialsProvider: + +def _create_azure_client_credentials_provider( + cfg: ClientContext, +) -> ClientCredentialsProvider: """Create an Azure client credentials provider.""" if not cfg.oauth_client_id or not cfg.oauth_client_secret or not cfg.tenant_id: - raise ValueError("Azure client credentials flow requires oauth_client_id, oauth_client_secret, and tenant_id") - - token_endpoint = "https://login.microsoftonline.com/{}/oauth2/v2.0/token".format(cfg.tenant_id) + raise ValueError( + "Azure client credentials flow requires oauth_client_id, oauth_client_secret, and tenant_id" + ) + + token_endpoint = "https://login.microsoftonline.com/{}/oauth2/v2.0/token".format( + cfg.tenant_id + ) return ClientCredentialsProvider( client_id=cfg.oauth_client_id, client_secret=cfg.oauth_client_secret, token_endpoint=token_endpoint, - auth_type_value="azure-client-credentials" + auth_type_value="azure-client-credentials", ) -def _create_databricks_client_credentials_provider(cfg: ClientContext) -> ClientCredentialsProvider: +def _create_databricks_client_credentials_provider( + cfg: ClientContext, +) -> ClientCredentialsProvider: """Create a Databricks client credentials provider for service principals.""" if not cfg.oauth_client_id or not cfg.oauth_client_secret: - raise ValueError("Databricks client credentials flow requires oauth_client_id and oauth_client_secret") - + raise ValueError( + "Databricks client credentials flow requires oauth_client_id and oauth_client_secret" + ) + token_endpoint = "{}oidc/v1/token".format(cfg.hostname) return ClientCredentialsProvider( client_id=cfg.oauth_client_id, client_secret=cfg.oauth_client_secret, token_endpoint=token_endpoint, - auth_type_value="client-credentials" + auth_type_value="client-credentials", ) @@ -102,9 +114,9 @@ def get_auth_provider(cfg: ClientContext): RuntimeError: If no valid authentication settings are provided """ from databricks.sql.auth.token_federation import DatabricksTokenFederationProvider - + base_provider = None - + if cfg.credentials_provider: base_provider = ExternalAuthProvider(cfg.credentials_provider) elif cfg.access_token is not None: diff --git a/src/databricks/sql/auth/authenticators.py b/src/databricks/sql/auth/authenticators.py index 5f0d4ea01..a3befa4ba 100644 --- a/src/databricks/sql/auth/authenticators.py +++ b/src/databricks/sql/auth/authenticators.py @@ -11,7 +11,11 @@ # Private API: this is an evolving interface and it will change in the future. # Please must not depend on it in your applications. from databricks.sql.experimental.oauth_persistence import OAuthToken, OAuthPersistence -from databricks.sql.auth.endpoint import AzureOAuthEndpointCollection, InHouseOAuthEndpointCollection +from databricks.sql.auth.endpoint import ( + AzureOAuthEndpointCollection, + InHouseOAuthEndpointCollection, +) + class AuthProvider: def add_headers(self, request_headers: Dict[str, str]): @@ -56,8 +60,10 @@ def auth_type(self) -> str: def __call__(self, *args, **kwargs) -> HeaderFactory: def get_headers(): return {"Authorization": self.__authorization_header_value} + return get_headers + # Private API: this is an evolving interface and it will change in the future. # Please must not depend on it in your applications. class DatabricksOAuthProvider(AuthProvider, CredentialsProvider): @@ -81,11 +87,8 @@ def __init__( idp_endpoint = get_oauth_endpoints(hostname, auth_type == "azure-oauth") if not idp_endpoint: - raise NotImplementedError( - f"OAuth is not supported for host ${hostname}" - ) + raise NotImplementedError(f"OAuth is not supported for host ${hostname}") - cloud_scopes = idp_endpoint.get_scopes_mapping(scopes) self._scopes_as_str = self.SCOPE_DELIM.join(cloud_scopes) @@ -107,6 +110,7 @@ def __call__(self, *args, **kwargs) -> HeaderFactory: def get_headers(): self._update_token_if_expired() return {"Authorization": "Bearer {}".format(self._access_token)} + return get_headers def _initial_get_token(self): @@ -170,14 +174,14 @@ def __init__( client_id: str, client_secret: str, token_endpoint: str, - auth_type_value: str = "client-credentials" + auth_type_value: str = "client-credentials", ): """ Initialize a ClientCredentialsProvider. - + Args: client_id: OAuth client ID - client_secret: OAuth client secret + client_secret: OAuth client secret token_endpoint: OAuth token endpoint URL auth_type_value: Auth type identifier """ @@ -185,10 +189,9 @@ def __init__( self.client_secret = client_secret self.token_endpoint = token_endpoint self.auth_type_value = auth_type_value - + self._cached_token = None self._token_expires_at = None - def auth_type(self) -> str: return self.auth_type_value @@ -197,8 +200,9 @@ def __call__(self, *args, **kwargs) -> HeaderFactory: def get_headers() -> Dict[str, str]: token = self._get_access_token() return {"Authorization": "Bearer {}".format(token)} + return get_headers - + def add_headers(self, request_headers: Dict[str, str]): token = self._get_access_token() request_headers["Authorization"] = "Bearer {}".format(token) @@ -206,41 +210,44 @@ def add_headers(self, request_headers: Dict[str, str]): def _get_access_token(self) -> str: """Get a valid access token using client credentials flow, with caching.""" # Check if we have a valid cached token (with 40 second buffer since azure doesn't respect a token with less than 30s expiry) - if (self._cached_token and self._token_expires_at and - time.time() < self._token_expires_at - 40): + if ( + self._cached_token + and self._token_expires_at + and time.time() < self._token_expires_at - 40 + ): return self._cached_token - + # Get new token using client credentials flow token_data = self._request_token() - - self._cached_token = token_data['access_token'] + + self._cached_token = token_data["access_token"] # expires_in is in seconds, convert to absolute time - self._token_expires_at = time.time() + token_data.get('expires_in', 3600) - + self._token_expires_at = time.time() + token_data.get("expires_in", 3600) + return self._cached_token def _request_token(self) -> dict: """Request a new token using OAuth client credentials flow.""" data = { - 'grant_type': 'client_credentials', - 'client_id': self.client_id, - 'client_secret': self.client_secret, - 'scope': self.AZURE_DATABRICKS_SCOPE, + "grant_type": "client_credentials", + "client_id": self.client_id, + "client_secret": self.client_secret, + "scope": self.AZURE_DATABRICKS_SCOPE, } - - headers = {'Content-Type': 'application/x-www-form-urlencoded'} - + + headers = {"Content-Type": "application/x-www-form-urlencoded"} + try: response = requests.post(self.token_endpoint, data=data, headers=headers) response.raise_for_status() - + token_data = response.json() - - if 'access_token' not in token_data: + + if "access_token" not in token_data: raise ValueError("No access_token in response: {}".format(token_data)) - + return token_data - + except requests.exceptions.RequestException as e: raise RuntimeError("Token request failed: {}".format(e)) from e except ValueError as e: diff --git a/src/databricks/sql/auth/token_federation.py b/src/databricks/sql/auth/token_federation.py index 75202a5dc..1be46cab4 100644 --- a/src/databricks/sql/auth/token_federation.py +++ b/src/databricks/sql/auth/token_federation.py @@ -19,7 +19,7 @@ class DatabricksTokenFederationProvider(CredentialsProvider): """ Token federation provider that exchanges external tokens for Databricks tokens. - + This implementation follows the JDBC pattern: 1. Try token exchange without HTTP Basic authentication (per RFC 8693) 2. Fall back to using external token directly if exchange fails @@ -118,7 +118,9 @@ def _parse_jwt_claims(self, token: str) -> Dict[str, Any]: Dict[str, Any]: Parsed JWT claims """ try: - return jwt.decode(token, options={"verify_signature": False, "verify_aud": False}) + return jwt.decode( + token, options={"verify_signature": False, "verify_aud": False} + ) except Exception as e: logger.debug("Failed to parse JWT: %s", str(e)) return {} @@ -143,7 +145,6 @@ def _get_expiry_from_jwt(self, token: str) -> Optional[datetime]: except (ValueError, TypeError) as e: logger.warning("Invalid JWT expiry value: %s", e) - def refresh_token(self) -> Token: """ Refresh the token and return the new Token object. @@ -184,7 +185,9 @@ def refresh_token(self) -> Token: self.current_token = new_token return new_token except Exception as e: - logger.debug("Token exchange failed: %s. Using external token as fallback.", e) + logger.debug( + "Token exchange failed: %s. Using external token as fallback.", e + ) expiry = self._get_expiry_from_jwt(access_token) fallback_token = Token(access_token, token_type, "", expiry) self.current_token = fallback_token @@ -223,7 +226,9 @@ def get_auth_headers(self) -> Dict[str, str]: # Always get the latest headers from the credentials provider header_factory = self.credentials_provider() headers = dict(header_factory()) if header_factory else {} - headers["Authorization"] = "{} {}".format(token.token_type, token.access_token) + headers["Authorization"] = "{} {}".format( + token.token_type, token.access_token + ) return headers except Exception as e: return dict(self.external_headers) if self.external_headers else {} @@ -233,7 +238,7 @@ def _send_token_exchange_request( ) -> Dict[str, Any]: """ Send the token exchange request to the token endpoint. - + For M2M flows, this should include HTTP Basic authentication using client credentials. For U2M flows, token exchange is validated purely based on the JWT token and federation policies. @@ -250,23 +255,29 @@ def _send_token_exchange_request( raise ValueError("Token endpoint not initialized") auth = None - if hasattr(self.credentials_provider, 'client_id') and hasattr(self.credentials_provider, 'client_secret'): + if hasattr(self.credentials_provider, "client_id") and hasattr( + self.credentials_provider, "client_secret" + ): client_id = self.credentials_provider.client_id client_secret = self.credentials_provider.client_secret auth = (client_id, client_secret) else: - logger.debug("No client credentials available, sending request without authentication") - + logger.debug( + "No client credentials available, sending request without authentication" + ) + response = requests.post( - self.token_endpoint, - data=token_exchange_data, + self.token_endpoint, + data=token_exchange_data, headers=self.EXCHANGE_HEADERS, - auth=auth + auth=auth, ) if response.status_code != 200: raise requests.HTTPError( - "Token exchange failed with status code {}: {}".format(response.status_code, response.text), + "Token exchange failed with status code {}: {}".format( + response.status_code, response.text + ), response=response, ) @@ -289,7 +300,7 @@ def _exchange_token(self, access_token: str) -> Token: self.token_endpoint = OIDCDiscoveryUtil.discover_token_endpoint( self.hostname ) - + # Prepare the request data according to RFC 8693 token_exchange_data = dict(self.TOKEN_EXCHANGE_PARAMS) token_exchange_data["subject_token"] = access_token @@ -319,4 +330,4 @@ def add_headers(self, request_headers: Dict[str, str]): """ headers = self.get_auth_headers() for k, v in headers.items(): - request_headers[k] = v \ No newline at end of file + request_headers[k] = v diff --git a/tests/unit/test_token_federation.py b/tests/unit/test_token_federation.py index 4a77aa98a..1d1395953 100644 --- a/tests/unit/test_token_federation.py +++ b/tests/unit/test_token_federation.py @@ -4,9 +4,7 @@ import jwt from databricks.sql.auth.token import Token -from databricks.sql.auth.token_federation import ( - DatabricksTokenFederationProvider -) +from databricks.sql.auth.token_federation import DatabricksTokenFederationProvider from databricks.sql.auth.oidc_utils import OIDCDiscoveryUtil