From 23813809efa0ddc86378d0e6970a69a864554573 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 1 Dec 2022 18:39:47 -0600 Subject: [PATCH 001/133] PYTHON-3460 Implement OIDC SASL mechanism wip oidc-sasl initial implementation wip add expires_in_seconds handling prep for refresh test finish refresh implementation clean up refresh handling support aws device workflow try oidc-enabled mongod syntax enforce crypt shared url set orchestration file debug debug syntax syntax syntax syntax syntax fix handling of orch file fix handling of orch file run test in eg add oidc test files update config fix config fix syntax fix syntax print token more debug try again try fixing log upload try fixing log upload more debug more debug try again try with rapid apply new link use rhel80-xlarge use rhel80-xlarge use rhel80-xlarge install lsof try with rhel7 install libcrypto try lsof again update for new api try with auth claim update for new oidc provider print config use string use hard-coded python remove virtualenv remove virtualenv manual venv another explicit python install deps fix install install virtualenv try to fix how virtualenv is used try to fix how virtualenv is used try to fix how virtualenv is used adjust issuer use new creds use more new settings try again fix audience use admin db try on test db again wip fix handling of client_resp refactor oidc handling clean up config --- .evergreen/config.yml | 122 ++++++++++++++++++++++++-- .evergreen/run-mongodb-oidc-test.sh | 72 ++++++++++++++++ pymongo/auth.py | 127 +++++++++++++++++++++++++++- pymongo/common.py | 20 +++++ pymongo/mongo_client.py | 1 + test/auth_aws/test_auth_oidc.py | 39 +++++++++ test_oidc.py | 119 ++++++++++++++++++++++++++ 7 files changed, 494 insertions(+), 6 deletions(-) create mode 100755 .evergreen/run-mongodb-oidc-test.sh create mode 100644 test/auth_aws/test_auth_oidc.py create mode 100644 test_oidc.py diff --git a/.evergreen/config.yml b/.evergreen/config.yml index ab61725a20..3f478ae6ec 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -75,7 +75,7 @@ functions: export TMPDIR="$MONGO_ORCHESTRATION_HOME/db" # Installation of cryptography requires a rust compiler on some machines - export PATH="$MONGODB_BINARIES:/home/admin/.cargo/bin:$PATH" + export PATH="/opt/mongodbtoolchain/v4/bin/:$MONGODB_BINARIES:/home/admin/.cargo/bin:$PATH" export PROJECT="${project}" EOT # See what we've done @@ -97,7 +97,7 @@ functions: # If this was a patch build, doing a fresh clone would not actually test the patch cp -R ${PROJECT_DIRECTORY}/ $DRIVERS_TOOLS else - git clone https://github.com/mongodb-labs/drivers-evergreen-tools.git $DRIVERS_TOOLS + git clone --branch DRIVERS-2508 https://github.com/blink1073/drivers-evergreen-tools.git $DRIVERS_TOOLS fi echo "{ \"releases\": { \"default\": \"$MONGODB_BINARIES\" }}" > $MONGO_ORCHESTRATION_HOME/orchestration.config @@ -165,6 +165,7 @@ functions: script: | set -o xtrace ${PREPARE_SHELL} + cat ${DRIVERS_TOOLS}/server.log find $MONGO_ORCHESTRATION_HOME -name \*.log | xargs tar czf mongodb-logs.tar.gz - command: archive.targz_pack params: @@ -284,8 +285,14 @@ functions: fi ${PREPARE_SHELL} + + # TODO: remove path hack and mongodb download url + /opt/mongodbtoolchain/v4/bin/python3 -m pip install virtualenv + export PATH=/opt/mongodbtoolchain/v4/bin/:$PATH + MONGODB_VERSION=${VERSION} \ TOPOLOGY=${TOPOLOGY} \ + MONGODB_DOWNLOAD_URL=https://mciuploads.s3.amazonaws.com/mongodb-mongo-master/linux-x86-dynamic-compile-required/d2bb64fbd29269667d665c1f09066be0725b1d78/dist/mongo-mongodb_mongo_master_linux_x86_dynamic_compile_required_d2bb64fbd29269667d665c1f09066be0725b1d78_23_01_03_16_35_49.tgz \ AUTH=${AUTH} \ SSL=${SSL} \ STORAGE_ENGINE=${STORAGE_ENGINE} \ @@ -294,6 +301,7 @@ functions: REQUIRE_API_VERSION=${REQUIRE_API_VERSION} \ LOAD_BALANCER=${LOAD_BALANCER} \ bash ${DRIVERS_TOOLS}/.evergreen/run-orchestration.sh + # run-orchestration generates expansion file with the MONGODB_URI for the cluster - command: expansions.update params: @@ -566,15 +574,12 @@ functions: "iam_auth_ecs_subnet_a": "${iam_auth_ecs_subnet_a}", "iam_auth_ecs_subnet_b": "${iam_auth_ecs_subnet_b}", "iam_auth_ecs_security_group": "${iam_auth_ecs_security_group}", - "iam_auth_assume_aws_account" : "${iam_auth_assume_aws_account}", "iam_auth_assume_aws_secret_access_key" : "${iam_auth_assume_aws_secret_access_key}", "iam_auth_assume_role_name" : "${iam_auth_assume_role_name}", - "iam_auth_ec2_instance_account" : "${iam_auth_ec2_instance_account}", "iam_auth_ec2_instance_secret_access_key" : "${iam_auth_ec2_instance_secret_access_key}", "iam_auth_ec2_instance_profile" : "${iam_auth_ec2_instance_profile}", - "iam_auth_assume_web_role_name": "${iam_auth_assume_web_role_name}", "iam_web_identity_issuer": "${iam_web_identity_issuer}", "iam_web_identity_rsa_key": "${iam_web_identity_rsa_key}", @@ -738,6 +743,70 @@ functions: fi PYTHON_BINARY=${PYTHON_BINARY} ASSERT_NO_URI_CREDS=true .evergreen/run-mongodb-aws-test.sh + "bootstrap oidc": + - command: ec2.assume_role + params: + role_arn: ${aws_test_secrets_role} + - command: shell.exec + type: test + params: + working_dir: "src" + script: | + ${PREPARE_SHELL} + if [ "${skip_EC2_auth_test}" = "true" ]; then + echo "This platform does not support the oidc auth test, skipping..." + exit 0 + fi + + cd ${DRIVERS_TOOLS}/.evergreen/auth_aws + export AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID} + export AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY} + export AWS_SESSION_TOKEN=${AWS_SESSION_TOKEN} + export AWS_TOKEN_DIR=/tmp/tokens + + . ./activate_venv.sh + python lib/aws_oidc_bootstrap.py + + # TODO: remove once we can target Ubuntu build + sudo yum install -y lsof + + "run oidc auth test with aws credentials": + - command: shell.exec + type: test + params: + working_dir: "src" + script: | + ${PREPARE_SHELL} + if [ "${skip_EC2_auth_test}" = "true" ]; then + echo "This platform does not support the oidc auth test, skipping..." + exit 0 + fi + cd ${DRIVERS_TOOLS}/.evergreen/auth_aws + . ./activate_venv.sh + mongo aws_setup_oidc.js + - command: shell.exec + type: test + params: + working_dir: "src" + silent: true + script: | + # DO NOT ECHO WITH XTRACE (which PREPARE_SHELL does) + cat <<'EOF' > "${PROJECT_DIRECTORY}/prepare_mongodb_aws.sh" + export AWS_WEB_IDENTITY_TOKEN_FILE=${AWS_WEB_IDENTITY_TOKEN_FILE|/tmp/tokens/test1} + export USE_MULTIPLE_PRINCIPALS=${USE_MULTIPLE_PRINCIPALS|false} + EOF + - command: shell.exec + type: test + params: + working_dir: "src" + script: | + ${PREPARE_SHELL} + if [ "${skip_web_identity_auth_test}" = "true" ]; then + echo "This platform does not support the oidc auth test, skipping..." + exit 0 + fi + PYTHON_BINARY=/opt/mongodbtoolchain/v4/bin/python3 ASSERT_NO_URI_CREDS=true USE_MULTIPLE_PRINCIPALS=${USE_MULTIPLE_PRINCIPALS|false} .evergreen/run-mongodb-oidc-test.sh + "run aws auth test with aws credentials as environment variables": - command: shell.exec type: test @@ -1956,6 +2025,7 @@ tasks: - func: "run aws auth test with aws EC2 credentials" - func: "run aws auth test with aws web identity credentials" - func: "run aws ECS auth test" + - name: "aws-auth-test-rapid" commands: - func: "bootstrap mongo-orchestration" @@ -1973,6 +2043,33 @@ tasks: - func: "run aws auth test with aws web identity credentials" - func: "run aws ECS auth test" + - name: "oidc-auth-test-single-latest" + commands: + - func: "bootstrap oidc" + - func: "bootstrap mongo-orchestration" + vars: + AUTH: "auth" + ORCHESTRATION_FILE: "auth-oidc.json" + TOPOLOGY: "server" + VERSION: "latest" + - func: "run oidc auth test with aws credentials" + + - name: "oidc-auth-test-multiple-latest" + commands: + - func: "bootstrap oidc" + vars: + USE_MULTIPLE_PRINCIPALS: "true" + - func: "bootstrap mongo-orchestration" + vars: + AUTH: "auth" + ORCHESTRATION_FILE: "auth-oidc.json" + TOPOLOGY: "server" + VERSION: "latest" + - func: "run oidc auth test with aws credentials" + vars: + USE_MULTIPLE_PRINCIPALS: "true" + AWS_WEB_IDENTITY_TOKEN_FILE: /tmp/tokens/test2 + - name: load-balancer-test commands: - func: "bootstrap mongo-orchestration" @@ -2179,6 +2276,12 @@ axes: batchtime: 10080 # 7 days variables: libmongocrypt_url: https://s3.amazonaws.com/mciuploads/libmongocrypt/rhel-70-64-bit/master/latest/libmongocrypt.tar.gz + - id: rhel80-xlarge + display_name: "RHEL 8.0" + run_on: rhel80-xlarge + batchtime: 10080 # 7 days + variables: + libmongocrypt_url: https://s3.amazonaws.com/mciuploads/libmongocrypt/rhel-70-64-bit/master/latest/libmongocrypt.tar.gz - id: rhel70-fips display_name: "RHEL 7.0 FIPS" run_on: rhel70-fips @@ -3019,6 +3122,15 @@ buildvariants: # macOS MongoDB servers do not staple OCSP responses and only support RSA. - name: ".ocsp-rsa !.ocsp-staple" +- matrix_name: "oidc-auth-test" + matrix_spec: + platform: [ rhel80-xlarge ] + python-version: ["3.7"] + display_name: "MONGODB-OIDC Auth ${platform} ${python-version}" + tasks: + - name: "oidc-auth-test-single-latest" + - name: "oidc-auth-test-multiple-latest" + - matrix_name: "aws-auth-test" matrix_spec: platform: [ubuntu-18.04] diff --git a/.evergreen/run-mongodb-oidc-test.sh b/.evergreen/run-mongodb-oidc-test.sh new file mode 100755 index 0000000000..b42cc1d9e3 --- /dev/null +++ b/.evergreen/run-mongodb-oidc-test.sh @@ -0,0 +1,72 @@ +#!/bin/bash + +set -o xtrace +set -o errexit # Exit the script with error if any of the commands fail + +############################################ +# Main Program # +############################################ + +# Supported/used environment variables: +# MONGODB_URI Set the URI, including an optional username/password to use +# to connect to the server via MONGODB-AWS authentication +# mechanism. +# PYTHON_BINARY The Python version to use. + +echo "Running MONGODB-OIDC authentication tests" +# ensure no secrets are printed in log files +set +x + +# load the script +shopt -s expand_aliases # needed for `urlencode` alias +[ -s "${PROJECT_DIRECTORY}/prepare_mongodb_aws.sh" ] && source "${PROJECT_DIRECTORY}/prepare_mongodb_aws.sh" + +MONGODB_URI=${MONGODB_URI:-"mongodb://localhost"} +MONGODB_URI="${MONGODB_URI}/test?authMechanism=MONGODB-OIDC" + +export MONGODB_URI="$MONGODB_URI" + +if [ "$ASSERT_NO_URI_CREDS" = "true" ]; then + if echo "$MONGODB_URI" | grep -q "@"; then + echo "MONGODB_URI unexpectedly contains user credentials!"; + exit 1 + fi +fi + +# show test output +set -x + +# Workaround macOS python 3.9 incompatibility with system virtualenv. +if [ "$(uname -s)" = "Darwin" ]; then + VIRTUALENV="/Library/Frameworks/Python.framework/Versions/3.9/bin/python3 -m virtualenv" +else + VIRTUALENV=$(command -v virtualenv) +fi + +authtest () { + if [ "Windows_NT" = "$OS" ]; then + PYTHON=$(cygpath -m $PYTHON) + fi + + echo "Running MONGODB-AWS authentication tests with $PYTHON" + $PYTHON --version + + $VIRTUALENV -p $PYTHON --never-download venvoidc + if [ "Windows_NT" = "$OS" ]; then + . venvoidc/Scripts/activate + else + . venvoidc/bin/activate + fi + python -m pip install '.[aws]' + python test/auth_aws/test_auth_oidc.py -v + deactivate + rm -rf venvoidc +} + +PYTHON=${PYTHON_BINARY:-} +if [ -z "$PYTHON" ]; then + echo "Cannot test without specifying PYTHON_BINARY" + exit 1 +fi + +authtest diff --git a/pymongo/auth.py b/pymongo/auth.py index 3d259335b0..fccd7cd6b9 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -21,9 +21,11 @@ import socket from base64 import standard_b64decode, standard_b64encode from collections import namedtuple +from datetime import datetime, timedelta, timezone from typing import Callable, Mapping from urllib.parse import quote +import bson from bson.binary import Binary from bson.son import SON from pymongo.auth_aws import _authenticate_aws @@ -48,6 +50,7 @@ [ "GSSAPI", "MONGODB-CR", + "MONGODB-OIDC", "MONGODB-X509", "MONGODB-AWS", "PLAIN", @@ -99,9 +102,15 @@ def __hash__(self): """Mechanism properties for MONGODB-AWS authentication.""" +_OIDCProperties = namedtuple( + "_OIDCProperties", ["on_oidc_request_token", "on_oidc_refresh_token", "principal_name"] +) +"""Mechanism properties for MONGODB-OIDC authentication.""" + + def _build_credentials_tuple(mech, source, user, passwd, extra, database): """Build and return a mechanism specific credentials tuple.""" - if mech not in ("MONGODB-X509", "MONGODB-AWS") and user is None: + if mech not in ("MONGODB-X509", "MONGODB-AWS", "MONGODB-OIDC") and user is None: raise ConfigurationError("%s requires a username." % (mech,)) if mech == "GSSAPI": if source is not None and source != "$external": @@ -137,6 +146,20 @@ def _build_credentials_tuple(mech, source, user, passwd, extra, database): aws_props = _AWSProperties(aws_session_token=aws_session_token) # user can be None for temporary link-local EC2 credentials. return MongoCredential(mech, "$external", user, passwd, aws_props, None) + elif mech == "MONGODB-OIDC": + if source is not None and source != "$external": + raise ValueError("authentication source must be $external or None for MONGODB-ODIC") + properties = extra.get("authmechanismproperties", {}) + on_oidc_request_token = properties.get("on_oidc_request_token") + on_oidc_refresh_token = properties.get("on_oidc_refresh_token", on_oidc_request_token) + principal_name = properties.get("principal_name", "") + oidc_props = _OIDCProperties( + on_oidc_request_token=on_oidc_request_token, + on_oidc_refresh_token=on_oidc_refresh_token, + principal_name=principal_name, + ) + return MongoCredential(mech, "$external", user, passwd, oidc_props, None) + elif mech == "PLAIN": source_database = source or database or "$external" return MongoCredential(mech, source_database, user, passwd, None, None) @@ -458,6 +481,107 @@ def _authenticate_mongo_cr(credentials, sock_info): sock_info.command(source, query) +""" +interface OIDCRequestTokenParams { + authorizeEndpoint?: string; + tokenEndpoint?: string; + deviceAuthorizeEndpoint?: string; + clientId: string; + clientSecret?: string; + requestScopes?: string[]; +} + +interface OIDCRequestTokenResult { + accessToken: string + expiresInSeconds?: number + refreshToken?: string +} +""" + +_oidc_auth_cache = {} +_oidc_exp_utc = {} +# TODO: Offer another parameter that is the refresh buffer? +# TOOD: Make a dataclass for the client resp and the internal storage +_oidc_buffer_seconds = 5 * 60 + + +def _authenticate_oidc(credentials, sock_info): + """Authenticate using MONGODB-OIDC.""" + properties: _OIDCProperties = credentials.mechanism_properties + + # Send the SASL start with the optional principal name. + payload = dict() + principal_name = properties.principal_name + if principal_name: + payload["n"] = principal_name + + cmd = SON( + [ + ("saslStart", 1), + ("mechanism", "MONGODB-OIDC"), + ("payload", Binary(bson.encode(payload))), + ("autoAuthorize", 1), + ] + ) + response = sock_info.command("$external", cmd) + server_payload = bson.decode(response["payload"]) + client_resp = None + token = None + + if principal_name in _oidc_auth_cache: + client_resp = _oidc_auth_cache[principal_name] + now_utc = datetime.now(timezone.utc) + exp_utc = _oidc_exp_utc[principal_name] + if (exp_utc - now_utc).total_seconds() <= _oidc_buffer_seconds: + del _oidc_auth_cache[principal_name] + if properties.on_oidc_refresh_token: + client_resp = properties.on_oidc_refresh_token(server_payload, auth) + else: + client_resp = None + + if client_resp is None and properties.on_oidc_request_token is not None: + if principal_name in _oidc_auth_cache: + auth = _oidc_auth_cache[principal_name] + token = auth["access_token"] + else: + client_resp = properties.on_oidc_request_token(server_payload) + + if client_resp is not None: + token = client_resp["access_token"] + if "expires_in_seconds" in client_resp: + expires_in = client_resp["expires_in_seconds"] + if expires_in >= _oidc_buffer_seconds: + now_utc = datetime.now(timezone.utc) + exp_utc = now_utc + timedelta(seconds=expires_in) + _oidc_exp_utc[principal_name] = exp_utc + _oidc_auth_cache[principal_name] = client_resp.copy() + + else: + aws_identity_file = os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] + with open(aws_identity_file) as fid: + token = fid.read().strip() + + payload = dict(jwt=token) + cmd = SON( + [ + ("saslContinue", 1), + ("conversationId", response["conversationId"]), + ("payload", Binary(bson.encode(payload))), + ] + ) + + try: + response = sock_info.command("$external", cmd) + except Exception: + if principal_name in _oidc_auth_cache: + del _oidc_auth_cache[principal_name] + raise + + if not response["done"]: + del _oidc_auth_cache[principal_name] + raise OperationFailure("SASL conversation failed to complete.") + + def _authenticate_default(credentials, sock_info): if sock_info.max_wire_version >= 7: if sock_info.negotiated_mechs: @@ -482,6 +606,7 @@ def _authenticate_default(credentials, sock_info): "MONGODB-CR": _authenticate_mongo_cr, "MONGODB-X509": _authenticate_x509, "MONGODB-AWS": _authenticate_aws, + "MONGODB-OIDC": _authenticate_oidc, "PLAIN": _authenticate_plain, "SCRAM-SHA-1": functools.partial(_authenticate_scram, mechanism="SCRAM-SHA-1"), "SCRAM-SHA-256": functools.partial(_authenticate_scram, mechanism="SCRAM-SHA-256"), diff --git a/pymongo/common.py b/pymongo/common.py index add70cfb5f..162a27631e 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -16,6 +16,7 @@ """Functions and classes common to multiple pymongo modules.""" import datetime +import inspect import warnings from collections import OrderedDict, abc from typing import ( @@ -422,6 +423,25 @@ def validate_read_preference_tags(name: str, value: Any) -> List[Dict[str, str]] def validate_auth_mechanism_properties(option: str, value: Any) -> Dict[str, Union[bool, str]]: """Validate authMechanismProperties.""" + if not isinstance(value, str): + if not isinstance(value, dict): + raise ValueError("Auth mechansim properties must be given as a string or a dictionary") + props = {} + for key, value in value.items(): + if isinstance(value, str): + props[key] = value + elif inspect.isfunction(value): + signature = inspect.signature(value) + if len(signature.parameters) == 0: + msg = "Auth mechanisim properity callbacks must accept at least one value" + raise ValueError(msg) + props[key] = value + else: + raise ValueError( + "Auth mechanisim properity values must be strings or callback functions" + ) + return props + value = validate_string(option, value) props: Dict[str, Any] = {} for opt in value.split(","): diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index dccd4bb6b1..f65797bb66 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -38,6 +38,7 @@ from typing import ( TYPE_CHECKING, Any, + Callable, Dict, FrozenSet, Generic, diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py new file mode 100644 index 0000000000..d6b11eac73 --- /dev/null +++ b/test/auth_aws/test_auth_oidc.py @@ -0,0 +1,39 @@ +# Copyright 2020-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test MONGODB-OIDC Authentication.""" + +import os +import sys +import unittest + +sys.path[0:0] = [""] + +from pymongo import MongoClient + + +class TestAuthOIDC(unittest.TestCase): + uri: str + + @classmethod + def setUpClass(cls): + cls.uri = os.environ["MONGODB_URI"] + + def test_connect_environment_var(self): + with MongoClient(self.uri) as client: + client.get_database().test.find_one() + + +if __name__ == "__main__": + unittest.main() diff --git a/test_oidc.py b/test_oidc.py new file mode 100644 index 0000000000..30f76af4ef --- /dev/null +++ b/test_oidc.py @@ -0,0 +1,119 @@ +import os +import threading +import time +import webbrowser +from http.server import BaseHTTPRequestHandler, HTTPServer +from queue import Queue + +from requests_oauth2client import AuthorizationRequest, OAuth2Client + +from pymongo import MongoClient + +client_secret = os.getenv("IDP_CLIENT_SECRET") + +auth_data = dict( + authorizeEndpoint="https://corp.mongodb.com/oauth2/v1/authorize", + tokenEndpoint="https://corp.mongodb.com/oauth2/v1/token", + issuer="https://corp.mongodb.com", + clientId="0oadp0hpl7q3UIehP297", + clientSecret=client_secret, +) + + +LOCAL_PORT = 8888 +REDIRECT_URI = f"http://localhost:{LOCAL_PORT}/authorization-code/callback" +RESPONSE_QUEUE = Queue() +INIT_CALLED = 0 +REFRESH_CALLED = 0 + + +class MyRequestHandler(BaseHTTPRequestHandler): + def do_GET(self): + RESPONSE_QUEUE.put(self.path) + self.send_response(200) + + +def run_server(): + server = HTTPServer(("localhost", LOCAL_PORT), MyRequestHandler) + try: + server.serve_forever() + except KeyboardInterrupt: + pass + + server.server_close() + + +# Start a server on 8888 and expose a callback endpoint +# the tunnel address will be 8889 + + +def get_auth_token(auth_data): + print("Getting auth token") + global INIT_CALLED + INIT_CALLED += 1 + client_id = auth_data["clientId"] + client_secret = auth_data["clientSecret"] + token_endpoint = auth_data["tokenEndpoint"] + authorization_endpoint = auth_data["authorizeEndpoint"] + request = AuthorizationRequest( + authorization_endpoint, + client_id, + scope="openid", + redirect_uri=REDIRECT_URI, + code_challenge_method="S256", + ) + webbrowser.open(str(request)) + response_uri = RESPONSE_QUEUE.get() + response = request.validate_callback(response_uri) + client = OAuth2Client(token_endpoint, auth=(client_id, client_secret)) + token_response = client.token_request( + { + "grant_type": "authorization_code", + "code": response.code, + "redirect_uri": REDIRECT_URI, + "code_verifier": response.code_verifier, + } + ) + print("token:") + print(str(token_response.id_token)) + return dict(access_token=str(token_response.id_token), expires_in_seconds=5 * 60 + 3) + + +def refresh_auth_token(auth_data, orig_data): + global REFRESH_CALLED + REFRESH_CALLED += 1 + print("Refreshing auth token") + access_token = orig_data["access_token"] + return dict(access_token=access_token, expires_in_seconds=10 * 60) + + +thread = threading.Thread(target=run_server, daemon=True) +thread.start() + +# print(get_auth_token(auth_data)) + +# AWS device workflow test. +if "AWS_WEB_IDENTITY_TOKEN_FILE" in os.environ: + props = dict() + client = MongoClient(port=8889, authmechanismproperties=props, authmechanism="MONGODB-OIDC") + print(client.test.command("ping")) + +# Browser workflow test. +else: + # Test token expiration and refresh + props = dict(on_oidc_request_token=get_auth_token, on_oidc_refresh_token=refresh_auth_token) + client = MongoClient(port=8889, authmechanismproperties=props, authmechanism="MONGODB-OIDC") + print(client.test.command("ping")) + assert INIT_CALLED == 1 + print("Sleeping...") + time.sleep(4) + client2 = MongoClient(port=8889, authmechanismproperties=props, authmechanism="MONGODB-OIDC") + print(client2.test.command("ping")) + assert INIT_CALLED == 1 + assert REFRESH_CALLED == 1 + print("Sleeping...") + time.sleep(2) + client3 = MongoClient(port=8889, authmechanismproperties=props, authmechanism="MONGODB-OIDC") + print(client3.test.command("ping")) + assert INIT_CALLED == 1 + assert REFRESH_CALLED == 1 From 5048aea437c83c2947e09bfc705257f7a2f1d8f1 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 11 Jan 2023 16:42:40 -0600 Subject: [PATCH 002/133] fix log cat --- .evergreen/config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.evergreen/config.yml b/.evergreen/config.yml index 3f478ae6ec..1c6b9e6afd 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -165,7 +165,7 @@ functions: script: | set -o xtrace ${PREPARE_SHELL} - cat ${DRIVERS_TOOLS}/server.log + find $MONGO_ORCHESTRATION_HOME -name \*.log | xargs cat find $MONGO_ORCHESTRATION_HOME -name \*.log | xargs tar czf mongodb-logs.tar.gz - command: archive.targz_pack params: From 17972ebfb31a29fe282a18a8934bdc2f40f02c08 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 11 Jan 2023 17:02:26 -0600 Subject: [PATCH 003/133] print tokens --- .evergreen/config.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.evergreen/config.yml b/.evergreen/config.yml index 1c6b9e6afd..bfd8dcad6b 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -770,6 +770,10 @@ functions: # TODO: remove once we can target Ubuntu build sudo yum install -y lsof + echo "Got tokens!" + cat /tmp/tokens/test1 + cat /tmp/tokens/test2 + "run oidc auth test with aws credentials": - command: shell.exec type: test From 0380ded91a755c313ca6b7a1900cd5209284057d Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 11 Jan 2023 17:16:33 -0600 Subject: [PATCH 004/133] try again --- .evergreen/config.yml | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/.evergreen/config.yml b/.evergreen/config.yml index bfd8dcad6b..5fff136314 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -165,7 +165,6 @@ functions: script: | set -o xtrace ${PREPARE_SHELL} - find $MONGO_ORCHESTRATION_HOME -name \*.log | xargs cat find $MONGO_ORCHESTRATION_HOME -name \*.log | xargs tar czf mongodb-logs.tar.gz - command: archive.targz_pack params: @@ -195,16 +194,16 @@ functions: permissions: public-read content_type: ${content_type|application/x-gzip} display_name: "mongodb-logs.tar.gz" - - command: s3.put - params: - aws_key: ${aws_key} - aws_secret: ${aws_secret} - local_file: drivers-tools/.evergreen/orchestration/server.log - remote_file: ${UPLOAD_BUCKET}/${build_variant}/${revision}/${version_id}/${build_id}/logs/${task_id}-${execution}-orchestration.log - bucket: mciuploads - permissions: public-read - content_type: ${content_type|text/plain} - display_name: "orchestration.log" + # - command: s3.put + # params: + # aws_key: ${aws_key} + # aws_secret: ${aws_secret} + # local_file: drivers-tools/.evergreen/orchestration/server.log + # remote_file: ${UPLOAD_BUCKET}/${build_variant}/${revision}/${version_id}/${build_id}/logs/${task_id}-${execution}-orchestration.log + # bucket: mciuploads + # permissions: public-read + # content_type: ${content_type|text/plain} + # display_name: "orchestration.log" "upload working dir": - command: archive.targz_pack @@ -770,10 +769,6 @@ functions: # TODO: remove once we can target Ubuntu build sudo yum install -y lsof - echo "Got tokens!" - cat /tmp/tokens/test1 - cat /tmp/tokens/test2 - "run oidc auth test with aws credentials": - command: shell.exec type: test From 950df799b519134076b33752868361175fc04e48 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Tue, 17 Jan 2023 09:45:27 -0600 Subject: [PATCH 005/133] undo comments --- .evergreen/config.yml | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/.evergreen/config.yml b/.evergreen/config.yml index 5fff136314..b4e50c9aca 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -194,16 +194,16 @@ functions: permissions: public-read content_type: ${content_type|application/x-gzip} display_name: "mongodb-logs.tar.gz" - # - command: s3.put - # params: - # aws_key: ${aws_key} - # aws_secret: ${aws_secret} - # local_file: drivers-tools/.evergreen/orchestration/server.log - # remote_file: ${UPLOAD_BUCKET}/${build_variant}/${revision}/${version_id}/${build_id}/logs/${task_id}-${execution}-orchestration.log - # bucket: mciuploads - # permissions: public-read - # content_type: ${content_type|text/plain} - # display_name: "orchestration.log" + - command: s3.put + params: + aws_key: ${aws_key} + aws_secret: ${aws_secret} + local_file: drivers-tools/.evergreen/orchestration/server.log + remote_file: ${UPLOAD_BUCKET}/${build_variant}/${revision}/${version_id}/${build_id}/logs/${task_id}-${execution}-orchestration.log + bucket: mciuploads + permissions: public-read + content_type: ${content_type|text/plain} + display_name: "orchestration.log" "upload working dir": - command: archive.targz_pack From ab10e68969f67511c722deef36a5f55f8fe49df0 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 19 Jan 2023 16:31:30 -0600 Subject: [PATCH 006/133] test multiple principals --- .evergreen/config.yml | 9 +++++++-- .evergreen/run-mongodb-oidc-test.sh | 6 +++++- pymongo/auth.py | 26 ++++++++++++++++++++++++++ test/auth_aws/test_auth_oidc.py | 2 +- 4 files changed, 39 insertions(+), 4 deletions(-) diff --git a/.evergreen/config.yml b/.evergreen/config.yml index b4e50c9aca..c30a39ef94 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -792,7 +792,7 @@ functions: # DO NOT ECHO WITH XTRACE (which PREPARE_SHELL does) cat <<'EOF' > "${PROJECT_DIRECTORY}/prepare_mongodb_aws.sh" export AWS_WEB_IDENTITY_TOKEN_FILE=${AWS_WEB_IDENTITY_TOKEN_FILE|/tmp/tokens/test1} - export USE_MULTIPLE_PRINCIPALS=${USE_MULTIPLE_PRINCIPALS|false} + export USE_MULTIPLE_PRINCIPALS=${USE_MULTIPLE_PRINCIPALS} EOF - command: shell.exec type: test @@ -804,7 +804,7 @@ functions: echo "This platform does not support the oidc auth test, skipping..." exit 0 fi - PYTHON_BINARY=/opt/mongodbtoolchain/v4/bin/python3 ASSERT_NO_URI_CREDS=true USE_MULTIPLE_PRINCIPALS=${USE_MULTIPLE_PRINCIPALS|false} .evergreen/run-mongodb-oidc-test.sh + PYTHON_BINARY=/opt/mongodbtoolchain/v4/bin/python3 ASSERT_NO_URI_CREDS=true USE_MULTIPLE_PRINCIPALS=${USE_MULTIPLE_PRINCIPALS} .evergreen/run-mongodb-oidc-test.sh "run aws auth test with aws credentials as environment variables": - command: shell.exec @@ -2045,6 +2045,8 @@ tasks: - name: "oidc-auth-test-single-latest" commands: - func: "bootstrap oidc" + vars: + USE_MULTIPLE_PRINCIPALS: "false" - func: "bootstrap mongo-orchestration" vars: AUTH: "auth" @@ -2052,6 +2054,9 @@ tasks: TOPOLOGY: "server" VERSION: "latest" - func: "run oidc auth test with aws credentials" + vars: + USE_MULTIPLE_PRINCIPALS: "true" + AWS_WEB_IDENTITY_TOKEN_FILE: /tmp/tokens/test2 - name: "oidc-auth-test-multiple-latest" commands: diff --git a/.evergreen/run-mongodb-oidc-test.sh b/.evergreen/run-mongodb-oidc-test.sh index b42cc1d9e3..46fee17c9d 100755 --- a/.evergreen/run-mongodb-oidc-test.sh +++ b/.evergreen/run-mongodb-oidc-test.sh @@ -22,7 +22,11 @@ shopt -s expand_aliases # needed for `urlencode` alias [ -s "${PROJECT_DIRECTORY}/prepare_mongodb_aws.sh" ] && source "${PROJECT_DIRECTORY}/prepare_mongodb_aws.sh" MONGODB_URI=${MONGODB_URI:-"mongodb://localhost"} -MONGODB_URI="${MONGODB_URI}/test?authMechanism=MONGODB-OIDC" +MONGODB_URI="${MONGODB_URI}/test/?authMechanism=MONGODB-OIDC" + +if [ "$USE_MULTIPLE_PRINCIPALS" = "true" ]; then + MONGODB_URI="${MONGODB_URI}&authMechanismProperties=PRINCIPAL_NAME=717cc021e105be9843cd2005e5a4607beae5a4960ef8098cb1247481626090f8" +fi export MONGODB_URI="$MONGODB_URI" diff --git a/pymongo/auth.py b/pymongo/auth.py index fccd7cd6b9..4f8efce595 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -512,9 +512,34 @@ def _authenticate_oidc(credentials, sock_info): # Send the SASL start with the optional principal name. payload = dict() principal_name = properties.principal_name + if principal_name: payload["n"] = principal_name + if principal_name in _oidc_auth_cache: + auth = _oidc_auth_cache[principal_name] + payload = dict(jwt=auth["access_token"]) + cmd = SON( + [ + ("saslStart", 1), + ("mechanism", "MONGODB-OIDC"), + ("payload", Binary(bson.encode(payload))), + ("autoAuthorize", 1), + ] + ) + + try: + response = sock_info.command("$external", cmd) + except Exception: + if principal_name in _oidc_auth_cache: + del _oidc_auth_cache[principal_name] + raise + + if not response["done"]: + del _oidc_auth_cache[principal_name] + raise OperationFailure("SASL conversation failed to complete.") + return + cmd = SON( [ ("saslStart", 1), @@ -560,6 +585,7 @@ def _authenticate_oidc(credentials, sock_info): aws_identity_file = os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] with open(aws_identity_file) as fid: token = fid.read().strip() + _oidc_auth_cache[principal_name] = dict(access_token=token) payload = dict(jwt=token) cmd = SON( diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py index d6b11eac73..81cac09004 100644 --- a/test/auth_aws/test_auth_oidc.py +++ b/test/auth_aws/test_auth_oidc.py @@ -32,7 +32,7 @@ def setUpClass(cls): def test_connect_environment_var(self): with MongoClient(self.uri) as client: - client.get_database().test.find_one() + client.test.test.find_one() if __name__ == "__main__": From e9a8992293a4beb0fd8028feacbcaabb1087d265 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 19 Jan 2023 16:40:17 -0600 Subject: [PATCH 007/133] fix handling of principal name --- .evergreen/run-mongodb-oidc-test.sh | 6 ++++-- pymongo/auth.py | 2 +- pymongo/common.py | 9 ++++++++- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/.evergreen/run-mongodb-oidc-test.sh b/.evergreen/run-mongodb-oidc-test.sh index 46fee17c9d..17eda55d08 100755 --- a/.evergreen/run-mongodb-oidc-test.sh +++ b/.evergreen/run-mongodb-oidc-test.sh @@ -22,14 +22,16 @@ shopt -s expand_aliases # needed for `urlencode` alias [ -s "${PROJECT_DIRECTORY}/prepare_mongodb_aws.sh" ] && source "${PROJECT_DIRECTORY}/prepare_mongodb_aws.sh" MONGODB_URI=${MONGODB_URI:-"mongodb://localhost"} -MONGODB_URI="${MONGODB_URI}/test/?authMechanism=MONGODB-OIDC" +MONGODB_URI="${MONGODB_URI}/test?authMechanism=MONGODB-OIDC" if [ "$USE_MULTIPLE_PRINCIPALS" = "true" ]; then - MONGODB_URI="${MONGODB_URI}&authMechanismProperties=PRINCIPAL_NAME=717cc021e105be9843cd2005e5a4607beae5a4960ef8098cb1247481626090f8" + MONGODB_URI="${MONGODB_URI}&authMechanismProperties=PRINCIPAL_NAME:717cc021e105be9843cd2005e5a4607beae5a4960ef8098cb1247481626090f8" fi export MONGODB_URI="$MONGODB_URI" +echo $MONGODB_URI + if [ "$ASSERT_NO_URI_CREDS" = "true" ]; then if echo "$MONGODB_URI" | grep -q "@"; then echo "MONGODB_URI unexpectedly contains user credentials!"; diff --git a/pymongo/auth.py b/pymongo/auth.py index 4f8efce595..a72af7db05 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -152,7 +152,7 @@ def _build_credentials_tuple(mech, source, user, passwd, extra, database): properties = extra.get("authmechanismproperties", {}) on_oidc_request_token = properties.get("on_oidc_request_token") on_oidc_refresh_token = properties.get("on_oidc_refresh_token", on_oidc_request_token) - principal_name = properties.get("principal_name", "") + principal_name = properties.get("PRINCIPAL_NAME", "") oidc_props = _OIDCProperties( on_oidc_request_token=on_oidc_request_token, on_oidc_refresh_token=on_oidc_refresh_token, diff --git a/pymongo/common.py b/pymongo/common.py index 162a27631e..04c4117432 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -417,7 +417,14 @@ def validate_read_preference_tags(name: str, value: Any) -> List[Dict[str, str]] _MECHANISM_PROPS = frozenset( - ["SERVICE_NAME", "CANONICALIZE_HOST_NAME", "SERVICE_REALM", "AWS_SESSION_TOKEN"] + [ + "SERVICE_NAME", + "CANONICALIZE_HOST_NAME", + "SERVICE_REALM", + "AWS_SESSION_TOKEN", + "PRINCIPAL_NAME", + "DEVICE_NAME", + ] ) From 883d4273e8a67c1733847be9cda3c08909643ceb Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 19 Jan 2023 17:40:37 -0600 Subject: [PATCH 008/133] try different token --- .evergreen/config.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.evergreen/config.yml b/.evergreen/config.yml index c30a39ef94..60b1468b77 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -2055,8 +2055,8 @@ tasks: VERSION: "latest" - func: "run oidc auth test with aws credentials" vars: - USE_MULTIPLE_PRINCIPALS: "true" - AWS_WEB_IDENTITY_TOKEN_FILE: /tmp/tokens/test2 + USE_MULTIPLE_PRINCIPALS: "false" + AWS_WEB_IDENTITY_TOKEN_FILE: /tmp/tokens/test1 - name: "oidc-auth-test-multiple-latest" commands: @@ -2072,7 +2072,7 @@ tasks: - func: "run oidc auth test with aws credentials" vars: USE_MULTIPLE_PRINCIPALS: "true" - AWS_WEB_IDENTITY_TOKEN_FILE: /tmp/tokens/test2 + AWS_WEB_IDENTITY_TOKEN_FILE: /tmp/tokens/test1 - name: load-balancer-test commands: From ccf5a3f6029488e78bf329911e3b36a203fba5d8 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Fri, 20 Jan 2023 12:16:30 -0600 Subject: [PATCH 009/133] test with auth_oidc scripts --- .evergreen/config.yml | 27 ++++++++++----------------- .evergreen/run-mongodb-oidc-test.sh | 8 ++++---- 2 files changed, 14 insertions(+), 21 deletions(-) diff --git a/.evergreen/config.yml b/.evergreen/config.yml index 60b1468b77..e23becdf02 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -97,7 +97,8 @@ functions: # If this was a patch build, doing a fresh clone would not actually test the patch cp -R ${PROJECT_DIRECTORY}/ $DRIVERS_TOOLS else - git clone --branch DRIVERS-2508 https://github.com/blink1073/drivers-evergreen-tools.git $DRIVERS_TOOLS + # TODO: switch back to master before merging + git clone --branch DRIVERS-2415 https://github.com/blink1073/drivers-evergreen-tools.git $DRIVERS_TOOLS fi echo "{ \"releases\": { \"default\": \"$MONGODB_BINARIES\" }}" > $MONGO_ORCHESTRATION_HOME/orchestration.config @@ -285,13 +286,8 @@ functions: ${PREPARE_SHELL} - # TODO: remove path hack and mongodb download url - /opt/mongodbtoolchain/v4/bin/python3 -m pip install virtualenv - export PATH=/opt/mongodbtoolchain/v4/bin/:$PATH - MONGODB_VERSION=${VERSION} \ TOPOLOGY=${TOPOLOGY} \ - MONGODB_DOWNLOAD_URL=https://mciuploads.s3.amazonaws.com/mongodb-mongo-master/linux-x86-dynamic-compile-required/d2bb64fbd29269667d665c1f09066be0725b1d78/dist/mongo-mongodb_mongo_master_linux_x86_dynamic_compile_required_d2bb64fbd29269667d665c1f09066be0725b1d78_23_01_03_16_35_49.tgz \ AUTH=${AUTH} \ SSL=${SSL} \ STORAGE_ENGINE=${STORAGE_ENGINE} \ @@ -757,17 +753,14 @@ functions: exit 0 fi - cd ${DRIVERS_TOOLS}/.evergreen/auth_aws + cd ${DRIVERS_TOOLS}/.evergreen/auth_oidc export AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID} export AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY} export AWS_SESSION_TOKEN=${AWS_SESSION_TOKEN} export AWS_TOKEN_DIR=/tmp/tokens . ./activate_venv.sh - python lib/aws_oidc_bootstrap.py - - # TODO: remove once we can target Ubuntu build - sudo yum install -y lsof + python oidc_bootstrap.py "run oidc auth test with aws credentials": - command: shell.exec @@ -780,9 +773,9 @@ functions: echo "This platform does not support the oidc auth test, skipping..." exit 0 fi - cd ${DRIVERS_TOOLS}/.evergreen/auth_aws + cd ${DRIVERS_TOOLS}/.evergreen/auth_oidc . ./activate_venv.sh - mongo aws_setup_oidc.js + mongo setup_oidc.js - command: shell.exec type: test params: @@ -790,7 +783,7 @@ functions: silent: true script: | # DO NOT ECHO WITH XTRACE (which PREPARE_SHELL does) - cat <<'EOF' > "${PROJECT_DIRECTORY}/prepare_mongodb_aws.sh" + cat <<'EOF' > "${PROJECT_DIRECTORY}/prepare_mongodb_oidc.sh" export AWS_WEB_IDENTITY_TOKEN_FILE=${AWS_WEB_IDENTITY_TOKEN_FILE|/tmp/tokens/test1} export USE_MULTIPLE_PRINCIPALS=${USE_MULTIPLE_PRINCIPALS} EOF @@ -804,7 +797,7 @@ functions: echo "This platform does not support the oidc auth test, skipping..." exit 0 fi - PYTHON_BINARY=/opt/mongodbtoolchain/v4/bin/python3 ASSERT_NO_URI_CREDS=true USE_MULTIPLE_PRINCIPALS=${USE_MULTIPLE_PRINCIPALS} .evergreen/run-mongodb-oidc-test.sh + ASSERT_NO_URI_CREDS=true USE_MULTIPLE_PRINCIPALS=${USE_MULTIPLE_PRINCIPALS} .evergreen/run-mongodb-oidc-test.sh "run aws auth test with aws credentials as environment variables": - command: shell.exec @@ -3128,8 +3121,8 @@ buildvariants: - matrix_name: "oidc-auth-test" matrix_spec: - platform: [ rhel80-xlarge ] - python-version: ["3.7"] + platform: [ ubuntu-20.04 ] + python-version: ["3.9"] display_name: "MONGODB-OIDC Auth ${platform} ${python-version}" tasks: - name: "oidc-auth-test-single-latest" diff --git a/.evergreen/run-mongodb-oidc-test.sh b/.evergreen/run-mongodb-oidc-test.sh index 17eda55d08..861019f11f 100755 --- a/.evergreen/run-mongodb-oidc-test.sh +++ b/.evergreen/run-mongodb-oidc-test.sh @@ -9,7 +9,7 @@ set -o errexit # Exit the script with error if any of the commands fail # Supported/used environment variables: # MONGODB_URI Set the URI, including an optional username/password to use -# to connect to the server via MONGODB-AWS authentication +# to connect to the server via MONGODB-OIDC authentication # mechanism. # PYTHON_BINARY The Python version to use. @@ -19,7 +19,7 @@ set +x # load the script shopt -s expand_aliases # needed for `urlencode` alias -[ -s "${PROJECT_DIRECTORY}/prepare_mongodb_aws.sh" ] && source "${PROJECT_DIRECTORY}/prepare_mongodb_aws.sh" +[ -s "${PROJECT_DIRECTORY}/prepare_mongodb_oidc.sh" ] && source "${PROJECT_DIRECTORY}/prepare_mongodb_oidc.sh" MONGODB_URI=${MONGODB_URI:-"mongodb://localhost"} MONGODB_URI="${MONGODB_URI}/test?authMechanism=MONGODB-OIDC" @@ -44,7 +44,7 @@ set -x # Workaround macOS python 3.9 incompatibility with system virtualenv. if [ "$(uname -s)" = "Darwin" ]; then - VIRTUALENV="/Library/Frameworks/Python.framework/Versions/3.9/bin/python3 -m virtualenv" + VIRTUALENV="/Library/Frameworks/Python.framework/Versions/3.10/bin/python3 -m virtualenv" else VIRTUALENV=$(command -v virtualenv) fi @@ -54,7 +54,7 @@ authtest () { PYTHON=$(cygpath -m $PYTHON) fi - echo "Running MONGODB-AWS authentication tests with $PYTHON" + echo "Running MONGODB-OIDC authentication tests with $PYTHON" $PYTHON --version $VIRTUALENV -p $PYTHON --never-download venvoidc From be21e222543bd4fed5a091d515c5bd82facaad41 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Fri, 20 Jan 2023 12:39:38 -0600 Subject: [PATCH 010/133] use bash shell --- .evergreen/config.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.evergreen/config.yml b/.evergreen/config.yml index e23becdf02..3f7cf53542 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -746,6 +746,7 @@ functions: type: test params: working_dir: "src" + shell: bash script: | ${PREPARE_SHELL} if [ "${skip_EC2_auth_test}" = "true" ]; then @@ -767,6 +768,7 @@ functions: type: test params: working_dir: "src" + shell: bash script: | ${PREPARE_SHELL} if [ "${skip_EC2_auth_test}" = "true" ]; then From 1595f31c1f860d82bd3f39f0170f6d4e89c525d1 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Fri, 20 Jan 2023 12:45:34 -0600 Subject: [PATCH 011/133] fix handling of python bin --- .evergreen/config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.evergreen/config.yml b/.evergreen/config.yml index 3f7cf53542..75c652277a 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -799,7 +799,7 @@ functions: echo "This platform does not support the oidc auth test, skipping..." exit 0 fi - ASSERT_NO_URI_CREDS=true USE_MULTIPLE_PRINCIPALS=${USE_MULTIPLE_PRINCIPALS} .evergreen/run-mongodb-oidc-test.sh + PYTHON_BINARY=${PYTHON_BINARY} ASSERT_NO_URI_CREDS=true USE_MULTIPLE_PRINCIPALS=${USE_MULTIPLE_PRINCIPALS} .evergreen/run-mongodb-oidc-test.sh "run aws auth test with aws credentials as environment variables": - command: shell.exec From 8bd23c69ed43d58fad2f1ede74d1e9931ca9bbc6 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Mon, 23 Jan 2023 09:46:54 -0600 Subject: [PATCH 012/133] do not cache aws device creds --- pymongo/auth.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pymongo/auth.py b/pymongo/auth.py index a72af7db05..b11cb84275 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -585,7 +585,6 @@ def _authenticate_oidc(credentials, sock_info): aws_identity_file = os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] with open(aws_identity_file) as fid: token = fid.read().strip() - _oidc_auth_cache[principal_name] = dict(access_token=token) payload = dict(jwt=token) cmd = SON( From 0e9376211a4e50a3db3bff633ca841b988d04cff Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Tue, 24 Jan 2023 09:41:06 -0600 Subject: [PATCH 013/133] remove skipping server step 1 and add todo --- pymongo/auth.py | 26 ++------------------------ 1 file changed, 2 insertions(+), 24 deletions(-) diff --git a/pymongo/auth.py b/pymongo/auth.py index b11cb84275..7ed7639051 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -516,30 +516,6 @@ def _authenticate_oidc(credentials, sock_info): if principal_name: payload["n"] = principal_name - if principal_name in _oidc_auth_cache: - auth = _oidc_auth_cache[principal_name] - payload = dict(jwt=auth["access_token"]) - cmd = SON( - [ - ("saslStart", 1), - ("mechanism", "MONGODB-OIDC"), - ("payload", Binary(bson.encode(payload))), - ("autoAuthorize", 1), - ] - ) - - try: - response = sock_info.command("$external", cmd) - except Exception: - if principal_name in _oidc_auth_cache: - del _oidc_auth_cache[principal_name] - raise - - if not response["done"]: - del _oidc_auth_cache[principal_name] - raise OperationFailure("SASL conversation failed to complete.") - return - cmd = SON( [ ("saslStart", 1), @@ -553,6 +529,8 @@ def _authenticate_oidc(credentials, sock_info): client_resp = None token = None + # TODO: Update cache key to include id(callback) and the client_id + # from the server response. if principal_name in _oidc_auth_cache: client_resp = _oidc_auth_cache[principal_name] now_utc = datetime.now(timezone.utc) From eef2712238d7207768bbed1d82cc76a5f550f352 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Tue, 24 Jan 2023 10:52:34 -0600 Subject: [PATCH 014/133] update handling of cache key --- .evergreen/run-mongodb-oidc-test.sh | 1 + pymongo/auth.py | 20 ++++++++++++-------- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/.evergreen/run-mongodb-oidc-test.sh b/.evergreen/run-mongodb-oidc-test.sh index 861019f11f..1d3b738d45 100755 --- a/.evergreen/run-mongodb-oidc-test.sh +++ b/.evergreen/run-mongodb-oidc-test.sh @@ -44,6 +44,7 @@ set -x # Workaround macOS python 3.9 incompatibility with system virtualenv. if [ "$(uname -s)" = "Darwin" ]; then + # TODO: change back to 3.9 before merging. VIRTUALENV="/Library/Frameworks/Python.framework/Versions/3.10/bin/python3 -m virtualenv" else VIRTUALENV=$(command -v virtualenv) diff --git a/pymongo/auth.py b/pymongo/auth.py index 7ed7639051..9b90124048 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -529,14 +529,18 @@ def _authenticate_oidc(credentials, sock_info): client_resp = None token = None - # TODO: Update cache key to include id(callback) and the client_id - # from the server response. - if principal_name in _oidc_auth_cache: - client_resp = _oidc_auth_cache[principal_name] + # The cache key includes the client_id, the principal name, + # and the id of the request callback if provided. + cache_key = server_payload["client_id"] + principal_name + if properties.on_oidc_request_token: + cache_key += str(id(properties.on_oidc_request_token)) + + if cache_key in _oidc_auth_cache: + client_resp = _oidc_auth_cache[cache_key] now_utc = datetime.now(timezone.utc) - exp_utc = _oidc_exp_utc[principal_name] + exp_utc = _oidc_exp_utc[cache_key] if (exp_utc - now_utc).total_seconds() <= _oidc_buffer_seconds: - del _oidc_auth_cache[principal_name] + del _oidc_auth_cache[cache_key] if properties.on_oidc_refresh_token: client_resp = properties.on_oidc_refresh_token(server_payload, auth) else: @@ -544,7 +548,7 @@ def _authenticate_oidc(credentials, sock_info): if client_resp is None and properties.on_oidc_request_token is not None: if principal_name in _oidc_auth_cache: - auth = _oidc_auth_cache[principal_name] + auth = _oidc_auth_cache[cache_key] token = auth["access_token"] else: client_resp = properties.on_oidc_request_token(server_payload) @@ -557,7 +561,7 @@ def _authenticate_oidc(credentials, sock_info): now_utc = datetime.now(timezone.utc) exp_utc = now_utc + timedelta(seconds=expires_in) _oidc_exp_utc[principal_name] = exp_utc - _oidc_auth_cache[principal_name] = client_resp.copy() + _oidc_auth_cache[cache_key] = client_resp.copy() else: aws_identity_file = os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] From 8e5638e13905617f73d18dd6705172c6664323b3 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Tue, 24 Jan 2023 10:53:01 -0600 Subject: [PATCH 015/133] fix clientId name --- pymongo/auth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymongo/auth.py b/pymongo/auth.py index 9b90124048..1b907d7b0f 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -531,7 +531,7 @@ def _authenticate_oidc(credentials, sock_info): # The cache key includes the client_id, the principal name, # and the id of the request callback if provided. - cache_key = server_payload["client_id"] + principal_name + cache_key = server_payload["clientId"] + principal_name if properties.on_oidc_request_token: cache_key += str(id(properties.on_oidc_request_token)) From 905b89d33bbc375797961b02eadba5b819c0663a Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Tue, 24 Jan 2023 12:09:56 -0600 Subject: [PATCH 016/133] handle when there is no principal name --- pymongo/auth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymongo/auth.py b/pymongo/auth.py index 1b907d7b0f..4f4b274f58 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -531,7 +531,7 @@ def _authenticate_oidc(credentials, sock_info): # The cache key includes the client_id, the principal name, # and the id of the request callback if provided. - cache_key = server_payload["clientId"] + principal_name + cache_key = server_payload["clientId"] + str(principal_name) if properties.on_oidc_request_token: cache_key += str(id(properties.on_oidc_request_token)) From 43d49ad6f53c47a67f441433f35cf44f0b19acb1 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Tue, 24 Jan 2023 12:15:27 -0600 Subject: [PATCH 017/133] handle device_name --- pymongo/auth.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/pymongo/auth.py b/pymongo/auth.py index 4f4b274f58..39f3d8f5c4 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -103,7 +103,8 @@ def __hash__(self): _OIDCProperties = namedtuple( - "_OIDCProperties", ["on_oidc_request_token", "on_oidc_refresh_token", "principal_name"] + "_OIDCProperties", + ["on_oidc_request_token", "on_oidc_refresh_token", "principal_name", "device_name"], ) """Mechanism properties for MONGODB-OIDC authentication.""" @@ -148,15 +149,21 @@ def _build_credentials_tuple(mech, source, user, passwd, extra, database): return MongoCredential(mech, "$external", user, passwd, aws_props, None) elif mech == "MONGODB-OIDC": if source is not None and source != "$external": - raise ValueError("authentication source must be $external or None for MONGODB-ODIC") + raise ValueError("authentication source must be $external or None for MONGODB-OIDC") properties = extra.get("authmechanismproperties", {}) on_oidc_request_token = properties.get("on_oidc_request_token") on_oidc_refresh_token = properties.get("on_oidc_refresh_token", on_oidc_request_token) principal_name = properties.get("PRINCIPAL_NAME", "") + device_name = properties.get("DEVICE_NAME", "") + if not on_oidc_request_token and device_name != "aws": + raise ConfigurationError( + "authentication with MONGODB-OIDC requires providing an on_oidc_request_token or a device_name of 'aws'" + ) oidc_props = _OIDCProperties( on_oidc_request_token=on_oidc_request_token, on_oidc_refresh_token=on_oidc_refresh_token, principal_name=principal_name, + device_name=device_name, ) return MongoCredential(mech, "$external", user, passwd, oidc_props, None) @@ -500,8 +507,7 @@ def _authenticate_mongo_cr(credentials, sock_info): _oidc_auth_cache = {} _oidc_exp_utc = {} -# TODO: Offer another parameter that is the refresh buffer? -# TOOD: Make a dataclass for the client resp and the internal storage +# TOOD: Make a namedtuple for the client resp and the internal storage _oidc_buffer_seconds = 5 * 60 @@ -563,7 +569,7 @@ def _authenticate_oidc(credentials, sock_info): _oidc_exp_utc[principal_name] = exp_utc _oidc_auth_cache[cache_key] = client_resp.copy() - else: + elif properties.device_name == "aws": aws_identity_file = os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] with open(aws_identity_file) as fid: token = fid.read().strip() From e93c4c007cc5db216cb3833003460687f5891a00 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Tue, 24 Jan 2023 16:17:01 -0600 Subject: [PATCH 018/133] add lock --- pymongo/auth.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/pymongo/auth.py b/pymongo/auth.py index 39f3d8f5c4..83fd39c0a7 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -19,6 +19,7 @@ import hmac import os import socket +import threading from base64 import standard_b64decode, standard_b64encode from collections import namedtuple from datetime import datetime, timedelta, timezone @@ -509,6 +510,7 @@ def _authenticate_mongo_cr(credentials, sock_info): _oidc_exp_utc = {} # TOOD: Make a namedtuple for the client resp and the internal storage _oidc_buffer_seconds = 5 * 60 +_oidc_locks = {} def _authenticate_oidc(credentials, sock_info): @@ -538,6 +540,14 @@ def _authenticate_oidc(credentials, sock_info): # The cache key includes the client_id, the principal name, # and the id of the request callback if provided. cache_key = server_payload["clientId"] + str(principal_name) + + if cache_key not in _oidc_locks: + _oidc_locks[cache_key] = threading.Lock() + + lock = _oidc_locks[cache_key] + + lock.acquire() + if properties.on_oidc_request_token: cache_key += str(id(properties.on_oidc_request_token)) @@ -559,6 +569,8 @@ def _authenticate_oidc(credentials, sock_info): else: client_resp = properties.on_oidc_request_token(server_payload) + lock.release() + if client_resp is not None: token = client_resp["access_token"] if "expires_in_seconds" in client_resp: From 324437e3a275b9ec478c25cdf122dd0532b687fd Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Tue, 24 Jan 2023 16:17:23 -0600 Subject: [PATCH 019/133] fix locking logic --- pymongo/auth.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pymongo/auth.py b/pymongo/auth.py index 83fd39c0a7..5a82f0679d 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -540,17 +540,15 @@ def _authenticate_oidc(credentials, sock_info): # The cache key includes the client_id, the principal name, # and the id of the request callback if provided. cache_key = server_payload["clientId"] + str(principal_name) + if properties.on_oidc_request_token: + cache_key += str(id(properties.on_oidc_request_token)) if cache_key not in _oidc_locks: _oidc_locks[cache_key] = threading.Lock() lock = _oidc_locks[cache_key] - lock.acquire() - if properties.on_oidc_request_token: - cache_key += str(id(properties.on_oidc_request_token)) - if cache_key in _oidc_auth_cache: client_resp = _oidc_auth_cache[cache_key] now_utc = datetime.now(timezone.utc) From 0ddeaf34b1800dbf97a08bb142605d7d2b7af1ea Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 25 Jan 2023 17:05:58 -0600 Subject: [PATCH 020/133] update for running both servers on the same replicaset --- .evergreen/config.yml | 29 ++++------------------------- .evergreen/run-mongodb-oidc-test.sh | 4 ++-- 2 files changed, 6 insertions(+), 27 deletions(-) diff --git a/.evergreen/config.yml b/.evergreen/config.yml index 75c652277a..341b1e3124 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -787,7 +787,6 @@ functions: # DO NOT ECHO WITH XTRACE (which PREPARE_SHELL does) cat <<'EOF' > "${PROJECT_DIRECTORY}/prepare_mongodb_oidc.sh" export AWS_WEB_IDENTITY_TOKEN_FILE=${AWS_WEB_IDENTITY_TOKEN_FILE|/tmp/tokens/test1} - export USE_MULTIPLE_PRINCIPALS=${USE_MULTIPLE_PRINCIPALS} EOF - command: shell.exec type: test @@ -799,7 +798,7 @@ functions: echo "This platform does not support the oidc auth test, skipping..." exit 0 fi - PYTHON_BINARY=${PYTHON_BINARY} ASSERT_NO_URI_CREDS=true USE_MULTIPLE_PRINCIPALS=${USE_MULTIPLE_PRINCIPALS} .evergreen/run-mongodb-oidc-test.sh + PYTHON_BINARY=${PYTHON_BINARY} ASSERT_NO_URI_CREDS=true .evergreen/run-mongodb-oidc-test.sh "run aws auth test with aws credentials as environment variables": - command: shell.exec @@ -2037,36 +2036,17 @@ tasks: - func: "run aws auth test with aws web identity credentials" - func: "run aws ECS auth test" - - name: "oidc-auth-test-single-latest" + - name: "oidc-auth-test-latest" commands: - func: "bootstrap oidc" - vars: - USE_MULTIPLE_PRINCIPALS: "false" - - func: "bootstrap mongo-orchestration" - vars: - AUTH: "auth" - ORCHESTRATION_FILE: "auth-oidc.json" - TOPOLOGY: "server" - VERSION: "latest" - - func: "run oidc auth test with aws credentials" - vars: - USE_MULTIPLE_PRINCIPALS: "false" - AWS_WEB_IDENTITY_TOKEN_FILE: /tmp/tokens/test1 - - - name: "oidc-auth-test-multiple-latest" - commands: - - func: "bootstrap oidc" - vars: - USE_MULTIPLE_PRINCIPALS: "true" - func: "bootstrap mongo-orchestration" vars: AUTH: "auth" ORCHESTRATION_FILE: "auth-oidc.json" - TOPOLOGY: "server" + TOPOLOGY: "replica_set" VERSION: "latest" - func: "run oidc auth test with aws credentials" vars: - USE_MULTIPLE_PRINCIPALS: "true" AWS_WEB_IDENTITY_TOKEN_FILE: /tmp/tokens/test1 - name: load-balancer-test @@ -3127,8 +3107,7 @@ buildvariants: python-version: ["3.9"] display_name: "MONGODB-OIDC Auth ${platform} ${python-version}" tasks: - - name: "oidc-auth-test-single-latest" - - name: "oidc-auth-test-multiple-latest" + - name: "oidc-auth-test-latest" - matrix_name: "aws-auth-test" matrix_spec: diff --git a/.evergreen/run-mongodb-oidc-test.sh b/.evergreen/run-mongodb-oidc-test.sh index 1d3b738d45..7bb8e91ab1 100755 --- a/.evergreen/run-mongodb-oidc-test.sh +++ b/.evergreen/run-mongodb-oidc-test.sh @@ -21,8 +21,8 @@ set +x shopt -s expand_aliases # needed for `urlencode` alias [ -s "${PROJECT_DIRECTORY}/prepare_mongodb_oidc.sh" ] && source "${PROJECT_DIRECTORY}/prepare_mongodb_oidc.sh" -MONGODB_URI=${MONGODB_URI:-"mongodb://localhost"} -MONGODB_URI="${MONGODB_URI}/test?authMechanism=MONGODB-OIDC" +MONGODB_URI=${MONGODB_URI:-"mongodb://localhost:27018"} +MONGODB_URI="${MONGODB_URI}/test?authMechanism=MONGODB-OIDC&directConnection=true&authMechanismProperties=DEVICE_NAME:aws" if [ "$USE_MULTIPLE_PRINCIPALS" = "true" ]; then MONGODB_URI="${MONGODB_URI}&authMechanismProperties=PRINCIPAL_NAME:717cc021e105be9843cd2005e5a4607beae5a4960ef8098cb1247481626090f8" From b9c0d47252efc576645dc6ad2a45bb110e619a09 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 25 Jan 2023 17:07:11 -0600 Subject: [PATCH 021/133] update oidc bootstrap --- .evergreen/config.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.evergreen/config.yml b/.evergreen/config.yml index 341b1e3124..998044dd09 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -761,7 +761,8 @@ functions: export AWS_TOKEN_DIR=/tmp/tokens . ./activate_venv.sh - python oidc_bootstrap.py + python oidc_write_orchestration.py + python oidc_get_tokens.py "run oidc auth test with aws credentials": - command: shell.exec From 0313d6cf92f2b84fea794ad1339e1a5794ec0098 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 26 Jan 2023 12:38:26 -0600 Subject: [PATCH 022/133] fix lint and typings --- pymongo/auth.py | 12 ++++++------ pymongo/common.py | 3 +-- pymongo/mongo_client.py | 1 - 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/pymongo/auth.py b/pymongo/auth.py index 5a82f0679d..40249b27bb 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -23,7 +23,7 @@ from base64 import standard_b64decode, standard_b64encode from collections import namedtuple from datetime import datetime, timedelta, timezone -from typing import Callable, Mapping +from typing import Callable, Dict, Mapping from urllib.parse import quote import bson @@ -506,11 +506,11 @@ def _authenticate_mongo_cr(credentials, sock_info): } """ -_oidc_auth_cache = {} -_oidc_exp_utc = {} +_oidc_auth_cache: Dict = {} +_oidc_exp_utc: Dict = {} # TOOD: Make a namedtuple for the client resp and the internal storage _oidc_buffer_seconds = 5 * 60 -_oidc_locks = {} +_oidc_locks: Dict = {} def _authenticate_oidc(credentials, sock_info): @@ -533,7 +533,7 @@ def _authenticate_oidc(credentials, sock_info): ] ) response = sock_info.command("$external", cmd) - server_payload = bson.decode(response["payload"]) + server_payload: Dict = bson.decode(response["payload"]) client_resp = None token = None @@ -554,7 +554,7 @@ def _authenticate_oidc(credentials, sock_info): now_utc = datetime.now(timezone.utc) exp_utc = _oidc_exp_utc[cache_key] if (exp_utc - now_utc).total_seconds() <= _oidc_buffer_seconds: - del _oidc_auth_cache[cache_key] + auth = _oidc_auth_cache.pop(cache_key) if properties.on_oidc_refresh_token: client_resp = properties.on_oidc_refresh_token(server_payload, auth) else: diff --git a/pymongo/common.py b/pymongo/common.py index 04c4117432..497f470ac9 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -430,10 +430,10 @@ def validate_read_preference_tags(name: str, value: Any) -> List[Dict[str, str]] def validate_auth_mechanism_properties(option: str, value: Any) -> Dict[str, Union[bool, str]]: """Validate authMechanismProperties.""" + props: Dict[str, Any] = {} if not isinstance(value, str): if not isinstance(value, dict): raise ValueError("Auth mechansim properties must be given as a string or a dictionary") - props = {} for key, value in value.items(): if isinstance(value, str): props[key] = value @@ -450,7 +450,6 @@ def validate_auth_mechanism_properties(option: str, value: Any) -> Dict[str, Uni return props value = validate_string(option, value) - props: Dict[str, Any] = {} for opt in value.split(","): try: key, val = opt.split(":") diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index 782c91f710..ab0c749889 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -38,7 +38,6 @@ from typing import ( TYPE_CHECKING, Any, - Callable, Dict, FrozenSet, Generic, From b02a969fae9d44891de6176694110b2af51a0b0c Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 26 Jan 2023 12:45:28 -0600 Subject: [PATCH 023/133] test 27017 --- .evergreen/run-mongodb-oidc-test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.evergreen/run-mongodb-oidc-test.sh b/.evergreen/run-mongodb-oidc-test.sh index 7bb8e91ab1..0ea5d3a828 100755 --- a/.evergreen/run-mongodb-oidc-test.sh +++ b/.evergreen/run-mongodb-oidc-test.sh @@ -21,7 +21,7 @@ set +x shopt -s expand_aliases # needed for `urlencode` alias [ -s "${PROJECT_DIRECTORY}/prepare_mongodb_oidc.sh" ] && source "${PROJECT_DIRECTORY}/prepare_mongodb_oidc.sh" -MONGODB_URI=${MONGODB_URI:-"mongodb://localhost:27018"} +MONGODB_URI=${MONGODB_URI:-"mongodb://localhost:27017"} MONGODB_URI="${MONGODB_URI}/test?authMechanism=MONGODB-OIDC&directConnection=true&authMechanismProperties=DEVICE_NAME:aws" if [ "$USE_MULTIPLE_PRINCIPALS" = "true" ]; then From 8d6d503af30e559497f6075d10d5044cf4e3e112 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 26 Jan 2023 17:07:42 -0600 Subject: [PATCH 024/133] update for spec compliance --- pymongo/auth.py | 84 ++++++++++++++++++++++++++++++------------------- 1 file changed, 52 insertions(+), 32 deletions(-) diff --git a/pymongo/auth.py b/pymongo/auth.py index 40249b27bb..bfbf73fc4a 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -508,6 +508,7 @@ def _authenticate_mongo_cr(credentials, sock_info): _oidc_auth_cache: Dict = {} _oidc_exp_utc: Dict = {} +_oidc_server_cache: Dict = {} # TOOD: Make a namedtuple for the client resp and the internal storage _oidc_buffer_seconds = 5 * 60 _oidc_locks: Dict = {} @@ -517,38 +518,59 @@ def _authenticate_oidc(credentials, sock_info): """Authenticate using MONGODB-OIDC.""" properties: _OIDCProperties = credentials.mechanism_properties - # Send the SASL start with the optional principal name. - payload = dict() + if properties.device_name == "aws": + aws_identity_file = os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] + with open(aws_identity_file) as fid: + token = fid.read().strip() + payload = dict(jwt=token) + cmd = SON( + [ + ("saslStart", 1), + ("payload", Binary(bson.encode(payload))), + ] + ) + response = sock_info.command("$external", cmd) + if not response["done"]: + raise OperationFailure("SASL conversation failed to complete.") + return + + address = sock_info.address principal_name = properties.principal_name + cache_key = f"{principal_name}{address[0]}{address[1]}" - if principal_name: - payload["n"] = principal_name + skip_step1 = cache_key in _oidc_server_cache + conversation_id = None + + if not skip_step1: + # Send the SASL start with the optional principal name. + payload = dict() + + if principal_name: + payload["n"] = principal_name + + cmd = SON( + [ + ("saslStart", 1), + ("mechanism", "MONGODB-OIDC"), + ("payload", Binary(bson.encode(payload))), + ("autoAuthorize", 1), + ] + ) + response = sock_info.command("$external", cmd) + _oidc_server_cache[cache_key] = bson.decode(response["payload"]) + conversation_id = response["conversationId"] - cmd = SON( - [ - ("saslStart", 1), - ("mechanism", "MONGODB-OIDC"), - ("payload", Binary(bson.encode(payload))), - ("autoAuthorize", 1), - ] - ) - response = sock_info.command("$external", cmd) - server_payload: Dict = bson.decode(response["payload"]) client_resp = None token = None - # The cache key includes the client_id, the principal name, - # and the id of the request callback if provided. - cache_key = server_payload["clientId"] + str(principal_name) - if properties.on_oidc_request_token: - cache_key += str(id(properties.on_oidc_request_token)) - if cache_key not in _oidc_locks: _oidc_locks[cache_key] = threading.Lock() lock = _oidc_locks[cache_key] lock.acquire() + server_payload = _oidc_server_cache[cache_key] + if cache_key in _oidc_auth_cache: client_resp = _oidc_auth_cache[cache_key] now_utc = datetime.now(timezone.utc) @@ -579,19 +601,17 @@ def _authenticate_oidc(credentials, sock_info): _oidc_exp_utc[principal_name] = exp_utc _oidc_auth_cache[cache_key] = client_resp.copy() - elif properties.device_name == "aws": - aws_identity_file = os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] - with open(aws_identity_file) as fid: - token = fid.read().strip() + payload = Binary(bson.encode(dict(jwt=token))) - payload = dict(jwt=token) - cmd = SON( - [ - ("saslContinue", 1), - ("conversationId", response["conversationId"]), - ("payload", Binary(bson.encode(payload))), - ] - ) + if skip_step1: + cmd = SON([("saslContinue", 1), ("conversationId", conversation_id), ("payload", payload)]) + else: + cmd = SON( + [ + ("saslStart", 1), + ("payload", payload), + ] + ) try: response = sock_info.command("$external", cmd) From 1a12de9567214f4ba11a3b0e73b3633064c968a4 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 26 Jan 2023 17:15:37 -0600 Subject: [PATCH 025/133] add mechanism field --- pymongo/auth.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/pymongo/auth.py b/pymongo/auth.py index bfbf73fc4a..bfb26590ca 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -526,6 +526,7 @@ def _authenticate_oidc(credentials, sock_info): cmd = SON( [ ("saslStart", 1), + ("mechanism", "MONGODB-OIDC"), ("payload", Binary(bson.encode(payload))), ] ) @@ -604,11 +605,19 @@ def _authenticate_oidc(credentials, sock_info): payload = Binary(bson.encode(dict(jwt=token))) if skip_step1: - cmd = SON([("saslContinue", 1), ("conversationId", conversation_id), ("payload", payload)]) + cmd = SON( + [ + ("saslContinue", 1), + ("mechanism", "MONGODB-OIDC"), + ("conversationId", conversation_id), + ("payload", payload), + ] + ) else: cmd = SON( [ ("saslStart", 1), + ("mechanism", "MONGODB-OIDC"), ("payload", payload), ] ) From f4083865e55182488d26dcd3f9b2caa71ed3a747 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 26 Jan 2023 20:28:50 -0600 Subject: [PATCH 026/133] clean up caching behavior --- pymongo/auth.py | 123 +++++++++++++++++++++++++++--------------------- 1 file changed, 69 insertions(+), 54 deletions(-) diff --git a/pymongo/auth.py b/pymongo/auth.py index bfb26590ca..12ca4e6fca 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -21,9 +21,10 @@ import socket import threading from base64 import standard_b64decode, standard_b64encode -from collections import namedtuple +from collections import OrderedDict, namedtuple +from dataclasses import dataclass from datetime import datetime, timedelta, timezone -from typing import Callable, Dict, Mapping +from typing import Callable, Dict, Mapping, Optional from urllib.parse import quote import bson @@ -110,6 +111,14 @@ def __hash__(self): """Mechanism properties for MONGODB-OIDC authentication.""" +@dataclass +class _OIDCCache: + token_result: Optional[Dict] + exp_utc: Optional[datetime] + server_resp: Optional[Dict] + lock: threading.Lock + + def _build_credentials_tuple(mech, source, user, passwd, extra, database): """Build and return a mechanism specific credentials tuple.""" if mech not in ("MONGODB-X509", "MONGODB-AWS", "MONGODB-OIDC") and user is None: @@ -506,12 +515,30 @@ def _authenticate_mongo_cr(credentials, sock_info): } """ -_oidc_auth_cache: Dict = {} -_oidc_exp_utc: Dict = {} -_oidc_server_cache: Dict = {} -# TOOD: Make a namedtuple for the client resp and the internal storage + +class _OIDCLRUCache: + def __init__(self): + self.cache = OrderedDict() + + def get(self, key: str) -> Optional[_OIDCCache]: + if key not in self.cache: + return None + else: + self.cache.move_to_end(key) + return self.cache[key] + + def put(self, key: str, value: _OIDCCache) -> None: + self.cache[key] = value + self.cache.move_to_end(key) + if len(self.cache) > 100: + self.cache.popitem(last=False) + + def remove(self, key: str): + self.cache.pop(key, None) + + +_oidc_cache = _OIDCLRUCache() _oidc_buffer_seconds = 5 * 60 -_oidc_locks: Dict = {} def _authenticate_oidc(credentials, sock_info): @@ -539,10 +566,14 @@ def _authenticate_oidc(credentials, sock_info): principal_name = properties.principal_name cache_key = f"{principal_name}{address[0]}{address[1]}" - skip_step1 = cache_key in _oidc_server_cache + cache_value = _oidc_cache.get(cache_key) conversation_id = None - if not skip_step1: + if cache_value is None: + lock = threading.Lock() + cache_value = _OIDCCache(lock=lock, token_result=None, server_resp=None, exp_utc=None) + _oidc_cache.put(cache_key, cache_value) + # Send the SASL start with the optional principal name. payload = dict() @@ -558,53 +589,38 @@ def _authenticate_oidc(credentials, sock_info): ] ) response = sock_info.command("$external", cmd) - _oidc_server_cache[cache_key] = bson.decode(response["payload"]) + cache_value.server_resp = bson.decode(response["payload"]) conversation_id = response["conversationId"] - client_resp = None - token = None - - if cache_key not in _oidc_locks: - _oidc_locks[cache_key] = threading.Lock() - - lock = _oidc_locks[cache_key] - lock.acquire() - - server_payload = _oidc_server_cache[cache_key] - - if cache_key in _oidc_auth_cache: - client_resp = _oidc_auth_cache[cache_key] - now_utc = datetime.now(timezone.utc) - exp_utc = _oidc_exp_utc[cache_key] - if (exp_utc - now_utc).total_seconds() <= _oidc_buffer_seconds: - auth = _oidc_auth_cache.pop(cache_key) - if properties.on_oidc_refresh_token: - client_resp = properties.on_oidc_refresh_token(server_payload, auth) - else: - client_resp = None - - if client_resp is None and properties.on_oidc_request_token is not None: - if principal_name in _oidc_auth_cache: - auth = _oidc_auth_cache[cache_key] - token = auth["access_token"] - else: - client_resp = properties.on_oidc_request_token(server_payload) - - lock.release() - - if client_resp is not None: - token = client_resp["access_token"] - if "expires_in_seconds" in client_resp: - expires_in = client_resp["expires_in_seconds"] - if expires_in >= _oidc_buffer_seconds: - now_utc = datetime.now(timezone.utc) - exp_utc = now_utc + timedelta(seconds=expires_in) - _oidc_exp_utc[principal_name] = exp_utc - _oidc_auth_cache[cache_key] = client_resp.copy() + with cache_value.lock: + if cache_value.exp_utc is not None: + now_utc = datetime.now(timezone.utc) + exp_utc = cache_value.exp_utc + if (exp_utc - now_utc).total_seconds() <= _oidc_buffer_seconds: + if properties.on_oidc_refresh_token: + cache_value.token_result = properties.on_oidc_refresh_token( + cache_value.server_resp, cache_value.token_result + ) + else: + cache_value.token_result = None + + if cache_value.token_result is None: + cache_value.token_result = properties.on_oidc_request_token(cache_value.server_resp) + + token_result = cache_value.token_result + token = token_result["access_token"] + if "expires_in_seconds" in token_result: + expires_in = token_result["expires_in_seconds"] + if expires_in >= _oidc_buffer_seconds: + now_utc = datetime.now(timezone.utc) + exp_utc = now_utc + timedelta(seconds=expires_in) + cache_value.exp_utc = exp_utc + else: + _oidc_cache.remove(cache_key) payload = Binary(bson.encode(dict(jwt=token))) - if skip_step1: + if conversation_id is not None: cmd = SON( [ ("saslContinue", 1), @@ -625,12 +641,11 @@ def _authenticate_oidc(credentials, sock_info): try: response = sock_info.command("$external", cmd) except Exception: - if principal_name in _oidc_auth_cache: - del _oidc_auth_cache[principal_name] + _oidc_cache.remove(cache_key) raise if not response["done"]: - del _oidc_auth_cache[principal_name] + _oidc_cache.remove(cache_key) raise OperationFailure("SASL conversation failed to complete.") From b7f161a70650958e0750071f74b2c5557847b671 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Fri, 27 Jan 2023 06:07:22 -0600 Subject: [PATCH 027/133] typing fixes --- pymongo/auth.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pymongo/auth.py b/pymongo/auth.py index 12ca4e6fca..f747f4f390 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -533,7 +533,7 @@ def put(self, key: str, value: _OIDCCache) -> None: if len(self.cache) > 100: self.cache.popitem(last=False) - def remove(self, key: str): + def remove(self, key: str) -> None: self.cache.pop(key, None) @@ -618,7 +618,7 @@ def _authenticate_oidc(credentials, sock_info): else: _oidc_cache.remove(cache_key) - payload = Binary(bson.encode(dict(jwt=token))) + bin_payload = Binary(bson.encode(dict(jwt=token))) if conversation_id is not None: cmd = SON( @@ -626,7 +626,7 @@ def _authenticate_oidc(credentials, sock_info): ("saslContinue", 1), ("mechanism", "MONGODB-OIDC"), ("conversationId", conversation_id), - ("payload", payload), + ("payload", bin_payload), ] ) else: @@ -634,7 +634,7 @@ def _authenticate_oidc(credentials, sock_info): [ ("saslStart", 1), ("mechanism", "MONGODB-OIDC"), - ("payload", payload), + ("payload", bin_payload), ] ) From 2826753e4295c6f67e307d6cc22867642d338328 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Mon, 30 Jan 2023 10:50:40 -0600 Subject: [PATCH 028/133] clean up caching behavior --- .evergreen/run-mongodb-oidc-test.sh | 2 +- pymongo/auth.py | 79 ++++++++++++++++------------- 2 files changed, 45 insertions(+), 36 deletions(-) diff --git a/.evergreen/run-mongodb-oidc-test.sh b/.evergreen/run-mongodb-oidc-test.sh index 0ea5d3a828..7bb8e91ab1 100755 --- a/.evergreen/run-mongodb-oidc-test.sh +++ b/.evergreen/run-mongodb-oidc-test.sh @@ -21,7 +21,7 @@ set +x shopt -s expand_aliases # needed for `urlencode` alias [ -s "${PROJECT_DIRECTORY}/prepare_mongodb_oidc.sh" ] && source "${PROJECT_DIRECTORY}/prepare_mongodb_oidc.sh" -MONGODB_URI=${MONGODB_URI:-"mongodb://localhost:27017"} +MONGODB_URI=${MONGODB_URI:-"mongodb://localhost:27018"} MONGODB_URI="${MONGODB_URI}/test?authMechanism=MONGODB-OIDC&directConnection=true&authMechanismProperties=DEVICE_NAME:aws" if [ "$USE_MULTIPLE_PRINCIPALS" = "true" ]; then diff --git a/pymongo/auth.py b/pymongo/auth.py index f747f4f390..a284d6f4ea 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -21,7 +21,7 @@ import socket import threading from base64 import standard_b64decode, standard_b64encode -from collections import OrderedDict, namedtuple +from collections import namedtuple from dataclasses import dataclass from datetime import datetime, timedelta, timezone from typing import Callable, Dict, Mapping, Optional @@ -114,7 +114,8 @@ def __hash__(self): @dataclass class _OIDCCache: token_result: Optional[Dict] - exp_utc: Optional[datetime] + token_exp_utc: Optional[datetime] + cache_exp_utc: datetime server_resp: Optional[Dict] lock: threading.Lock @@ -516,35 +517,25 @@ def _authenticate_mongo_cr(credentials, sock_info): """ -class _OIDCLRUCache: - def __init__(self): - self.cache = OrderedDict() - - def get(self, key: str) -> Optional[_OIDCCache]: - if key not in self.cache: - return None - else: - self.cache.move_to_end(key) - return self.cache[key] - - def put(self, key: str, value: _OIDCCache) -> None: - self.cache[key] = value - self.cache.move_to_end(key) - if len(self.cache) > 100: - self.cache.popitem(last=False) - - def remove(self, key: str) -> None: - self.cache.pop(key, None) - - -_oidc_cache = _OIDCLRUCache() -_oidc_buffer_seconds = 5 * 60 +_oidc_cache: Dict[str, _OIDCCache] = {} +_OIDC_TOKEN_BUFFER_MINUTES = 5 +_OIDC_CACHE_TIMEOUT_MINUTES = 60 * 5 def _authenticate_oidc(credentials, sock_info): """Authenticate using MONGODB-OIDC.""" properties: _OIDCProperties = credentials.mechanism_properties + # Clear out old items in the cache. + now_utc = datetime.now(timezone.utc) + to_remove = [] + for key, value in _oidc_cache.items(): + if value.cache_exp_utc > now_utc: + to_remove.append(key) + for key in to_remove: + del _oidc_cache[key] + + # Handle aws device credentials. if properties.device_name == "aws": aws_identity_file = os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] with open(aws_identity_file) as fid: @@ -562,6 +553,7 @@ def _authenticate_oidc(credentials, sock_info): raise OperationFailure("SASL conversation failed to complete.") return + # Handle authorization code credentials. address = sock_info.address principal_name = properties.principal_name cache_key = f"{principal_name}{address[0]}{address[1]}" @@ -571,8 +563,15 @@ def _authenticate_oidc(credentials, sock_info): if cache_value is None: lock = threading.Lock() - cache_value = _OIDCCache(lock=lock, token_result=None, server_resp=None, exp_utc=None) - _oidc_cache.put(cache_key, cache_value) + cache_exp_utc = datetime.now(timezone.utc) + timedelta(minutes=_OIDC_CACHE_TIMEOUT_MINUTES) + cache_value = _OIDCCache( + lock=lock, + token_result=None, + server_resp=None, + token_exp_utc=None, + cache_exp_utc=cache_exp_utc, + ) + _oidc_cache[cache_key] = cache_value # Send the SASL start with the optional principal name. payload = dict() @@ -593,30 +592,40 @@ def _authenticate_oidc(credentials, sock_info): conversation_id = response["conversationId"] with cache_value.lock: - if cache_value.exp_utc is not None: + if cache_value.token_exp_utc is not None: now_utc = datetime.now(timezone.utc) - exp_utc = cache_value.exp_utc - if (exp_utc - now_utc).total_seconds() <= _oidc_buffer_seconds: + exp_utc = cache_value.token_exp_utc + buffer_seconds = _OIDC_TOKEN_BUFFER_MINUTES * 60 + if (exp_utc - now_utc).total_seconds() <= buffer_seconds: if properties.on_oidc_refresh_token: cache_value.token_result = properties.on_oidc_refresh_token( cache_value.server_resp, cache_value.token_result ) + cache_exp_utc = datetime.now(timezone.utc) + timedelta( + minutes=_OIDC_CACHE_TIMEOUT_MINUTES + ) + cache_value.cache_exp_utc = cache_exp_utc else: cache_value.token_result = None if cache_value.token_result is None: cache_value.token_result = properties.on_oidc_request_token(cache_value.server_resp) + cache_exp_utc = datetime.now(timezone.utc) + timedelta( + minutes=_OIDC_CACHE_TIMEOUT_MINUTES + ) + cache_value.cache_exp_utc = cache_exp_utc token_result = cache_value.token_result token = token_result["access_token"] if "expires_in_seconds" in token_result: expires_in = token_result["expires_in_seconds"] - if expires_in >= _oidc_buffer_seconds: + buffer_seconds = _OIDC_TOKEN_BUFFER_MINUTES * 60 + if expires_in >= buffer_seconds: now_utc = datetime.now(timezone.utc) exp_utc = now_utc + timedelta(seconds=expires_in) - cache_value.exp_utc = exp_utc + cache_value.token_exp_utc = exp_utc else: - _oidc_cache.remove(cache_key) + _oidc_cache.pop(cache_key, None) bin_payload = Binary(bson.encode(dict(jwt=token))) @@ -641,11 +650,11 @@ def _authenticate_oidc(credentials, sock_info): try: response = sock_info.command("$external", cmd) except Exception: - _oidc_cache.remove(cache_key) + _oidc_cache.pop(cache_key, None) raise if not response["done"]: - _oidc_cache.remove(cache_key) + _oidc_cache.pop(cache_key, None) raise OperationFailure("SASL conversation failed to complete.") From 57046f8324372d55c627281024f9b46d5ca9b770 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Mon, 30 Jan 2023 16:10:14 -0600 Subject: [PATCH 029/133] cleanup --- .evergreen/config.yml | 1 - .evergreen/run-mongodb-oidc-test.sh | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/.evergreen/config.yml b/.evergreen/config.yml index 998044dd09..c08b00b7c6 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -777,7 +777,6 @@ functions: exit 0 fi cd ${DRIVERS_TOOLS}/.evergreen/auth_oidc - . ./activate_venv.sh mongo setup_oidc.js - command: shell.exec type: test diff --git a/.evergreen/run-mongodb-oidc-test.sh b/.evergreen/run-mongodb-oidc-test.sh index 7bb8e91ab1..0ea5d3a828 100755 --- a/.evergreen/run-mongodb-oidc-test.sh +++ b/.evergreen/run-mongodb-oidc-test.sh @@ -21,7 +21,7 @@ set +x shopt -s expand_aliases # needed for `urlencode` alias [ -s "${PROJECT_DIRECTORY}/prepare_mongodb_oidc.sh" ] && source "${PROJECT_DIRECTORY}/prepare_mongodb_oidc.sh" -MONGODB_URI=${MONGODB_URI:-"mongodb://localhost:27018"} +MONGODB_URI=${MONGODB_URI:-"mongodb://localhost:27017"} MONGODB_URI="${MONGODB_URI}/test?authMechanism=MONGODB-OIDC&directConnection=true&authMechanismProperties=DEVICE_NAME:aws" if [ "$USE_MULTIPLE_PRINCIPALS" = "true" ]; then From 9f4fb49c0b118846cc50e862a5acf823db1f8552 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Mon, 30 Jan 2023 16:31:22 -0600 Subject: [PATCH 030/133] add handling of auth connection string tests --- .evergreen/resync-specs.sh | 3 + pymongo/common.py | 2 + test/auth/connection-string.json | 165 ++++++++++++++++++++++++++++++- test/test_auth_spec.py | 21 +++- 4 files changed, 189 insertions(+), 2 deletions(-) diff --git a/.evergreen/resync-specs.sh b/.evergreen/resync-specs.sh index 489ff28b3a..817a2d96bc 100755 --- a/.evergreen/resync-specs.sh +++ b/.evergreen/resync-specs.sh @@ -70,6 +70,9 @@ for spec in "$@" do # Match the spec dir name, the python test dir name, and/or common abbreviations. case "$spec" in + auth) + cpjson auth/tests/ auth + ;; atlas-data-lake-testing|data_lake) cpjson atlas-data-lake-testing/tests/ data_lake ;; diff --git a/pymongo/common.py b/pymongo/common.py index 497f470ac9..97825ac984 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -437,6 +437,8 @@ def validate_auth_mechanism_properties(option: str, value: Any) -> Dict[str, Uni for key, value in value.items(): if isinstance(value, str): props[key] = value + elif isinstance(value, bool): + props[key] = str(value).lower() elif inspect.isfunction(value): signature = inspect.signature(value) if len(signature.parameters) == 0: diff --git a/test/auth/connection-string.json b/test/auth/connection-string.json index 2a37ae8df4..5f9bed4c5a 100644 --- a/test/auth/connection-string.json +++ b/test/auth/connection-string.json @@ -444,6 +444,169 @@ "AWS_SESSION_TOKEN": "token!@#$%^&*()_+" } } + }, + { + "description": "should recognise the mechanism and request callback (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC", + "callback": ["oidcRequest"], + "valid": true, + "credential": { + "username": null, + "password": null, + "source": "$external", + "mechanism": "MONGODB-OIDC", + "mechanism_properties": { + "REQUEST_TOKEN_CALLBACK": true + } + } + }, + { + "description": "should recognise the mechanism when auth source is explicitly specified and with request callback (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authSource=$external", + "callback": ["oidcRequest"], + "valid": true, + "credential": { + "username": null, + "password": null, + "source": "$external", + "mechanism": "MONGODB-OIDC", + "mechanism_properties": { + "REQUEST_TOKEN_CALLBACK": true + } + } + }, + { + "description": "should recognise the mechanism with request and refresh callback (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC", + "callback": ["oidcRequest", "oidcRefresh"], + "valid": true, + "credential": { + "username": null, + "password": null, + "source": "$external", + "mechanism": "MONGODB-OIDC", + "mechanism_properties": { + "REQUEST_TOKEN_CALLBACK": true, + "REFRESH_TOKEN_CALLBACK": true + } + } + }, + { + "description": "should recognise the mechanism and principalName with request callback (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PRINCIPAL_NAME:principalName", + "callback": ["oidcRequest"], + "valid": true, + "credential": { + "username": null, + "password": null, + "source": "$external", + "mechanism": "MONGODB-OIDC", + "mechanism_properties": { + "REQUEST_TOKEN_CALLBACK": true, + "PRINCIPAL_NAME": "principalName" + } + } + }, + { + "description": "should recognise the mechanism with aws device (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=DEVICE_NAME:aws", + "valid": true, + "credential": { + "username": null, + "password": null, + "source": "$external", + "mechanism": "MONGODB-OIDC", + "mechanism_properties": { + "DEVICE_NAME": "aws" + } + } + }, + { + "description": "should recognise the mechanism when auth source is explicitly specified and with aws device (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authSource=$external&authMechanismProperties=DEVICE_NAME:aws", + "valid": true, + "credential": { + "username": null, + "password": null, + "source": "$external", + "mechanism": "MONGODB-OIDC", + "mechanism_properties": { + "DEVICE_NAME": "aws" + } + } + }, + { + "description": "should recognise the mechanism with azure device (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=DEVICE_NAME:azure", + "valid": true, + "credential": { + "username": null, + "password": null, + "source": "$external", + "mechanism": "MONGODB-OIDC", + "mechanism_properties": { + "DEVICE_NAME": "azure" + } + } + }, + { + "description": "should recognise the mechanism with gcp device (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=DEVICE_NAME:gcp", + "valid": true, + "credential": { + "username": null, + "password": null, + "source": "$external", + "mechanism": "MONGODB-OIDC", + "mechanism_properties": { + "DEVICE_NAME": "gcp" + } + } + }, + { + "description": "should throw an exception if username is specified (MONGODB-OIDC)", + "uri": "mongodb://user@localhost/?authMechanism=MONGODB-OIDC", + "callback": ["oidcRequest"], + "valid": false, + "credential": null + }, + { + "description": "should throw an exception if username and password are specified (MONGODB-OIDC)", + "uri": "mongodb://user:pass@localhost/?authMechanism=MONGODB-OIDC", + "callback": ["oidcRequest"], + "valid": false, + "credential": null + }, + { + "description": "should throw an exception if principalName and deviceName are specified (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PRINCIPAL_NAME:principalName,DEVICE_NAME:gcp", + "valid": false, + "credential": null + }, + { + "description": "should throw an exception if specified deviceName is not supported (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=DEVICE_NAME:unexisted", + "valid": false, + "credential": null + }, + { + "description": "should throw an exception if neither deviceName nor callbacks specified (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC", + "valid": false, + "credential": null + }, + { + "description": "should throw an exception when only refresh callback is specified (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC", + "callback": ["oidcRefresh"], + "valid": false, + "credential": null + }, + { + "description": "should throw an exception when unsupported auth property is specified (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=UnsupportedProperty:unexisted", + "valid": false, + "credential": null } ] -} +} \ No newline at end of file diff --git a/test/test_auth_spec.py b/test/test_auth_spec.py index 9f2fa374ac..2bb75de28a 100644 --- a/test/test_auth_spec.py +++ b/test/test_auth_spec.py @@ -41,7 +41,16 @@ def run_test(self): if not valid: self.assertRaises(Exception, MongoClient, uri, connect=False) else: - client = MongoClient(uri, connect=False) + props = {} + if credential: + props = credential["mechanism_properties"] or {} + if props.get("REQUEST_TOKEN_CALLBACK"): + props["on_oidc_request_token"] = lambda x: 1 + del props["REQUEST_TOKEN_CALLBACK"] + if props.get("REFRESH_TOKEN_CALLBACK"): + props["on_oidc_refresh_token"] = lambda x, y: 1 + del props["REFRESH_TOKEN_CALLBACK"] + client = MongoClient(uri, connect=False, authmechanismproperties=props) credentials = client.options.pool_options._credentials if credential is None: self.assertIsNone(credentials) @@ -70,6 +79,16 @@ def run_test(self): self.assertEqual( actual.aws_session_token, expected["AWS_SESSION_TOKEN"] ) + elif "DEVICE_NAME" in expected: + self.assertEqual(actual.device_name, expected["DEVICE_NAME"]) + elif "on_oidc_request_token" in expected: + self.assertEqual( + actual.on_oidc_request_token, expected["on_oidc_request_token"] + ) + elif "on_oidc_refresh_token" in expected: + self.assertEqual( + actual.on_oidc_refresh_token, expected["on_oidc_refresh_token"] + ) else: self.fail("Unhandled property: %s" % (key,)) else: From 62d78dc9815265ebebc21e7d8a89df9dd1684add Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 1 Feb 2023 16:04:00 -0600 Subject: [PATCH 031/133] add initial handling of reauth required --- .evergreen/run-mongodb-oidc-test.sh | 6 +++++ pymongo/auth.py | 38 ++++++++++++----------------- pymongo/helpers.py | 1 + pymongo/mongo_client.py | 6 ++++- pymongo/pool.py | 4 +-- test/auth_aws/test_auth_oidc.py | 36 +++++++++++++++++++++++++-- 6 files changed, 64 insertions(+), 27 deletions(-) diff --git a/.evergreen/run-mongodb-oidc-test.sh b/.evergreen/run-mongodb-oidc-test.sh index 0ea5d3a828..28317315da 100755 --- a/.evergreen/run-mongodb-oidc-test.sh +++ b/.evergreen/run-mongodb-oidc-test.sh @@ -23,11 +23,17 @@ shopt -s expand_aliases # needed for `urlencode` alias MONGODB_URI=${MONGODB_URI:-"mongodb://localhost:27017"} MONGODB_URI="${MONGODB_URI}/test?authMechanism=MONGODB-OIDC&directConnection=true&authMechanismProperties=DEVICE_NAME:aws" +#MONGODB_URI="${MONGODB_URI}/test?authMechanism=MONGODB-OIDC" if [ "$USE_MULTIPLE_PRINCIPALS" = "true" ]; then MONGODB_URI="${MONGODB_URI}&authMechanismProperties=PRINCIPAL_NAME:717cc021e105be9843cd2005e5a4607beae5a4960ef8098cb1247481626090f8" fi +if [ -z "${AWS_TOKEN_DIR}" ]; then + echo "Must specify AWS_TOKEN_DIR" + exit 1 +fi + export MONGODB_URI="$MONGODB_URI" echo $MONGODB_URI diff --git a/pymongo/auth.py b/pymongo/auth.py index a284d6f4ea..08dc5cb4d3 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -530,7 +530,7 @@ def _authenticate_oidc(credentials, sock_info): now_utc = datetime.now(timezone.utc) to_remove = [] for key, value in _oidc_cache.items(): - if value.cache_exp_utc > now_utc: + if value.cache_exp_utc < now_utc: to_remove.append(key) for key in to_remove: del _oidc_cache[key] @@ -591,25 +591,22 @@ def _authenticate_oidc(credentials, sock_info): cache_value.server_resp = bson.decode(response["payload"]) conversation_id = response["conversationId"] - with cache_value.lock: - if cache_value.token_exp_utc is not None: - now_utc = datetime.now(timezone.utc) - exp_utc = cache_value.token_exp_utc - buffer_seconds = _OIDC_TOKEN_BUFFER_MINUTES * 60 - if (exp_utc - now_utc).total_seconds() <= buffer_seconds: - if properties.on_oidc_refresh_token: - cache_value.token_result = properties.on_oidc_refresh_token( - cache_value.server_resp, cache_value.token_result - ) - cache_exp_utc = datetime.now(timezone.utc) + timedelta( - minutes=_OIDC_CACHE_TIMEOUT_MINUTES - ) - cache_value.cache_exp_utc = cache_exp_utc - else: - cache_value.token_result = None + current_valid_token = False + if cache_value.token_exp_utc is not None: + now_utc = datetime.now(timezone.utc) + exp_utc = cache_value.token_exp_utc + buffer_seconds = _OIDC_TOKEN_BUFFER_MINUTES * 60 + if (exp_utc - now_utc).total_seconds() >= buffer_seconds: + current_valid_token = True - if cache_value.token_result is None: - cache_value.token_result = properties.on_oidc_request_token(cache_value.server_resp) + if not current_valid_token: + with cache_value.lock: + if cache_value.token_result is None or properties.on_oidc_refresh_token is None: + cache_value.token_result = properties.on_oidc_request_token(cache_value.server_resp) + else: + cache_value.token_result = properties.on_oidc_refresh_token( + cache_value.server_resp, cache_value.token_result + ) cache_exp_utc = datetime.now(timezone.utc) + timedelta( minutes=_OIDC_CACHE_TIMEOUT_MINUTES ) @@ -624,8 +621,6 @@ def _authenticate_oidc(credentials, sock_info): now_utc = datetime.now(timezone.utc) exp_utc = now_utc + timedelta(seconds=expires_in) cache_value.token_exp_utc = exp_utc - else: - _oidc_cache.pop(cache_key, None) bin_payload = Binary(bson.encode(dict(jwt=token))) @@ -633,7 +628,6 @@ def _authenticate_oidc(credentials, sock_info): cmd = SON( [ ("saslContinue", 1), - ("mechanism", "MONGODB-OIDC"), ("conversationId", conversation_id), ("payload", bin_payload), ] diff --git a/pymongo/helpers.py b/pymongo/helpers.py index 31325c8af2..985d218bf9 100644 --- a/pymongo/helpers.py +++ b/pymongo/helpers.py @@ -65,6 +65,7 @@ 89, # NetworkTimeout 9001, # SocketException 262, # ExceededTimeLimit + 391, ] ) diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index ab0c749889..1f6b3e8b9f 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -1394,7 +1394,11 @@ def is_retrying(): assert last_error is not None raise last_error retryable = False - return func(session, sock_info, retryable) + try: + return func(session, sock_info, retryable) + except OperationFailure as exc: + sock_info.authenticate(force=True) + return func(session, sock_info, retryable) except ServerSelectionTimeoutError: if is_retrying(): # The application may think the write was never attempted diff --git a/pymongo/pool.py b/pymongo/pool.py index 6355692ac9..e319da84cb 100644 --- a/pymongo/pool.py +++ b/pymongo/pool.py @@ -857,14 +857,14 @@ def write_command(self, request_id, msg, codec_options): helpers._check_command_response(result, self.max_wire_version) return result - def authenticate(self): + def authenticate(self, force=False): """Authenticate to the server if needed. Can raise ConnectionFailure or OperationFailure. """ # CMAP spec says to publish the ready event only after authenticating # the connection. - if not self.ready: + if not self.ready or force: creds = self.opts._credentials if creds: auth.authenticate(creds, self) diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py index 81cac09004..2d66d56e63 100644 --- a/test/auth_aws/test_auth_oidc.py +++ b/test/auth_aws/test_auth_oidc.py @@ -31,8 +31,40 @@ def setUpClass(cls): cls.uri = os.environ["MONGODB_URI"] def test_connect_environment_var(self): - with MongoClient(self.uri) as client: - client.test.test.find_one() + aws_token_dir = os.environ["AWS_TOKEN_DIR"] + + def get_auth_token(info): + with open(os.path.join(aws_token_dir, "test_user1_expires")) as fid: + token = fid.read() + return dict(access_token=token) + + def refresh_auth_token(server_info, auth_info): + with open(os.path.join(aws_token_dir, "test_user1")) as fid: + token = fid.read() + return dict(access_token=token) + + props = dict( + on_oidc_request_token=get_auth_token, + on_oidc_refresh_token=refresh_auth_token, + principal_name="test_user1", + ) + client = MongoClient(self.uri, authmechanismproperties=props) + client.test.test.find_one() + + import time + + time.sleep(60) + + orders = client.test.orders + inventory = client.test.inventory + with client.start_session() as session: + with session.start_transaction(): + orders.insert_one({"sku": "abc123", "qty": 100}, session=session) + inventory.update_one( + {"sku": "abc123", "qty": {"$gte": 100}}, + {"$inc": {"qty": -100}}, + session=session, + ) if __name__ == "__main__": From 8bff0ff674c9d2f6cdc8ac117e2abc935aee70b8 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 1 Feb 2023 16:41:09 -0600 Subject: [PATCH 032/133] finish reauthentication --- pymongo/helpers.py | 3 +++ pymongo/mongo_client.py | 16 +++++++++++++--- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/pymongo/helpers.py b/pymongo/helpers.py index 985d218bf9..8851a63c94 100644 --- a/pymongo/helpers.py +++ b/pymongo/helpers.py @@ -69,6 +69,9 @@ ] ) +# Server code raised when re-authentication is required +_REAUTHENTICATION_REQUIRED_CODE = 391 + def _gen_index_name(keys): """Generate an index name from the set of fields it is over.""" diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index 1f6b3e8b9f..2fdf4adaa0 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -1394,11 +1394,14 @@ def is_retrying(): assert last_error is not None raise last_error retryable = False + # Handle re-authentication. try: return func(session, sock_info, retryable) except OperationFailure as exc: - sock_info.authenticate(force=True) - return func(session, sock_info, retryable) + if exc.code == helpers._REAUTHENTICATION_REQUIRED_CODE: + sock_info.authenticate(force=True) + return func(session, sock_info, retryable) + raise except ServerSelectionTimeoutError: if is_retrying(): # The application may think the write was never attempted @@ -1462,7 +1465,14 @@ def _retryable_read(self, func, read_pref, session, address=None, retryable=True # not support retryable reads, raise the last error. assert last_error is not None raise last_error - return func(session, server, sock_info, read_pref) + # Handle re-authentication. + try: + return func(session, server, sock_info, read_pref) + except OperationFailure as exc: + if exc.code == helpers._REAUTHENTICATION_REQUIRED_CODE: + sock_info.authenticate(force=True) + return func(session, server, sock_info, read_pref) + raise except ServerSelectionTimeoutError: if retrying: # The application may think the write was never attempted From 35182f322cfba28b9103a9207b93cc989e73b92b Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 1 Feb 2023 21:27:49 -0600 Subject: [PATCH 033/133] add timeout handling --- pymongo/auth.py | 14 ++++++++++++-- test/auth_aws/test_auth_oidc.py | 4 ++-- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/pymongo/auth.py b/pymongo/auth.py index 08dc5cb4d3..c2cb286ebe 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -30,6 +30,7 @@ import bson from bson.binary import Binary from bson.son import SON +from pymongo import _csot from pymongo.auth_aws import _authenticate_aws from pymongo.errors import ConfigurationError, OperationFailure from pymongo.saslprep import saslprep @@ -599,13 +600,22 @@ def _authenticate_oidc(credentials, sock_info): if (exp_utc - now_utc).total_seconds() >= buffer_seconds: current_valid_token = True + # CSOT: apply timeout to callback. + timeout = _csot.remaining() + if timeout is None: + timeout = sock_info.opts.connect_timeout + elif timeout <= 0: + raise TimeoutError("timed out") + if not current_valid_token: with cache_value.lock: if cache_value.token_result is None or properties.on_oidc_refresh_token is None: - cache_value.token_result = properties.on_oidc_request_token(cache_value.server_resp) + cache_value.token_result = properties.on_oidc_request_token( + cache_value.server_resp, timeout + ) else: cache_value.token_result = properties.on_oidc_refresh_token( - cache_value.server_resp, cache_value.token_result + cache_value.server_resp, cache_value.token_result, timeout ) cache_exp_utc = datetime.now(timezone.utc) + timedelta( minutes=_OIDC_CACHE_TIMEOUT_MINUTES diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py index 2d66d56e63..aedc4aff86 100644 --- a/test/auth_aws/test_auth_oidc.py +++ b/test/auth_aws/test_auth_oidc.py @@ -33,12 +33,12 @@ def setUpClass(cls): def test_connect_environment_var(self): aws_token_dir = os.environ["AWS_TOKEN_DIR"] - def get_auth_token(info): + def get_auth_token(info, timeout): with open(os.path.join(aws_token_dir, "test_user1_expires")) as fid: token = fid.read() return dict(access_token=token) - def refresh_auth_token(server_info, auth_info): + def refresh_auth_token(server_info, auth_info, timeout): with open(os.path.join(aws_token_dir, "test_user1")) as fid: token = fid.read() return dict(access_token=token) From 7ac2fed5990964117297e26df9dacef9d4318de2 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 1 Feb 2023 21:52:56 -0600 Subject: [PATCH 034/133] fix connection string tests --- test/auth/connection-string.json | 28 ---------------------------- 1 file changed, 28 deletions(-) diff --git a/test/auth/connection-string.json b/test/auth/connection-string.json index 5f9bed4c5a..38cee87771 100644 --- a/test/auth/connection-string.json +++ b/test/auth/connection-string.json @@ -535,34 +535,6 @@ } } }, - { - "description": "should recognise the mechanism with azure device (MONGODB-OIDC)", - "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=DEVICE_NAME:azure", - "valid": true, - "credential": { - "username": null, - "password": null, - "source": "$external", - "mechanism": "MONGODB-OIDC", - "mechanism_properties": { - "DEVICE_NAME": "azure" - } - } - }, - { - "description": "should recognise the mechanism with gcp device (MONGODB-OIDC)", - "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=DEVICE_NAME:gcp", - "valid": true, - "credential": { - "username": null, - "password": null, - "source": "$external", - "mechanism": "MONGODB-OIDC", - "mechanism_properties": { - "DEVICE_NAME": "gcp" - } - } - }, { "description": "should throw an exception if username is specified (MONGODB-OIDC)", "uri": "mongodb://user@localhost/?authMechanism=MONGODB-OIDC", From 8ce95bc4bf4aace3f578ec92f562a859f57fa5cd Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 1 Feb 2023 21:59:28 -0600 Subject: [PATCH 035/133] add callback return validation --- pymongo/auth.py | 7 ++++++- test/test_auth_spec.py | 4 ++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/pymongo/auth.py b/pymongo/auth.py index c2cb286ebe..0acdb54c85 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -623,9 +623,14 @@ def _authenticate_oidc(credentials, sock_info): cache_value.cache_exp_utc = cache_exp_utc token_result = cache_value.token_result + if not isinstance(token_result, dict): + raise ValueError("OIDC callback returned invalid result") + if "access_token" not in token_result: + raise ValueError("OIDC callback did not return an access_token") + token = token_result["access_token"] if "expires_in_seconds" in token_result: - expires_in = token_result["expires_in_seconds"] + expires_in = int(token_result["expires_in_seconds"]) buffer_seconds = _OIDC_TOKEN_BUFFER_MINUTES * 60 if expires_in >= buffer_seconds: now_utc = datetime.now(timezone.utc) diff --git a/test/test_auth_spec.py b/test/test_auth_spec.py index 2bb75de28a..61fc14edc1 100644 --- a/test/test_auth_spec.py +++ b/test/test_auth_spec.py @@ -45,10 +45,10 @@ def run_test(self): if credential: props = credential["mechanism_properties"] or {} if props.get("REQUEST_TOKEN_CALLBACK"): - props["on_oidc_request_token"] = lambda x: 1 + props["on_oidc_request_token"] = lambda x, y: 1 del props["REQUEST_TOKEN_CALLBACK"] if props.get("REFRESH_TOKEN_CALLBACK"): - props["on_oidc_refresh_token"] = lambda x, y: 1 + props["on_oidc_refresh_token"] = lambda x, y, z: 1 del props["REFRESH_TOKEN_CALLBACK"] client = MongoClient(uri, connect=False, authmechanismproperties=props) credentials = client.options.pool_options._credentials From 0ecfba41b1f180677c11fb2942b0c14f5aec3347 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 1 Feb 2023 22:09:07 -0600 Subject: [PATCH 036/133] add more callback validation --- .evergreen/config.yml | 1 + pymongo/common.py | 10 ++++++++-- test/test_auth_spec.py | 2 +- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/.evergreen/config.yml b/.evergreen/config.yml index c08b00b7c6..77d03644ec 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -787,6 +787,7 @@ functions: # DO NOT ECHO WITH XTRACE (which PREPARE_SHELL does) cat <<'EOF' > "${PROJECT_DIRECTORY}/prepare_mongodb_oidc.sh" export AWS_WEB_IDENTITY_TOKEN_FILE=${AWS_WEB_IDENTITY_TOKEN_FILE|/tmp/tokens/test1} + export AWS_TOKEN_DIR=/tmp/tokens EOF - command: shell.exec type: test diff --git a/pymongo/common.py b/pymongo/common.py index 97825ac984..3b0cdd32f1 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -441,8 +441,14 @@ def validate_auth_mechanism_properties(option: str, value: Any) -> Dict[str, Uni props[key] = str(value).lower() elif inspect.isfunction(value): signature = inspect.signature(value) - if len(signature.parameters) == 0: - msg = "Auth mechanisim properity callbacks must accept at least one value" + if key == "on_oidc_request_token": + expected_params = 2 + elif key == "on_oidc_refresh_token": + expected_params = 3 + else: + raise ValueError(f"Unrecognized Auth mechanisim function {key}") + if len(signature.parameters) == expected_params: + msg = f"{key} must accept {expected_params} parameters" raise ValueError(msg) props[key] = value else: diff --git a/test/test_auth_spec.py b/test/test_auth_spec.py index 61fc14edc1..ac24fda76f 100644 --- a/test/test_auth_spec.py +++ b/test/test_auth_spec.py @@ -48,7 +48,7 @@ def run_test(self): props["on_oidc_request_token"] = lambda x, y: 1 del props["REQUEST_TOKEN_CALLBACK"] if props.get("REFRESH_TOKEN_CALLBACK"): - props["on_oidc_refresh_token"] = lambda x, y, z: 1 + props["on_oidc_refresh_token"] = lambda x, y, z: 1 # type:ignore[misc] del props["REFRESH_TOKEN_CALLBACK"] client = MongoClient(uri, connect=False, authmechanismproperties=props) credentials = client.options.pool_options._credentials From b42cc7d703a30d09c05c527b8345bcd49631c592 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 1 Feb 2023 22:22:42 -0600 Subject: [PATCH 037/133] fix function validation --- pymongo/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymongo/common.py b/pymongo/common.py index 3b0cdd32f1..06fe2fcb6b 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -447,7 +447,7 @@ def validate_auth_mechanism_properties(option: str, value: Any) -> Dict[str, Uni expected_params = 3 else: raise ValueError(f"Unrecognized Auth mechanisim function {key}") - if len(signature.parameters) == expected_params: + if len(signature.parameters) != expected_params: msg = f"{key} must accept {expected_params} parameters" raise ValueError(msg) props[key] = value From 834c130aebc58d0b273e48eb6b6eeafcdd80f9a2 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 2 Feb 2023 14:15:30 -0600 Subject: [PATCH 038/133] implement some prose tests --- .evergreen/config.yml | 1 - .evergreen/run-mongodb-oidc-test.sh | 16 ++-- pymongo/auth.py | 2 +- test/auth_aws/test_auth_oidc.py | 117 ++++++++++++++++++++++------ 4 files changed, 101 insertions(+), 35 deletions(-) diff --git a/.evergreen/config.yml b/.evergreen/config.yml index 77d03644ec..704ec6b5a9 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -786,7 +786,6 @@ functions: script: | # DO NOT ECHO WITH XTRACE (which PREPARE_SHELL does) cat <<'EOF' > "${PROJECT_DIRECTORY}/prepare_mongodb_oidc.sh" - export AWS_WEB_IDENTITY_TOKEN_FILE=${AWS_WEB_IDENTITY_TOKEN_FILE|/tmp/tokens/test1} export AWS_TOKEN_DIR=/tmp/tokens EOF - command: shell.exec diff --git a/.evergreen/run-mongodb-oidc-test.sh b/.evergreen/run-mongodb-oidc-test.sh index 28317315da..453d229857 100755 --- a/.evergreen/run-mongodb-oidc-test.sh +++ b/.evergreen/run-mongodb-oidc-test.sh @@ -21,22 +21,20 @@ set +x shopt -s expand_aliases # needed for `urlencode` alias [ -s "${PROJECT_DIRECTORY}/prepare_mongodb_oidc.sh" ] && source "${PROJECT_DIRECTORY}/prepare_mongodb_oidc.sh" -MONGODB_URI=${MONGODB_URI:-"mongodb://localhost:27017"} -MONGODB_URI="${MONGODB_URI}/test?authMechanism=MONGODB-OIDC&directConnection=true&authMechanismProperties=DEVICE_NAME:aws" -#MONGODB_URI="${MONGODB_URI}/test?authMechanism=MONGODB-OIDC" - -if [ "$USE_MULTIPLE_PRINCIPALS" = "true" ]; then - MONGODB_URI="${MONGODB_URI}&authMechanismProperties=PRINCIPAL_NAME:717cc021e105be9843cd2005e5a4607beae5a4960ef8098cb1247481626090f8" -fi +MONGODB_URI=${MONGODB_URI:-"mongodb://localhost"} +MONGODB_URI_SINGLE="${MONGODB_URI}/?authMechanism=MONGODB-OIDC" +MONGODB_URI_MULTIPLE="${MONGODB_URI}:27018/?authMechanism=MONGODB-OIDC&directConnection=true" if [ -z "${AWS_TOKEN_DIR}" ]; then echo "Must specify AWS_TOKEN_DIR" exit 1 fi -export MONGODB_URI="$MONGODB_URI" +export MONGODB_URI_SINGLE="$MONGODB_URI_SINGLE" +export MONGODB_URI_MULTIPLE="$MONGODB_URI_MULTIPLE" -echo $MONGODB_URI +echo $MONGODB_URI_SINGLE +echo $MONGODB_URI_MULTIPLE if [ "$ASSERT_NO_URI_CREDS" = "true" ]; then if echo "$MONGODB_URI" | grep -q "@"; then diff --git a/pymongo/auth.py b/pymongo/auth.py index 0acdb54c85..e0ed23ca16 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -164,7 +164,7 @@ def _build_credentials_tuple(mech, source, user, passwd, extra, database): raise ValueError("authentication source must be $external or None for MONGODB-OIDC") properties = extra.get("authmechanismproperties", {}) on_oidc_request_token = properties.get("on_oidc_request_token") - on_oidc_refresh_token = properties.get("on_oidc_refresh_token", on_oidc_request_token) + on_oidc_refresh_token = properties.get("on_oidc_refresh_token", None) principal_name = properties.get("PRINCIPAL_NAME", "") device_name = properties.get("DEVICE_NAME", "") if not on_oidc_request_token and device_name != "aws": diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py index aedc4aff86..6421a53331 100644 --- a/test/auth_aws/test_auth_oidc.py +++ b/test/auth_aws/test_auth_oidc.py @@ -17,10 +17,12 @@ import os import sys import unittest +from typing import Dict sys.path[0:0] = [""] from pymongo import MongoClient +from pymongo.auth import _oidc_cache class TestAuthOIDC(unittest.TestCase): @@ -28,43 +30,110 @@ class TestAuthOIDC(unittest.TestCase): @classmethod def setUpClass(cls): - cls.uri = os.environ["MONGODB_URI"] + cls.uri_single = os.environ["MONGODB_URI_SINGLE"] + cls.uri_multiple = os.environ["MONGODB_URI_MULTIPLE"] + cls.token_dir = os.environ["AWS_TOKEN_DIR"] + + def test_connect_aws_device_workflow(self): + os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] = os.path.join(self.token_dir, "test_user1") + props = dict(DEVICE_NAME="aws") + client = MongoClient(self.uri_single, authmechanismproperties=props) + client.test.test.find_one() + client.close() + + client = MongoClient(self.uri_multiple, authmechanismproperties=props) + client.test.test.find_one() + client.close() + + os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] = os.path.join(self.token_dir, "test_user2") + client = MongoClient(self.uri_multiple, authmechanismproperties=props) + client.test.test.find_one() + client.close() - def test_connect_environment_var(self): - aws_token_dir = os.environ["AWS_TOKEN_DIR"] + def test_connect_authorization_code_workflow(self): + token_file = os.path.join(self.token_dir, "test_user1") - def get_auth_token(info, timeout): - with open(os.path.join(aws_token_dir, "test_user1_expires")) as fid: + def request_token(info, timeout): + with open(token_file) as fid: token = fid.read() return dict(access_token=token) - def refresh_auth_token(server_info, auth_info, timeout): - with open(os.path.join(aws_token_dir, "test_user1")) as fid: + props: Dict = dict(on_oidc_request_token=request_token) + client = MongoClient(self.uri_single, authmechanismproperties=props) + client.test.test.find_one() + client.close() + + _oidc_cache.clear() + props["PRINCIPAL_NAME"] = "test_user1" + client = MongoClient(self.uri_multiple, authmechanismproperties=props) + client.test.test.find_one() + client.close() + + _oidc_cache.clear() + props["PRINCIPAL_NAME"] = "test_user2" + token_file = os.path.join(self.token_dir, "test_user2") + client = MongoClient(self.uri_multiple, authmechanismproperties=props) + client.test.test.find_one() + client.close() + + def test_bad_callbacks(self): + _oidc_cache.clear() + + def request_token_null(info, timeout): + return None + + props: Dict = dict(on_oidc_request_token=request_token_null) + client = MongoClient(self.uri_single, authMechanismProperties=props) + with self.assertRaises(ValueError): + client.test.test.find_one() + client.close() + + def request_token_no_token(info, timeout): + return dict() + + _oidc_cache.clear() + props: Dict = dict(on_oidc_request_token=request_token_no_token) + client = MongoClient(self.uri_single, authMechanismProperties=props) + with self.assertRaises(ValueError): + client.test.test.find_one() + client.close() + + def request_refresh_null(info, creds, timeout): + return None + + token_file = os.path.join(self.token_dir, "test_user1") + + def request_token(info, timeout): + with open(token_file) as fid: token = fid.read() return dict(access_token=token) - props = dict( - on_oidc_request_token=get_auth_token, - on_oidc_refresh_token=refresh_auth_token, - principal_name="test_user1", + _oidc_cache.clear() + props: Dict = dict( + on_oidc_request_token=request_token, on_oidc_refresh_token=request_refresh_null ) - client = MongoClient(self.uri, authmechanismproperties=props) + client = MongoClient(self.uri_single, authMechanismProperties=props) client.test.test.find_one() + client.close() - import time + client = MongoClient(self.uri_single, authMechanismProperties=props) + with self.assertRaises(ValueError): + client.test.test.find_one() + client.close() - time.sleep(60) + def request_refresh_no_token(info, creds, timeout): + return dict() + + _oidc_cache.clear() + props["on_oidc_refresh_token"] = request_refresh_no_token + client = MongoClient(self.uri_single, authMechanismProperties=props) + client.test.test.find_one() + client.close() - orders = client.test.orders - inventory = client.test.inventory - with client.start_session() as session: - with session.start_transaction(): - orders.insert_one({"sku": "abc123", "qty": 100}, session=session) - inventory.update_one( - {"sku": "abc123", "qty": {"$gte": 100}}, - {"$inc": {"qty": -100}}, - session=session, - ) + client = MongoClient(self.uri_single, authMechanismProperties=props) + with self.assertRaises(ValueError): + client.test.test.find_one() + client.close() if __name__ == "__main__": From 818911da6d9de749cac1cbb372aa0481eaeac48b Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 2 Feb 2023 18:05:45 -0600 Subject: [PATCH 039/133] add unified tests for reconnect --- test/auth/{ => legacy}/connection-string.json | 0 test/auth/unified/reauthenticate.json | 436 ++++++++++++++++++ test/test_auth_spec.py | 10 +- test/unified_format.py | 3 +- 4 files changed, 447 insertions(+), 2 deletions(-) rename test/auth/{ => legacy}/connection-string.json (100%) create mode 100644 test/auth/unified/reauthenticate.json diff --git a/test/auth/connection-string.json b/test/auth/legacy/connection-string.json similarity index 100% rename from test/auth/connection-string.json rename to test/auth/legacy/connection-string.json diff --git a/test/auth/unified/reauthenticate.json b/test/auth/unified/reauthenticate.json new file mode 100644 index 0000000000..8402a77081 --- /dev/null +++ b/test/auth/unified/reauthenticate.json @@ -0,0 +1,436 @@ +{ + "description": "reauthenticate", + "schemaVersion": "1.13", + "runOnRequirements": [ + { + "minServerVersion": "6.3", + "auth": true + } + ], + "tests": [ + { + "description": "Read command should reauthenticate when receive ReauthenticationRequired error code and retryReads=false", + "operations": [ + { + "name": "createEntities", + "object": "testRunner", + "arguments": { + "entities": [ + { + "client": { + "id": "client0", + "uriOptions": { + "retryReads": false + }, + "observeEvents": [ + "commandStartedEvent", + "commandSucceededEvent", + "commandFailedEvent" + ] + } + }, + { + "database": { + "id": "database0", + "client": "client0", + "databaseName": "db" + } + }, + { + "collection": { + "id": "collection0", + "database": "database0", + "collectionName": "collName" + } + } + ] + } + }, + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "client0", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "find" + ], + "errorCode": 391 + } + } + } + }, + { + "name": "find", + "arguments": { + "filter": { } + }, + "object": "collection0", + "expectResult": [] + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "find": "collName", + "filter": { + } + } + } + }, + { + "commandFailedEvent": { + "commandName": "find" + } + }, + { + "commandStartedEvent": { + "command": { + "find": "collName", + "filter": { + } + } + } + }, + { + "commandSucceededEvent": { + "commandName": "find" + } + } + ] + } + ] + }, + { + "description": "Read command should reauthenticate when receive ReauthenticationRequired error code and retryReads=true", + "operations": [ + { + "name": "createEntities", + "object": "testRunner", + "arguments": { + "entities": [ + { + "client": { + "id": "client0", + "uriOptions": { + "retryReads": true + }, + "observeEvents": [ + "commandStartedEvent", + "commandSucceededEvent", + "commandFailedEvent" + ] + } + }, + { + "database": { + "id": "database0", + "client": "client0", + "databaseName": "db" + } + }, + { + "collection": { + "id": "collection0", + "database": "database0", + "collectionName": "collName" + } + } + ] + } + }, + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "client0", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "find" + ], + "errorCode": 391 + } + } + } + }, + { + "name": "find", + "arguments": { + "filter": { } + }, + "object": "collection0", + "expectResult": [] + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "find": "collName", + "filter": { + } + } + } + }, + { + "commandFailedEvent": { + "commandName": "find" + } + }, + { + "commandStartedEvent": { + "command": { + "find": "collName", + "filter": { + } + } + } + }, + { + "commandSucceededEvent": { + "commandName": "find" + } + } + ] + } + ] + }, + { + "description": "Write command should reauthenticate when receive ReauthenticationRequired error code and retryWrites=false", + "operations": [ + { + "name": "createEntities", + "object": "testRunner", + "arguments": { + "entities": [ + { + "client": { + "id": "client0", + "uriOptions": { + "retryWrites": false + }, + "observeEvents": [ + "commandStartedEvent", + "commandSucceededEvent", + "commandFailedEvent" + ] + } + }, + { + "database": { + "id": "database0", + "client": "client0", + "databaseName": "db" + } + }, + { + "collection": { + "id": "collection0", + "database": "database0", + "collectionName": "collName" + } + } + ] + } + }, + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "client0", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "insert" + ], + "errorCode": 391 + } + } + } + }, + { + "name": "insertOne", + "object": "collection0", + "arguments": { + "document": { + "_id": 1, + "x": 1 + } + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "insert": "collName", + "documents": [{ + "_id": 1, + "x": 1 + } + ] + } + } + }, + { + "commandFailedEvent": { + "commandName": "insert" + } + }, + { + "commandStartedEvent": { + "command": { + "insert": "collName", + "documents": [{ + "_id": 1, + "x": 1 + } + ] + } + } + }, + { + "commandSucceededEvent": { + "commandName": "insert" + } + } + ] + } + ] + }, + { + "description": "Write command should reauthenticate when receive ReauthenticationRequired error code and retryWrites=true", + "operations": [ + { + "name": "createEntities", + "object": "testRunner", + "arguments": { + "entities": [ + { + "client": { + "id": "client0", + "uriOptions": { + "retryWrites": true + }, + "observeEvents": [ + "commandStartedEvent", + "commandSucceededEvent", + "commandFailedEvent" + ] + } + }, + { + "database": { + "id": "database0", + "client": "client0", + "databaseName": "db" + } + }, + { + "collection": { + "id": "collection0", + "database": "database0", + "collectionName": "collName" + } + } + ] + } + }, + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "client0", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "insert" + ], + "errorCode": 391 + } + } + } + }, + { + "name": "insertOne", + "object": "collection0", + "arguments": { + "document": { + "_id": 1, + "x": 1 + } + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "insert": "collName", + "documents": [{ + "_id": 1, + "x": 1 + } + ] + } + } + }, + { + "commandFailedEvent": { + "commandName": "insert" + } + }, + { + "commandStartedEvent": { + "command": { + "insert": "collName", + "documents": [{ + "_id": 1, + "x": 1 + } + ] + } + } + }, + { + "commandSucceededEvent": { + "commandName": "insert" + } + } + ] + } + ] + } + ] +} \ No newline at end of file diff --git a/test/test_auth_spec.py b/test/test_auth_spec.py index ac24fda76f..613add1b1a 100644 --- a/test/test_auth_spec.py +++ b/test/test_auth_spec.py @@ -22,6 +22,7 @@ sys.path[0:0] = [""] from test import unittest +from test.unified_format import generate_test_classes from pymongo import MongoClient @@ -101,7 +102,7 @@ def run_test(self): def create_tests(): - for filename in glob.glob(os.path.join(_TEST_PATH, "*.json")): + for filename in glob.glob(os.path.join(_TEST_PATH, "legacy", "*.json")): test_suffix, _ = os.path.splitext(os.path.basename(filename)) with open(filename) as auth_tests: test_cases = json.load(auth_tests)["tests"] @@ -116,5 +117,12 @@ def create_tests(): create_tests() +globals().update( + generate_test_classes( + os.path.join(_TEST_PATH, "unified"), + module=__name__, + ) +) + if __name__ == "__main__": unittest.main() diff --git a/test/unified_format.py b/test/unified_format.py index 5afc746859..09f890f290 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -862,7 +862,8 @@ class UnifiedSpecTestMixinV1(IntegrationTest): a class attribute ``TEST_SPEC``. """ - SCHEMA_VERSION = Version.from_string("1.12") + # TODO: check for other changes in 1.13 + SCHEMA_VERSION = Version.from_string("1.13") RUN_ON_LOAD_BALANCER = True RUN_ON_SERVERLESS = True TEST_SPEC: Any From 3ccd8501cf890aafa9536f798767822e99546fd9 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Tue, 7 Feb 2023 09:24:43 -0600 Subject: [PATCH 040/133] update tests --- test/auth/unified/reauthenticate.json | 436 ------------------ .../unified/reauthenticate_with_retry.json | 201 ++++++++ .../unified/reauthenticate_without_retry.json | 201 ++++++++ test/unified_format.py | 3 +- 4 files changed, 403 insertions(+), 438 deletions(-) delete mode 100644 test/auth/unified/reauthenticate.json create mode 100644 test/auth/unified/reauthenticate_with_retry.json create mode 100644 test/auth/unified/reauthenticate_without_retry.json diff --git a/test/auth/unified/reauthenticate.json b/test/auth/unified/reauthenticate.json deleted file mode 100644 index 8402a77081..0000000000 --- a/test/auth/unified/reauthenticate.json +++ /dev/null @@ -1,436 +0,0 @@ -{ - "description": "reauthenticate", - "schemaVersion": "1.13", - "runOnRequirements": [ - { - "minServerVersion": "6.3", - "auth": true - } - ], - "tests": [ - { - "description": "Read command should reauthenticate when receive ReauthenticationRequired error code and retryReads=false", - "operations": [ - { - "name": "createEntities", - "object": "testRunner", - "arguments": { - "entities": [ - { - "client": { - "id": "client0", - "uriOptions": { - "retryReads": false - }, - "observeEvents": [ - "commandStartedEvent", - "commandSucceededEvent", - "commandFailedEvent" - ] - } - }, - { - "database": { - "id": "database0", - "client": "client0", - "databaseName": "db" - } - }, - { - "collection": { - "id": "collection0", - "database": "database0", - "collectionName": "collName" - } - } - ] - } - }, - { - "name": "failPoint", - "object": "testRunner", - "arguments": { - "client": "client0", - "failPoint": { - "configureFailPoint": "failCommand", - "mode": { - "times": 1 - }, - "data": { - "failCommands": [ - "find" - ], - "errorCode": 391 - } - } - } - }, - { - "name": "find", - "arguments": { - "filter": { } - }, - "object": "collection0", - "expectResult": [] - } - ], - "expectEvents": [ - { - "client": "client0", - "events": [ - { - "commandStartedEvent": { - "command": { - "find": "collName", - "filter": { - } - } - } - }, - { - "commandFailedEvent": { - "commandName": "find" - } - }, - { - "commandStartedEvent": { - "command": { - "find": "collName", - "filter": { - } - } - } - }, - { - "commandSucceededEvent": { - "commandName": "find" - } - } - ] - } - ] - }, - { - "description": "Read command should reauthenticate when receive ReauthenticationRequired error code and retryReads=true", - "operations": [ - { - "name": "createEntities", - "object": "testRunner", - "arguments": { - "entities": [ - { - "client": { - "id": "client0", - "uriOptions": { - "retryReads": true - }, - "observeEvents": [ - "commandStartedEvent", - "commandSucceededEvent", - "commandFailedEvent" - ] - } - }, - { - "database": { - "id": "database0", - "client": "client0", - "databaseName": "db" - } - }, - { - "collection": { - "id": "collection0", - "database": "database0", - "collectionName": "collName" - } - } - ] - } - }, - { - "name": "failPoint", - "object": "testRunner", - "arguments": { - "client": "client0", - "failPoint": { - "configureFailPoint": "failCommand", - "mode": { - "times": 1 - }, - "data": { - "failCommands": [ - "find" - ], - "errorCode": 391 - } - } - } - }, - { - "name": "find", - "arguments": { - "filter": { } - }, - "object": "collection0", - "expectResult": [] - } - ], - "expectEvents": [ - { - "client": "client0", - "events": [ - { - "commandStartedEvent": { - "command": { - "find": "collName", - "filter": { - } - } - } - }, - { - "commandFailedEvent": { - "commandName": "find" - } - }, - { - "commandStartedEvent": { - "command": { - "find": "collName", - "filter": { - } - } - } - }, - { - "commandSucceededEvent": { - "commandName": "find" - } - } - ] - } - ] - }, - { - "description": "Write command should reauthenticate when receive ReauthenticationRequired error code and retryWrites=false", - "operations": [ - { - "name": "createEntities", - "object": "testRunner", - "arguments": { - "entities": [ - { - "client": { - "id": "client0", - "uriOptions": { - "retryWrites": false - }, - "observeEvents": [ - "commandStartedEvent", - "commandSucceededEvent", - "commandFailedEvent" - ] - } - }, - { - "database": { - "id": "database0", - "client": "client0", - "databaseName": "db" - } - }, - { - "collection": { - "id": "collection0", - "database": "database0", - "collectionName": "collName" - } - } - ] - } - }, - { - "name": "failPoint", - "object": "testRunner", - "arguments": { - "client": "client0", - "failPoint": { - "configureFailPoint": "failCommand", - "mode": { - "times": 1 - }, - "data": { - "failCommands": [ - "insert" - ], - "errorCode": 391 - } - } - } - }, - { - "name": "insertOne", - "object": "collection0", - "arguments": { - "document": { - "_id": 1, - "x": 1 - } - } - } - ], - "expectEvents": [ - { - "client": "client0", - "events": [ - { - "commandStartedEvent": { - "command": { - "insert": "collName", - "documents": [{ - "_id": 1, - "x": 1 - } - ] - } - } - }, - { - "commandFailedEvent": { - "commandName": "insert" - } - }, - { - "commandStartedEvent": { - "command": { - "insert": "collName", - "documents": [{ - "_id": 1, - "x": 1 - } - ] - } - } - }, - { - "commandSucceededEvent": { - "commandName": "insert" - } - } - ] - } - ] - }, - { - "description": "Write command should reauthenticate when receive ReauthenticationRequired error code and retryWrites=true", - "operations": [ - { - "name": "createEntities", - "object": "testRunner", - "arguments": { - "entities": [ - { - "client": { - "id": "client0", - "uriOptions": { - "retryWrites": true - }, - "observeEvents": [ - "commandStartedEvent", - "commandSucceededEvent", - "commandFailedEvent" - ] - } - }, - { - "database": { - "id": "database0", - "client": "client0", - "databaseName": "db" - } - }, - { - "collection": { - "id": "collection0", - "database": "database0", - "collectionName": "collName" - } - } - ] - } - }, - { - "name": "failPoint", - "object": "testRunner", - "arguments": { - "client": "client0", - "failPoint": { - "configureFailPoint": "failCommand", - "mode": { - "times": 1 - }, - "data": { - "failCommands": [ - "insert" - ], - "errorCode": 391 - } - } - } - }, - { - "name": "insertOne", - "object": "collection0", - "arguments": { - "document": { - "_id": 1, - "x": 1 - } - } - } - ], - "expectEvents": [ - { - "client": "client0", - "events": [ - { - "commandStartedEvent": { - "command": { - "insert": "collName", - "documents": [{ - "_id": 1, - "x": 1 - } - ] - } - } - }, - { - "commandFailedEvent": { - "commandName": "insert" - } - }, - { - "commandStartedEvent": { - "command": { - "insert": "collName", - "documents": [{ - "_id": 1, - "x": 1 - } - ] - } - } - }, - { - "commandSucceededEvent": { - "commandName": "insert" - } - } - ] - } - ] - } - ] -} \ No newline at end of file diff --git a/test/auth/unified/reauthenticate_with_retry.json b/test/auth/unified/reauthenticate_with_retry.json new file mode 100644 index 0000000000..036e56a432 --- /dev/null +++ b/test/auth/unified/reauthenticate_with_retry.json @@ -0,0 +1,201 @@ +{ + "description":"reauthenticate_with_retry", + "schemaVersion":"1.12", + "runOnRequirements":[ + { + "minServerVersion":"6.3", + "auth":true + } + ], + "createEntities":[ + { + "client":{ + "id":"client0", + "uriOptions":{ + "retryReads":true, + "retryWrites":true + }, + "observeEvents":[ + "commandStartedEvent", + "commandSucceededEvent", + "commandFailedEvent" + ] + } + }, + { + "database":{ + "id":"database0", + "client":"client0", + "databaseName":"db" + } + }, + { + "collection":{ + "id":"collection0", + "database":"database0", + "collectionName":"collName" + } + } + ], + "initialData":[ + { + "collectionName":"collName", + "databaseName":"db", + "documents":[ + + ] + } + ], + "tests":[ + { + "description":"Read command should reauthenticate when receive ReauthenticationRequired error code and retryReads=true", + "operations":[ + { + "name":"failPoint", + "object":"testRunner", + "arguments":{ + "client":"client0", + "failPoint":{ + "configureFailPoint":"failCommand", + "mode":{ + "times":1 + }, + "data":{ + "failCommands":[ + "find" + ], + "errorCode":391 + } + } + } + }, + { + "name":"find", + "arguments":{ + "filter":{ + + } + }, + "object":"collection0", + "expectResult":[ + + ] + } + ], + "expectEvents":[ + { + "client":"client0", + "events":[ + { + "commandStartedEvent":{ + "command":{ + "find":"collName", + "filter":{ + + } + } + } + }, + { + "commandFailedEvent":{ + "commandName":"find" + } + }, + { + "commandStartedEvent":{ + "command":{ + "find":"collName", + "filter":{ + + } + } + } + }, + { + "commandSucceededEvent":{ + "commandName":"find" + } + } + ] + } + ] + }, + { + "description":"Write command should reauthenticate when receive ReauthenticationRequired error code and retryWrites=true", + "operations":[ + { + "name":"failPoint", + "object":"testRunner", + "arguments":{ + "client":"client0", + "failPoint":{ + "configureFailPoint":"failCommand", + "mode":{ + "times":1 + }, + "data":{ + "failCommands":[ + "insert" + ], + "errorCode":391 + } + } + } + }, + { + "name":"insertOne", + "object":"collection0", + "arguments":{ + "document":{ + "_id":1, + "x":1 + } + } + } + ], + "expectEvents":[ + { + "client":"client0", + "events":[ + { + "commandStartedEvent":{ + "command":{ + "insert":"collName", + "documents":[ + { + "_id":1, + "x":1 + } + ] + } + } + }, + { + "commandFailedEvent":{ + "commandName":"insert" + } + }, + { + "commandStartedEvent":{ + "command":{ + "insert":"collName", + "documents":[ + { + "_id":1, + "x":1 + } + ] + } + } + }, + { + "commandSucceededEvent":{ + "commandName":"insert" + } + } + ] + } + ] + } + ] +} \ No newline at end of file diff --git a/test/auth/unified/reauthenticate_without_retry.json b/test/auth/unified/reauthenticate_without_retry.json new file mode 100644 index 0000000000..0c0b47d77a --- /dev/null +++ b/test/auth/unified/reauthenticate_without_retry.json @@ -0,0 +1,201 @@ +{ + "description":"reauthenticate_without_retry", + "schemaVersion":"1.12", + "runOnRequirements":[ + { + "minServerVersion":"6.3", + "auth":true + } + ], + "createEntities":[ + { + "client":{ + "id":"client0", + "uriOptions":{ + "retryReads":false, + "retryWrites":false + }, + "observeEvents":[ + "commandStartedEvent", + "commandSucceededEvent", + "commandFailedEvent" + ] + } + }, + { + "database":{ + "id":"database0", + "client":"client0", + "databaseName":"db" + } + }, + { + "collection":{ + "id":"collection0", + "database":"database0", + "collectionName":"collName" + } + } + ], + "initialData":[ + { + "collectionName":"collName", + "databaseName":"db", + "documents":[ + + ] + } + ], + "tests":[ + { + "description":"Read command should reauthenticate when receive ReauthenticationRequired error code and retryReads=false", + "operations":[ + { + "name":"failPoint", + "object":"testRunner", + "arguments":{ + "client":"client0", + "failPoint":{ + "configureFailPoint":"failCommand", + "mode":{ + "times":1 + }, + "data":{ + "failCommands":[ + "find" + ], + "errorCode":391 + } + } + } + }, + { + "name":"find", + "arguments":{ + "filter":{ + + } + }, + "object":"collection0", + "expectResult":[ + + ] + } + ], + "expectEvents":[ + { + "client":"client0", + "events":[ + { + "commandStartedEvent":{ + "command":{ + "find":"collName", + "filter":{ + + } + } + } + }, + { + "commandFailedEvent":{ + "commandName":"find" + } + }, + { + "commandStartedEvent":{ + "command":{ + "find":"collName", + "filter":{ + + } + } + } + }, + { + "commandSucceededEvent":{ + "commandName":"find" + } + } + ] + } + ] + }, + { + "description":"Write command should reauthenticate when receive ReauthenticationRequired error code and retryWrites=false", + "operations":[ + { + "name":"failPoint", + "object":"testRunner", + "arguments":{ + "client":"client0", + "failPoint":{ + "configureFailPoint":"failCommand", + "mode":{ + "times":1 + }, + "data":{ + "failCommands":[ + "insert" + ], + "errorCode":391 + } + } + } + }, + { + "name":"insertOne", + "object":"collection0", + "arguments":{ + "document":{ + "_id":1, + "x":1 + } + } + } + ], + "expectEvents":[ + { + "client":"client0", + "events":[ + { + "commandStartedEvent":{ + "command":{ + "insert":"collName", + "documents":[ + { + "_id":1, + "x":1 + } + ] + } + } + }, + { + "commandFailedEvent":{ + "commandName":"insert" + } + }, + { + "commandStartedEvent":{ + "command":{ + "insert":"collName", + "documents":[ + { + "_id":1, + "x":1 + } + ] + } + } + }, + { + "commandSucceededEvent":{ + "commandName":"insert" + } + } + ] + } + ] + } + ] +} \ No newline at end of file diff --git a/test/unified_format.py b/test/unified_format.py index 09f890f290..5afc746859 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -862,8 +862,7 @@ class UnifiedSpecTestMixinV1(IntegrationTest): a class attribute ``TEST_SPEC``. """ - # TODO: check for other changes in 1.13 - SCHEMA_VERSION = Version.from_string("1.13") + SCHEMA_VERSION = Version.from_string("1.12") RUN_ON_LOAD_BALANCER = True RUN_ON_SERVERLESS = True TEST_SPEC: Any From a59ba5371daa83ba3cf8631238c682b03f12b4f0 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Tue, 7 Feb 2023 19:43:13 -0600 Subject: [PATCH 041/133] finish prose tests --- pymongo/auth.py | 8 +-- test/auth_aws/test_auth_oidc.py | 103 +++++++++++++++++++++++++++++++- 2 files changed, 104 insertions(+), 7 deletions(-) diff --git a/pymongo/auth.py b/pymongo/auth.py index e0ed23ca16..b384bb54e0 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -520,6 +520,7 @@ def _authenticate_mongo_cr(credentials, sock_info): _oidc_cache: Dict[str, _OIDCCache] = {} _OIDC_TOKEN_BUFFER_MINUTES = 5 +_OIDC_CALLBACK_TIMEOUT_SECONDS = 5 * 60 _OIDC_CACHE_TIMEOUT_MINUTES = 60 * 5 @@ -600,12 +601,7 @@ def _authenticate_oidc(credentials, sock_info): if (exp_utc - now_utc).total_seconds() >= buffer_seconds: current_valid_token = True - # CSOT: apply timeout to callback. - timeout = _csot.remaining() - if timeout is None: - timeout = sock_info.opts.connect_timeout - elif timeout <= 0: - raise TimeoutError("timed out") + timeout = _OIDC_CALLBACK_TIMEOUT_SECONDS if not current_valid_token: with cache_value.lock: diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py index 6421a53331..411b78728c 100644 --- a/test/auth_aws/test_auth_oidc.py +++ b/test/auth_aws/test_auth_oidc.py @@ -22,7 +22,7 @@ sys.path[0:0] = [""] from pymongo import MongoClient -from pymongo.auth import _oidc_cache +from pymongo.auth import OperationFailure, _oidc_cache class TestAuthOIDC(unittest.TestCase): @@ -135,6 +135,107 @@ def request_refresh_no_token(info, creds, timeout): client.test.test.find_one() client.close() + def test_caching(self): + request_called = 0 + refresh_called = 0 + + # . Clear the cache. + _oidc_cache.clear() + # . Create a new client with a request callback and a refresh callback. Both callbacks will read the contents of the ``AWS_WEB_IDENTITY_TOKEN_FILE`` location to obtain a valid access token. + # . Give a callback response with a valid accessToken and an expiresInSeconds that is within one minute. + token_file = os.path.join(self.token_dir, "test_user1") + + def request_token(info, timeout): + nonlocal request_called + assert "authorizationEndpoint" in info + assert "tokenEndpoint" in info + assert "clientId" in info + assert timeout == 60 * 5 + with open(token_file) as fid: + token = fid.read() + request_called += 1 + return dict(access_token=token, expires_in_seconds=60) + + def refresh_token(info, creds, timeout): + nonlocal refresh_called + assert "authorizationEndpoint" in info + assert "tokenEndpoint" in info + assert "clientId" in info + assert timeout == 60 * 5 + assert "access_token" in creds + refresh_called += 1 + with open(token_file) as fid: + token = fid.read() + return dict(access_token=token, expires_in_seconds=60) + + _oidc_cache.clear() + props: Dict = dict(on_oidc_request_token=request_token, on_oidc_refresh_token=refresh_token) + + # . Ensure that a ``find`` operation adds credentials to the cache. + client = MongoClient(self.uri_single, authMechanismProperties=props) + client.test.test.find_one() + client.close() + + assert len(_oidc_cache) == 1 + + # . Create a new client with the same request callback and a refresh callback. + # . Ensure that a ``find`` operation results in a call to the refresh callback. + client = MongoClient(self.uri_single, authMechanismProperties=props) + client.test.test.find_one() + client.close() + + assert refresh_called == 1 + assert len(_oidc_cache) == 1 + + # . Clear the cache. + _oidc_cache.clear() + + # . Create a new client with a request callback callback. + # . Give a callback response with a valid accessToken and an expiresInSeconds that is within one minute. + del props["on_oidc_refresh_token"] + client = MongoClient(self.uri_single, authMechanismProperties=props) + + # . Ensure that a ``find`` operation adds credentials to the cache. + request_called = 0 + client.test.test.find_one() + client.close() + assert request_called == 1 + assert len(_oidc_cache) == 1 + + # . Create a new client with the same request callback. + # . Ensure that a ``find`` operation results in a call to the request callback. + client = MongoClient(self.uri_single, authMechanismProperties=props) + client.test.test.find_one() + client.close() + assert request_called == 2 + assert len(_oidc_cache) == 1 + + # . Create a new client with a refresh callback that gives invalid credentials. + def bad_refresh(info, creds, timeout): + return dict(access_token="bad") + + props["on_oidc_refresh_token"] = bad_refresh + client = MongoClient(self.uri_single, authMechanismProperties=props) + + # . Ensure that a ``find`` operation results in an error. + with self.assertRaises(OperationFailure): + client.test.test.find_one() + client.close() + + # . Ensure that the cache has been cleared. + assert len(_oidc_cache) == 0 + + # . Clear the cache. + # . Create a new client using the AWS device workflow. + # . Ensure that a ``find`` operation does not add credentials to the cache. + _oidc_cache.clear() + os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] = os.path.join(self.token_dir, "test_user1") + props = dict(DEVICE_NAME="aws") + client = MongoClient(self.uri_single, authmechanismproperties=props) + client.test.test.find_one() + client.close() + assert len(_oidc_cache) == 0 + if __name__ == "__main__": unittest.main() From 5f59c31c9df333f6cf3bf881d8c6ec246e34017e Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 8 Feb 2023 13:20:01 -0600 Subject: [PATCH 042/133] update docs and prose test on reauthentication --- pymongo/auth.py | 11 +++-- pymongo/mongo_client.py | 4 +- pymongo/pool.py | 6 +-- test/auth_aws/test_auth_oidc.py | 79 +++++++++++++++++++++++++-------- 4 files changed, 73 insertions(+), 27 deletions(-) diff --git a/pymongo/auth.py b/pymongo/auth.py index b384bb54e0..7faa1ee89d 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -524,7 +524,7 @@ def _authenticate_mongo_cr(credentials, sock_info): _OIDC_CACHE_TIMEOUT_MINUTES = 60 * 5 -def _authenticate_oidc(credentials, sock_info): +def _authenticate_oidc(credentials, sock_info, reauthenticate): """Authenticate using MONGODB-OIDC.""" properties: _OIDCProperties = credentials.mechanism_properties @@ -594,7 +594,7 @@ def _authenticate_oidc(credentials, sock_info): conversation_id = response["conversationId"] current_valid_token = False - if cache_value.token_exp_utc is not None: + if cache_value.token_exp_utc is not None and not reauthenticate: now_utc = datetime.now(timezone.utc) exp_utc = cache_value.token_exp_utc buffer_seconds = _OIDC_TOKEN_BUFFER_MINUTES * 60 @@ -748,8 +748,11 @@ def speculate_command(self): } -def authenticate(credentials, sock_info): +def authenticate(credentials, sock_info, reauthenticate=False): """Authenticate sock_info.""" mechanism = credentials.mechanism auth_func = _AUTH_MAP[mechanism] - auth_func(credentials, sock_info) + if mechanism == "MONGODB-OIDC": + auth_func(credentials, sock_info, reauthenticate) # type:ignore + else: + auth_func(credentials, sock_info) diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index 2fdf4adaa0..46876e863e 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -1399,7 +1399,7 @@ def is_retrying(): return func(session, sock_info, retryable) except OperationFailure as exc: if exc.code == helpers._REAUTHENTICATION_REQUIRED_CODE: - sock_info.authenticate(force=True) + sock_info.authenticate(reauthenticate=True) return func(session, sock_info, retryable) raise except ServerSelectionTimeoutError: @@ -1470,7 +1470,7 @@ def _retryable_read(self, func, read_pref, session, address=None, retryable=True return func(session, server, sock_info, read_pref) except OperationFailure as exc: if exc.code == helpers._REAUTHENTICATION_REQUIRED_CODE: - sock_info.authenticate(force=True) + sock_info.authenticate(reauthenticate=True) return func(session, server, sock_info, read_pref) raise except ServerSelectionTimeoutError: diff --git a/pymongo/pool.py b/pymongo/pool.py index e319da84cb..80b440ae6d 100644 --- a/pymongo/pool.py +++ b/pymongo/pool.py @@ -857,17 +857,17 @@ def write_command(self, request_id, msg, codec_options): helpers._check_command_response(result, self.max_wire_version) return result - def authenticate(self, force=False): + def authenticate(self, reauthenticate=False): """Authenticate to the server if needed. Can raise ConnectionFailure or OperationFailure. """ # CMAP spec says to publish the ready event only after authenticating # the connection. - if not self.ready or force: + if not self.ready or reauthenticate: creds = self.opts._credentials if creds: - auth.authenticate(creds, self) + auth.authenticate(creds, self, reauthenticate=reauthenticate) self.ready = True if self.enabled_for_cmap: self.listeners.publish_connection_ready(self.address, self.id) diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py index 411b78728c..eb5fa1eb3b 100644 --- a/test/auth_aws/test_auth_oidc.py +++ b/test/auth_aws/test_auth_oidc.py @@ -23,6 +23,7 @@ from pymongo import MongoClient from pymongo.auth import OperationFailure, _oidc_cache +from pymongo.server_selectors import readable_server_selector class TestAuthOIDC(unittest.TestCase): @@ -139,10 +140,10 @@ def test_caching(self): request_called = 0 refresh_called = 0 - # . Clear the cache. + # Clear the cache. _oidc_cache.clear() - # . Create a new client with a request callback and a refresh callback. Both callbacks will read the contents of the ``AWS_WEB_IDENTITY_TOKEN_FILE`` location to obtain a valid access token. - # . Give a callback response with a valid accessToken and an expiresInSeconds that is within one minute. + # Create a new client with a request callback and a refresh callback. Both callbacks will read the contents of the ``AWS_WEB_IDENTITY_TOKEN_FILE`` location to obtain a valid access token. + # Give a callback response with a valid accessToken and an expiresInSeconds that is within one minute. token_file = os.path.join(self.token_dir, "test_user1") def request_token(info, timeout): @@ -171,15 +172,15 @@ def refresh_token(info, creds, timeout): _oidc_cache.clear() props: Dict = dict(on_oidc_request_token=request_token, on_oidc_refresh_token=refresh_token) - # . Ensure that a ``find`` operation adds credentials to the cache. + # Ensure that a ``find`` operation adds credentials to the cache. client = MongoClient(self.uri_single, authMechanismProperties=props) client.test.test.find_one() client.close() assert len(_oidc_cache) == 1 - # . Create a new client with the same request callback and a refresh callback. - # . Ensure that a ``find`` operation results in a call to the refresh callback. + # Create a new client with the same request callback and a refresh callback. + # Ensure that a ``find`` operation results in a call to the refresh callback. client = MongoClient(self.uri_single, authMechanismProperties=props) client.test.test.find_one() client.close() @@ -187,47 +188,47 @@ def refresh_token(info, creds, timeout): assert refresh_called == 1 assert len(_oidc_cache) == 1 - # . Clear the cache. + # Clear the cache. _oidc_cache.clear() - # . Create a new client with a request callback callback. - # . Give a callback response with a valid accessToken and an expiresInSeconds that is within one minute. + # Create a new client with a request callback callback. + # Give a callback response with a valid accessToken and an expiresInSeconds that is within one minute. del props["on_oidc_refresh_token"] client = MongoClient(self.uri_single, authMechanismProperties=props) - # . Ensure that a ``find`` operation adds credentials to the cache. + # Ensure that a ``find`` operation adds credentials to the cache. request_called = 0 client.test.test.find_one() client.close() assert request_called == 1 assert len(_oidc_cache) == 1 - # . Create a new client with the same request callback. - # . Ensure that a ``find`` operation results in a call to the request callback. + # Create a new client with the same request callback. + # Ensure that a ``find`` operation results in a call to the request callback. client = MongoClient(self.uri_single, authMechanismProperties=props) client.test.test.find_one() client.close() assert request_called == 2 assert len(_oidc_cache) == 1 - # . Create a new client with a refresh callback that gives invalid credentials. + # Create a new client with a refresh callback that gives invalid credentials. def bad_refresh(info, creds, timeout): return dict(access_token="bad") props["on_oidc_refresh_token"] = bad_refresh client = MongoClient(self.uri_single, authMechanismProperties=props) - # . Ensure that a ``find`` operation results in an error. + # Ensure that a ``find`` operation results in an error. with self.assertRaises(OperationFailure): client.test.test.find_one() client.close() - # . Ensure that the cache has been cleared. + # Ensure that the cache has been cleared. assert len(_oidc_cache) == 0 - # . Clear the cache. - # . Create a new client using the AWS device workflow. - # . Ensure that a ``find`` operation does not add credentials to the cache. + # Clear the cache. + # Create a new client using the AWS device workflow. + # Ensure that a ``find`` operation does not add credentials to the cache. _oidc_cache.clear() os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] = os.path.join(self.token_dir, "test_user1") props = dict(DEVICE_NAME="aws") @@ -236,6 +237,48 @@ def bad_refresh(info, creds, timeout): client.close() assert len(_oidc_cache) == 0 + def test_reauthenticate(self): + token_file = os.path.join(self.token_dir, "test_user1") + refresh_called = 0 + + # Clear the cache + _oidc_cache.clear() + + # Create request and refresh callbacks that return valid credentials + # that will not expire soon. + def request_token(info, timeout): + with open(token_file) as fid: + token = fid.read() + return dict(access_token=token, expires_in_seconds=1000) + + def refresh_token(info, creds, timeout): + nonlocal refresh_called + with open(token_file) as fid: + token = fid.read() + refresh_called += 1 + return dict(access_token=token, expires_in_seconds=1000) + + # Create a client with the callbacks. + props: Dict = dict(on_oidc_request_token=request_token, on_oidc_refresh_token=refresh_token) + client = MongoClient(self.uri_single, authmechanismproperties=props) + + # Perform a find operation. + client.test.test.find_one() + # Perform another find operation. + client.test.test.find_one() + + # Assert that the refresh callback has not been called. + self.assertEqual(refresh_called, 0) + + # Force a reauthenication + server = client._get_topology().select_server(readable_server_selector) + with server._pool.get_socket() as sock_info: # type:ignore + sock_info.authenticate(True) + + # Assert that the refresh callback has been called. + self.assertEqual(refresh_called, 1) + client.close() + if __name__ == "__main__": unittest.main() From 2cc0a62fb0667c0927fe7d8d539ecdc7e9728fe1 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 8 Feb 2023 13:47:12 -0600 Subject: [PATCH 043/133] debug --- pymongo/auth.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pymongo/auth.py b/pymongo/auth.py index 7faa1ee89d..2e999f4972 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -238,6 +238,7 @@ def _authenticate_scram(credentials, sock_info, mechanism): if ctx and ctx.speculate_succeeded(): nonce, first_bare = ctx.scram_data res = ctx.speculative_authenticate + raise ValueError("this is the one") else: nonce, first_bare, cmd = _authenticate_scram_start(credentials, mechanism) res = sock_info.command(source, cmd) From 7479a2432628a9e120b2236a7e591ea61c6937a5 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 8 Feb 2023 14:12:57 -0600 Subject: [PATCH 044/133] debug --- pymongo/auth.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pymongo/auth.py b/pymongo/auth.py index 2e999f4972..c0367d0bbc 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -235,10 +235,11 @@ def _authenticate_scram(credentials, sock_info, mechanism): _hmac = hmac.HMAC ctx = sock_info.auth_ctx + speculated = False if ctx and ctx.speculate_succeeded(): nonce, first_bare = ctx.scram_data res = ctx.speculative_authenticate - raise ValueError("this is the one") + speculated = True else: nonce, first_bare, cmd = _authenticate_scram_start(credentials, mechanism) res = sock_info.command(source, cmd) @@ -281,7 +282,10 @@ def _authenticate_scram(credentials, sock_info, mechanism): ("payload", Binary(client_final)), ] ) - res = sock_info.command(source, cmd) + try: + res = sock_info.command(source, cmd) + except Exception as e: + raise ValueError(f"speculated? {speculated}") parsed = _parse_scram_response(res["payload"]) if not hmac.compare_digest(parsed[b"v"], server_sig): From caabfb6be26540ef06c6484e97ebedf030e6f916 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 8 Feb 2023 14:26:57 -0600 Subject: [PATCH 045/133] try clearing auth_ctx --- pymongo/auth.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pymongo/auth.py b/pymongo/auth.py index c0367d0bbc..b7d94cd956 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -757,6 +757,8 @@ def authenticate(credentials, sock_info, reauthenticate=False): """Authenticate sock_info.""" mechanism = credentials.mechanism auth_func = _AUTH_MAP[mechanism] + if reauthenticate: + sock_info.auth_ctx = None if mechanism == "MONGODB-OIDC": auth_func(credentials, sock_info, reauthenticate) # type:ignore else: From b236f2a55f745c91667cad873b24008566d07855 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 8 Feb 2023 14:33:49 -0600 Subject: [PATCH 046/133] try another way to re-auth --- pymongo/auth.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymongo/auth.py b/pymongo/auth.py index b7d94cd956..5c5f2fc3d3 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -757,8 +757,8 @@ def authenticate(credentials, sock_info, reauthenticate=False): """Authenticate sock_info.""" mechanism = credentials.mechanism auth_func = _AUTH_MAP[mechanism] - if reauthenticate: - sock_info.auth_ctx = None + if reauthenticate and sock_info.pool_ref().handshake: + sock_info.hello() if mechanism == "MONGODB-OIDC": auth_func(credentials, sock_info, reauthenticate) # type:ignore else: From eab4d833bb344f621f7d7fe7b9f5a632303b140b Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 8 Feb 2023 14:53:24 -0600 Subject: [PATCH 047/133] try another way to re-auth --- pymongo/auth.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pymongo/auth.py b/pymongo/auth.py index 5c5f2fc3d3..e1e37ed870 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -282,10 +282,7 @@ def _authenticate_scram(credentials, sock_info, mechanism): ("payload", Binary(client_final)), ] ) - try: - res = sock_info.command(source, cmd) - except Exception as e: - raise ValueError(f"speculated? {speculated}") + res = sock_info.command(source, cmd) parsed = _parse_scram_response(res["payload"]) if not hmac.compare_digest(parsed[b"v"], server_sig): From 933be5f888c784e533d9835f80b89c3a4f98198c Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 8 Feb 2023 15:03:11 -0600 Subject: [PATCH 048/133] cleanup --- pymongo/auth.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pymongo/auth.py b/pymongo/auth.py index e1e37ed870..7ac034aa93 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -235,11 +235,9 @@ def _authenticate_scram(credentials, sock_info, mechanism): _hmac = hmac.HMAC ctx = sock_info.auth_ctx - speculated = False if ctx and ctx.speculate_succeeded(): nonce, first_bare = ctx.scram_data res = ctx.speculative_authenticate - speculated = True else: nonce, first_bare, cmd = _authenticate_scram_start(credentials, mechanism) res = sock_info.command(source, cmd) @@ -754,8 +752,8 @@ def authenticate(credentials, sock_info, reauthenticate=False): """Authenticate sock_info.""" mechanism = credentials.mechanism auth_func = _AUTH_MAP[mechanism] - if reauthenticate and sock_info.pool_ref().handshake: - sock_info.hello() + if reauthenticate: + sock_info.performed_handshake = False if mechanism == "MONGODB-OIDC": auth_func(credentials, sock_info, reauthenticate) # type:ignore else: From 20aecb7da4c9b0f727cb7f0be90031d83d4dcb7c Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 8 Feb 2023 15:20:03 -0600 Subject: [PATCH 049/133] try again --- pymongo/auth.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pymongo/auth.py b/pymongo/auth.py index 7ac034aa93..72e603dfeb 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -752,8 +752,9 @@ def authenticate(credentials, sock_info, reauthenticate=False): """Authenticate sock_info.""" mechanism = credentials.mechanism auth_func = _AUTH_MAP[mechanism] - if reauthenticate: + if reauthenticate and sock_info.performed_handshake: sock_info.performed_handshake = False + sock_info.hello() if mechanism == "MONGODB-OIDC": auth_func(credentials, sock_info, reauthenticate) # type:ignore else: From 7ed72fe8d5830d08f6545263a635553e6d84b117 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 8 Feb 2023 15:29:23 -0600 Subject: [PATCH 050/133] try again --- pymongo/auth.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pymongo/auth.py b/pymongo/auth.py index 72e603dfeb..e29eb432dc 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -753,7 +753,6 @@ def authenticate(credentials, sock_info, reauthenticate=False): mechanism = credentials.mechanism auth_func = _AUTH_MAP[mechanism] if reauthenticate and sock_info.performed_handshake: - sock_info.performed_handshake = False sock_info.hello() if mechanism == "MONGODB-OIDC": auth_func(credentials, sock_info, reauthenticate) # type:ignore From c4c64bc978e8a70f18aaa54993de4080287312b1 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 9 Feb 2023 13:55:18 -0600 Subject: [PATCH 051/133] use username as principal name --- pymongo/auth.py | 6 ++---- pymongo/common.py | 1 - test/auth/legacy/connection-string.json | 20 ++++++-------------- test/auth_aws/test_auth_oidc.py | 10 ++++++---- 4 files changed, 14 insertions(+), 23 deletions(-) diff --git a/pymongo/auth.py b/pymongo/auth.py index e29eb432dc..60d3e69866 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -107,7 +107,7 @@ def __hash__(self): _OIDCProperties = namedtuple( "_OIDCProperties", - ["on_oidc_request_token", "on_oidc_refresh_token", "principal_name", "device_name"], + ["on_oidc_request_token", "on_oidc_refresh_token", "device_name"], ) """Mechanism properties for MONGODB-OIDC authentication.""" @@ -165,7 +165,6 @@ def _build_credentials_tuple(mech, source, user, passwd, extra, database): properties = extra.get("authmechanismproperties", {}) on_oidc_request_token = properties.get("on_oidc_request_token") on_oidc_refresh_token = properties.get("on_oidc_refresh_token", None) - principal_name = properties.get("PRINCIPAL_NAME", "") device_name = properties.get("DEVICE_NAME", "") if not on_oidc_request_token and device_name != "aws": raise ConfigurationError( @@ -174,7 +173,6 @@ def _build_credentials_tuple(mech, source, user, passwd, extra, database): oidc_props = _OIDCProperties( on_oidc_request_token=on_oidc_request_token, on_oidc_refresh_token=on_oidc_refresh_token, - principal_name=principal_name, device_name=device_name, ) return MongoCredential(mech, "$external", user, passwd, oidc_props, None) @@ -557,7 +555,7 @@ def _authenticate_oidc(credentials, sock_info, reauthenticate): # Handle authorization code credentials. address = sock_info.address - principal_name = properties.principal_name + principal_name = credentials.username cache_key = f"{principal_name}{address[0]}{address[1]}" cache_value = _oidc_cache.get(cache_key) diff --git a/pymongo/common.py b/pymongo/common.py index 06fe2fcb6b..149966e558 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -422,7 +422,6 @@ def validate_read_preference_tags(name: str, value: Any) -> List[Dict[str, str]] "CANONICALIZE_HOST_NAME", "SERVICE_REALM", "AWS_SESSION_TOKEN", - "PRINCIPAL_NAME", "DEVICE_NAME", ] ) diff --git a/test/auth/legacy/connection-string.json b/test/auth/legacy/connection-string.json index 38cee87771..167d7dd290 100644 --- a/test/auth/legacy/connection-string.json +++ b/test/auth/legacy/connection-string.json @@ -492,18 +492,17 @@ } }, { - "description": "should recognise the mechanism and principalName with request callback (MONGODB-OIDC)", - "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PRINCIPAL_NAME:principalName", + "description": "should recognise the mechanism and username with request callback (MONGODB-OIDC)", + "uri": "mongodb://principalName@localhost/?authMechanism=MONGODB-OIDC", "callback": ["oidcRequest"], "valid": true, "credential": { - "username": null, + "username": "principalName", "password": null, "source": "$external", "mechanism": "MONGODB-OIDC", "mechanism_properties": { - "REQUEST_TOKEN_CALLBACK": true, - "PRINCIPAL_NAME": "principalName" + "REQUEST_TOKEN_CALLBACK": true } } }, @@ -535,13 +534,6 @@ } } }, - { - "description": "should throw an exception if username is specified (MONGODB-OIDC)", - "uri": "mongodb://user@localhost/?authMechanism=MONGODB-OIDC", - "callback": ["oidcRequest"], - "valid": false, - "credential": null - }, { "description": "should throw an exception if username and password are specified (MONGODB-OIDC)", "uri": "mongodb://user:pass@localhost/?authMechanism=MONGODB-OIDC", @@ -550,8 +542,8 @@ "credential": null }, { - "description": "should throw an exception if principalName and deviceName are specified (MONGODB-OIDC)", - "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PRINCIPAL_NAME:principalName,DEVICE_NAME:gcp", + "description": "should throw an exception if username and deviceName are specified (MONGODB-OIDC)", + "uri": "mongodb://principalName@localhost/?authMechanism=MONGODB-OIDC&DEVICE_NAME:gcp", "valid": false, "credential": null }, diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py index eb5fa1eb3b..1ccc9ff731 100644 --- a/test/auth_aws/test_auth_oidc.py +++ b/test/auth_aws/test_auth_oidc.py @@ -65,15 +65,17 @@ def request_token(info, timeout): client.close() _oidc_cache.clear() - props["PRINCIPAL_NAME"] = "test_user1" - client = MongoClient(self.uri_multiple, authmechanismproperties=props) + client = MongoClient( + self.uri_multiple, username="test_user1", authmechanismproperties=props + ) client.test.test.find_one() client.close() _oidc_cache.clear() - props["PRINCIPAL_NAME"] = "test_user2" token_file = os.path.join(self.token_dir, "test_user2") - client = MongoClient(self.uri_multiple, authmechanismproperties=props) + client = MongoClient( + self.uri_multiple, username="test_user2", authmechanismproperties=props + ) client.test.test.find_one() client.close() From 3acf277f16846495afa0b7a61f7f89349e90c4bd Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 9 Feb 2023 16:24:30 -0600 Subject: [PATCH 052/133] add read and write reauth tests using failcommand --- .evergreen/run-mongodb-oidc-test.sh | 2 + test/auth_aws/test_auth_oidc.py | 67 ++++++++++++++++++++++++++--- 2 files changed, 62 insertions(+), 7 deletions(-) diff --git a/.evergreen/run-mongodb-oidc-test.sh b/.evergreen/run-mongodb-oidc-test.sh index 453d229857..8a98552223 100755 --- a/.evergreen/run-mongodb-oidc-test.sh +++ b/.evergreen/run-mongodb-oidc-test.sh @@ -32,9 +32,11 @@ fi export MONGODB_URI_SINGLE="$MONGODB_URI_SINGLE" export MONGODB_URI_MULTIPLE="$MONGODB_URI_MULTIPLE" +export MONGODB_URI="$MONGODB_URI" echo $MONGODB_URI_SINGLE echo $MONGODB_URI_MULTIPLE +echo $MONGODB_URI if [ "$ASSERT_NO_URI_CREDS" = "true" ]; then if echo "$MONGODB_URI" | grep -q "@"; then diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py index 1ccc9ff731..6df1fa7e97 100644 --- a/test/auth_aws/test_auth_oidc.py +++ b/test/auth_aws/test_auth_oidc.py @@ -17,13 +17,14 @@ import os import sys import unittest +from contextlib import contextmanager from typing import Dict sys.path[0:0] = [""] +from bson import SON from pymongo import MongoClient from pymongo.auth import OperationFailure, _oidc_cache -from pymongo.server_selectors import readable_server_selector class TestAuthOIDC(unittest.TestCase): @@ -33,8 +34,20 @@ class TestAuthOIDC(unittest.TestCase): def setUpClass(cls): cls.uri_single = os.environ["MONGODB_URI_SINGLE"] cls.uri_multiple = os.environ["MONGODB_URI_MULTIPLE"] + cls.uri_admin = os.environ["MONGODB_URI"] cls.token_dir = os.environ["AWS_TOKEN_DIR"] + @contextmanager + def fail_point(self, command_args): + cmd_on = SON([("configureFailPoint", "failCommand")]) + cmd_on.update(command_args) + client = MongoClient(self.uri_admin) + client.admin.command(cmd_on) + try: + yield + finally: + client.admin.command("configureFailPoint", cmd_on["configureFailPoint"], mode="off") + def test_connect_aws_device_workflow(self): os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] = os.path.join(self.token_dir, "test_user1") props = dict(DEVICE_NAME="aws") @@ -239,7 +252,7 @@ def bad_refresh(info, creds, timeout): client.close() assert len(_oidc_cache) == 0 - def test_reauthenticate(self): + def test_reauthenticate_read(self): token_file = os.path.join(self.token_dir, "test_user1") refresh_called = 0 @@ -266,16 +279,56 @@ def refresh_token(info, creds, timeout): # Perform a find operation. client.test.test.find_one() - # Perform another find operation. + + # Assert that the refresh callback has not been called. + self.assertEqual(refresh_called, 0) + + with self.fail_point( + {"mode": {"times": 1}, "data": {"failCommands": ["find"], "errorCode": 391}} + ): + # Perform a find operation. + client.test.test.find_one() + + # Assert that the refresh callback has been called. + self.assertEqual(refresh_called, 1) + client.close() + + def test_reauthenticate_write(self): + token_file = os.path.join(self.token_dir, "test_user1") + refresh_called = 0 + + # Clear the cache + _oidc_cache.clear() + + # Create request and refresh callbacks that return valid credentials + # that will not expire soon. + def request_token(info, timeout): + with open(token_file) as fid: + token = fid.read() + return dict(access_token=token, expires_in_seconds=1000) + + def refresh_token(info, creds, timeout): + nonlocal refresh_called + with open(token_file) as fid: + token = fid.read() + refresh_called += 1 + return dict(access_token=token, expires_in_seconds=1000) + + # Create a client with the callbacks. + props: Dict = dict(on_oidc_request_token=request_token, on_oidc_refresh_token=refresh_token) + client = MongoClient(self.uri_single, authmechanismproperties=props) + + # Perform a find operation. client.test.test.find_one() # Assert that the refresh callback has not been called. self.assertEqual(refresh_called, 0) - # Force a reauthenication - server = client._get_topology().select_server(readable_server_selector) - with server._pool.get_socket() as sock_info: # type:ignore - sock_info.authenticate(True) + with self.fail_point( + {"mode": {"times": 1}, "data": {"failCommands": ["insert"], "errorCode": 391}} + ): + # Perform an insert operation. + client.test.test.insert_one({}) # Assert that the refresh callback has been called. self.assertEqual(refresh_called, 1) From 0da3037f8ff27ce73f5be6c8dead2b9a1f650e5f Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 9 Feb 2023 17:51:39 -0600 Subject: [PATCH 053/133] lint --- pymongo/auth.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pymongo/auth.py b/pymongo/auth.py index 60d3e69866..3438ef1808 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -30,7 +30,6 @@ import bson from bson.binary import Binary from bson.son import SON -from pymongo import _csot from pymongo.auth_aws import _authenticate_aws from pymongo.errors import ConfigurationError, OperationFailure from pymongo.saslprep import saslprep @@ -753,6 +752,6 @@ def authenticate(credentials, sock_info, reauthenticate=False): if reauthenticate and sock_info.performed_handshake: sock_info.hello() if mechanism == "MONGODB-OIDC": - auth_func(credentials, sock_info, reauthenticate) # type:ignore + auth_func(credentials, sock_info, reauthenticate) else: auth_func(credentials, sock_info) From 370ef9428bb181da09e3978e035001aef088dfd2 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Fri, 10 Feb 2023 11:36:43 -0600 Subject: [PATCH 054/133] add event listeners for reauth tests --- test/auth_aws/test_auth_oidc.py | 32 ++++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py index 6df1fa7e97..4b24e61054 100644 --- a/test/auth_aws/test_auth_oidc.py +++ b/test/auth_aws/test_auth_oidc.py @@ -22,6 +22,8 @@ sys.path[0:0] = [""] +from test.utils import EventListener + from bson import SON from pymongo import MongoClient from pymongo.auth import OperationFailure, _oidc_cache @@ -255,6 +257,7 @@ def bad_refresh(info, creds, timeout): def test_reauthenticate_read(self): token_file = os.path.join(self.token_dir, "test_user1") refresh_called = 0 + listener = EventListener() # Clear the cache _oidc_cache.clear() @@ -275,7 +278,9 @@ def refresh_token(info, creds, timeout): # Create a client with the callbacks. props: Dict = dict(on_oidc_request_token=request_token, on_oidc_refresh_token=refresh_token) - client = MongoClient(self.uri_single, authmechanismproperties=props) + client = MongoClient( + self.uri_single, event_listeners=[listener], authmechanismproperties=props + ) # Perform a find operation. client.test.test.find_one() @@ -283,12 +288,22 @@ def refresh_token(info, creds, timeout): # Assert that the refresh callback has not been called. self.assertEqual(refresh_called, 0) + listener.reset() + with self.fail_point( {"mode": {"times": 1}, "data": {"failCommands": ["find"], "errorCode": 391}} ): # Perform a find operation. client.test.test.find_one() + started_events = [i.command_name for i in listener.started_events] + succeeded_events = [i.command_name for i in listener.succeeded_events] + failed_events = [i.command_name for i in listener.failed_events] + + assert started_events == ["find", "saslStart", "find"] + assert succeeded_events == ["saslStart", "find"] + assert failed_events == ["find"] + # Assert that the refresh callback has been called. self.assertEqual(refresh_called, 1) client.close() @@ -296,6 +311,7 @@ def refresh_token(info, creds, timeout): def test_reauthenticate_write(self): token_file = os.path.join(self.token_dir, "test_user1") refresh_called = 0 + listener = EventListener() # Clear the cache _oidc_cache.clear() @@ -316,7 +332,9 @@ def refresh_token(info, creds, timeout): # Create a client with the callbacks. props: Dict = dict(on_oidc_request_token=request_token, on_oidc_refresh_token=refresh_token) - client = MongoClient(self.uri_single, authmechanismproperties=props) + client = MongoClient( + self.uri_single, event_listeners=[listener], authmechanismproperties=props + ) # Perform a find operation. client.test.test.find_one() @@ -324,12 +342,22 @@ def refresh_token(info, creds, timeout): # Assert that the refresh callback has not been called. self.assertEqual(refresh_called, 0) + listener.reset() + with self.fail_point( {"mode": {"times": 1}, "data": {"failCommands": ["insert"], "errorCode": 391}} ): # Perform an insert operation. client.test.test.insert_one({}) + started_events = [i.command_name for i in listener.started_events] + succeeded_events = [i.command_name for i in listener.succeeded_events] + failed_events = [i.command_name for i in listener.failed_events] + + assert started_events == ["insert", "saslStart", "insert"] + assert succeeded_events == ["saslStart", "insert"] + assert failed_events == ["insert"] + # Assert that the refresh callback has been called. self.assertEqual(refresh_called, 1) client.close() From 0abc433ee30d9b2ceb7a6a6c0db512418368c4b6 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Fri, 10 Feb 2023 15:00:16 -0600 Subject: [PATCH 055/133] add changelog and example --- doc/changelog.rst | 1 + doc/examples/authentication.rst | 117 ++++++++++++++++++++++++++++++++ pymongo/auth.py | 19 +----- 3 files changed, 119 insertions(+), 18 deletions(-) diff --git a/doc/changelog.rst b/doc/changelog.rst index 6a6e6fef2d..465a0f2ae4 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -4,6 +4,7 @@ Changelog Changes in Version 4.4 ----------------------- +- Added support for :ref:`OIDC ` authentication with MongoDB Enterprise 7.0+. - Added support for passing a list containing (key, direction) pairs or keys to :meth:`~pymongo.collection.Collection.create_index`. - **BETA** Added support for range queries on client side field level encrypted collections. diff --git a/doc/examples/authentication.rst b/doc/examples/authentication.rst index a984d17fc0..f25418fba1 100644 --- a/doc/examples/authentication.rst +++ b/doc/examples/authentication.rst @@ -384,3 +384,120 @@ would be:: .. _Assume Role: https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRole.html .. _EC2 instance: https://docs.aws.amazon.com/IAM/latest/UserGuide/id_roles_use_switch-role-ec2.html .. _environment variables: https://docs.aws.amazon.com/lambda/latest/dg/configuration-envvars.html#configuration-envvars-runtime + + +.. _oidc_sasl: + +MONGODB-OIDC +------------ +.. versionadded:: 4.4 + +The MONGODB-OIDC authentication mechanism is available in MongoDB Enterprise 7.0+. + +AWS OIDC Support +~~~~~~~~~~~~~~~~ + +PyMongo supports automatic authentication when AWS OIDC credentials are +available, by installing pymongo with the +``aws`` extra:: + + $ python -m pip install 'pymongo[aws]' + +A sample URI would be: + + >>> from pymongo import MongoClient + >>> uri = "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=DEVICE_NAME:aws") + >>> client = MongoClient(uri) + +The driver will use the authentication token from the file given by the +``AWS_WEB_IDENTITY_TOKEN_FILE`` environment variable provided by AWS to +authenticate with the server. + +Authentication Code Flow Support +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +PyMongo supports using the Authentication Code Flow for OIDC, when +callbacks are given to the ``MongoClient``. When multiple identity providers +are configured on the server, a ``username`` must be provided, which is the +Principal Name used on the provider. The ``on_oidc_request_callback`` +is intended to accept information about the Identity Provider, and return +credentials that are used to authenticate with the server, usually through +a browser interaction with the user. The callback must be of the form:: + + def request_callback(ProviderInfo, timeout_seconds) -> TokenResult: + ... + return dict(access_token=...) + +Where ``ProviderInfo`` is a dictionary of the following form:: + + authorization_endpoint: + description: >- + URL where the IDP may be contacted for end user + authentication and authorization code generation. + type: string + optional: true # Req if deviceAuthorizeEndpoint not present + token_endpoint: + description: >- + URL where the IDP may be contacted for authorization + code <=> ID/access token exchange. + type: string + optional: true # Req if deviceAuthorizeEndpoint not present + device_authorization_endpoint: + description: >- + URL where the IDP may be contacted for device + authentication and authorization code generation. + type: string + optional: true # Req if authorizeEndpoint not present + client_id: + description: "Unique client ID for this OIDC client" + type: string + client_ecret: + description: "Secret used when communicating with IDP" + type: string + optional: true + request_scopes: + description: "Additional scopes to request from IDP" + type: array + optional: true + +And ``TokenResult`` is a dictionary of the following form:: + + access_token: + description: "The OIDC access token" + type: string + expires_in_seconds: + description: "The expiration time in seconds from the current time" + type: int + optional: true + refresh_token: + description: "The OIDC refresh token" + type: str + optional: true + +And ``timeout_seconds`` will always be 3000 (5 minutes). An example +client would be:: + + >>> from pymongo import MongoClient + >>> uri = "mongodb://localhost/?authMechanism=MONGODB-OIDC") + >>> client = MongoClient(uri, on_oidc_request_callback=my_callback) + +If the identity provider supports refresh, a refresh callback can also +be provided. If a refresh callback is provided, it will be called +if ``expires_in_seconds`` was given in the request response and is +within 5 minutes, or the server raises a ``ReAuthenticationRequired`` +error during an operation. The refresh callback must of the form:: + + + def request_callback(ProviderInfo, TokenResult, timeout_seconds) -> TokenResult: + ... + return dict(access_token=...) + +Where ``ProviderInfo``, ``timeout_seconds`` and the return value are of +the same form as the request callback, and the ``TokenResult`` parameter +is the result of the request callback, which will contain the ``refresh_token`` +if it was provided. An example using both callbacks would be:: + + >>> from pymongo import MongoClient + >>> uri = "mongodb://localhost/?authMechanism=MONGODB-OIDC") + >>> client = MongoClient(uri, on_oidc_request_callback=my_request_callback, + ... on_oidc_refresh_callback=my_refresh_callback) diff --git a/pymongo/auth.py b/pymongo/auth.py index 3438ef1808..6ecfb28698 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -497,24 +497,7 @@ def _authenticate_mongo_cr(credentials, sock_info): sock_info.command(source, query) -""" -interface OIDCRequestTokenParams { - authorizeEndpoint?: string; - tokenEndpoint?: string; - deviceAuthorizeEndpoint?: string; - clientId: string; - clientSecret?: string; - requestScopes?: string[]; -} - -interface OIDCRequestTokenResult { - accessToken: string - expiresInSeconds?: number - refreshToken?: string -} -""" - - +# MONGO-OIDC private variables. _oidc_cache: Dict[str, _OIDCCache] = {} _OIDC_TOKEN_BUFFER_MINUTES = 5 _OIDC_CALLBACK_TIMEOUT_SECONDS = 5 * 60 From 93a4884623cbb7e6fb67defad65bf2df2fca3220 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Fri, 10 Feb 2023 15:02:20 -0600 Subject: [PATCH 056/133] clean up example --- doc/examples/authentication.rst | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/doc/examples/authentication.rst b/doc/examples/authentication.rst index f25418fba1..e965e798ea 100644 --- a/doc/examples/authentication.rst +++ b/doc/examples/authentication.rst @@ -417,9 +417,7 @@ Authentication Code Flow Support ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PyMongo supports using the Authentication Code Flow for OIDC, when -callbacks are given to the ``MongoClient``. When multiple identity providers -are configured on the server, a ``username`` must be provided, which is the -Principal Name used on the provider. The ``on_oidc_request_callback`` +callbacks are given to the ``MongoClient``. The ``on_oidc_request_callback`` is intended to accept information about the Identity Provider, and return credentials that are used to authenticate with the server, usually through a browser interaction with the user. The callback must be of the form:: @@ -474,7 +472,7 @@ And ``TokenResult`` is a dictionary of the following form:: type: str optional: true -And ``timeout_seconds`` will always be 3000 (5 minutes). An example +And ``timeout_seconds`` will always be 300 (5 minutes). An example client would be:: >>> from pymongo import MongoClient @@ -501,3 +499,12 @@ if it was provided. An example using both callbacks would be:: >>> uri = "mongodb://localhost/?authMechanism=MONGODB-OIDC") >>> client = MongoClient(uri, on_oidc_request_callback=my_request_callback, ... on_oidc_refresh_callback=my_refresh_callback) + +Note: when multiple identity providers +are configured on the server, a ``username`` must be provided, which is the +Principal Name used on the provider. For example:: + + >>> from pymongo import MongoClient + >>> uri = "mongodb://my_username@localhost/?authMechanism=MONGODB-OIDC") + >>> client = MongoClient(uri, on_oidc_request_callback=my_request_callback, + ... on_oidc_refresh_callback=my_refresh_callback) From b0d1e3b9614a0769378856e161551f8b2a1de977 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Fri, 10 Feb 2023 15:02:49 -0600 Subject: [PATCH 057/133] clean up example --- doc/examples/authentication.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/doc/examples/authentication.rst b/doc/examples/authentication.rst index e965e798ea..6574f7ce5c 100644 --- a/doc/examples/authentication.rst +++ b/doc/examples/authentication.rst @@ -433,19 +433,19 @@ Where ``ProviderInfo`` is a dictionary of the following form:: URL where the IDP may be contacted for end user authentication and authorization code generation. type: string - optional: true # Req if deviceAuthorizeEndpoint not present + optional: true # Req if device_authorization_endpoint not present token_endpoint: description: >- URL where the IDP may be contacted for authorization code <=> ID/access token exchange. type: string - optional: true # Req if deviceAuthorizeEndpoint not present + optional: true # Req if device_authorization_endpoint not present device_authorization_endpoint: description: >- URL where the IDP may be contacted for device authentication and authorization code generation. type: string - optional: true # Req if authorizeEndpoint not present + optional: true # Req if authorization_endpoint not present client_id: description: "Unique client ID for this OIDC client" type: string From 256915ab35a8b0bfc1403b8c259fac0a6e98a0e1 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Fri, 10 Feb 2023 17:03:09 -0600 Subject: [PATCH 058/133] clean up reauth handling --- pymongo/auth.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pymongo/auth.py b/pymongo/auth.py index 6ecfb28698..a9549df5a9 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -732,9 +732,10 @@ def authenticate(credentials, sock_info, reauthenticate=False): """Authenticate sock_info.""" mechanism = credentials.mechanism auth_func = _AUTH_MAP[mechanism] - if reauthenticate and sock_info.performed_handshake: - sock_info.hello() if mechanism == "MONGODB-OIDC": auth_func(credentials, sock_info, reauthenticate) else: + if reauthenticate and sock_info.performed_handshake: + # Existing hello response is stale, call it again. + sock_info.hello() auth_func(credentials, sock_info) From 51c663935eeb014a53037ea9d6cbbb903693fda4 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Fri, 10 Feb 2023 17:08:34 -0600 Subject: [PATCH 059/133] clean up server response --- pymongo/auth.py | 12 +++++++++++- pymongo/common.py | 7 +++++++ test/auth_aws/test_auth_oidc.py | 12 ++++++------ test/utils.py | 2 +- 4 files changed, 25 insertions(+), 8 deletions(-) diff --git a/pymongo/auth.py b/pymongo/auth.py index a9549df5a9..2e45fabef7 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -570,7 +570,17 @@ def _authenticate_oidc(credentials, sock_info, reauthenticate): ] ) response = sock_info.command("$external", cmd) - cache_value.server_resp = bson.decode(response["payload"]) + + # Convert the server response to be more pythonic. + # Avoid circular import + from pymongo.common import camel_to_snake + + orig_server_resp = bson.decode(response["payload"]) + server_resp = dict() + for key, value in orig_server_resp.items(): + server_resp[camel_to_snake(key)] = value + cache_value.server_resp = server_resp + conversation_id = response["conversationId"] current_valid_token = False diff --git a/pymongo/common.py b/pymongo/common.py index 149966e558..66183fac35 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -17,6 +17,7 @@ import datetime import inspect +import re import warnings from collections import OrderedDict, abc from typing import ( @@ -170,6 +171,12 @@ def raise_config_error(key: str, dummy: Any) -> NoReturn: } +def camel_to_snake(camel): + # Regex to convert CamelCase to snake_case. + snake = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel) + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", snake).lower() + + def validate_boolean(option: str, value: Any) -> bool: """Validates that 'value' is True or False.""" if isinstance(value, bool): diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py index 4b24e61054..69f0898abf 100644 --- a/test/auth_aws/test_auth_oidc.py +++ b/test/auth_aws/test_auth_oidc.py @@ -165,9 +165,9 @@ def test_caching(self): def request_token(info, timeout): nonlocal request_called - assert "authorizationEndpoint" in info - assert "tokenEndpoint" in info - assert "clientId" in info + assert "authorization_endpoint" in info + assert "token_endpoint" in info + assert "client_id" in info assert timeout == 60 * 5 with open(token_file) as fid: token = fid.read() @@ -176,9 +176,9 @@ def request_token(info, timeout): def refresh_token(info, creds, timeout): nonlocal refresh_called - assert "authorizationEndpoint" in info - assert "tokenEndpoint" in info - assert "clientId" in info + assert "authorization_endpoint" in info + assert "token_endpoint" in info + assert "client_id" in info assert timeout == 60 * 5 assert "access_token" in creds refresh_called += 1 diff --git a/test/utils.py b/test/utils.py index 842e9e3a7b..22588d0db5 100644 --- a/test/utils.py +++ b/test/utils.py @@ -667,7 +667,7 @@ def get_command_line(client): return command_line -def camel_to_snake(camel): +def camel_to_snake(camel: str) -> str: # Regex to convert CamelCase to snake_case. snake = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel) return re.sub("([a-z0-9])([A-Z])", r"\1_\2", snake).lower() From b78e4047d54b7c4f184a245d49246c0e0addf3eb Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Fri, 10 Feb 2023 17:10:47 -0600 Subject: [PATCH 060/133] clean up config --- .evergreen/config.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.evergreen/config.yml b/.evergreen/config.yml index 704ec6b5a9..935aba29c6 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -75,7 +75,7 @@ functions: export TMPDIR="$MONGO_ORCHESTRATION_HOME/db" # Installation of cryptography requires a rust compiler on some machines - export PATH="/opt/mongodbtoolchain/v4/bin/:$MONGODB_BINARIES:/home/admin/.cargo/bin:$PATH" + export PATH="$MONGODB_BINARIES:/home/admin/.cargo/bin:$PATH" export PROJECT="${project}" EOT # See what we've done @@ -285,7 +285,6 @@ functions: fi ${PREPARE_SHELL} - MONGODB_VERSION=${VERSION} \ TOPOLOGY=${TOPOLOGY} \ AUTH=${AUTH} \ @@ -569,12 +568,14 @@ functions: "iam_auth_ecs_subnet_a": "${iam_auth_ecs_subnet_a}", "iam_auth_ecs_subnet_b": "${iam_auth_ecs_subnet_b}", "iam_auth_ecs_security_group": "${iam_auth_ecs_security_group}", + "iam_auth_assume_aws_account" : "${iam_auth_assume_aws_account}", "iam_auth_assume_aws_secret_access_key" : "${iam_auth_assume_aws_secret_access_key}", "iam_auth_assume_role_name" : "${iam_auth_assume_role_name}", "iam_auth_ec2_instance_account" : "${iam_auth_ec2_instance_account}", "iam_auth_ec2_instance_secret_access_key" : "${iam_auth_ec2_instance_secret_access_key}", "iam_auth_ec2_instance_profile" : "${iam_auth_ec2_instance_profile}", + "iam_auth_assume_web_role_name": "${iam_auth_assume_web_role_name}", "iam_web_identity_issuer": "${iam_web_identity_issuer}", "iam_web_identity_rsa_key": "${iam_web_identity_rsa_key}", From ce250deb891224ed4d4f193e0541c5de7f1e137f Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Fri, 10 Feb 2023 17:11:50 -0600 Subject: [PATCH 061/133] typing --- pymongo/auth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymongo/auth.py b/pymongo/auth.py index 2e45fabef7..71626fb6f8 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -575,7 +575,7 @@ def _authenticate_oidc(credentials, sock_info, reauthenticate): # Avoid circular import from pymongo.common import camel_to_snake - orig_server_resp = bson.decode(response["payload"]) + orig_server_resp: Dict = bson.decode(response["payload"]) server_resp = dict() for key, value in orig_server_resp.items(): server_resp[camel_to_snake(key)] = value From 3339837a2979de04b343ad10a5496d2acb0219fa Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Mon, 13 Feb 2023 10:37:03 -0600 Subject: [PATCH 062/133] clarify the supported auth types --- doc/examples/authentication.rst | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/doc/examples/authentication.rst b/doc/examples/authentication.rst index 6574f7ce5c..e023fec981 100644 --- a/doc/examples/authentication.rst +++ b/doc/examples/authentication.rst @@ -413,14 +413,14 @@ The driver will use the authentication token from the file given by the ``AWS_WEB_IDENTITY_TOKEN_FILE`` environment variable provided by AWS to authenticate with the server. -Authentication Code Flow Support -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -PyMongo supports using the Authentication Code Flow for OIDC, when -callbacks are given to the ``MongoClient``. The ``on_oidc_request_callback`` -is intended to accept information about the Identity Provider, and return -credentials that are used to authenticate with the server, usually through -a browser interaction with the user. The callback must be of the form:: +Callback-based OIDC Support +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +PyMongo supports user-provided callbacks for OIDC, which are are given to the +``MongoClient``. The ``on_oidc_request_callback`` is intended to accept +information about the Identity Provider, and return credentials that are used +to authenticate with the server, usually through a browser interaction with +the user. The callback must be of the form:: def request_callback(ProviderInfo, timeout_seconds) -> TokenResult: ... From 2c337317f5dc5eb842de5a314e96ad433fe74c5c Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Mon, 13 Feb 2023 13:27:03 -0600 Subject: [PATCH 063/133] use mongosh --- .evergreen/config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.evergreen/config.yml b/.evergreen/config.yml index 935aba29c6..c8864d89e2 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -778,7 +778,7 @@ functions: exit 0 fi cd ${DRIVERS_TOOLS}/.evergreen/auth_oidc - mongo setup_oidc.js + mongosh setup_oidc.js - command: shell.exec type: test params: From 529cf6339ef33ee35ab2c121ad639e01b7925780 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Tue, 14 Feb 2023 14:54:34 -0600 Subject: [PATCH 064/133] add test for multiple principals and no username --- test/auth_aws/test_auth_oidc.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py index 69f0898abf..e04a595b59 100644 --- a/test/auth_aws/test_auth_oidc.py +++ b/test/auth_aws/test_auth_oidc.py @@ -50,7 +50,7 @@ def fail_point(self, command_args): finally: client.admin.command("configureFailPoint", cmd_on["configureFailPoint"], mode="off") - def test_connect_aws_device_workflow(self): + def test_connect_aws(self): os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] = os.path.join(self.token_dir, "test_user1") props = dict(DEVICE_NAME="aws") client = MongoClient(self.uri_single, authmechanismproperties=props) @@ -66,7 +66,7 @@ def test_connect_aws_device_workflow(self): client.test.test.find_one() client.close() - def test_connect_authorization_code_workflow(self): + def test_connect_callbacks(self): token_file = os.path.join(self.token_dir, "test_user1") def request_token(info, timeout): @@ -94,6 +94,11 @@ def request_token(info, timeout): client.test.test.find_one() client.close() + client = MongoClient(self.uri_multiple, authmechanismproperties=props) + with self.assertRaises(OperationFailure): + client.test.test.find_one() + client.close() + def test_bad_callbacks(self): _oidc_cache.clear() From 5f0fde0bf0e79330c357924f11f1f11f14d83dbe Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Fri, 17 Feb 2023 13:48:37 -0600 Subject: [PATCH 065/133] add principal name to callbacks, and include callbacks in cache key --- pymongo/auth.py | 15 ++++++---- pymongo/common.py | 4 +-- test/auth_aws/test_auth_oidc.py | 53 +++++++++++++++++++++++---------- 3 files changed, 49 insertions(+), 23 deletions(-) diff --git a/pymongo/auth.py b/pymongo/auth.py index 71626fb6f8..f460b6e24a 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -538,7 +538,10 @@ def _authenticate_oidc(credentials, sock_info, reauthenticate): # Handle authorization code credentials. address = sock_info.address principal_name = credentials.username - cache_key = f"{principal_name}{address[0]}{address[1]}" + request_cb = properties.on_oidc_request_token + refresh_cb = properties.on_oidc_refresh_token + + cache_key = f"{principal_name}{address[0]}{address[1]}{id(request_cb)}{id(refresh_cb)}" cache_value = _oidc_cache.get(cache_key) conversation_id = None @@ -595,13 +598,13 @@ def _authenticate_oidc(credentials, sock_info, reauthenticate): if not current_valid_token: with cache_value.lock: - if cache_value.token_result is None or properties.on_oidc_refresh_token is None: - cache_value.token_result = properties.on_oidc_request_token( - cache_value.server_resp, timeout + if cache_value.token_result is None or refresh_cb is None: + cache_value.token_result = request_cb( + principal_name, cache_value.server_resp, timeout ) else: - cache_value.token_result = properties.on_oidc_refresh_token( - cache_value.server_resp, cache_value.token_result, timeout + cache_value.token_result = refresh_cb( + principal_name, cache_value.server_resp, cache_value.token_result, timeout ) cache_exp_utc = datetime.now(timezone.utc) + timedelta( minutes=_OIDC_CACHE_TIMEOUT_MINUTES diff --git a/pymongo/common.py b/pymongo/common.py index 66183fac35..f38dba4a7a 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -448,9 +448,9 @@ def validate_auth_mechanism_properties(option: str, value: Any) -> Dict[str, Uni elif inspect.isfunction(value): signature = inspect.signature(value) if key == "on_oidc_request_token": - expected_params = 2 - elif key == "on_oidc_refresh_token": expected_params = 3 + elif key == "on_oidc_refresh_token": + expected_params = 4 else: raise ValueError(f"Unrecognized Auth mechanisim function {key}") if len(signature.parameters) != expected_params: diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py index e04a595b59..ade4b7602d 100644 --- a/test/auth_aws/test_auth_oidc.py +++ b/test/auth_aws/test_auth_oidc.py @@ -69,7 +69,7 @@ def test_connect_aws(self): def test_connect_callbacks(self): token_file = os.path.join(self.token_dir, "test_user1") - def request_token(info, timeout): + def request_token(principal, info, timeout): with open(token_file) as fid: token = fid.read() return dict(access_token=token) @@ -102,7 +102,7 @@ def request_token(info, timeout): def test_bad_callbacks(self): _oidc_cache.clear() - def request_token_null(info, timeout): + def request_token_null(principal, info, timeout): return None props: Dict = dict(on_oidc_request_token=request_token_null) @@ -111,7 +111,7 @@ def request_token_null(info, timeout): client.test.test.find_one() client.close() - def request_token_no_token(info, timeout): + def request_token_no_token(principal, info, timeout): return dict() _oidc_cache.clear() @@ -121,12 +121,12 @@ def request_token_no_token(info, timeout): client.test.test.find_one() client.close() - def request_refresh_null(info, creds, timeout): + def request_refresh_null(principal, info, creds, timeout): return None token_file = os.path.join(self.token_dir, "test_user1") - def request_token(info, timeout): + def request_token(principal, info, timeout): with open(token_file) as fid: token = fid.read() return dict(access_token=token) @@ -144,7 +144,7 @@ def request_token(info, timeout): client.test.test.find_one() client.close() - def request_refresh_no_token(info, creds, timeout): + def request_refresh_no_token(principal, info, creds, timeout): return dict() _oidc_cache.clear() @@ -168,7 +168,7 @@ def test_caching(self): # Give a callback response with a valid accessToken and an expiresInSeconds that is within one minute. token_file = os.path.join(self.token_dir, "test_user1") - def request_token(info, timeout): + def request_token(principal, info, timeout): nonlocal request_called assert "authorization_endpoint" in info assert "token_endpoint" in info @@ -179,7 +179,7 @@ def request_token(info, timeout): request_called += 1 return dict(access_token=token, expires_in_seconds=60) - def refresh_token(info, creds, timeout): + def refresh_token(principal, info, creds, timeout): nonlocal refresh_called assert "authorization_endpoint" in info assert "token_endpoint" in info @@ -233,17 +233,40 @@ def refresh_token(info, creds, timeout): assert request_called == 2 assert len(_oidc_cache) == 1 + # Create a new client with a different request callback. + def request_token_2(principal, info, timeout): + return request_token(principal, info, timeout) + + props["on_oidc_request_token"] = request_token_2 + client = MongoClient(self.uri_single, authMechanismProperties=props) + + # Ensure that a ``find`` operation adds a new entry to the cache. + client.test.test.find_one() + client.close() + assert request_called == 3 + assert len(_oidc_cache) == 2 + + # Clear the cache + _oidc_cache.clear() + # Create a new client with a refresh callback that gives invalid credentials. - def bad_refresh(info, creds, timeout): + def bad_refresh(principal, info, creds, timeout): return dict(access_token="bad") + # Add a token to the cache that will expire soon. props["on_oidc_refresh_token"] = bad_refresh client = MongoClient(self.uri_single, authMechanismProperties=props) + client.test.test.find_one() + client.close() - # Ensure that a ``find`` operation results in an error. + # Create a new client with the same callbacks. + client = MongoClient(self.uri_single, authMechanismProperties=props) + + # Ensure that another ``find`` operation results in an error. with self.assertRaises(OperationFailure): client.test.test.find_one() - client.close() + + client.close() # Ensure that the cache has been cleared. assert len(_oidc_cache) == 0 @@ -269,12 +292,12 @@ def test_reauthenticate_read(self): # Create request and refresh callbacks that return valid credentials # that will not expire soon. - def request_token(info, timeout): + def request_token(principal, info, timeout): with open(token_file) as fid: token = fid.read() return dict(access_token=token, expires_in_seconds=1000) - def refresh_token(info, creds, timeout): + def refresh_token(principal, info, creds, timeout): nonlocal refresh_called with open(token_file) as fid: token = fid.read() @@ -323,12 +346,12 @@ def test_reauthenticate_write(self): # Create request and refresh callbacks that return valid credentials # that will not expire soon. - def request_token(info, timeout): + def request_token(principal, info, timeout): with open(token_file) as fid: token = fid.read() return dict(access_token=token, expires_in_seconds=1000) - def refresh_token(info, creds, timeout): + def refresh_token(principal, info, creds, timeout): nonlocal refresh_called with open(token_file) as fid: token = fid.read() From 765dde6f8dea3a381f3bd6a10e80e1510f7c2c22 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Fri, 17 Feb 2023 13:54:37 -0600 Subject: [PATCH 066/133] fix test --- test/test_auth_spec.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_auth_spec.py b/test/test_auth_spec.py index 613add1b1a..4ab5982814 100644 --- a/test/test_auth_spec.py +++ b/test/test_auth_spec.py @@ -46,10 +46,10 @@ def run_test(self): if credential: props = credential["mechanism_properties"] or {} if props.get("REQUEST_TOKEN_CALLBACK"): - props["on_oidc_request_token"] = lambda x, y: 1 + props["on_oidc_request_token"] = lambda x, y, z: 1 del props["REQUEST_TOKEN_CALLBACK"] if props.get("REFRESH_TOKEN_CALLBACK"): - props["on_oidc_refresh_token"] = lambda x, y, z: 1 # type:ignore[misc] + props["on_oidc_refresh_token"] = lambda a, b, c, d: 1 # type:ignore[misc] del props["REFRESH_TOKEN_CALLBACK"] client = MongoClient(uri, connect=False, authmechanismproperties=props) credentials = client.options.pool_options._credentials From 6186ed99c6c9c3aa8610ab763a590f97476130a3 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Tue, 21 Feb 2023 13:35:55 -0600 Subject: [PATCH 067/133] update cache exp on access --- pymongo/auth.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pymongo/auth.py b/pymongo/auth.py index f460b6e24a..5c6831a2ca 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -545,6 +545,10 @@ def _authenticate_oidc(credentials, sock_info, reauthenticate): cache_value = _oidc_cache.get(cache_key) conversation_id = None + if cache_value is not None: + cache_value.cache_exp_utc = datetime.now(timezone.utc) + timedelta( + minutes=_OIDC_CACHE_TIMEOUT_MINUTES + ) if cache_value is None: lock = threading.Lock() From c2e91e216e6c07e8b38931cab043ea89728bd169 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 22 Feb 2023 12:15:31 -0600 Subject: [PATCH 068/133] switch to OIDC_TOKEN_DIR --- .evergreen/config.yml | 4 ++-- .evergreen/run-mongodb-oidc-test.sh | 4 ++-- test/auth_aws/test_auth_oidc.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.evergreen/config.yml b/.evergreen/config.yml index c8864d89e2..48886f4325 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -759,7 +759,7 @@ functions: export AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID} export AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY} export AWS_SESSION_TOKEN=${AWS_SESSION_TOKEN} - export AWS_TOKEN_DIR=/tmp/tokens + export OIDC_TOKEN_DIR=/tmp/tokens . ./activate_venv.sh python oidc_write_orchestration.py @@ -787,7 +787,7 @@ functions: script: | # DO NOT ECHO WITH XTRACE (which PREPARE_SHELL does) cat <<'EOF' > "${PROJECT_DIRECTORY}/prepare_mongodb_oidc.sh" - export AWS_TOKEN_DIR=/tmp/tokens + export OIDC_TOKEN_DIR=/tmp/tokens EOF - command: shell.exec type: test diff --git a/.evergreen/run-mongodb-oidc-test.sh b/.evergreen/run-mongodb-oidc-test.sh index 8a98552223..d4a61f2e1c 100755 --- a/.evergreen/run-mongodb-oidc-test.sh +++ b/.evergreen/run-mongodb-oidc-test.sh @@ -25,8 +25,8 @@ MONGODB_URI=${MONGODB_URI:-"mongodb://localhost"} MONGODB_URI_SINGLE="${MONGODB_URI}/?authMechanism=MONGODB-OIDC" MONGODB_URI_MULTIPLE="${MONGODB_URI}:27018/?authMechanism=MONGODB-OIDC&directConnection=true" -if [ -z "${AWS_TOKEN_DIR}" ]; then - echo "Must specify AWS_TOKEN_DIR" +if [ -z "${OIDC_TOKEN_DIR}" ]; then + echo "Must specify OIDC_TOKEN_DIR" exit 1 fi diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py index ade4b7602d..67434a496f 100644 --- a/test/auth_aws/test_auth_oidc.py +++ b/test/auth_aws/test_auth_oidc.py @@ -37,7 +37,7 @@ def setUpClass(cls): cls.uri_single = os.environ["MONGODB_URI_SINGLE"] cls.uri_multiple = os.environ["MONGODB_URI_MULTIPLE"] cls.uri_admin = os.environ["MONGODB_URI"] - cls.token_dir = os.environ["AWS_TOKEN_DIR"] + cls.token_dir = os.environ["OIDC_TOKEN_DIR"] @contextmanager def fail_point(self, command_args): From ca6dd7d7ec33ffff1f312555753032e0fd2bc827 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 22 Feb 2023 12:27:32 -0600 Subject: [PATCH 069/133] switch to activate-authoidcvenv.sh --- .evergreen/config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.evergreen/config.yml b/.evergreen/config.yml index 48886f4325..d895cfb63c 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -761,7 +761,7 @@ functions: export AWS_SESSION_TOKEN=${AWS_SESSION_TOKEN} export OIDC_TOKEN_DIR=/tmp/tokens - . ./activate_venv.sh + . ./activate-authoidcvenv.sh python oidc_write_orchestration.py python oidc_get_tokens.py From 5f730ecf041b147c225b3842fb85b32d8e21c42e Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 22 Feb 2023 12:54:03 -0600 Subject: [PATCH 070/133] upgrade pip and setuptools --- .evergreen/run-mongodb-oidc-test.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/.evergreen/run-mongodb-oidc-test.sh b/.evergreen/run-mongodb-oidc-test.sh index d4a61f2e1c..97262ac76e 100755 --- a/.evergreen/run-mongodb-oidc-test.sh +++ b/.evergreen/run-mongodb-oidc-test.sh @@ -70,6 +70,7 @@ authtest () { else . venvoidc/bin/activate fi + python -m pip install -U pip setuptools python -m pip install '.[aws]' python test/auth_aws/test_auth_oidc.py -v deactivate From f1998d6407c3d5bccefb8b9da55c2fb936c9f208 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 22 Feb 2023 13:03:33 -0600 Subject: [PATCH 071/133] remove reauth write prose test --- test/auth_aws/test_auth_oidc.py | 54 --------------------------------- 1 file changed, 54 deletions(-) diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py index 67434a496f..a7ca9b4aa8 100644 --- a/test/auth_aws/test_auth_oidc.py +++ b/test/auth_aws/test_auth_oidc.py @@ -336,60 +336,6 @@ def refresh_token(principal, info, creds, timeout): self.assertEqual(refresh_called, 1) client.close() - def test_reauthenticate_write(self): - token_file = os.path.join(self.token_dir, "test_user1") - refresh_called = 0 - listener = EventListener() - - # Clear the cache - _oidc_cache.clear() - - # Create request and refresh callbacks that return valid credentials - # that will not expire soon. - def request_token(principal, info, timeout): - with open(token_file) as fid: - token = fid.read() - return dict(access_token=token, expires_in_seconds=1000) - - def refresh_token(principal, info, creds, timeout): - nonlocal refresh_called - with open(token_file) as fid: - token = fid.read() - refresh_called += 1 - return dict(access_token=token, expires_in_seconds=1000) - - # Create a client with the callbacks. - props: Dict = dict(on_oidc_request_token=request_token, on_oidc_refresh_token=refresh_token) - client = MongoClient( - self.uri_single, event_listeners=[listener], authmechanismproperties=props - ) - - # Perform a find operation. - client.test.test.find_one() - - # Assert that the refresh callback has not been called. - self.assertEqual(refresh_called, 0) - - listener.reset() - - with self.fail_point( - {"mode": {"times": 1}, "data": {"failCommands": ["insert"], "errorCode": 391}} - ): - # Perform an insert operation. - client.test.test.insert_one({}) - - started_events = [i.command_name for i in listener.started_events] - succeeded_events = [i.command_name for i in listener.succeeded_events] - failed_events = [i.command_name for i in listener.failed_events] - - assert started_events == ["insert", "saslStart", "insert"] - assert succeeded_events == ["saslStart", "insert"] - assert failed_events == ["insert"] - - # Assert that the refresh callback has been called. - self.assertEqual(refresh_called, 1) - client.close() - if __name__ == "__main__": unittest.main() From ba2253b2e4dfeb07bd54dd337478a73a8f59739a Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 23 Feb 2023 10:32:52 -0600 Subject: [PATCH 072/133] remove temp file --- test_oidc.py | 119 --------------------------------------------------- 1 file changed, 119 deletions(-) delete mode 100644 test_oidc.py diff --git a/test_oidc.py b/test_oidc.py deleted file mode 100644 index 30f76af4ef..0000000000 --- a/test_oidc.py +++ /dev/null @@ -1,119 +0,0 @@ -import os -import threading -import time -import webbrowser -from http.server import BaseHTTPRequestHandler, HTTPServer -from queue import Queue - -from requests_oauth2client import AuthorizationRequest, OAuth2Client - -from pymongo import MongoClient - -client_secret = os.getenv("IDP_CLIENT_SECRET") - -auth_data = dict( - authorizeEndpoint="https://corp.mongodb.com/oauth2/v1/authorize", - tokenEndpoint="https://corp.mongodb.com/oauth2/v1/token", - issuer="https://corp.mongodb.com", - clientId="0oadp0hpl7q3UIehP297", - clientSecret=client_secret, -) - - -LOCAL_PORT = 8888 -REDIRECT_URI = f"http://localhost:{LOCAL_PORT}/authorization-code/callback" -RESPONSE_QUEUE = Queue() -INIT_CALLED = 0 -REFRESH_CALLED = 0 - - -class MyRequestHandler(BaseHTTPRequestHandler): - def do_GET(self): - RESPONSE_QUEUE.put(self.path) - self.send_response(200) - - -def run_server(): - server = HTTPServer(("localhost", LOCAL_PORT), MyRequestHandler) - try: - server.serve_forever() - except KeyboardInterrupt: - pass - - server.server_close() - - -# Start a server on 8888 and expose a callback endpoint -# the tunnel address will be 8889 - - -def get_auth_token(auth_data): - print("Getting auth token") - global INIT_CALLED - INIT_CALLED += 1 - client_id = auth_data["clientId"] - client_secret = auth_data["clientSecret"] - token_endpoint = auth_data["tokenEndpoint"] - authorization_endpoint = auth_data["authorizeEndpoint"] - request = AuthorizationRequest( - authorization_endpoint, - client_id, - scope="openid", - redirect_uri=REDIRECT_URI, - code_challenge_method="S256", - ) - webbrowser.open(str(request)) - response_uri = RESPONSE_QUEUE.get() - response = request.validate_callback(response_uri) - client = OAuth2Client(token_endpoint, auth=(client_id, client_secret)) - token_response = client.token_request( - { - "grant_type": "authorization_code", - "code": response.code, - "redirect_uri": REDIRECT_URI, - "code_verifier": response.code_verifier, - } - ) - print("token:") - print(str(token_response.id_token)) - return dict(access_token=str(token_response.id_token), expires_in_seconds=5 * 60 + 3) - - -def refresh_auth_token(auth_data, orig_data): - global REFRESH_CALLED - REFRESH_CALLED += 1 - print("Refreshing auth token") - access_token = orig_data["access_token"] - return dict(access_token=access_token, expires_in_seconds=10 * 60) - - -thread = threading.Thread(target=run_server, daemon=True) -thread.start() - -# print(get_auth_token(auth_data)) - -# AWS device workflow test. -if "AWS_WEB_IDENTITY_TOKEN_FILE" in os.environ: - props = dict() - client = MongoClient(port=8889, authmechanismproperties=props, authmechanism="MONGODB-OIDC") - print(client.test.command("ping")) - -# Browser workflow test. -else: - # Test token expiration and refresh - props = dict(on_oidc_request_token=get_auth_token, on_oidc_refresh_token=refresh_auth_token) - client = MongoClient(port=8889, authmechanismproperties=props, authmechanism="MONGODB-OIDC") - print(client.test.command("ping")) - assert INIT_CALLED == 1 - print("Sleeping...") - time.sleep(4) - client2 = MongoClient(port=8889, authmechanismproperties=props, authmechanism="MONGODB-OIDC") - print(client2.test.command("ping")) - assert INIT_CALLED == 1 - assert REFRESH_CALLED == 1 - print("Sleeping...") - time.sleep(2) - client3 = MongoClient(port=8889, authmechanismproperties=props, authmechanism="MONGODB-OIDC") - print(client3.test.command("ping")) - assert INIT_CALLED == 1 - assert REFRESH_CALLED == 1 From b02304101ba20e3bac8a781e23cd8b6be2224f07 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Fri, 24 Feb 2023 07:52:47 -0600 Subject: [PATCH 073/133] use main drivers-evergreen-tools --- .evergreen/config.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.evergreen/config.yml b/.evergreen/config.yml index d895cfb63c..78fe17cc4e 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -97,8 +97,7 @@ functions: # If this was a patch build, doing a fresh clone would not actually test the patch cp -R ${PROJECT_DIRECTORY}/ $DRIVERS_TOOLS else - # TODO: switch back to master before merging - git clone --branch DRIVERS-2415 https://github.com/blink1073/drivers-evergreen-tools.git $DRIVERS_TOOLS + git clone https://github.com/mongodb-labs/drivers-evergreen-tools.git $DRIVERS_TOOLS fi echo "{ \"releases\": { \"default\": \"$MONGODB_BINARIES\" }}" > $MONGO_ORCHESTRATION_HOME/orchestration.config From 3b6288b39afcf9c9a94b85b643e49368b9213a26 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Mon, 27 Feb 2023 10:36:30 -0600 Subject: [PATCH 074/133] install xml and coverage --- .evergreen/run-tests.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.evergreen/run-tests.sh b/.evergreen/run-tests.sh index d495e2671a..a2c578f8b9 100755 --- a/.evergreen/run-tests.sh +++ b/.evergreen/run-tests.sh @@ -218,6 +218,8 @@ if [ -n "$COVERAGE" ] && [ "$PYTHON_IMPL" = "CPython" ]; then fi $PYTHON setup.py clean +$PYTHON -m pip install unittest-xml-reporting coverage + if [ -z "$GREEN_FRAMEWORK" ]; then if [ -z "$C_EXTENSIONS" ] && [ "$PYTHON_IMPL" = "CPython" ]; then # Fail if the C extensions fail to build. From a574237d86edc4680b1cba08ded0f8347675cafa Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Mon, 27 Feb 2023 13:01:26 -0600 Subject: [PATCH 075/133] better reauth support --- pymongo/auth.py | 3 +++ test/auth_aws/test_auth_oidc.py | 11 +++++++---- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/pymongo/auth.py b/pymongo/auth.py index 5c6831a2ca..2007556619 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -653,6 +653,9 @@ def _authenticate_oidc(credentials, sock_info, reauthenticate): response = sock_info.command("$external", cmd) except Exception: _oidc_cache.pop(cache_key, None) + # Allow for one retry on reauthenticate when using server step 2. + if reauthenticate and conversation_id is None: + return _authenticate_oidc(credentials, sock_info, False) raise if not response["done"]: diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py index a7ca9b4aa8..a78d1618db 100644 --- a/test/auth_aws/test_auth_oidc.py +++ b/test/auth_aws/test_auth_oidc.py @@ -319,7 +319,10 @@ def refresh_token(principal, info, creds, timeout): listener.reset() with self.fail_point( - {"mode": {"times": 1}, "data": {"failCommands": ["find"], "errorCode": 391}} + { + "mode": {"times": 2}, + "data": {"failCommands": ["find", "saslStart"], "errorCode": 391}, + } ): # Perform a find operation. client.test.test.find_one() @@ -328,9 +331,9 @@ def refresh_token(principal, info, creds, timeout): succeeded_events = [i.command_name for i in listener.succeeded_events] failed_events = [i.command_name for i in listener.failed_events] - assert started_events == ["find", "saslStart", "find"] - assert succeeded_events == ["saslStart", "find"] - assert failed_events == ["find"] + assert started_events == ["find", "saslStart", "saslStart", "saslContinue", "find"] + assert succeeded_events == ["saslStart", "saslContinue", "find"] + assert failed_events == ["find", "saslStart"] # Assert that the refresh callback has been called. self.assertEqual(refresh_called, 1) From 3607f723b90bb318f697034f389913b0709f0404 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Mon, 27 Feb 2023 13:07:49 -0600 Subject: [PATCH 076/133] Revert "install xml and coverage" This reverts commit 3b6288b39afcf9c9a94b85b643e49368b9213a26. --- .evergreen/run-tests.sh | 2 -- 1 file changed, 2 deletions(-) diff --git a/.evergreen/run-tests.sh b/.evergreen/run-tests.sh index a2c578f8b9..d495e2671a 100755 --- a/.evergreen/run-tests.sh +++ b/.evergreen/run-tests.sh @@ -218,8 +218,6 @@ if [ -n "$COVERAGE" ] && [ "$PYTHON_IMPL" = "CPython" ]; then fi $PYTHON setup.py clean -$PYTHON -m pip install unittest-xml-reporting coverage - if [ -z "$GREEN_FRAMEWORK" ]; then if [ -z "$C_EXTENSIONS" ] && [ "$PYTHON_IMPL" = "CPython" ]; then # Fail if the C extensions fail to build. From 75a707baabdd55d1f6330b52c39483ed10cc9d57 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Tue, 28 Feb 2023 08:51:24 -0600 Subject: [PATCH 077/133] implement speculativeAuth for OIDC --- pymongo/auth.py | 257 +++++++++++++++++++------------- pymongo/pool.py | 2 +- test/auth_aws/test_auth_oidc.py | 6 +- 3 files changed, 159 insertions(+), 106 deletions(-) diff --git a/pymongo/auth.py b/pymongo/auth.py index 2007556619..6b9971fd61 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -497,101 +497,87 @@ def _authenticate_mongo_cr(credentials, sock_info): sock_info.command(source, query) -# MONGO-OIDC private variables. +# MONGODB-OIDC private variables. _oidc_cache: Dict[str, _OIDCCache] = {} _OIDC_TOKEN_BUFFER_MINUTES = 5 _OIDC_CALLBACK_TIMEOUT_SECONDS = 5 * 60 _OIDC_CACHE_TIMEOUT_MINUTES = 60 * 5 -def _authenticate_oidc(credentials, sock_info, reauthenticate): - """Authenticate using MONGODB-OIDC.""" - properties: _OIDCProperties = credentials.mechanism_properties - - # Clear out old items in the cache. - now_utc = datetime.now(timezone.utc) - to_remove = [] - for key, value in _oidc_cache.items(): - if value.cache_exp_utc < now_utc: - to_remove.append(key) - for key in to_remove: - del _oidc_cache[key] - - # Handle aws device credentials. - if properties.device_name == "aws": - aws_identity_file = os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] - with open(aws_identity_file) as fid: - token = fid.read().strip() - payload = dict(jwt=token) - cmd = SON( - [ - ("saslStart", 1), - ("mechanism", "MONGODB-OIDC"), - ("payload", Binary(bson.encode(payload))), - ] - ) - response = sock_info.command("$external", cmd) - if not response["done"]: - raise OperationFailure("SASL conversation failed to complete.") - return - +def _oidc_get_cache_key(credentials, address): # Handle authorization code credentials. - address = sock_info.address + address = address principal_name = credentials.username + properties = credentials.mechanism_properties request_cb = properties.on_oidc_request_token refresh_cb = properties.on_oidc_refresh_token + return f"{principal_name}{address[0]}{address[1]}{id(request_cb)}{id(refresh_cb)}" - cache_key = f"{principal_name}{address[0]}{address[1]}{id(request_cb)}{id(refresh_cb)}" - cache_value = _oidc_cache.get(cache_key) - conversation_id = None - if cache_value is not None: - cache_value.cache_exp_utc = datetime.now(timezone.utc) + timedelta( - minutes=_OIDC_CACHE_TIMEOUT_MINUTES - ) +def _authenticate_oidc(credentials, sock_info, reauthenticate): + """Authenticate using MONGODB-OIDC.""" + ctx = sock_info.auth_ctx + cmd = None + cache_key = _oidc_get_cache_key(credentials, sock_info.address) - if cache_value is None: - lock = threading.Lock() - cache_exp_utc = datetime.now(timezone.utc) + timedelta(minutes=_OIDC_CACHE_TIMEOUT_MINUTES) - cache_value = _OIDCCache( - lock=lock, - token_result=None, - server_resp=None, - token_exp_utc=None, - cache_exp_utc=cache_exp_utc, - ) - _oidc_cache[cache_key] = cache_value + if ctx and ctx.speculate_succeeded(): + resp = ctx.speculative_authenticate + else: + cmd = _authenticate_oidc_start(credentials, sock_info.address) + try: + resp = sock_info.command(credentials.source, cmd) + except Exception: + _oidc_cache.pop(cache_key, None) + # Allow for one retry on reauthenticate when using server step 2. + if reauthenticate: + return _authenticate_oidc(credentials, sock_info, False) + raise + + if resp["done"]: + return - # Send the SASL start with the optional principal name. - payload = dict() + # Convert the server response to be more pythonic. + # Avoid circular import + from pymongo.common import camel_to_snake - if principal_name: - payload["n"] = principal_name + orig_server_resp: Dict = bson.decode(resp["payload"]) + server_resp = dict() + for key, value in orig_server_resp.items(): + server_resp[camel_to_snake(key)] = value - cmd = SON( - [ - ("saslStart", 1), - ("mechanism", "MONGODB-OIDC"), - ("payload", Binary(bson.encode(payload))), - ("autoAuthorize", 1), - ] - ) - response = sock_info.command("$external", cmd) + if "token_endpoint" in server_resp: + _oidc_cache[cache_key].server_resp = server_resp - # Convert the server response to be more pythonic. - # Avoid circular import - from pymongo.common import camel_to_snake + conversation_id = resp["conversationId"] + token = _oidc_get_current_token(credentials, sock_info.address) + bin_payload = Binary(bson.encode(dict(jwt=token))) + cmd = SON( + [ + ("saslContinue", 1), + ("conversationId", conversation_id), + ("payload", bin_payload), + ] + ) + response = sock_info.command("$external", cmd) + if not response["done"]: + _oidc_cache.pop(cache_key, None) + raise OperationFailure("SASL conversation failed to complete.") - orig_server_resp: Dict = bson.decode(response["payload"]) - server_resp = dict() - for key, value in orig_server_resp.items(): - server_resp[camel_to_snake(key)] = value - cache_value.server_resp = server_resp - conversation_id = response["conversationId"] +def _oidc_get_current_token(credentials, address, use_callbacks=True): + properties: _OIDCProperties = credentials.mechanism_properties + cache_key = _oidc_get_cache_key(credentials, address) + cache_value = _oidc_cache[cache_key] + principal_name = credentials.username + + request_cb = properties.on_oidc_request_token + refresh_cb = properties.on_oidc_refresh_token + if not use_callbacks: + request_cb = None + refresh_cb = None current_valid_token = False - if cache_value.token_exp_utc is not None and not reauthenticate: + if cache_value.token_exp_utc is not None: now_utc = datetime.now(timezone.utc) exp_utc = cache_value.token_exp_utc buffer_seconds = _OIDC_TOKEN_BUFFER_MINUTES * 60 @@ -600,13 +586,16 @@ def _authenticate_oidc(credentials, sock_info, reauthenticate): timeout = _OIDC_CALLBACK_TIMEOUT_SECONDS + if not use_callbacks and not current_valid_token: + return dict() + if not current_valid_token: with cache_value.lock: if cache_value.token_result is None or refresh_cb is None: cache_value.token_result = request_cb( principal_name, cache_value.server_resp, timeout ) - else: + elif request_cb is not None: cache_value.token_result = refresh_cb( principal_name, cache_value.server_resp, cache_value.token_result, timeout ) @@ -630,37 +619,90 @@ def _authenticate_oidc(credentials, sock_info, reauthenticate): exp_utc = now_utc + timedelta(seconds=expires_in) cache_value.token_exp_utc = exp_utc - bin_payload = Binary(bson.encode(dict(jwt=token))) + return token + + +def _invalidate_oidc_token(credentials, address): + cache_key = _oidc_get_cache_key(credentials, address) + cache_value = _oidc_cache.get(cache_key) + if cache_value: + cache_value.token_exp_utc = None + + +def _authenticate_oidc_start(credentials, address, use_callbacks=True): + properties: _OIDCProperties = credentials.mechanism_properties - if conversation_id is not None: + # Clear out old items in the cache. + now_utc = datetime.now(timezone.utc) + to_remove = [] + for key, value in _oidc_cache.items(): + if value.cache_exp_utc < now_utc: + to_remove.append(key) + for key in to_remove: + del _oidc_cache[key] + + # Handle aws device credentials. + if properties.device_name == "aws": + aws_identity_file = os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] + with open(aws_identity_file) as fid: + token = fid.read().strip() + payload = dict(jwt=token) cmd = SON( [ - ("saslContinue", 1), - ("conversationId", conversation_id), - ("payload", bin_payload), + ("saslStart", 1), + ("mechanism", "MONGODB-OIDC"), + ("payload", Binary(bson.encode(payload))), ] ) - else: + return cmd + + cache_key = _oidc_get_cache_key(credentials, address) + cache_value = _oidc_cache.get(cache_key) + principal_name = credentials.username + + if cache_value is not None: + cache_value.cache_exp_utc = datetime.now(timezone.utc) + timedelta( + minutes=_OIDC_CACHE_TIMEOUT_MINUTES + ) + + if cache_value is None: + lock = threading.Lock() + cache_exp_utc = datetime.now(timezone.utc) + timedelta(minutes=_OIDC_CACHE_TIMEOUT_MINUTES) + cache_value = _OIDCCache( + lock=lock, + token_result=None, + server_resp=None, + token_exp_utc=None, + cache_exp_utc=cache_exp_utc, + ) + _oidc_cache[cache_key] = cache_value + + if cache_value.server_resp is None: + # Send the SASL start with the optional principal name. + payload = dict() + + if principal_name: + payload["n"] = principal_name + cmd = SON( [ ("saslStart", 1), ("mechanism", "MONGODB-OIDC"), - ("payload", bin_payload), + ("payload", Binary(bson.encode(payload))), + ("autoAuthorize", 1), ] ) + return cmd - try: - response = sock_info.command("$external", cmd) - except Exception: - _oidc_cache.pop(cache_key, None) - # Allow for one retry on reauthenticate when using server step 2. - if reauthenticate and conversation_id is None: - return _authenticate_oidc(credentials, sock_info, False) - raise - - if not response["done"]: - _oidc_cache.pop(cache_key, None) - raise OperationFailure("SASL conversation failed to complete.") + token = _oidc_get_current_token(credentials, address, use_callbacks) + bin_payload = Binary(bson.encode(dict(jwt=token))) + return SON( + [ + ("saslStart", 1), + ("mechanism", "MONGODB-OIDC"), + ("payload", bin_payload), + ] + ) def _authenticate_default(credentials, sock_info): @@ -696,15 +738,16 @@ def _authenticate_default(credentials, sock_info): class _AuthContext(object): - def __init__(self, credentials): + def __init__(self, credentials, address): self.credentials = credentials self.speculative_authenticate = None + self.address = address @staticmethod - def from_credentials(creds): + def from_credentials(creds, address): spec_cls = _SPECULATIVE_AUTH_MAP.get(creds.mechanism) if spec_cls: - return spec_cls(creds) + return spec_cls(creds, address) return None def speculate_command(self): @@ -718,8 +761,8 @@ def speculate_succeeded(self): class _ScramContext(_AuthContext): - def __init__(self, credentials, mechanism): - super(_ScramContext, self).__init__(credentials) + def __init__(self, credentials, mechanism, address): + super(_ScramContext, self).__init__(credentials, address) self.scram_data = None self.mechanism = mechanism @@ -740,10 +783,18 @@ def speculate_command(self): return cmd +class _OIDCContext(_AuthContext): + def speculate_command(self): + cmd = _authenticate_oidc_start(self.credentials, self.address) + cmd["db"] = self.credentials.source + return cmd + + _SPECULATIVE_AUTH_MAP: Mapping[str, Callable] = { "MONGODB-X509": _X509Context, "SCRAM-SHA-1": functools.partial(_ScramContext, mechanism="SCRAM-SHA-1"), "SCRAM-SHA-256": functools.partial(_ScramContext, mechanism="SCRAM-SHA-256"), + "MONGODB-OIDC": _OIDCContext, "DEFAULT": functools.partial(_ScramContext, mechanism="SCRAM-SHA-256"), } @@ -752,10 +803,12 @@ def authenticate(credentials, sock_info, reauthenticate=False): """Authenticate sock_info.""" mechanism = credentials.mechanism auth_func = _AUTH_MAP[mechanism] + if reauthenticate and mechanism == "MONGODB-OIDC": + _invalidate_oidc_token(credentials, sock_info.address) + if reauthenticate and sock_info.performed_handshake: + # Existing hello response is stale, call it again. + sock_info.hello() if mechanism == "MONGODB-OIDC": auth_func(credentials, sock_info, reauthenticate) else: - if reauthenticate and sock_info.performed_handshake: - # Existing hello response is stale, call it again. - sock_info.hello() auth_func(credentials, sock_info) diff --git a/pymongo/pool.py b/pymongo/pool.py index 80b440ae6d..d9aa6e7a38 100644 --- a/pymongo/pool.py +++ b/pymongo/pool.py @@ -647,7 +647,7 @@ def _hello(self, cluster_time, topology_version, heartbeat_frequency): if creds: if creds.mechanism == "DEFAULT" and creds.username: cmd["saslSupportedMechs"] = creds.source + "." + creds.username - auth_ctx = auth._AuthContext.from_credentials(creds) + auth_ctx = auth._AuthContext.from_credentials(creds, self.address) if auth_ctx: cmd["speculativeAuthenticate"] = auth_ctx.speculate_command() else: diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py index a78d1618db..fd31ccf858 100644 --- a/test/auth_aws/test_auth_oidc.py +++ b/test/auth_aws/test_auth_oidc.py @@ -331,9 +331,9 @@ def refresh_token(principal, info, creds, timeout): succeeded_events = [i.command_name for i in listener.succeeded_events] failed_events = [i.command_name for i in listener.failed_events] - assert started_events == ["find", "saslStart", "saslStart", "saslContinue", "find"] - assert succeeded_events == ["saslStart", "saslContinue", "find"] - assert failed_events == ["find", "saslStart"] + assert started_events == ["find", "find", "find"], started_events + assert succeeded_events == ["find"], succeeded_events + assert failed_events == ["find", "find"], failed_events # Assert that the refresh callback has been called. self.assertEqual(refresh_called, 1) From 35fc668f4d3189514b5101320659c700576432ec Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Tue, 28 Feb 2023 08:58:35 -0600 Subject: [PATCH 078/133] fix errors --- pymongo/auth.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymongo/auth.py b/pymongo/auth.py index 6b9971fd61..8f5c486054 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -478,7 +478,7 @@ def _authenticate_x509(credentials, sock_info): # MONGODB-X509 is done after the speculative auth step. return - cmd = _X509Context(credentials).speculate_command() + cmd = _X509Context(credentials, sock_info.address).speculate_command() sock_info.command("$external", cmd) @@ -761,7 +761,7 @@ def speculate_succeeded(self): class _ScramContext(_AuthContext): - def __init__(self, credentials, mechanism, address): + def __init__(self, credentials, address, mechanism): super(_ScramContext, self).__init__(credentials, address) self.scram_data = None self.mechanism = mechanism From 3ceda8a7df2367262ea3a95b0b29ad79d88c7d26 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Tue, 28 Feb 2023 10:36:34 -0600 Subject: [PATCH 079/133] clarify reauth behavior --- pymongo/auth.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymongo/auth.py b/pymongo/auth.py index 8f5c486054..a0d2410f3c 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -528,8 +528,8 @@ def _authenticate_oidc(credentials, sock_info, reauthenticate): resp = sock_info.command(credentials.source, cmd) except Exception: _oidc_cache.pop(cache_key, None) - # Allow for one retry on reauthenticate when using server step 2. - if reauthenticate: + # Allow for one retry on reauthenticate when callbacks are in use. + if reauthenticate and not credentials.mechanism_properties.device_name: return _authenticate_oidc(credentials, sock_info, False) raise From 6976ef85b149cb584efa4d6217e777617404f6ae Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 1 Mar 2023 15:29:25 -0600 Subject: [PATCH 080/133] fix speculativeAuth and add prose tests --- pymongo/auth.py | 12 ++-- test/auth_aws/test_auth_oidc.py | 100 +++++++++++++++++++++++++++++++- 2 files changed, 105 insertions(+), 7 deletions(-) diff --git a/pymongo/auth.py b/pymongo/auth.py index a0d2410f3c..f3b1b5b45c 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -587,7 +587,7 @@ def _oidc_get_current_token(credentials, address, use_callbacks=True): timeout = _OIDC_CALLBACK_TIMEOUT_SECONDS if not use_callbacks and not current_valid_token: - return dict() + return None if not current_valid_token: with cache_value.lock: @@ -695,6 +695,8 @@ def _authenticate_oidc_start(credentials, address, use_callbacks=True): return cmd token = _oidc_get_current_token(credentials, address, use_callbacks) + if not token: + return None bin_payload = Binary(bson.encode(dict(jwt=token))) return SON( [ @@ -785,7 +787,9 @@ def speculate_command(self): class _OIDCContext(_AuthContext): def speculate_command(self): - cmd = _authenticate_oidc_start(self.credentials, self.address) + cmd = _authenticate_oidc_start(self.credentials, self.address, False) + if cmd is None: + return cmd["db"] = self.credentials.source return cmd @@ -806,8 +810,8 @@ def authenticate(credentials, sock_info, reauthenticate=False): if reauthenticate and mechanism == "MONGODB-OIDC": _invalidate_oidc_token(credentials, sock_info.address) if reauthenticate and sock_info.performed_handshake: - # Existing hello response is stale, call it again. - sock_info.hello() + # Existing auth_ctx is stale, remove it. + sock_info.auth_ctx = None if mechanism == "MONGODB-OIDC": auth_func(credentials, sock_info, reauthenticate) else: diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py index fd31ccf858..15e33843d5 100644 --- a/test/auth_aws/test_auth_oidc.py +++ b/test/auth_aws/test_auth_oidc.py @@ -282,6 +282,94 @@ def bad_refresh(principal, info, creds, timeout): client.close() assert len(_oidc_cache) == 0 + def test_speculative_auth_succeeds(self): + # Clear the cache + _oidc_cache.clear() + token_file = os.path.join(self.token_dir, "test_user1") + + def request_token(principal, info, timeout): + with open(token_file) as fid: + token = fid.read() + return dict(access_token=token, expires_in_seconds=1000) + + # Create a client with a request callback that returns a valid token + # that will not expire soon. + props: Dict = dict(on_oidc_request_token=request_token) + client = MongoClient(self.uri_single, authmechanismproperties=props) + + # Set a fail point for saslStart commands. + with self.fail_point( + { + "mode": {"times": 2}, + "data": {"failCommands": ["saslStart"], "errorCode": 18}, + } + ): + # Perform a find operation. + client.test.test.find_one() + + # Close the client. + client.close() + + # Create a new client. + client = MongoClient(self.uri_single, authmechanismproperties=props) + + # Set a fail point for saslStart commands. + with self.fail_point( + { + "mode": {"times": 2}, + "data": {"failCommands": ["saslStart"], "errorCode": 18}, + } + ): + # Perform a find operation. + client.test.test.find_one() + + # Close the client. + client.close() + + def test_speculative_auth_fails(self): + # Clear the cache + _oidc_cache.clear() + token_file = os.path.join(self.token_dir, "test_user1") + + def request_token(principal, info, timeout): + with open(token_file) as fid: + token = fid.read() + return dict(access_token=token, expires_in_seconds=60) + + # Create a client with a request callback that returns a valid token + # that will expire soon. + props: Dict = dict(on_oidc_request_token=request_token) + client = MongoClient(self.uri_single, authmechanismproperties=props) + + # Set a fail point for saslStart commands. + with self.fail_point( + { + "mode": {"times": 2}, + "data": {"failCommands": ["saslStart"], "errorCode": 18}, + } + ): + # Perform a find operation. + client.test.test.find_one() + + # Close the client. + client.close() + + client = MongoClient(self.uri_single, authmechanismproperties=props) + + # Set a fail point for saslStart commands. + with self.fail_point( + { + "mode": {"times": 2}, + "data": {"failCommands": ["saslStart"], "errorCode": 18}, + } + ): + # Perform a find operation. + with self.assertRaises(OperationFailure): + client.test.test.find_one() + + # Close the client. + client.close() + def test_reauthenticate_read(self): token_file = os.path.join(self.token_dir, "test_user1") refresh_called = 0 @@ -331,9 +419,15 @@ def refresh_token(principal, info, creds, timeout): succeeded_events = [i.command_name for i in listener.succeeded_events] failed_events = [i.command_name for i in listener.failed_events] - assert started_events == ["find", "find", "find"], started_events - assert succeeded_events == ["find"], succeeded_events - assert failed_events == ["find", "find"], failed_events + assert started_events == [ + "find", + "saslStart", + "saslStart", + "saslContinue", + "find", + ], started_events + assert succeeded_events == ["saslStart", "saslContinue", "find"], succeeded_events + assert failed_events == ["find", "saslStart"], failed_events # Assert that the refresh callback has been called. self.assertEqual(refresh_called, 1) From 43a5269dcc9c70af318e7a9c36728c5fa62a8e10 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 1 Mar 2023 18:14:20 -0600 Subject: [PATCH 081/133] update docstring --- test/auth_aws/test_auth_oidc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py index 15e33843d5..bf7625b5a8 100644 --- a/test/auth_aws/test_auth_oidc.py +++ b/test/auth_aws/test_auth_oidc.py @@ -363,7 +363,7 @@ def request_token(principal, info, timeout): "data": {"failCommands": ["saslStart"], "errorCode": 18}, } ): - # Perform a find operation. + # Perform a find operation that fails. with self.assertRaises(OperationFailure): client.test.test.find_one() From 40ed8939e13e5ec2d0d0293a518fbdf90cf81b9b Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Tue, 7 Mar 2023 10:43:48 -0600 Subject: [PATCH 082/133] rename DEVICE_NAME to PRINCIPAL_NAME --- doc/examples/authentication.rst | 2 +- pymongo/auth.py | 16 ++++++++-------- pymongo/common.py | 2 +- test/auth/legacy/connection-string.json | 12 ++++++------ test/auth_aws/test_auth_oidc.py | 4 ++-- test/test_auth_spec.py | 4 ++-- 6 files changed, 20 insertions(+), 20 deletions(-) diff --git a/doc/examples/authentication.rst b/doc/examples/authentication.rst index e023fec981..8aaa55f0c5 100644 --- a/doc/examples/authentication.rst +++ b/doc/examples/authentication.rst @@ -406,7 +406,7 @@ available, by installing pymongo with the A sample URI would be: >>> from pymongo import MongoClient - >>> uri = "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=DEVICE_NAME:aws") + >>> uri = "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:aws") >>> client = MongoClient(uri) The driver will use the authentication token from the file given by the diff --git a/pymongo/auth.py b/pymongo/auth.py index f3b1b5b45c..7dfb6d7ddc 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -106,7 +106,7 @@ def __hash__(self): _OIDCProperties = namedtuple( "_OIDCProperties", - ["on_oidc_request_token", "on_oidc_refresh_token", "device_name"], + ["on_oidc_request_token", "on_oidc_refresh_token", "provider_name"], ) """Mechanism properties for MONGODB-OIDC authentication.""" @@ -164,15 +164,15 @@ def _build_credentials_tuple(mech, source, user, passwd, extra, database): properties = extra.get("authmechanismproperties", {}) on_oidc_request_token = properties.get("on_oidc_request_token") on_oidc_refresh_token = properties.get("on_oidc_refresh_token", None) - device_name = properties.get("DEVICE_NAME", "") - if not on_oidc_request_token and device_name != "aws": + provider_name = properties.get("PROVIDER_NAME", "") + if not on_oidc_request_token and provider_name != "aws": raise ConfigurationError( - "authentication with MONGODB-OIDC requires providing an on_oidc_request_token or a device_name of 'aws'" + "authentication with MONGODB-OIDC requires providing an on_oidc_request_token or a provider_name of 'aws'" ) oidc_props = _OIDCProperties( on_oidc_request_token=on_oidc_request_token, on_oidc_refresh_token=on_oidc_refresh_token, - device_name=device_name, + provider_name=provider_name, ) return MongoCredential(mech, "$external", user, passwd, oidc_props, None) @@ -529,7 +529,7 @@ def _authenticate_oidc(credentials, sock_info, reauthenticate): except Exception: _oidc_cache.pop(cache_key, None) # Allow for one retry on reauthenticate when callbacks are in use. - if reauthenticate and not credentials.mechanism_properties.device_name: + if reauthenticate and not credentials.mechanism_properties.provider_name: return _authenticate_oidc(credentials, sock_info, False) raise @@ -641,8 +641,8 @@ def _authenticate_oidc_start(credentials, address, use_callbacks=True): for key in to_remove: del _oidc_cache[key] - # Handle aws device credentials. - if properties.device_name == "aws": + # Handle aws provider credentials. + if properties.provider_name == "aws": aws_identity_file = os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] with open(aws_identity_file) as fid: token = fid.read().strip() diff --git a/pymongo/common.py b/pymongo/common.py index f38dba4a7a..cdb2d66285 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -429,7 +429,7 @@ def validate_read_preference_tags(name: str, value: Any) -> List[Dict[str, str]] "CANONICALIZE_HOST_NAME", "SERVICE_REALM", "AWS_SESSION_TOKEN", - "DEVICE_NAME", + "PROVIDER_NAME", ] ) diff --git a/test/auth/legacy/connection-string.json b/test/auth/legacy/connection-string.json index 167d7dd290..ca979010af 100644 --- a/test/auth/legacy/connection-string.json +++ b/test/auth/legacy/connection-string.json @@ -508,7 +508,7 @@ }, { "description": "should recognise the mechanism with aws device (MONGODB-OIDC)", - "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=DEVICE_NAME:aws", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:aws", "valid": true, "credential": { "username": null, @@ -516,13 +516,13 @@ "source": "$external", "mechanism": "MONGODB-OIDC", "mechanism_properties": { - "DEVICE_NAME": "aws" + "PROVIDER_NAME": "aws" } } }, { "description": "should recognise the mechanism when auth source is explicitly specified and with aws device (MONGODB-OIDC)", - "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authSource=$external&authMechanismProperties=DEVICE_NAME:aws", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authSource=$external&authMechanismProperties=PROVIDER_NAME:aws", "valid": true, "credential": { "username": null, @@ -530,7 +530,7 @@ "source": "$external", "mechanism": "MONGODB-OIDC", "mechanism_properties": { - "DEVICE_NAME": "aws" + "PROVIDER_NAME": "aws" } } }, @@ -543,13 +543,13 @@ }, { "description": "should throw an exception if username and deviceName are specified (MONGODB-OIDC)", - "uri": "mongodb://principalName@localhost/?authMechanism=MONGODB-OIDC&DEVICE_NAME:gcp", + "uri": "mongodb://principalName@localhost/?authMechanism=MONGODB-OIDC&PROVIDER_NAME:gcp", "valid": false, "credential": null }, { "description": "should throw an exception if specified deviceName is not supported (MONGODB-OIDC)", - "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=DEVICE_NAME:unexisted", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:unexisted", "valid": false, "credential": null }, diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py index bf7625b5a8..3d19be2c14 100644 --- a/test/auth_aws/test_auth_oidc.py +++ b/test/auth_aws/test_auth_oidc.py @@ -52,7 +52,7 @@ def fail_point(self, command_args): def test_connect_aws(self): os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] = os.path.join(self.token_dir, "test_user1") - props = dict(DEVICE_NAME="aws") + props = dict(PROVIDER_NAME="aws") client = MongoClient(self.uri_single, authmechanismproperties=props) client.test.test.find_one() client.close() @@ -276,7 +276,7 @@ def bad_refresh(principal, info, creds, timeout): # Ensure that a ``find`` operation does not add credentials to the cache. _oidc_cache.clear() os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] = os.path.join(self.token_dir, "test_user1") - props = dict(DEVICE_NAME="aws") + props = dict(PROVIDER_NAME="aws") client = MongoClient(self.uri_single, authmechanismproperties=props) client.test.test.find_one() client.close() diff --git a/test/test_auth_spec.py b/test/test_auth_spec.py index 4ab5982814..36969beb83 100644 --- a/test/test_auth_spec.py +++ b/test/test_auth_spec.py @@ -80,8 +80,8 @@ def run_test(self): self.assertEqual( actual.aws_session_token, expected["AWS_SESSION_TOKEN"] ) - elif "DEVICE_NAME" in expected: - self.assertEqual(actual.device_name, expected["DEVICE_NAME"]) + elif "PROVIDER_NAME" in expected: + self.assertEqual(actual.PROVIDER_NAME, expected["PROVIDER_NAME"]) elif "on_oidc_request_token" in expected: self.assertEqual( actual.on_oidc_request_token, expected["on_oidc_request_token"] From a46ba288dde3b466cb294ff67adafc23bea795c6 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Tue, 7 Mar 2023 11:28:49 -0600 Subject: [PATCH 083/133] fix test runner --- test/test_auth_spec.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_auth_spec.py b/test/test_auth_spec.py index 36969beb83..6dc2b5a253 100644 --- a/test/test_auth_spec.py +++ b/test/test_auth_spec.py @@ -81,7 +81,7 @@ def run_test(self): actual.aws_session_token, expected["AWS_SESSION_TOKEN"] ) elif "PROVIDER_NAME" in expected: - self.assertEqual(actual.PROVIDER_NAME, expected["PROVIDER_NAME"]) + self.assertEqual(actual.provider_name, expected["PROVIDER_NAME"]) elif "on_oidc_request_token" in expected: self.assertEqual( actual.on_oidc_request_token, expected["on_oidc_request_token"] From c7df390ff3dce21560cb0585d45065413aeb46c1 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 9 Mar 2023 11:52:59 -0600 Subject: [PATCH 084/133] better reauth support --- pymongo/auth.py | 8 +++++--- pymongo/helpers.py | 1 - test/auth_aws/test_auth_oidc.py | 21 +++++++++++++++++++++ 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/pymongo/auth.py b/pymongo/auth.py index 7dfb6d7ddc..05f5cb57d5 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -519,6 +519,7 @@ def _authenticate_oidc(credentials, sock_info, reauthenticate): ctx = sock_info.auth_ctx cmd = None cache_key = _oidc_get_cache_key(credentials, sock_info.address) + in_cache = cache_key in _oidc_cache if ctx and ctx.speculate_succeeded(): resp = ctx.speculative_authenticate @@ -526,10 +527,11 @@ def _authenticate_oidc(credentials, sock_info, reauthenticate): cmd = _authenticate_oidc_start(credentials, sock_info.address) try: resp = sock_info.command(credentials.source, cmd) - except Exception: + except Exception as e: _oidc_cache.pop(cache_key, None) - # Allow for one retry on reauthenticate when callbacks are in use. - if reauthenticate and not credentials.mechanism_properties.provider_name: + # Allow for one retry on reauthenticate when callbacks are in use + # and there was no cache. + if reauthenticate and not credentials.mechanism_properties.provider_name and in_cache: return _authenticate_oidc(credentials, sock_info, False) raise diff --git a/pymongo/helpers.py b/pymongo/helpers.py index 8851a63c94..28978dc036 100644 --- a/pymongo/helpers.py +++ b/pymongo/helpers.py @@ -65,7 +65,6 @@ 89, # NetworkTimeout 9001, # SocketException 262, # ExceededTimeLimit - 391, ] ) diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py index 3d19be2c14..101b1d392f 100644 --- a/test/auth_aws/test_auth_oidc.py +++ b/test/auth_aws/test_auth_oidc.py @@ -433,6 +433,27 @@ def refresh_token(principal, info, creds, timeout): self.assertEqual(refresh_called, 1) client.close() + # Create a new client with the callbacks. + client = MongoClient(self.uri_single, authmechanismproperties=props) + + # Perform a find operation. + client.test.test.find_one() + + # Clear the cache. + _oidc_cache.clear() + + with self.fail_point( + { + "mode": {"times": 2}, + "data": {"failCommands": ["find", "saslStart"], "errorCode": 391}, + } + ): + # Perform a find operation that fails. + with self.assertRaises(OperationFailure): + client.test.test.find_one() + + client.close() + if __name__ == "__main__": unittest.main() From 02ce52d49b7ffaa3f80f1f9e474685a6929d2826 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Mon, 13 Mar 2023 17:20:44 -0500 Subject: [PATCH 085/133] add test numbering and new prose test --- test/auth_aws/test_auth_oidc.py | 61 +++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py index 101b1d392f..579c99c91a 100644 --- a/test/auth_aws/test_auth_oidc.py +++ b/test/auth_aws/test_auth_oidc.py @@ -53,14 +53,18 @@ def fail_point(self, command_args): def test_connect_aws(self): os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] = os.path.join(self.token_dir, "test_user1") props = dict(PROVIDER_NAME="aws") + + # Test 1 client = MongoClient(self.uri_single, authmechanismproperties=props) client.test.test.find_one() client.close() + # Test 2 client = MongoClient(self.uri_multiple, authmechanismproperties=props) client.test.test.find_one() client.close() + # Test 3 os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] = os.path.join(self.token_dir, "test_user2") client = MongoClient(self.uri_multiple, authmechanismproperties=props) client.test.test.find_one() @@ -74,11 +78,13 @@ def request_token(principal, info, timeout): token = fid.read() return dict(access_token=token) + # Test 1 props: Dict = dict(on_oidc_request_token=request_token) client = MongoClient(self.uri_single, authmechanismproperties=props) client.test.test.find_one() client.close() + # Test 2 _oidc_cache.clear() client = MongoClient( self.uri_multiple, username="test_user1", authmechanismproperties=props @@ -86,6 +92,7 @@ def request_token(principal, info, timeout): client.test.test.find_one() client.close() + # Test 3 _oidc_cache.clear() token_file = os.path.join(self.token_dir, "test_user2") client = MongoClient( @@ -105,6 +112,7 @@ def test_bad_callbacks(self): def request_token_null(principal, info, timeout): return None + # Test 1 props: Dict = dict(on_oidc_request_token=request_token_null) client = MongoClient(self.uri_single, authMechanismProperties=props) with self.assertRaises(ValueError): @@ -114,6 +122,7 @@ def request_token_null(principal, info, timeout): def request_token_no_token(principal, info, timeout): return dict() + # Test 2 _oidc_cache.clear() props: Dict = dict(on_oidc_request_token=request_token_no_token) client = MongoClient(self.uri_single, authMechanismProperties=props) @@ -131,6 +140,7 @@ def request_token(principal, info, timeout): token = fid.read() return dict(access_token=token) + # Test 3 _oidc_cache.clear() props: Dict = dict( on_oidc_request_token=request_token, on_oidc_refresh_token=request_refresh_null @@ -147,12 +157,14 @@ def request_token(principal, info, timeout): def request_refresh_no_token(principal, info, creds, timeout): return dict() + # Test 4 _oidc_cache.clear() props["on_oidc_refresh_token"] = request_refresh_no_token client = MongoClient(self.uri_single, authMechanismProperties=props) client.test.test.find_one() client.close() + # Test 5 client = MongoClient(self.uri_single, authMechanismProperties=props) with self.assertRaises(ValueError): client.test.test.find_one() @@ -162,6 +174,7 @@ def test_caching(self): request_called = 0 refresh_called = 0 + # Test 1 # Clear the cache. _oidc_cache.clear() # Create a new client with a request callback and a refresh callback. Both callbacks will read the contents of the ``AWS_WEB_IDENTITY_TOKEN_FILE`` location to obtain a valid access token. @@ -191,6 +204,7 @@ def refresh_token(principal, info, creds, timeout): token = fid.read() return dict(access_token=token, expires_in_seconds=60) + # Test 2 _oidc_cache.clear() props: Dict = dict(on_oidc_request_token=request_token, on_oidc_refresh_token=refresh_token) @@ -201,6 +215,7 @@ def refresh_token(principal, info, creds, timeout): assert len(_oidc_cache) == 1 + # Test 3 # Create a new client with the same request callback and a refresh callback. # Ensure that a ``find`` operation results in a call to the refresh callback. client = MongoClient(self.uri_single, authMechanismProperties=props) @@ -210,6 +225,7 @@ def refresh_token(principal, info, creds, timeout): assert refresh_called == 1 assert len(_oidc_cache) == 1 + # Test 4 # Clear the cache. _oidc_cache.clear() @@ -225,6 +241,7 @@ def refresh_token(principal, info, creds, timeout): assert request_called == 1 assert len(_oidc_cache) == 1 + # Test 5 # Create a new client with the same request callback. # Ensure that a ``find`` operation results in a call to the request callback. client = MongoClient(self.uri_single, authMechanismProperties=props) @@ -246,6 +263,7 @@ def request_token_2(principal, info, timeout): assert request_called == 3 assert len(_oidc_cache) == 2 + # Test 6. # Clear the cache _oidc_cache.clear() @@ -271,6 +289,7 @@ def bad_refresh(principal, info, creds, timeout): # Ensure that the cache has been cleared. assert len(_oidc_cache) == 0 + # Test 7. # Clear the cache. # Create a new client using the AWS device workflow. # Ensure that a ``find`` operation does not add credentials to the cache. @@ -292,6 +311,7 @@ def request_token(principal, info, timeout): token = fid.read() return dict(access_token=token, expires_in_seconds=1000) + # Test 1 # Create a client with a request callback that returns a valid token # that will not expire soon. props: Dict = dict(on_oidc_request_token=request_token) @@ -310,6 +330,7 @@ def request_token(principal, info, timeout): # Close the client. client.close() + # Test 2 # Create a new client. client = MongoClient(self.uri_single, authmechanismproperties=props) @@ -336,6 +357,7 @@ def request_token(principal, info, timeout): token = fid.read() return dict(access_token=token, expires_in_seconds=60) + # Test 1 # Create a client with a request callback that returns a valid token # that will expire soon. props: Dict = dict(on_oidc_request_token=request_token) @@ -354,6 +376,7 @@ def request_token(principal, info, timeout): # Close the client. client.close() + # Test 2 client = MongoClient(self.uri_single, authmechanismproperties=props) # Set a fail point for saslStart commands. @@ -375,6 +398,7 @@ def test_reauthenticate_read(self): refresh_called = 0 listener = EventListener() + # Test 1 # Clear the cache _oidc_cache.clear() @@ -433,6 +457,7 @@ def refresh_token(principal, info, creds, timeout): self.assertEqual(refresh_called, 1) client.close() + # Test 2 # Create a new client with the callbacks. client = MongoClient(self.uri_single, authmechanismproperties=props) @@ -454,6 +479,42 @@ def refresh_token(principal, info, creds, timeout): client.close() + # . Ensure there is a cache. + + # Test 3 + # Clear the cache. + _oidc_cache.clear() + + # Create request and refresh callbacks that return valid credentials + # that will not expire soon. + client = MongoClient( + self.uri_single, event_listeners=[listener], authmechanismproperties=props + ) + + # Perform a find operation. + client.test.test.find_one() + + # Close the client. + client.close() + + # Create a new client with the same callbacks. + client = MongoClient( + self.uri_single, event_listeners=[listener], authmechanismproperties=props + ) + + # Set a fail point for ``saslStart`` commands of the form + with self.fail_point( + { + "mode": {"times": 2}, + "data": {"failCommands": ["find", "saslStart"], "errorCode": 391}, + } + ): + # Perform a find operation that succeeds. + client.test.test.find_one() + + # Close the client. + client.close() + if __name__ == "__main__": unittest.main() From 28a804b5bdc22bf7b697583634027423f5a1610f Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Mon, 13 Mar 2023 17:25:26 -0500 Subject: [PATCH 086/133] lint --- pymongo/auth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymongo/auth.py b/pymongo/auth.py index 05f5cb57d5..ae9205e0ab 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -527,7 +527,7 @@ def _authenticate_oidc(credentials, sock_info, reauthenticate): cmd = _authenticate_oidc_start(credentials, sock_info.address) try: resp = sock_info.command(credentials.source, cmd) - except Exception as e: + except Exception: _oidc_cache.pop(cache_key, None) # Allow for one retry on reauthenticate when callbacks are in use # and there was no cache. From 19dda748013cb4b7487f814639d68cb6c96b77d7 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Tue, 14 Mar 2023 10:59:21 -0500 Subject: [PATCH 087/133] wip clean up tests --- test/auth_aws/test_auth_oidc.py | 71 +++++++++++++++++---------------- 1 file changed, 36 insertions(+), 35 deletions(-) diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py index 579c99c91a..4e5bc76027 100644 --- a/test/auth_aws/test_auth_oidc.py +++ b/test/auth_aws/test_auth_oidc.py @@ -50,26 +50,6 @@ def fail_point(self, command_args): finally: client.admin.command("configureFailPoint", cmd_on["configureFailPoint"], mode="off") - def test_connect_aws(self): - os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] = os.path.join(self.token_dir, "test_user1") - props = dict(PROVIDER_NAME="aws") - - # Test 1 - client = MongoClient(self.uri_single, authmechanismproperties=props) - client.test.test.find_one() - client.close() - - # Test 2 - client = MongoClient(self.uri_multiple, authmechanismproperties=props) - client.test.test.find_one() - client.close() - - # Test 3 - os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] = os.path.join(self.token_dir, "test_user2") - client = MongoClient(self.uri_multiple, authmechanismproperties=props) - client.test.test.find_one() - client.close() - def test_connect_callbacks(self): token_file = os.path.join(self.token_dir, "test_user1") @@ -78,13 +58,19 @@ def request_token(principal, info, timeout): token = fid.read() return dict(access_token=token) - # Test 1 + # Single Principal Implicit Username props: Dict = dict(on_oidc_request_token=request_token) client = MongoClient(self.uri_single, authmechanismproperties=props) client.test.test.find_one() client.close() - # Test 2 + # Single Principal Explicit Username + props: Dict = dict(on_oidc_request_token=request_token) + client = MongoClient(self.uri_single, username="test_user1", authmechanismproperties=props) + client.test.test.find_one() + client.close() + + # Multiple Principal User 1 _oidc_cache.clear() client = MongoClient( self.uri_multiple, username="test_user1", authmechanismproperties=props @@ -92,7 +78,7 @@ def request_token(principal, info, timeout): client.test.test.find_one() client.close() - # Test 3 + # Multiple Principal User 2 _oidc_cache.clear() token_file = os.path.join(self.token_dir, "test_user2") client = MongoClient( @@ -101,18 +87,39 @@ def request_token(principal, info, timeout): client.test.test.find_one() client.close() + # Multiple No User client = MongoClient(self.uri_multiple, authmechanismproperties=props) with self.assertRaises(OperationFailure): client.test.test.find_one() client.close() + def test_connect_aws(self): + os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] = os.path.join(self.token_dir, "test_user1") + props = dict(PROVIDER_NAME="aws") + + # Single principal + client = MongoClient(self.uri_single, authmechanismproperties=props) + client.test.test.find_one() + client.close() + + # Multiple principal user 1 + client = MongoClient(self.uri_multiple, authmechanismproperties=props) + client.test.test.find_one() + client.close() + + # Multiple principal user 2 + os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] = os.path.join(self.token_dir, "test_user2") + client = MongoClient(self.uri_multiple, authmechanismproperties=props) + client.test.test.find_one() + client.close() + def test_bad_callbacks(self): _oidc_cache.clear() def request_token_null(principal, info, timeout): return None - # Test 1 + # Request Callback returns null props: Dict = dict(on_oidc_request_token=request_token_null) client = MongoClient(self.uri_single, authMechanismProperties=props) with self.assertRaises(ValueError): @@ -122,7 +129,7 @@ def request_token_null(principal, info, timeout): def request_token_no_token(principal, info, timeout): return dict() - # Test 2 + # Refresh Callback returns null _oidc_cache.clear() props: Dict = dict(on_oidc_request_token=request_token_no_token) client = MongoClient(self.uri_single, authMechanismProperties=props) @@ -140,7 +147,7 @@ def request_token(principal, info, timeout): token = fid.read() return dict(access_token=token) - # Test 3 + # Request callback returns unexpected result _oidc_cache.clear() props: Dict = dict( on_oidc_request_token=request_token, on_oidc_refresh_token=request_refresh_null @@ -157,14 +164,12 @@ def request_token(principal, info, timeout): def request_refresh_no_token(principal, info, creds, timeout): return dict() - # Test 4 + # Refresh callback returns unexpected result _oidc_cache.clear() props["on_oidc_refresh_token"] = request_refresh_no_token client = MongoClient(self.uri_single, authMechanismProperties=props) client.test.test.find_one() client.close() - - # Test 5 client = MongoClient(self.uri_single, authMechanismProperties=props) with self.assertRaises(ValueError): client.test.test.find_one() @@ -174,7 +179,7 @@ def test_caching(self): request_called = 0 refresh_called = 0 - # Test 1 + # Cache succeeds # Clear the cache. _oidc_cache.clear() # Create a new client with a request callback and a refresh callback. Both callbacks will read the contents of the ``AWS_WEB_IDENTITY_TOKEN_FILE`` location to obtain a valid access token. @@ -204,7 +209,7 @@ def refresh_token(principal, info, creds, timeout): token = fid.read() return dict(access_token=token, expires_in_seconds=60) - # Test 2 + # Cache expires _oidc_cache.clear() props: Dict = dict(on_oidc_request_token=request_token, on_oidc_refresh_token=refresh_token) @@ -311,7 +316,6 @@ def request_token(principal, info, timeout): token = fid.read() return dict(access_token=token, expires_in_seconds=1000) - # Test 1 # Create a client with a request callback that returns a valid token # that will not expire soon. props: Dict = dict(on_oidc_request_token=request_token) @@ -330,7 +334,6 @@ def request_token(principal, info, timeout): # Close the client. client.close() - # Test 2 # Create a new client. client = MongoClient(self.uri_single, authmechanismproperties=props) @@ -357,7 +360,6 @@ def request_token(principal, info, timeout): token = fid.read() return dict(access_token=token, expires_in_seconds=60) - # Test 1 # Create a client with a request callback that returns a valid token # that will expire soon. props: Dict = dict(on_oidc_request_token=request_token) @@ -376,7 +378,6 @@ def request_token(principal, info, timeout): # Close the client. client.close() - # Test 2 client = MongoClient(self.uri_single, authmechanismproperties=props) # Set a fail point for saslStart commands. From 9c33a022519d18f183171c1ad0c7128d514b2eab Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Tue, 14 Mar 2023 13:58:35 -0500 Subject: [PATCH 088/133] wip refractor with test headings --- test/auth_aws/test_auth_oidc.py | 226 +++++++++++++++++--------------- 1 file changed, 117 insertions(+), 109 deletions(-) diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py index 4e5bc76027..f1e1fefa7f 100644 --- a/test/auth_aws/test_auth_oidc.py +++ b/test/auth_aws/test_auth_oidc.py @@ -39,6 +39,38 @@ def setUpClass(cls): cls.uri_admin = os.environ["MONGODB_URI"] cls.token_dir = os.environ["OIDC_TOKEN_DIR"] + def setup(self): + _oidc_cache.clear() + os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] = os.path.join(self.token_dir, "test_user1") + + def create_request_callback(self, username="test_user1", expires_in_seconds=None): + + token_file = os.path.join(self.token_dir, username) + + def request_token(principal, info, timeout): + with open(token_file) as fid: + token = fid.read() + resp = dict(access_token=token) + if expires_in_seconds is not None: + resp["expires_in_seconds"] = expires_in_seconds + return resp + + return request_token + + def create_refresh_callback(self, username="test_user1", expires_in_seconds=None): + + token_file = os.path.join(self.token_dir, username) + + def refresh_token(principal, info, creds, timeout): + with open(token_file) as fid: + token = fid.read() + resp = dict(access_token=token) + if expires_in_seconds is not None: + resp["expires_in_seconds"] = expires_in_seconds + return resp + + return refresh_token + @contextmanager def fail_point(self, command_args): cmd_on = SON([("configureFailPoint", "failCommand")]) @@ -50,144 +82,127 @@ def fail_point(self, command_args): finally: client.admin.command("configureFailPoint", cmd_on["configureFailPoint"], mode="off") - def test_connect_callbacks(self): - token_file = os.path.join(self.token_dir, "test_user1") - - def request_token(principal, info, timeout): - with open(token_file) as fid: - token = fid.read() - return dict(access_token=token) - - # Single Principal Implicit Username + def test_connect_callbacks_single_implicit_username(self): + request_token = self.create_request_callback() props: Dict = dict(on_oidc_request_token=request_token) client = MongoClient(self.uri_single, authmechanismproperties=props) client.test.test.find_one() client.close() - # Single Principal Explicit Username + def test_connect_callbacks_single_explicit_username(self): + request_token = self.create_request_callback() props: Dict = dict(on_oidc_request_token=request_token) client = MongoClient(self.uri_single, username="test_user1", authmechanismproperties=props) client.test.test.find_one() client.close() - # Multiple Principal User 1 - _oidc_cache.clear() + def test_connect_callbacks_multiple_principal_user1(self): + request_token = self.create_request_callback() + props: Dict = dict(on_oidc_request_token=request_token) client = MongoClient( self.uri_multiple, username="test_user1", authmechanismproperties=props ) client.test.test.find_one() client.close() - # Multiple Principal User 2 - _oidc_cache.clear() - token_file = os.path.join(self.token_dir, "test_user2") + def test_connect_callbacks_multiple_principal_user2(self): + request_token = self.create_request_callback("test_user2") + props: Dict = dict(on_oidc_request_token=request_token) client = MongoClient( self.uri_multiple, username="test_user2", authmechanismproperties=props ) client.test.test.find_one() client.close() - # Multiple No User + def test_connect_callbacks_multiple_no_username(self): + request_token = self.create_request_callback() + props: Dict = dict(on_oidc_request_token=request_token) client = MongoClient(self.uri_multiple, authmechanismproperties=props) with self.assertRaises(OperationFailure): client.test.test.find_one() client.close() - def test_connect_aws(self): - os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] = os.path.join(self.token_dir, "test_user1") + def test_connect_aws_single_principal(self): props = dict(PROVIDER_NAME="aws") - - # Single principal client = MongoClient(self.uri_single, authmechanismproperties=props) client.test.test.find_one() client.close() - # Multiple principal user 1 + def test_connect_aws_multiple_principal_user1(self): + props = dict(PROVIDER_NAME="aws") client = MongoClient(self.uri_multiple, authmechanismproperties=props) client.test.test.find_one() client.close() - # Multiple principal user 2 + def test_connect_aws_multiple_principal_user2(self): os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] = os.path.join(self.token_dir, "test_user2") + props = dict(PROVIDER_NAME="aws") client = MongoClient(self.uri_multiple, authmechanismproperties=props) client.test.test.find_one() client.close() - def test_bad_callbacks(self): - _oidc_cache.clear() - + def test_request_callback_returns_null(self): def request_token_null(principal, info, timeout): return None - # Request Callback returns null props: Dict = dict(on_oidc_request_token=request_token_null) client = MongoClient(self.uri_single, authMechanismProperties=props) with self.assertRaises(ValueError): client.test.test.find_one() client.close() - def request_token_no_token(principal, info, timeout): - return dict() + def test_refresh_callback_returns_null(self): + request_cb = self.create_request_callback(expires_in_seconds=60) - # Refresh Callback returns null - _oidc_cache.clear() - props: Dict = dict(on_oidc_request_token=request_token_no_token) - client = MongoClient(self.uri_single, authMechanismProperties=props) - with self.assertRaises(ValueError): - client.test.test.find_one() - client.close() - - def request_refresh_null(principal, info, creds, timeout): + def refresh_token_null(principal, info, creds, timeout): return None - token_file = os.path.join(self.token_dir, "test_user1") - - def request_token(principal, info, timeout): - with open(token_file) as fid: - token = fid.read() - return dict(access_token=token) - - # Request callback returns unexpected result - _oidc_cache.clear() props: Dict = dict( - on_oidc_request_token=request_token, on_oidc_refresh_token=request_refresh_null + on_oidc_request_token=request_cb, on_oidc_refresh_token=refresh_token_null ) client = MongoClient(self.uri_single, authMechanismProperties=props) client.test.test.find_one() + with self.assertRaises(ValueError): + client.test.test.find_one() client.close() + def test_request_callback_invalid_result(self): + def request_token_invalid(principal, info, timeout): + return None + + props: Dict = dict(on_oidc_request_token=request_token_invalid) client = MongoClient(self.uri_single, authMechanismProperties=props) with self.assertRaises(ValueError): client.test.test.find_one() client.close() - def request_refresh_no_token(principal, info, creds, timeout): + def test_refresh_callback_invalid_result(self): + request_cb = self.create_request_callback(expires_in_seconds=60) + + def refresh_cb_no_token(principal, info, cred, timeout): return dict() - # Refresh callback returns unexpected result - _oidc_cache.clear() - props["on_oidc_refresh_token"] = request_refresh_no_token + props: Dict = dict( + on_oidc_request_token=request_cb, on_oidc_refresh_token=refresh_cb_no_token + ) client = MongoClient(self.uri_single, authMechanismProperties=props) client.test.test.find_one() - client.close() - client = MongoClient(self.uri_single, authMechanismProperties=props) with self.assertRaises(ValueError): client.test.test.find_one() client.close() - def test_caching(self): + def test_cache_with_refresh(self): request_called = 0 refresh_called = 0 - # Cache succeeds - # Clear the cache. - _oidc_cache.clear() # Create a new client with a request callback and a refresh callback. Both callbacks will read the contents of the ``AWS_WEB_IDENTITY_TOKEN_FILE`` location to obtain a valid access token. + # Give a callback response with a valid accessToken and an expiresInSeconds that is within one minute. token_file = os.path.join(self.token_dir, "test_user1") def request_token(principal, info, timeout): nonlocal request_called + # Validate the info. assert "authorization_endpoint" in info assert "token_endpoint" in info assert "client_id" in info @@ -199,6 +214,7 @@ def request_token(principal, info, timeout): def refresh_token(principal, info, creds, timeout): nonlocal refresh_called + # Validate the info. assert "authorization_endpoint" in info assert "token_endpoint" in info assert "client_id" in info @@ -209,8 +225,6 @@ def refresh_token(principal, info, creds, timeout): token = fid.read() return dict(access_token=token, expires_in_seconds=60) - # Cache expires - _oidc_cache.clear() props: Dict = dict(on_oidc_request_token=request_token, on_oidc_refresh_token=refresh_token) # Ensure that a ``find`` operation adds credentials to the cache. @@ -220,7 +234,6 @@ def refresh_token(principal, info, creds, timeout): assert len(_oidc_cache) == 1 - # Test 3 # Create a new client with the same request callback and a refresh callback. # Ensure that a ``find`` operation results in a call to the refresh callback. client = MongoClient(self.uri_single, authMechanismProperties=props) @@ -230,13 +243,18 @@ def refresh_token(principal, info, creds, timeout): assert refresh_called == 1 assert len(_oidc_cache) == 1 - # Test 4 - # Clear the cache. - _oidc_cache.clear() - + def test_cache_with_no_refresh(self): # Create a new client with a request callback callback. # Give a callback response with a valid accessToken and an expiresInSeconds that is within one minute. - del props["on_oidc_refresh_token"] + request_called = 0 + inner_cb = self.create_request_callback(expires_in_seconds=60) + + def request_cb(principal, info, timeout): + nonlocal request_called + request_called += 1 + return inner_cb(principal, info, timeout) + + props = dict(on_oidc_request_token=request_cb) client = MongoClient(self.uri_single, authMechanismProperties=props) # Ensure that a ``find`` operation adds credentials to the cache. @@ -246,7 +264,9 @@ def refresh_token(principal, info, creds, timeout): assert request_called == 1 assert len(_oidc_cache) == 1 - # Test 5 + def test_cache_key_includes_callback(self): + # Cache key includes callback + # Create a new client with the same request callback. # Ensure that a ``find`` operation results in a call to the request callback. client = MongoClient(self.uri_single, authMechanismProperties=props) @@ -268,7 +288,8 @@ def request_token_2(principal, info, timeout): assert request_called == 3 assert len(_oidc_cache) == 2 - # Test 6. + # Error clears cache + # Clear the cache _oidc_cache.clear() @@ -294,7 +315,8 @@ def bad_refresh(principal, info, creds, timeout): # Ensure that the cache has been cleared. assert len(_oidc_cache) == 0 - # Test 7. + # AWS Automatic workflow does not use cache + # Clear the cache. # Create a new client using the AWS device workflow. # Ensure that a ``find`` operation does not add credentials to the cache. @@ -369,22 +391,7 @@ def request_token(principal, info, timeout): with self.fail_point( { "mode": {"times": 2}, - "data": {"failCommands": ["saslStart"], "errorCode": 18}, - } - ): - # Perform a find operation. - client.test.test.find_one() - - # Close the client. - client.close() - - client = MongoClient(self.uri_single, authmechanismproperties=props) - - # Set a fail point for saslStart commands. - with self.fail_point( - { - "mode": {"times": 2}, - "data": {"failCommands": ["saslStart"], "errorCode": 18}, + "data": {"failCommands": ["isMaster", "saslStart"], "errorCode": 18}, } ): # Perform a find operation that fails. @@ -399,7 +406,8 @@ def test_reauthenticate_read(self): refresh_called = 0 listener = EventListener() - # Test 1 + # Succeeds + # Clear the cache _oidc_cache.clear() @@ -458,31 +466,8 @@ def refresh_token(principal, info, creds, timeout): self.assertEqual(refresh_called, 1) client.close() - # Test 2 - # Create a new client with the callbacks. - client = MongoClient(self.uri_single, authmechanismproperties=props) - - # Perform a find operation. - client.test.test.find_one() - - # Clear the cache. - _oidc_cache.clear() - - with self.fail_point( - { - "mode": {"times": 2}, - "data": {"failCommands": ["find", "saslStart"], "errorCode": 391}, - } - ): - # Perform a find operation that fails. - with self.assertRaises(OperationFailure): - client.test.test.find_one() - - client.close() - - # . Ensure there is a cache. + # Retries and Succeeds with Cache - # Test 3 # Clear the cache. _oidc_cache.clear() @@ -516,6 +501,29 @@ def refresh_token(principal, info, creds, timeout): # Close the client. client.close() + # Fails with no Cache + + # Create a new client with the callbacks. + client = MongoClient(self.uri_single, authmechanismproperties=props) + + # Perform a find operation. + client.test.test.find_one() + + # Clear the cache. + _oidc_cache.clear() + + with self.fail_point( + { + "mode": {"times": 2}, + "data": {"failCommands": ["find", "saslStart"], "errorCode": 391}, + } + ): + # Perform a find operation that fails. + with self.assertRaises(OperationFailure): + client.test.test.find_one() + + client.close() + if __name__ == "__main__": unittest.main() From c21771f03772ff0d760b456b925fd04a42d61ed0 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Tue, 14 Mar 2023 15:33:18 -0500 Subject: [PATCH 089/133] refactor prose tests --- test/auth_aws/test_auth_oidc.py | 188 ++++++++++++++------------------ 1 file changed, 82 insertions(+), 106 deletions(-) diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py index f1e1fefa7f..915b6d8b73 100644 --- a/test/auth_aws/test_auth_oidc.py +++ b/test/auth_aws/test_auth_oidc.py @@ -39,34 +39,49 @@ def setUpClass(cls): cls.uri_admin = os.environ["MONGODB_URI"] cls.token_dir = os.environ["OIDC_TOKEN_DIR"] - def setup(self): + def setUp(self): + self.request_called = 0 + self.refresh_called = 0 _oidc_cache.clear() os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] = os.path.join(self.token_dir, "test_user1") - def create_request_callback(self, username="test_user1", expires_in_seconds=None): + def create_request_cb(self, username="test_user1", expires_in_seconds=None): token_file = os.path.join(self.token_dir, username) def request_token(principal, info, timeout): + # Validate the info. + assert "authorization_endpoint" in info or "device_authorization_endpoint" in info + assert "token_endpoint" in info + assert "client_id" in info + assert timeout == 60 * 5 with open(token_file) as fid: token = fid.read() resp = dict(access_token=token) if expires_in_seconds is not None: resp["expires_in_seconds"] = expires_in_seconds + self.request_called += 1 return resp return request_token - def create_refresh_callback(self, username="test_user1", expires_in_seconds=None): + def create_refresh_cb(self, username="test_user1", expires_in_seconds=None): token_file = os.path.join(self.token_dir, username) def refresh_token(principal, info, creds, timeout): with open(token_file) as fid: token = fid.read() + # Validate the info. + assert "authorization_endpoint" in info + assert "token_endpoint" in info + assert "client_id" in info + assert timeout == 60 * 5 + assert "access_token" in creds resp = dict(access_token=token) if expires_in_seconds is not None: resp["expires_in_seconds"] = expires_in_seconds + self.refresh_called += 1 return resp return refresh_token @@ -83,21 +98,21 @@ def fail_point(self, command_args): client.admin.command("configureFailPoint", cmd_on["configureFailPoint"], mode="off") def test_connect_callbacks_single_implicit_username(self): - request_token = self.create_request_callback() + request_token = self.create_request_cb() props: Dict = dict(on_oidc_request_token=request_token) client = MongoClient(self.uri_single, authmechanismproperties=props) client.test.test.find_one() client.close() def test_connect_callbacks_single_explicit_username(self): - request_token = self.create_request_callback() + request_token = self.create_request_cb() props: Dict = dict(on_oidc_request_token=request_token) client = MongoClient(self.uri_single, username="test_user1", authmechanismproperties=props) client.test.test.find_one() client.close() def test_connect_callbacks_multiple_principal_user1(self): - request_token = self.create_request_callback() + request_token = self.create_request_cb() props: Dict = dict(on_oidc_request_token=request_token) client = MongoClient( self.uri_multiple, username="test_user1", authmechanismproperties=props @@ -106,7 +121,7 @@ def test_connect_callbacks_multiple_principal_user1(self): client.close() def test_connect_callbacks_multiple_principal_user2(self): - request_token = self.create_request_callback("test_user2") + request_token = self.create_request_cb("test_user2") props: Dict = dict(on_oidc_request_token=request_token) client = MongoClient( self.uri_multiple, username="test_user2", authmechanismproperties=props @@ -115,7 +130,7 @@ def test_connect_callbacks_multiple_principal_user2(self): client.close() def test_connect_callbacks_multiple_no_username(self): - request_token = self.create_request_callback() + request_token = self.create_request_cb() props: Dict = dict(on_oidc_request_token=request_token) client = MongoClient(self.uri_multiple, authmechanismproperties=props) with self.assertRaises(OperationFailure): @@ -152,7 +167,7 @@ def request_token_null(principal, info, timeout): client.close() def test_refresh_callback_returns_null(self): - request_cb = self.create_request_callback(expires_in_seconds=60) + request_cb = self.create_request_cb(expires_in_seconds=60) def refresh_token_null(principal, info, creds, timeout): return None @@ -162,6 +177,9 @@ def refresh_token_null(principal, info, creds, timeout): ) client = MongoClient(self.uri_single, authMechanismProperties=props) client.test.test.find_one() + client.close() + + client = MongoClient(self.uri_single, authMechanismProperties=props) with self.assertRaises(ValueError): client.test.test.find_one() client.close() @@ -177,7 +195,7 @@ def request_token_invalid(principal, info, timeout): client.close() def test_refresh_callback_invalid_result(self): - request_cb = self.create_request_callback(expires_in_seconds=60) + request_cb = self.create_request_cb(expires_in_seconds=60) def refresh_cb_no_token(principal, info, cred, timeout): return dict() @@ -187,45 +205,21 @@ def refresh_cb_no_token(principal, info, cred, timeout): ) client = MongoClient(self.uri_single, authMechanismProperties=props) client.test.test.find_one() + client.close() + + client = MongoClient(self.uri_single, authMechanismProperties=props) with self.assertRaises(ValueError): client.test.test.find_one() client.close() def test_cache_with_refresh(self): - request_called = 0 - refresh_called = 0 - # Create a new client with a request callback and a refresh callback. Both callbacks will read the contents of the ``AWS_WEB_IDENTITY_TOKEN_FILE`` location to obtain a valid access token. # Give a callback response with a valid accessToken and an expiresInSeconds that is within one minute. - token_file = os.path.join(self.token_dir, "test_user1") - - def request_token(principal, info, timeout): - nonlocal request_called - # Validate the info. - assert "authorization_endpoint" in info - assert "token_endpoint" in info - assert "client_id" in info - assert timeout == 60 * 5 - with open(token_file) as fid: - token = fid.read() - request_called += 1 - return dict(access_token=token, expires_in_seconds=60) - - def refresh_token(principal, info, creds, timeout): - nonlocal refresh_called - # Validate the info. - assert "authorization_endpoint" in info - assert "token_endpoint" in info - assert "client_id" in info - assert timeout == 60 * 5 - assert "access_token" in creds - refresh_called += 1 - with open(token_file) as fid: - token = fid.read() - return dict(access_token=token, expires_in_seconds=60) + request_cb = self.create_request_cb(expires_in_seconds=60) + refresh_cb = self.create_refresh_cb() - props: Dict = dict(on_oidc_request_token=request_token, on_oidc_refresh_token=refresh_token) + props: Dict = dict(on_oidc_request_token=request_cb, on_oidc_refresh_token=refresh_cb) # Ensure that a ``find`` operation adds credentials to the cache. client = MongoClient(self.uri_single, authMechanismProperties=props) @@ -240,19 +234,13 @@ def refresh_token(principal, info, creds, timeout): client.test.test.find_one() client.close() - assert refresh_called == 1 + assert self.refresh_called == 1 assert len(_oidc_cache) == 1 def test_cache_with_no_refresh(self): # Create a new client with a request callback callback. # Give a callback response with a valid accessToken and an expiresInSeconds that is within one minute. - request_called = 0 - inner_cb = self.create_request_callback(expires_in_seconds=60) - - def request_cb(principal, info, timeout): - nonlocal request_called - request_called += 1 - return inner_cb(principal, info, timeout) + request_cb = self.create_request_cb() props = dict(on_oidc_request_token=request_cb) client = MongoClient(self.uri_single, authMechanismProperties=props) @@ -261,23 +249,30 @@ def request_cb(principal, info, timeout): request_called = 0 client.test.test.find_one() client.close() - assert request_called == 1 + assert self.request_called == 1 assert len(_oidc_cache) == 1 - def test_cache_key_includes_callback(self): - # Cache key includes callback - # Create a new client with the same request callback. # Ensure that a ``find`` operation results in a call to the request callback. client = MongoClient(self.uri_single, authMechanismProperties=props) client.test.test.find_one() client.close() - assert request_called == 2 + assert self.request_called == 2 assert len(_oidc_cache) == 1 + def test_cache_key_includes_callback(self): + request_cb = self.create_request_cb() + + props: Dict = dict(on_oidc_request_token=request_cb) + + # Ensure that a ``find`` operation adds a new entry to the cache. + client = MongoClient(self.uri_single, authMechanismProperties=props) + client.test.test.find_one() + client.close() + # Create a new client with a different request callback. def request_token_2(principal, info, timeout): - return request_token(principal, info, timeout) + return request_cb(principal, info, timeout) props["on_oidc_request_token"] = request_token_2 client = MongoClient(self.uri_single, authMechanismProperties=props) @@ -285,20 +280,17 @@ def request_token_2(principal, info, timeout): # Ensure that a ``find`` operation adds a new entry to the cache. client.test.test.find_one() client.close() - assert request_called == 3 assert len(_oidc_cache) == 2 - # Error clears cache - - # Clear the cache - _oidc_cache.clear() + def test_cache_clears_on_error(self): + request_cb = self.create_request_cb() - # Create a new client with a refresh callback that gives invalid credentials. - def bad_refresh(principal, info, creds, timeout): + # Create a new client with a valid request callback that gives credentials that expire within 5 minutes and a refresh callback that gives invalid credentials. + def refresh_cb(principal, info, creds, timeout): return dict(access_token="bad") # Add a token to the cache that will expire soon. - props["on_oidc_refresh_token"] = bad_refresh + props: Dict = dict(on_oidc_request_token=request_cb, on_oidc_refresh_token=refresh_cb) client = MongoClient(self.uri_single, authMechanismProperties=props) client.test.test.find_one() client.close() @@ -315,20 +307,16 @@ def bad_refresh(principal, info, creds, timeout): # Ensure that the cache has been cleared. assert len(_oidc_cache) == 0 - # AWS Automatic workflow does not use cache - - # Clear the cache. + def test_cache_is_not_used_in_aws_automatic_workflow(self): # Create a new client using the AWS device workflow. # Ensure that a ``find`` operation does not add credentials to the cache. - _oidc_cache.clear() - os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] = os.path.join(self.token_dir, "test_user1") props = dict(PROVIDER_NAME="aws") client = MongoClient(self.uri_single, authmechanismproperties=props) client.test.test.find_one() client.close() assert len(_oidc_cache) == 0 - def test_speculative_auth_succeeds(self): + def test_speculative_auth_success(self): # Clear the cache _oidc_cache.clear() token_file = os.path.join(self.token_dir, "test_user1") @@ -372,7 +360,7 @@ def request_token(principal, info, timeout): # Close the client. client.close() - def test_speculative_auth_fails(self): + def test_speculative_auth_failure(self): # Clear the cache _oidc_cache.clear() token_file = os.path.join(self.token_dir, "test_user1") @@ -401,32 +389,16 @@ def request_token(principal, info, timeout): # Close the client. client.close() - def test_reauthenticate_read(self): - token_file = os.path.join(self.token_dir, "test_user1") - refresh_called = 0 + def test_reauthenticate_succeeds(self): listener = EventListener() - # Succeeds - - # Clear the cache - _oidc_cache.clear() - # Create request and refresh callbacks that return valid credentials # that will not expire soon. - def request_token(principal, info, timeout): - with open(token_file) as fid: - token = fid.read() - return dict(access_token=token, expires_in_seconds=1000) - - def refresh_token(principal, info, creds, timeout): - nonlocal refresh_called - with open(token_file) as fid: - token = fid.read() - refresh_called += 1 - return dict(access_token=token, expires_in_seconds=1000) + request_cb = self.create_request_cb() + refresh_cb = self.create_refresh_cb() # Create a client with the callbacks. - props: Dict = dict(on_oidc_request_token=request_token, on_oidc_refresh_token=refresh_token) + props: Dict = dict(on_oidc_request_token=request_cb, on_oidc_refresh_token=refresh_cb) client = MongoClient( self.uri_single, event_listeners=[listener], authmechanismproperties=props ) @@ -435,7 +407,7 @@ def refresh_token(principal, info, creds, timeout): client.test.test.find_one() # Assert that the refresh callback has not been called. - self.assertEqual(refresh_called, 0) + self.assertEqual(self.refresh_called, 0) listener.reset() @@ -463,16 +435,19 @@ def refresh_token(principal, info, creds, timeout): assert failed_events == ["find", "saslStart"], failed_events # Assert that the refresh callback has been called. - self.assertEqual(refresh_called, 1) + self.assertEqual(self.refresh_called, 1) client.close() - # Retries and Succeeds with Cache - - # Clear the cache. - _oidc_cache.clear() + def test_reauthenticate_retries_and_succees_with_cache(self): + listener = EventListener() # Create request and refresh callbacks that return valid credentials # that will not expire soon. + request_cb = self.create_request_cb() + refresh_cb = self.create_refresh_cb() + + # Create a client with the callbacks. + props: Dict = dict(on_oidc_request_token=request_cb, on_oidc_refresh_token=refresh_cb) client = MongoClient( self.uri_single, event_listeners=[listener], authmechanismproperties=props ) @@ -480,14 +455,6 @@ def refresh_token(principal, info, creds, timeout): # Perform a find operation. client.test.test.find_one() - # Close the client. - client.close() - - # Create a new client with the same callbacks. - client = MongoClient( - self.uri_single, event_listeners=[listener], authmechanismproperties=props - ) - # Set a fail point for ``saslStart`` commands of the form with self.fail_point( { @@ -501,10 +468,19 @@ def refresh_token(principal, info, creds, timeout): # Close the client. client.close() - # Fails with no Cache + def test_reauthenticate_fails_with_no_cache(self): + listener = EventListener() - # Create a new client with the callbacks. - client = MongoClient(self.uri_single, authmechanismproperties=props) + # Create request and refresh callbacks that return valid credentials + # that will not expire soon. + request_cb = self.create_request_cb() + refresh_cb = self.create_refresh_cb() + + # Create a client with the callbacks. + props: Dict = dict(on_oidc_request_token=request_cb, on_oidc_refresh_token=refresh_cb) + client = MongoClient( + self.uri_single, event_listeners=[listener], authmechanismproperties=props + ) # Perform a find operation. client.test.test.find_one() From bad4527ad4c85a8efa99a8446341a16ee0f6f9ee Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Tue, 14 Mar 2023 17:41:04 -0500 Subject: [PATCH 090/133] remove speculative auth error test --- test/auth_aws/test_auth_oidc.py | 29 ----------------------------- 1 file changed, 29 deletions(-) diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py index 915b6d8b73..05345719a1 100644 --- a/test/auth_aws/test_auth_oidc.py +++ b/test/auth_aws/test_auth_oidc.py @@ -360,35 +360,6 @@ def request_token(principal, info, timeout): # Close the client. client.close() - def test_speculative_auth_failure(self): - # Clear the cache - _oidc_cache.clear() - token_file = os.path.join(self.token_dir, "test_user1") - - def request_token(principal, info, timeout): - with open(token_file) as fid: - token = fid.read() - return dict(access_token=token, expires_in_seconds=60) - - # Create a client with a request callback that returns a valid token - # that will expire soon. - props: Dict = dict(on_oidc_request_token=request_token) - client = MongoClient(self.uri_single, authmechanismproperties=props) - - # Set a fail point for saslStart commands. - with self.fail_point( - { - "mode": {"times": 2}, - "data": {"failCommands": ["isMaster", "saslStart"], "errorCode": 18}, - } - ): - # Perform a find operation that fails. - with self.assertRaises(OperationFailure): - client.test.test.find_one() - - # Close the client. - client.close() - def test_reauthenticate_succeeds(self): listener = EventListener() From 7beac7ae096ec0ca1e5b38ef8cb9a2e7178e74df Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Mon, 20 Mar 2023 13:18:20 -0500 Subject: [PATCH 091/133] Refactor to address review --- pymongo/auth.py | 248 +--------------------------- pymongo/auth_oidc.py | 279 ++++++++++++++++++++++++++++++++ test/auth_aws/test_auth_oidc.py | 4 +- test/utils.py | 7 +- 4 files changed, 288 insertions(+), 250 deletions(-) create mode 100644 pymongo/auth_oidc.py diff --git a/pymongo/auth.py b/pymongo/auth.py index ae9205e0ab..ddae161c7b 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -19,18 +19,15 @@ import hmac import os import socket -import threading from base64 import standard_b64decode, standard_b64encode from collections import namedtuple -from dataclasses import dataclass -from datetime import datetime, timedelta, timezone -from typing import Callable, Dict, Mapping, Optional +from typing import Callable, Mapping from urllib.parse import quote -import bson from bson.binary import Binary from bson.son import SON from pymongo.auth_aws import _authenticate_aws +from pymongo.auth_oidc import _authenticate_oidc, _OIDCContextMixin, _OIDCProperties from pymongo.errors import ConfigurationError, OperationFailure from pymongo.saslprep import saslprep @@ -104,22 +101,6 @@ def __hash__(self): """Mechanism properties for MONGODB-AWS authentication.""" -_OIDCProperties = namedtuple( - "_OIDCProperties", - ["on_oidc_request_token", "on_oidc_refresh_token", "provider_name"], -) -"""Mechanism properties for MONGODB-OIDC authentication.""" - - -@dataclass -class _OIDCCache: - token_result: Optional[Dict] - token_exp_utc: Optional[datetime] - cache_exp_utc: datetime - server_resp: Optional[Dict] - lock: threading.Lock - - def _build_credentials_tuple(mech, source, user, passwd, extra, database): """Build and return a mechanism specific credentials tuple.""" if mech not in ("MONGODB-X509", "MONGODB-AWS", "MONGODB-OIDC") and user is None: @@ -497,218 +478,6 @@ def _authenticate_mongo_cr(credentials, sock_info): sock_info.command(source, query) -# MONGODB-OIDC private variables. -_oidc_cache: Dict[str, _OIDCCache] = {} -_OIDC_TOKEN_BUFFER_MINUTES = 5 -_OIDC_CALLBACK_TIMEOUT_SECONDS = 5 * 60 -_OIDC_CACHE_TIMEOUT_MINUTES = 60 * 5 - - -def _oidc_get_cache_key(credentials, address): - # Handle authorization code credentials. - address = address - principal_name = credentials.username - properties = credentials.mechanism_properties - request_cb = properties.on_oidc_request_token - refresh_cb = properties.on_oidc_refresh_token - return f"{principal_name}{address[0]}{address[1]}{id(request_cb)}{id(refresh_cb)}" - - -def _authenticate_oidc(credentials, sock_info, reauthenticate): - """Authenticate using MONGODB-OIDC.""" - ctx = sock_info.auth_ctx - cmd = None - cache_key = _oidc_get_cache_key(credentials, sock_info.address) - in_cache = cache_key in _oidc_cache - - if ctx and ctx.speculate_succeeded(): - resp = ctx.speculative_authenticate - else: - cmd = _authenticate_oidc_start(credentials, sock_info.address) - try: - resp = sock_info.command(credentials.source, cmd) - except Exception: - _oidc_cache.pop(cache_key, None) - # Allow for one retry on reauthenticate when callbacks are in use - # and there was no cache. - if reauthenticate and not credentials.mechanism_properties.provider_name and in_cache: - return _authenticate_oidc(credentials, sock_info, False) - raise - - if resp["done"]: - return - - # Convert the server response to be more pythonic. - # Avoid circular import - from pymongo.common import camel_to_snake - - orig_server_resp: Dict = bson.decode(resp["payload"]) - server_resp = dict() - for key, value in orig_server_resp.items(): - server_resp[camel_to_snake(key)] = value - - if "token_endpoint" in server_resp: - _oidc_cache[cache_key].server_resp = server_resp - - conversation_id = resp["conversationId"] - token = _oidc_get_current_token(credentials, sock_info.address) - bin_payload = Binary(bson.encode(dict(jwt=token))) - cmd = SON( - [ - ("saslContinue", 1), - ("conversationId", conversation_id), - ("payload", bin_payload), - ] - ) - response = sock_info.command("$external", cmd) - if not response["done"]: - _oidc_cache.pop(cache_key, None) - raise OperationFailure("SASL conversation failed to complete.") - - -def _oidc_get_current_token(credentials, address, use_callbacks=True): - properties: _OIDCProperties = credentials.mechanism_properties - cache_key = _oidc_get_cache_key(credentials, address) - cache_value = _oidc_cache[cache_key] - principal_name = credentials.username - - request_cb = properties.on_oidc_request_token - refresh_cb = properties.on_oidc_refresh_token - if not use_callbacks: - request_cb = None - refresh_cb = None - - current_valid_token = False - if cache_value.token_exp_utc is not None: - now_utc = datetime.now(timezone.utc) - exp_utc = cache_value.token_exp_utc - buffer_seconds = _OIDC_TOKEN_BUFFER_MINUTES * 60 - if (exp_utc - now_utc).total_seconds() >= buffer_seconds: - current_valid_token = True - - timeout = _OIDC_CALLBACK_TIMEOUT_SECONDS - - if not use_callbacks and not current_valid_token: - return None - - if not current_valid_token: - with cache_value.lock: - if cache_value.token_result is None or refresh_cb is None: - cache_value.token_result = request_cb( - principal_name, cache_value.server_resp, timeout - ) - elif request_cb is not None: - cache_value.token_result = refresh_cb( - principal_name, cache_value.server_resp, cache_value.token_result, timeout - ) - cache_exp_utc = datetime.now(timezone.utc) + timedelta( - minutes=_OIDC_CACHE_TIMEOUT_MINUTES - ) - cache_value.cache_exp_utc = cache_exp_utc - - token_result = cache_value.token_result - if not isinstance(token_result, dict): - raise ValueError("OIDC callback returned invalid result") - if "access_token" not in token_result: - raise ValueError("OIDC callback did not return an access_token") - - token = token_result["access_token"] - if "expires_in_seconds" in token_result: - expires_in = int(token_result["expires_in_seconds"]) - buffer_seconds = _OIDC_TOKEN_BUFFER_MINUTES * 60 - if expires_in >= buffer_seconds: - now_utc = datetime.now(timezone.utc) - exp_utc = now_utc + timedelta(seconds=expires_in) - cache_value.token_exp_utc = exp_utc - - return token - - -def _invalidate_oidc_token(credentials, address): - cache_key = _oidc_get_cache_key(credentials, address) - cache_value = _oidc_cache.get(cache_key) - if cache_value: - cache_value.token_exp_utc = None - - -def _authenticate_oidc_start(credentials, address, use_callbacks=True): - properties: _OIDCProperties = credentials.mechanism_properties - - # Clear out old items in the cache. - now_utc = datetime.now(timezone.utc) - to_remove = [] - for key, value in _oidc_cache.items(): - if value.cache_exp_utc < now_utc: - to_remove.append(key) - for key in to_remove: - del _oidc_cache[key] - - # Handle aws provider credentials. - if properties.provider_name == "aws": - aws_identity_file = os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] - with open(aws_identity_file) as fid: - token = fid.read().strip() - payload = dict(jwt=token) - cmd = SON( - [ - ("saslStart", 1), - ("mechanism", "MONGODB-OIDC"), - ("payload", Binary(bson.encode(payload))), - ] - ) - return cmd - - cache_key = _oidc_get_cache_key(credentials, address) - cache_value = _oidc_cache.get(cache_key) - principal_name = credentials.username - - if cache_value is not None: - cache_value.cache_exp_utc = datetime.now(timezone.utc) + timedelta( - minutes=_OIDC_CACHE_TIMEOUT_MINUTES - ) - - if cache_value is None: - lock = threading.Lock() - cache_exp_utc = datetime.now(timezone.utc) + timedelta(minutes=_OIDC_CACHE_TIMEOUT_MINUTES) - cache_value = _OIDCCache( - lock=lock, - token_result=None, - server_resp=None, - token_exp_utc=None, - cache_exp_utc=cache_exp_utc, - ) - _oidc_cache[cache_key] = cache_value - - if cache_value.server_resp is None: - # Send the SASL start with the optional principal name. - payload = dict() - - if principal_name: - payload["n"] = principal_name - - cmd = SON( - [ - ("saslStart", 1), - ("mechanism", "MONGODB-OIDC"), - ("payload", Binary(bson.encode(payload))), - ("autoAuthorize", 1), - ] - ) - return cmd - - token = _oidc_get_current_token(credentials, address, use_callbacks) - if not token: - return None - bin_payload = Binary(bson.encode(dict(jwt=token))) - return SON( - [ - ("saslStart", 1), - ("mechanism", "MONGODB-OIDC"), - ("payload", bin_payload), - ] - ) - - def _authenticate_default(credentials, sock_info): if sock_info.max_wire_version >= 7: if sock_info.negotiated_mechs: @@ -787,13 +556,8 @@ def speculate_command(self): return cmd -class _OIDCContext(_AuthContext): - def speculate_command(self): - cmd = _authenticate_oidc_start(self.credentials, self.address, False) - if cmd is None: - return - cmd["db"] = self.credentials.source - return cmd +class _OIDCContext(_OIDCContextMixin, _AuthContext): + pass _SPECULATIVE_AUTH_MAP: Mapping[str, Callable] = { @@ -809,12 +573,10 @@ def authenticate(credentials, sock_info, reauthenticate=False): """Authenticate sock_info.""" mechanism = credentials.mechanism auth_func = _AUTH_MAP[mechanism] - if reauthenticate and mechanism == "MONGODB-OIDC": - _invalidate_oidc_token(credentials, sock_info.address) if reauthenticate and sock_info.performed_handshake: # Existing auth_ctx is stale, remove it. sock_info.auth_ctx = None if mechanism == "MONGODB-OIDC": - auth_func(credentials, sock_info, reauthenticate) + _authenticate_oidc(credentials, sock_info, reauthenticate) else: auth_func(credentials, sock_info) diff --git a/pymongo/auth_oidc.py b/pymongo/auth_oidc.py new file mode 100644 index 0000000000..beb4c290b0 --- /dev/null +++ b/pymongo/auth_oidc.py @@ -0,0 +1,279 @@ +# Copyright 2023-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""""MONGODB-OIDC Authentication helpers.""" +import os +import threading +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +from typing import Any, Callable, Dict, Optional + +import bson +from bson.binary import Binary +from bson.son import SON +from pymongo.errors import OperationFailure + + +@dataclass +class _OIDCProperties: + on_oidc_request_token: Optional[Callable[..., Dict]] + on_oidc_refresh_token: Optional[Callable[..., Dict]] + provider_name: Optional[str] + + +"""Mechanism properties for MONGODB-OIDC authentication.""" + + +@dataclass +class _OIDCCache: + token_result: Optional[Dict] + token_exp_utc: Optional[datetime] + cache_exp_utc: datetime + server_resp: Optional[Dict] + lock: threading.Lock + + +class _OIDCMechanism: + cache: Dict[str, _OIDCCache] = {} + token_buffer_minutes = 5 + callback_timeout_seconds = 5 * 60 + cache_timeout_minutes = 60 * 5 + + def get_cache_key(self, credentials, address): + # Handle authorization code credentials. + address = address + principal_name = credentials.username + properties = credentials.mechanism_properties + request_cb = properties.on_oidc_request_token + refresh_cb = properties.on_oidc_refresh_token + return f"{principal_name}{address[0]}{address[1]}{id(request_cb)}{id(refresh_cb)}" + + def get_current_token(self, credentials, address, use_callbacks=True): + properties: _OIDCProperties = credentials.mechanism_properties + cache_key = self.get_cache_key(credentials, address) + cache_value = self.cache[cache_key] + principal_name = credentials.username + + request_cb = properties.on_oidc_request_token + refresh_cb = properties.on_oidc_refresh_token + if not use_callbacks: + request_cb = None + refresh_cb = None + + current_valid_token = False + if cache_value.token_exp_utc is not None: + now_utc = datetime.now(timezone.utc) + exp_utc = cache_value.token_exp_utc + buffer_seconds = self.token_buffer_minutes * 60 + if (exp_utc - now_utc).total_seconds() >= buffer_seconds: + current_valid_token = True + + timeout = self.callback_timeout_seconds + + if not use_callbacks and not current_valid_token: + return None + + if not current_valid_token and request_cb is not None: + with cache_value.lock: + if cache_value.token_result is None or refresh_cb is None: + cache_value.token_result = request_cb( + principal_name, cache_value.server_resp, timeout + ) + elif request_cb is not None: + cache_value.token_result = refresh_cb( + principal_name, cache_value.server_resp, cache_value.token_result, timeout + ) + cache_exp_utc = datetime.now(timezone.utc) + timedelta( + minutes=self.cache_timeout_minutes + ) + cache_value.cache_exp_utc = cache_exp_utc + + token_result = cache_value.token_result + if not isinstance(token_result, dict): + raise ValueError("OIDC callback returned invalid result") + if "access_token" not in token_result: + raise ValueError("OIDC callback did not return an access_token") + + token = token_result["access_token"] + if "expires_in_seconds" in token_result: + expires_in = int(token_result["expires_in_seconds"]) + buffer_seconds = self.token_buffer_minutes * 60 + if expires_in >= buffer_seconds: + now_utc = datetime.now(timezone.utc) + exp_utc = now_utc + timedelta(seconds=expires_in) + cache_value.token_exp_utc = exp_utc + + return token + + def invalidate_token(self, credentials, address): + cache_key = self.get_cache_key(credentials, address) + cache_value = self.cache.get(cache_key) + if cache_value: + cache_value.token_exp_utc = None + + def auth_start(self, credentials, address, use_callbacks=True): + properties: _OIDCProperties = credentials.mechanism_properties + + # Clear out old items in the cache. + now_utc = datetime.now(timezone.utc) + to_remove = [] + for key, value in self.cache.items(): + if value.cache_exp_utc < now_utc: + to_remove.append(key) + for key in to_remove: + del self.cache[key] + + # Handle aws provider credentials. + if properties.provider_name == "aws": + aws_identity_file = os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] + with open(aws_identity_file) as fid: + token = fid.read().strip() + payload = dict(jwt=token) + cmd = SON( + [ + ("saslStart", 1), + ("mechanism", "MONGODB-OIDC"), + ("payload", Binary(bson.encode(payload))), + ] + ) + return cmd + + cache_key = self.get_cache_key(credentials, address) + cache_value = self.cache.get(cache_key) + principal_name = credentials.username + + if cache_value is not None: + cache_value.cache_exp_utc = datetime.now(timezone.utc) + timedelta( + minutes=self.cache_timeout_minutes + ) + + if cache_value is None: + lock = threading.Lock() + cache_exp_utc = datetime.now(timezone.utc) + timedelta( + minutes=self.cache_timeout_minutes + ) + cache_value = _OIDCCache( + lock=lock, + token_result=None, + server_resp=None, + token_exp_utc=None, + cache_exp_utc=cache_exp_utc, + ) + self.cache[cache_key] = cache_value + + if cache_value.server_resp is None: + # Send the SASL start with the optional principal name. + payload = dict() + + if principal_name: + payload["n"] = principal_name + + cmd = SON( + [ + ("saslStart", 1), + ("mechanism", "MONGODB-OIDC"), + ("payload", Binary(bson.encode(payload))), + ("autoAuthorize", 1), + ] + ) + return cmd + + token = self.get_current_token(credentials, address, use_callbacks) + if not token: + return None + bin_payload = Binary(bson.encode(dict(jwt=token))) + return SON( + [ + ("saslStart", 1), + ("mechanism", "MONGODB-OIDC"), + ("payload", bin_payload), + ] + ) + + def authenticate(self, credentials, sock_info, reauthenticate): + if reauthenticate: + self.invalidate_token(credentials, sock_info.address) + cache = self.cache + ctx = sock_info.auth_ctx + cmd = None + cache_key = self.get_cache_key(credentials, sock_info.address) + in_cache = cache_key in cache + + if ctx and ctx.speculate_succeeded(): + resp = ctx.speculative_authenticate + else: + cmd = self.auth_start(credentials, sock_info.address) + try: + resp = sock_info.command(credentials.source, cmd) + except Exception: + cache.pop(cache_key, None) + # Allow for one retry on reauthenticate when callbacks are in use + # and there was no cache. + if ( + reauthenticate + and not credentials.mechanism_properties.provider_name + and in_cache + ): + return _authenticate_oidc(credentials, sock_info, False) + raise + + if resp["done"]: + return + + # Convert the server response to be more pythonic. + # Avoid circular import + from pymongo.common import camel_to_snake + + orig_server_resp: Dict = bson.decode(resp["payload"]) + server_resp = dict() + for key, value in orig_server_resp.items(): + server_resp[camel_to_snake(key)] = value + + if "token_endpoint" in server_resp: + cache[cache_key].server_resp = server_resp + + conversation_id = resp["conversationId"] + token = self.get_current_token(credentials, sock_info.address) + bin_payload = Binary(bson.encode(dict(jwt=token))) + cmd = SON( + [ + ("saslContinue", 1), + ("conversationId", conversation_id), + ("payload", bin_payload), + ] + ) + response = sock_info.command("$external", cmd) + if not response["done"]: + cache.pop(cache_key, None) + raise OperationFailure("SASL conversation failed to complete.") + + +_internal = _OIDCMechanism() + + +class _OIDCContextMixin: + credentials: Any + address: Any + + def speculate_command(self): + cmd = _internal.auth_start(self.credentials, self.address, False) + if cmd is None: + return + cmd["db"] = self.credentials.source # type:ignore + return cmd + + +def _authenticate_oidc(credentials, sock_info, reauthenticate): + """Authenticate using MONGODB-OIDC.""" + return _internal.authenticate(credentials, sock_info, reauthenticate) diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py index 05345719a1..4841e724cf 100644 --- a/test/auth_aws/test_auth_oidc.py +++ b/test/auth_aws/test_auth_oidc.py @@ -26,7 +26,9 @@ from bson import SON from pymongo import MongoClient -from pymongo.auth import OperationFailure, _oidc_cache +from pymongo.auth_oidc import OperationFailure, _internal + +_oidc_cache = _internal.cache class TestAuthOIDC(unittest.TestCase): diff --git a/test/utils.py b/test/utils.py index 22588d0db5..88130e1d6d 100644 --- a/test/utils.py +++ b/test/utils.py @@ -36,6 +36,7 @@ from bson.son import SON from pymongo import MongoClient, monitoring, operations, read_preferences from pymongo.collection import ReturnDocument +from pymongo.common import camel_to_snake # noqa from pymongo.cursor import CursorType from pymongo.errors import ConfigurationError, OperationFailure from pymongo.hello import HelloCompat @@ -667,12 +668,6 @@ def get_command_line(client): return command_line -def camel_to_snake(camel: str) -> str: - # Regex to convert CamelCase to snake_case. - snake = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel) - return re.sub("([a-z0-9])([A-Z])", r"\1_\2", snake).lower() - - def camel_to_upper_camel(camel): return camel[0].upper() + camel[1:] From 8c8088aeb018c09e1e077284c5cb95a87ff989c7 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Mon, 20 Mar 2023 13:56:57 -0500 Subject: [PATCH 092/133] update for prose test clarifications --- pymongo/auth_oidc.py | 4 ++ test/auth_aws/test_auth_oidc.py | 98 +++++++++++++++++++++++++++++---- 2 files changed, 91 insertions(+), 11 deletions(-) diff --git a/pymongo/auth_oidc.py b/pymongo/auth_oidc.py index beb4c290b0..45700a5ec0 100644 --- a/pymongo/auth_oidc.py +++ b/pymongo/auth_oidc.py @@ -105,6 +105,10 @@ def get_current_token(self, credentials, address, use_callbacks=True): if "access_token" not in token_result: raise ValueError("OIDC callback did not return an access_token") + for key in token_result: + if key not in ["access_token", "expires_in_seconds", "refresh_token"]: + raise ValueError(f"OIDC callback returned invalid result key '{key}'") + token = token_result["access_token"] if "expires_in_seconds" in token_result: expires_in = int(token_result["expires_in_seconds"]) diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py index 4841e724cf..01991da283 100644 --- a/test/auth_aws/test_auth_oidc.py +++ b/test/auth_aws/test_auth_oidc.py @@ -52,11 +52,28 @@ def create_request_cb(self, username="test_user1", expires_in_seconds=None): token_file = os.path.join(self.token_dir, username) def request_token(principal, info, timeout): + # Validate the principal. + if principal is not None: + self.assertIsInstance(principal, str) + # Validate the info. - assert "authorization_endpoint" in info or "device_authorization_endpoint" in info - assert "token_endpoint" in info - assert "client_id" in info - assert timeout == 60 * 5 + if "device_authorization_endpoint" not in info: + self.assertIn("authorization_endpoint", info) + self.assertIn("token_endpoint", info) + self.assertIn("client_id", info) + for key in info: + self.assertIn( + key, + [ + "authorization_endpoint", + "token_endpoint", + "client_id", + "device_authorization_endpoint", + ], + ) + + # Validate the timeout. + self.assertEqual(timeout, 60 * 5) with open(token_file) as fid: token = fid.read() resp = dict(access_token=token) @@ -74,12 +91,34 @@ def create_refresh_cb(self, username="test_user1", expires_in_seconds=None): def refresh_token(principal, info, creds, timeout): with open(token_file) as fid: token = fid.read() - # Validate the info. - assert "authorization_endpoint" in info - assert "token_endpoint" in info - assert "client_id" in info - assert timeout == 60 * 5 - assert "access_token" in creds + + # Validate the principal. + if principal is not None: + self.assertIsInstance(principal, str) + + # Validate the info. + if "device_authorization_endpoint" not in info: + self.assertIn("authorization_endpoint", info) + self.assertIn("token_endpoint", info) + self.assertIn("client_id", info) + for key in info: + self.assertIn( + key, + [ + "authorization_endpoint", + "token_endpoint", + "client_id", + "device_authorization_endpoint", + ], + ) + + # Validate the creds + self.assertIn("access_token", creds) + for key in creds: + self.assertIn(key, ["access_token", "expires_in_seconds", "refresh_token"]) + # Validate the timeout. + self.assertEqual(timeout, 60 * 5) + resp = dict(access_token=token) if expires_in_seconds is not None: resp["expires_in_seconds"] = expires_in_seconds @@ -158,6 +197,19 @@ def test_connect_aws_multiple_principal_user2(self): client.test.test.find_one() client.close() + def test_valid_callbacks(self): + request_cb = self.create_request_cb(expires_in_seconds=60) + refresh_cb = self.create_refresh_cb() + + props: Dict = dict(on_oidc_request_token=request_cb, on_oidc_refresh_token=refresh_cb) + client = MongoClient(self.uri_single, authmechanismproperties=props) + client.test.test.find_one() + client.close() + + client = MongoClient(self.uri_single, authmechanismproperties=props) + client.test.test.find_one() + client.close() + def test_request_callback_returns_null(self): def request_token_null(principal, info, timeout): return None @@ -188,7 +240,7 @@ def refresh_token_null(principal, info, creds, timeout): def test_request_callback_invalid_result(self): def request_token_invalid(principal, info, timeout): - return None + return dict() props: Dict = dict(on_oidc_request_token=request_token_invalid) client = MongoClient(self.uri_single, authMechanismProperties=props) @@ -196,6 +248,15 @@ def request_token_invalid(principal, info, timeout): client.test.test.find_one() client.close() + def request_token_invalid2(principal, info, timeout): + return dict(access_token="foo", other="bar") + + props: Dict = dict(on_oidc_request_token=request_token_invalid2) + client = MongoClient(self.uri_single, authMechanismProperties=props) + with self.assertRaises(ValueError): + client.test.test.find_one() + client.close() + def test_refresh_callback_invalid_result(self): request_cb = self.create_request_cb(expires_in_seconds=60) @@ -214,6 +275,21 @@ def refresh_cb_no_token(principal, info, cred, timeout): client.test.test.find_one() client.close() + def refresh_cb_extra_value(principal, info, cred, timeout): + return dict(access_token="foo", other="bar") + + props: Dict = dict( + on_oidc_request_token=request_cb, on_oidc_refresh_token=refresh_cb_extra_value + ) + client = MongoClient(self.uri_single, authMechanismProperties=props) + client.test.test.find_one() + client.close() + + client = MongoClient(self.uri_single, authMechanismProperties=props) + with self.assertRaises(ValueError): + client.test.test.find_one() + client.close() + def test_cache_with_refresh(self): # Create a new client with a request callback and a refresh callback. Both callbacks will read the contents of the ``AWS_WEB_IDENTITY_TOKEN_FILE`` location to obtain a valid access token. From 6c45619853be63e0f3ef2901fc52f0a48440dee1 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Mon, 20 Mar 2023 14:10:54 -0500 Subject: [PATCH 093/133] lint --- pymongo/auth_oidc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymongo/auth_oidc.py b/pymongo/auth_oidc.py index 45700a5ec0..a559d7a64a 100644 --- a/pymongo/auth_oidc.py +++ b/pymongo/auth_oidc.py @@ -274,7 +274,7 @@ def speculate_command(self): cmd = _internal.auth_start(self.credentials, self.address, False) if cmd is None: return - cmd["db"] = self.credentials.source # type:ignore + cmd["db"] = self.credentials.source return cmd From 500c98fd2a951f057f34be77f456b51c246c2054 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Mon, 20 Mar 2023 14:19:27 -0500 Subject: [PATCH 094/133] allow for extra keys --- pymongo/auth_oidc.py | 4 --- test/auth_aws/test_auth_oidc.py | 62 ++++++++------------------------- 2 files changed, 15 insertions(+), 51 deletions(-) diff --git a/pymongo/auth_oidc.py b/pymongo/auth_oidc.py index a559d7a64a..ff235b4e00 100644 --- a/pymongo/auth_oidc.py +++ b/pymongo/auth_oidc.py @@ -105,10 +105,6 @@ def get_current_token(self, credentials, address, use_callbacks=True): if "access_token" not in token_result: raise ValueError("OIDC callback did not return an access_token") - for key in token_result: - if key not in ["access_token", "expires_in_seconds", "refresh_token"]: - raise ValueError(f"OIDC callback returned invalid result key '{key}'") - token = token_result["access_token"] if "expires_in_seconds" in token_result: expires_in = int(token_result["expires_in_seconds"]) diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py index 01991da283..1a5be1fd4c 100644 --- a/test/auth_aws/test_auth_oidc.py +++ b/test/auth_aws/test_auth_oidc.py @@ -61,16 +61,6 @@ def request_token(principal, info, timeout): self.assertIn("authorization_endpoint", info) self.assertIn("token_endpoint", info) self.assertIn("client_id", info) - for key in info: - self.assertIn( - key, - [ - "authorization_endpoint", - "token_endpoint", - "client_id", - "device_authorization_endpoint", - ], - ) # Validate the timeout. self.assertEqual(timeout, 60 * 5) @@ -101,21 +91,10 @@ def refresh_token(principal, info, creds, timeout): self.assertIn("authorization_endpoint", info) self.assertIn("token_endpoint", info) self.assertIn("client_id", info) - for key in info: - self.assertIn( - key, - [ - "authorization_endpoint", - "token_endpoint", - "client_id", - "device_authorization_endpoint", - ], - ) # Validate the creds self.assertIn("access_token", creds) - for key in creds: - self.assertIn(key, ["access_token", "expires_in_seconds", "refresh_token"]) + # Validate the timeout. self.assertEqual(timeout, 60 * 5) @@ -201,7 +180,20 @@ def test_valid_callbacks(self): request_cb = self.create_request_cb(expires_in_seconds=60) refresh_cb = self.create_refresh_cb() - props: Dict = dict(on_oidc_request_token=request_cb, on_oidc_refresh_token=refresh_cb) + def request_cb_extra_value(principal, info, timeout): + result = request_cb(principal, info, timeout) + result["foo"] = "bar" + return result + + def refresh_cb_extra_value(principal, info, cred, timeout): + result = refresh_cb(principal, info, cred, timeout) + result["foo"] = "bar" + return result + + props: Dict = dict( + on_oidc_request_token=request_cb_extra_value, + on_oidc_refresh_token=refresh_cb_extra_value, + ) client = MongoClient(self.uri_single, authmechanismproperties=props) client.test.test.find_one() client.close() @@ -248,15 +240,6 @@ def request_token_invalid(principal, info, timeout): client.test.test.find_one() client.close() - def request_token_invalid2(principal, info, timeout): - return dict(access_token="foo", other="bar") - - props: Dict = dict(on_oidc_request_token=request_token_invalid2) - client = MongoClient(self.uri_single, authMechanismProperties=props) - with self.assertRaises(ValueError): - client.test.test.find_one() - client.close() - def test_refresh_callback_invalid_result(self): request_cb = self.create_request_cb(expires_in_seconds=60) @@ -275,21 +258,6 @@ def refresh_cb_no_token(principal, info, cred, timeout): client.test.test.find_one() client.close() - def refresh_cb_extra_value(principal, info, cred, timeout): - return dict(access_token="foo", other="bar") - - props: Dict = dict( - on_oidc_request_token=request_cb, on_oidc_refresh_token=refresh_cb_extra_value - ) - client = MongoClient(self.uri_single, authMechanismProperties=props) - client.test.test.find_one() - client.close() - - client = MongoClient(self.uri_single, authMechanismProperties=props) - with self.assertRaises(ValueError): - client.test.test.find_one() - client.close() - def test_cache_with_refresh(self): # Create a new client with a request callback and a refresh callback. Both callbacks will read the contents of the ``AWS_WEB_IDENTITY_TOKEN_FILE`` location to obtain a valid access token. From 11261b9963c9868fbb874ba4521590382c806a4c Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 23 Mar 2023 16:19:51 -0500 Subject: [PATCH 095/133] validate callback results --- .evergreen/config.yml | 3 +-- pymongo/auth_oidc.py | 8 ++++++ test/auth_aws/test_auth_oidc.py | 47 ++++++++++++++++++++++++--------- 3 files changed, 43 insertions(+), 15 deletions(-) diff --git a/.evergreen/config.yml b/.evergreen/config.yml index 78fe17cc4e..0a2c38b612 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -294,7 +294,6 @@ functions: REQUIRE_API_VERSION=${REQUIRE_API_VERSION} \ LOAD_BALANCER=${LOAD_BALANCER} \ bash ${DRIVERS_TOOLS}/.evergreen/run-orchestration.sh - # run-orchestration generates expansion file with the MONGODB_URI for the cluster - command: expansions.update params: @@ -571,6 +570,7 @@ functions: "iam_auth_assume_aws_account" : "${iam_auth_assume_aws_account}", "iam_auth_assume_aws_secret_access_key" : "${iam_auth_assume_aws_secret_access_key}", "iam_auth_assume_role_name" : "${iam_auth_assume_role_name}", + "iam_auth_ec2_instance_account" : "${iam_auth_ec2_instance_account}", "iam_auth_ec2_instance_secret_access_key" : "${iam_auth_ec2_instance_secret_access_key}", "iam_auth_ec2_instance_profile" : "${iam_auth_ec2_instance_profile}", @@ -2018,7 +2018,6 @@ tasks: - func: "run aws auth test with aws EC2 credentials" - func: "run aws auth test with aws web identity credentials" - func: "run aws ECS auth test" - - name: "aws-auth-test-rapid" commands: - func: "bootstrap mongo-orchestration" diff --git a/pymongo/auth_oidc.py b/pymongo/auth_oidc.py index ff235b4e00..f4e35dafb3 100644 --- a/pymongo/auth_oidc.py +++ b/pymongo/auth_oidc.py @@ -100,11 +100,19 @@ def get_current_token(self, credentials, address, use_callbacks=True): cache_value.cache_exp_utc = cache_exp_utc token_result = cache_value.token_result + + # Validate callback return value. if not isinstance(token_result, dict): raise ValueError("OIDC callback returned invalid result") + if "access_token" not in token_result: raise ValueError("OIDC callback did not return an access_token") + expected = ["access_token", "expires_in_seconds", "refesh_token"] + for key in token_result: + if key not in expected: + raise ValueError(f'Unexpected field in callback result "{key}"') + token = token_result["access_token"] if "expires_in_seconds" in token_result: expires_in = int(token_result["expires_in_seconds"]) diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py index 1a5be1fd4c..9af276bee8 100644 --- a/test/auth_aws/test_auth_oidc.py +++ b/test/auth_aws/test_auth_oidc.py @@ -180,19 +180,9 @@ def test_valid_callbacks(self): request_cb = self.create_request_cb(expires_in_seconds=60) refresh_cb = self.create_refresh_cb() - def request_cb_extra_value(principal, info, timeout): - result = request_cb(principal, info, timeout) - result["foo"] = "bar" - return result - - def refresh_cb_extra_value(principal, info, cred, timeout): - result = refresh_cb(principal, info, cred, timeout) - result["foo"] = "bar" - return result - props: Dict = dict( - on_oidc_request_token=request_cb_extra_value, - on_oidc_refresh_token=refresh_cb_extra_value, + on_oidc_request_token=request_cb, + on_oidc_refresh_token=refresh_cb, ) client = MongoClient(self.uri_single, authmechanismproperties=props) client.test.test.find_one() @@ -240,7 +230,18 @@ def request_token_invalid(principal, info, timeout): client.test.test.find_one() client.close() - def test_refresh_callback_invalid_result(self): + def request_cb_extra_value(principal, info, timeout): + result = self.create_request_cb()(principal, info, timeout) + result["foo"] = "bar" + return result + + props: Dict = dict(on_oidc_request_token=request_cb_extra_value) + client = MongoClient(self.uri_single, authMechanismProperties=props) + with self.assertRaises(ValueError): + client.test.test.find_one() + client.close() + + def test_refresh_callback_missing_data(self): request_cb = self.create_request_cb(expires_in_seconds=60) def refresh_cb_no_token(principal, info, cred, timeout): @@ -258,6 +259,26 @@ def refresh_cb_no_token(principal, info, cred, timeout): client.test.test.find_one() client.close() + def test_refresh_callback_extra_data(self): + request_cb = self.create_request_cb(expires_in_seconds=60) + + def refresh_cb_extra_value(principal, info, cred, timeout): + result = self.create_refresh_cb()(principal, info, cred, timeout) + result["foo"] = "bar" + return result + + props: Dict = dict( + on_oidc_request_token=request_cb, on_oidc_refresh_token=refresh_cb_extra_value + ) + client = MongoClient(self.uri_single, authMechanismProperties=props) + client.test.test.find_one() + client.close() + + client = MongoClient(self.uri_single, authMechanismProperties=props) + with self.assertRaises(ValueError): + client.test.test.find_one() + client.close() + def test_cache_with_refresh(self): # Create a new client with a request callback and a refresh callback. Both callbacks will read the contents of the ``AWS_WEB_IDENTITY_TOKEN_FILE`` location to obtain a valid access token. From f4ebaf6d33d1ff94325dcbff42a5fb2d2b956cc4 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 30 Mar 2023 14:54:21 -0500 Subject: [PATCH 096/133] updates for security mitigations --- pymongo/auth.py | 1 + pymongo/auth_oidc.py | 19 ++++++++++++++++--- pymongo/common.py | 1 + pymongo/mongo_client.py | 12 ++++++++++++ test/auth_aws/test_auth_oidc.py | 21 ++++++++++++++------- 5 files changed, 44 insertions(+), 10 deletions(-) diff --git a/pymongo/auth.py b/pymongo/auth.py index ddae161c7b..f5acca8e46 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -154,6 +154,7 @@ def _build_credentials_tuple(mech, source, user, passwd, extra, database): on_oidc_request_token=on_oidc_request_token, on_oidc_refresh_token=on_oidc_refresh_token, provider_name=provider_name, + allowed_hosts=extra["authOIDCAllowedHosts"], ) return MongoCredential(mech, "$external", user, passwd, oidc_props, None) diff --git a/pymongo/auth_oidc.py b/pymongo/auth_oidc.py index f4e35dafb3..e9946e116d 100644 --- a/pymongo/auth_oidc.py +++ b/pymongo/auth_oidc.py @@ -13,16 +13,17 @@ # limitations under the License. """"MONGODB-OIDC Authentication helpers.""" +import fnmatch import os import threading from dataclasses import dataclass from datetime import datetime, timedelta, timezone -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, List, Optional import bson from bson.binary import Binary from bson.son import SON -from pymongo.errors import OperationFailure +from pymongo.errors import OperationFailure, PyMongoError @dataclass @@ -30,6 +31,7 @@ class _OIDCProperties: on_oidc_request_token: Optional[Callable[..., Dict]] on_oidc_refresh_token: Optional[Callable[..., Dict]] provider_name: Optional[str] + allowed_hosts: List[str] """Mechanism properties for MONGODB-OIDC authentication.""" @@ -84,6 +86,17 @@ def get_current_token(self, credentials, address, use_callbacks=True): if not use_callbacks and not current_valid_token: return None + # Ensure that the desired address is allowed. + found = False + allowed_hosts = credentials.mechanism_properties.allowed_hosts + for patt in allowed_hosts: + if fnmatch.fnmatch(address[0], patt): + found = True + if not found: + raise PyMongoError( + f"Refusing to connect to {address[0]}, which is not in authOIDCAllowedHosts: {allowed_hosts}" + ) + if not current_valid_token and request_cb is not None: with cache_value.lock: if cache_value.token_result is None or refresh_cb is None: @@ -248,7 +261,7 @@ def authenticate(self, credentials, sock_info, reauthenticate): for key, value in orig_server_resp.items(): server_resp[camel_to_snake(key)] = value - if "token_endpoint" in server_resp: + if "issuer" in server_resp: cache[cache_key].server_resp = server_resp conversation_id = resp["conversationId"] diff --git a/pymongo/common.py b/pymongo/common.py index cdb2d66285..3d1b4d34de 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -755,6 +755,7 @@ def validate_datetime_conversion(option: Any, value: Any) -> Optional[DatetimeCo "password": validate_string_or_none, "server_selector": validate_is_callable_or_none, "auto_encryption_opts": validate_auto_encryption_opts_or_none, + "authoidcallowedhosts": validate_list, } # Dictionary where keys are any URI option name, and values are the diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index 46876e863e..d110f434f6 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -477,6 +477,12 @@ def __init__( name>'. To specify the session token for MONGODB-AWS authentication pass ``authMechanismProperties='AWS_SESSION_TOKEN:'``. + - `authOIDCAllowedHosts``: A list of patterns that are used to match + allowed server hosts for OIDC callbacks. By default a standard + set of Atlas hosts and ``localhost`` are allowed. If an attempt + is made to connect to a host that is not allowed, a + ``PyMongoError`` will be raised and the OIDC callback(s) will not + be called. .. seealso:: :doc:`/examples/authentication` @@ -719,6 +725,12 @@ def __init__( keyword_opts = common._CaseInsensitiveDictionary(kwargs) keyword_opts["document_class"] = doc_class + # Use a default secure value for authOIDCAllowedHosts. + keyword_opts.setdefault( + "authOIDCAllowedHosts", + ["*.mongodb.net", "*.mongodb-dev.net", "*.mongodbgov.net", "localhost"], + ) + seeds = set() username = None password = None diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py index 9af276bee8..3ff2ac13e4 100644 --- a/test/auth_aws/test_auth_oidc.py +++ b/test/auth_aws/test_auth_oidc.py @@ -26,7 +26,8 @@ from bson import SON from pymongo import MongoClient -from pymongo.auth_oidc import OperationFailure, _internal +from pymongo.auth_oidc import _internal +from pymongo.errors import OperationFailure, PyMongoError _oidc_cache = _internal.cache @@ -57,9 +58,7 @@ def request_token(principal, info, timeout): self.assertIsInstance(principal, str) # Validate the info. - if "device_authorization_endpoint" not in info: - self.assertIn("authorization_endpoint", info) - self.assertIn("token_endpoint", info) + self.assertIn("issuer", info) self.assertIn("client_id", info) # Validate the timeout. @@ -87,9 +86,7 @@ def refresh_token(principal, info, creds, timeout): self.assertIsInstance(principal, str) # Validate the info. - if "device_authorization_endpoint" not in info: - self.assertIn("authorization_endpoint", info) - self.assertIn("token_endpoint", info) + self.assertIn("issuer", info) self.assertIn("client_id", info) # Validate the creds @@ -157,6 +154,16 @@ def test_connect_callbacks_multiple_no_username(self): client.test.test.find_one() client.close() + def test_oidc_allowed_hosts_blocked(self): + request_token = self.create_request_cb() + props: Dict = dict(on_oidc_request_token=request_token) + client = MongoClient( + self.uri_single, authOIDCAllowedHosts=[], authmechanismproperties=props + ) + with self.assertRaises(PyMongoError): + client.test.test.find_one() + client.close() + def test_connect_aws_single_principal(self): props = dict(PROVIDER_NAME="aws") client = MongoClient(self.uri_single, authmechanismproperties=props) From f5848858509b3e613b03e54d3cc7f957e9379fad Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Fri, 31 Mar 2023 10:37:29 -0500 Subject: [PATCH 097/133] make allowed_hosts a mechanism property --- doc/examples/authentication.rst | 47 ++++++++++++++++--------------- pymongo/auth.py | 16 ++++++----- pymongo/auth_oidc.py | 12 ++++---- pymongo/common.py | 10 ++++--- pymongo/mongo_client.py | 12 -------- test/auth_aws/test_auth_oidc.py | 50 ++++++++++++++++----------------- 6 files changed, 70 insertions(+), 77 deletions(-) diff --git a/doc/examples/authentication.rst b/doc/examples/authentication.rst index 8aaa55f0c5..0f0b65a35b 100644 --- a/doc/examples/authentication.rst +++ b/doc/examples/authentication.rst @@ -417,7 +417,7 @@ Callback-based OIDC Support ~~~~~~~~~~~~~~~~~~~~~~~~~~~ PyMongo supports user-provided callbacks for OIDC, which are are given to the -``MongoClient``. The ``on_oidc_request_callback`` is intended to accept +``MongoClient``. The ``request_token_callback`` is intended to accept information about the Identity Provider, and return credentials that are used to authenticate with the server, usually through a browser interaction with the user. The callback must be of the form:: @@ -428,24 +428,12 @@ the user. The callback must be of the form:: Where ``ProviderInfo`` is a dictionary of the following form:: - authorization_endpoint: + issuer: description: >- - URL where the IDP may be contacted for end user - authentication and authorization code generation. + URL which describes the Authorization Server. This identifier should be + the iss of provided access tokens, and be viable for RFC8414 + metadata discovery and RFC9207 identification. type: string - optional: true # Req if device_authorization_endpoint not present - token_endpoint: - description: >- - URL where the IDP may be contacted for authorization - code <=> ID/access token exchange. - type: string - optional: true # Req if device_authorization_endpoint not present - device_authorization_endpoint: - description: >- - URL where the IDP may be contacted for device - authentication and authorization code generation. - type: string - optional: true # Req if authorization_endpoint not present client_id: description: "Unique client ID for this OIDC client" type: string @@ -477,7 +465,7 @@ client would be:: >>> from pymongo import MongoClient >>> uri = "mongodb://localhost/?authMechanism=MONGODB-OIDC") - >>> client = MongoClient(uri, on_oidc_request_callback=my_callback) + >>> client = MongoClient(uri, request_token_callback=my_callback) If the identity provider supports refresh, a refresh callback can also be provided. If a refresh callback is provided, it will be called @@ -497,8 +485,8 @@ if it was provided. An example using both callbacks would be:: >>> from pymongo import MongoClient >>> uri = "mongodb://localhost/?authMechanism=MONGODB-OIDC") - >>> client = MongoClient(uri, on_oidc_request_callback=my_request_callback, - ... on_oidc_refresh_callback=my_refresh_callback) + >>> client = MongoClient(uri, request_token_callback=my_request_callback, + ... refresh_token_callback=my_refresh_callback) Note: when multiple identity providers are configured on the server, a ``username`` must be provided, which is the @@ -506,5 +494,20 @@ Principal Name used on the provider. For example:: >>> from pymongo import MongoClient >>> uri = "mongodb://my_username@localhost/?authMechanism=MONGODB-OIDC") - >>> client = MongoClient(uri, on_oidc_request_callback=my_request_callback, - ... on_oidc_refresh_callback=my_refresh_callback) + >>> client = MongoClient(uri, authMechanismProperties=request_token_callback=my_request_callback, + ... refresh_token_callback=my_refresh_callback) + +Allowed Hosts +~~~~~~~~~~~~~ + +By default, only Atlas and localhost MongoDB server hostnames are allowed +to use MONGODB-OIDC. If a different set of host(s) are desired, use +the ``allowed_hosts`` option: + + >>> from pymongo import MongoClient + >>> uri = "mongodb://my_username@localhost/?authMechanism=MONGODB-OIDC") + >>> my_host = "example.com" + >>> client = MongoClient(uri, authMechanismProperties=request_token_callback=my_request_callback, + ... refresh_token_callback=my_refresh_callback, allowed_hosts=[my_host]) + +The hostnames support Unix shell-style wildcards. diff --git a/pymongo/auth.py b/pymongo/auth.py index f5acca8e46..c3edf7679a 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -143,18 +143,20 @@ def _build_credentials_tuple(mech, source, user, passwd, extra, database): if source is not None and source != "$external": raise ValueError("authentication source must be $external or None for MONGODB-OIDC") properties = extra.get("authmechanismproperties", {}) - on_oidc_request_token = properties.get("on_oidc_request_token") - on_oidc_refresh_token = properties.get("on_oidc_refresh_token", None) + request_token_callback = properties.get("request_token_callback") + refresh_token_callback = properties.get("refresh_token_callback", None) provider_name = properties.get("PROVIDER_NAME", "") - if not on_oidc_request_token and provider_name != "aws": + default_allowed = ["*.mongodb.net", "*.mongodb-dev.net", "*.mongodbgov.net", "localhost"] + allowed_hosts = properties.get("allowed_hosts", default_allowed) + if not request_token_callback and provider_name != "aws": raise ConfigurationError( - "authentication with MONGODB-OIDC requires providing an on_oidc_request_token or a provider_name of 'aws'" + "authentication with MONGODB-OIDC requires providing an request_token_callback or a provider_name of 'aws'" ) oidc_props = _OIDCProperties( - on_oidc_request_token=on_oidc_request_token, - on_oidc_refresh_token=on_oidc_refresh_token, + request_token_callback=request_token_callback, + refresh_token_callback=refresh_token_callback, provider_name=provider_name, - allowed_hosts=extra["authOIDCAllowedHosts"], + allowed_hosts=allowed_hosts, ) return MongoCredential(mech, "$external", user, passwd, oidc_props, None) diff --git a/pymongo/auth_oidc.py b/pymongo/auth_oidc.py index e9946e116d..5eca6f6421 100644 --- a/pymongo/auth_oidc.py +++ b/pymongo/auth_oidc.py @@ -28,8 +28,8 @@ @dataclass class _OIDCProperties: - on_oidc_request_token: Optional[Callable[..., Dict]] - on_oidc_refresh_token: Optional[Callable[..., Dict]] + request_token_callback: Optional[Callable[..., Dict]] + refresh_token_callback: Optional[Callable[..., Dict]] provider_name: Optional[str] allowed_hosts: List[str] @@ -57,8 +57,8 @@ def get_cache_key(self, credentials, address): address = address principal_name = credentials.username properties = credentials.mechanism_properties - request_cb = properties.on_oidc_request_token - refresh_cb = properties.on_oidc_refresh_token + request_cb = properties.request_token_callback + refresh_cb = properties.refresh_token_callback return f"{principal_name}{address[0]}{address[1]}{id(request_cb)}{id(refresh_cb)}" def get_current_token(self, credentials, address, use_callbacks=True): @@ -67,8 +67,8 @@ def get_current_token(self, credentials, address, use_callbacks=True): cache_value = self.cache[cache_key] principal_name = credentials.username - request_cb = properties.on_oidc_request_token - refresh_cb = properties.on_oidc_refresh_token + request_cb = properties.request_token_callback + refresh_cb = properties.refresh_token_callback if not use_callbacks: request_cb = None refresh_cb = None diff --git a/pymongo/common.py b/pymongo/common.py index 3d1b4d34de..7c1d70d6c1 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -445,21 +445,23 @@ def validate_auth_mechanism_properties(option: str, value: Any) -> Dict[str, Uni props[key] = value elif isinstance(value, bool): props[key] = str(value).lower() + elif key in ["allowed_hosts"] and isinstance(value, list): + props[key] = value elif inspect.isfunction(value): signature = inspect.signature(value) - if key == "on_oidc_request_token": + if key == "request_token_callback": expected_params = 3 - elif key == "on_oidc_refresh_token": + elif key == "refresh_token_callback": expected_params = 4 else: - raise ValueError(f"Unrecognized Auth mechanisim function {key}") + raise ValueError(f"Unrecognized Auth mechanism function {key}") if len(signature.parameters) != expected_params: msg = f"{key} must accept {expected_params} parameters" raise ValueError(msg) props[key] = value else: raise ValueError( - "Auth mechanisim properity values must be strings or callback functions" + "Auth mechanism properity values must be strings or callback functions" ) return props diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index d110f434f6..46876e863e 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -477,12 +477,6 @@ def __init__( name>'. To specify the session token for MONGODB-AWS authentication pass ``authMechanismProperties='AWS_SESSION_TOKEN:'``. - - `authOIDCAllowedHosts``: A list of patterns that are used to match - allowed server hosts for OIDC callbacks. By default a standard - set of Atlas hosts and ``localhost`` are allowed. If an attempt - is made to connect to a host that is not allowed, a - ``PyMongoError`` will be raised and the OIDC callback(s) will not - be called. .. seealso:: :doc:`/examples/authentication` @@ -725,12 +719,6 @@ def __init__( keyword_opts = common._CaseInsensitiveDictionary(kwargs) keyword_opts["document_class"] = doc_class - # Use a default secure value for authOIDCAllowedHosts. - keyword_opts.setdefault( - "authOIDCAllowedHosts", - ["*.mongodb.net", "*.mongodb-dev.net", "*.mongodbgov.net", "localhost"], - ) - seeds = set() username = None password = None diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py index 3ff2ac13e4..6cc584983d 100644 --- a/test/auth_aws/test_auth_oidc.py +++ b/test/auth_aws/test_auth_oidc.py @@ -116,21 +116,21 @@ def fail_point(self, command_args): def test_connect_callbacks_single_implicit_username(self): request_token = self.create_request_cb() - props: Dict = dict(on_oidc_request_token=request_token) + props: Dict = dict(request_token_callback=request_token) client = MongoClient(self.uri_single, authmechanismproperties=props) client.test.test.find_one() client.close() def test_connect_callbacks_single_explicit_username(self): request_token = self.create_request_cb() - props: Dict = dict(on_oidc_request_token=request_token) + props: Dict = dict(request_token_callback=request_token) client = MongoClient(self.uri_single, username="test_user1", authmechanismproperties=props) client.test.test.find_one() client.close() def test_connect_callbacks_multiple_principal_user1(self): request_token = self.create_request_cb() - props: Dict = dict(on_oidc_request_token=request_token) + props: Dict = dict(request_token_callback=request_token) client = MongoClient( self.uri_multiple, username="test_user1", authmechanismproperties=props ) @@ -139,7 +139,7 @@ def test_connect_callbacks_multiple_principal_user1(self): def test_connect_callbacks_multiple_principal_user2(self): request_token = self.create_request_cb("test_user2") - props: Dict = dict(on_oidc_request_token=request_token) + props: Dict = dict(request_token_callback=request_token) client = MongoClient( self.uri_multiple, username="test_user2", authmechanismproperties=props ) @@ -148,7 +148,7 @@ def test_connect_callbacks_multiple_principal_user2(self): def test_connect_callbacks_multiple_no_username(self): request_token = self.create_request_cb() - props: Dict = dict(on_oidc_request_token=request_token) + props: Dict = dict(request_token_callback=request_token) client = MongoClient(self.uri_multiple, authmechanismproperties=props) with self.assertRaises(OperationFailure): client.test.test.find_one() @@ -156,10 +156,8 @@ def test_connect_callbacks_multiple_no_username(self): def test_oidc_allowed_hosts_blocked(self): request_token = self.create_request_cb() - props: Dict = dict(on_oidc_request_token=request_token) - client = MongoClient( - self.uri_single, authOIDCAllowedHosts=[], authmechanismproperties=props - ) + props: Dict = dict(request_token_callback=request_token, allowed_hosts=[]) + client = MongoClient(self.uri_single, authmechanismproperties=props) with self.assertRaises(PyMongoError): client.test.test.find_one() client.close() @@ -188,8 +186,8 @@ def test_valid_callbacks(self): refresh_cb = self.create_refresh_cb() props: Dict = dict( - on_oidc_request_token=request_cb, - on_oidc_refresh_token=refresh_cb, + request_token_callback=request_cb, + refresh_token_callback=refresh_cb, ) client = MongoClient(self.uri_single, authmechanismproperties=props) client.test.test.find_one() @@ -203,7 +201,7 @@ def test_request_callback_returns_null(self): def request_token_null(principal, info, timeout): return None - props: Dict = dict(on_oidc_request_token=request_token_null) + props: Dict = dict(request_token_callback=request_token_null) client = MongoClient(self.uri_single, authMechanismProperties=props) with self.assertRaises(ValueError): client.test.test.find_one() @@ -216,7 +214,7 @@ def refresh_token_null(principal, info, creds, timeout): return None props: Dict = dict( - on_oidc_request_token=request_cb, on_oidc_refresh_token=refresh_token_null + request_token_callback=request_cb, refresh_token_callback=refresh_token_null ) client = MongoClient(self.uri_single, authMechanismProperties=props) client.test.test.find_one() @@ -231,7 +229,7 @@ def test_request_callback_invalid_result(self): def request_token_invalid(principal, info, timeout): return dict() - props: Dict = dict(on_oidc_request_token=request_token_invalid) + props: Dict = dict(request_token_callback=request_token_invalid) client = MongoClient(self.uri_single, authMechanismProperties=props) with self.assertRaises(ValueError): client.test.test.find_one() @@ -242,7 +240,7 @@ def request_cb_extra_value(principal, info, timeout): result["foo"] = "bar" return result - props: Dict = dict(on_oidc_request_token=request_cb_extra_value) + props: Dict = dict(request_token_callback=request_cb_extra_value) client = MongoClient(self.uri_single, authMechanismProperties=props) with self.assertRaises(ValueError): client.test.test.find_one() @@ -255,7 +253,7 @@ def refresh_cb_no_token(principal, info, cred, timeout): return dict() props: Dict = dict( - on_oidc_request_token=request_cb, on_oidc_refresh_token=refresh_cb_no_token + request_token_callback=request_cb, refresh_token_callback=refresh_cb_no_token ) client = MongoClient(self.uri_single, authMechanismProperties=props) client.test.test.find_one() @@ -275,7 +273,7 @@ def refresh_cb_extra_value(principal, info, cred, timeout): return result props: Dict = dict( - on_oidc_request_token=request_cb, on_oidc_refresh_token=refresh_cb_extra_value + request_token_callback=request_cb, refresh_token_callback=refresh_cb_extra_value ) client = MongoClient(self.uri_single, authMechanismProperties=props) client.test.test.find_one() @@ -293,7 +291,7 @@ def test_cache_with_refresh(self): request_cb = self.create_request_cb(expires_in_seconds=60) refresh_cb = self.create_refresh_cb() - props: Dict = dict(on_oidc_request_token=request_cb, on_oidc_refresh_token=refresh_cb) + props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb) # Ensure that a ``find`` operation adds credentials to the cache. client = MongoClient(self.uri_single, authMechanismProperties=props) @@ -316,7 +314,7 @@ def test_cache_with_no_refresh(self): # Give a callback response with a valid accessToken and an expiresInSeconds that is within one minute. request_cb = self.create_request_cb() - props = dict(on_oidc_request_token=request_cb) + props = dict(request_token_callback=request_cb) client = MongoClient(self.uri_single, authMechanismProperties=props) # Ensure that a ``find`` operation adds credentials to the cache. @@ -337,7 +335,7 @@ def test_cache_with_no_refresh(self): def test_cache_key_includes_callback(self): request_cb = self.create_request_cb() - props: Dict = dict(on_oidc_request_token=request_cb) + props: Dict = dict(request_token_callback=request_cb) # Ensure that a ``find`` operation adds a new entry to the cache. client = MongoClient(self.uri_single, authMechanismProperties=props) @@ -348,7 +346,7 @@ def test_cache_key_includes_callback(self): def request_token_2(principal, info, timeout): return request_cb(principal, info, timeout) - props["on_oidc_request_token"] = request_token_2 + props["request_token_callback"] = request_token_2 client = MongoClient(self.uri_single, authMechanismProperties=props) # Ensure that a ``find`` operation adds a new entry to the cache. @@ -364,7 +362,7 @@ def refresh_cb(principal, info, creds, timeout): return dict(access_token="bad") # Add a token to the cache that will expire soon. - props: Dict = dict(on_oidc_request_token=request_cb, on_oidc_refresh_token=refresh_cb) + props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb) client = MongoClient(self.uri_single, authMechanismProperties=props) client.test.test.find_one() client.close() @@ -402,7 +400,7 @@ def request_token(principal, info, timeout): # Create a client with a request callback that returns a valid token # that will not expire soon. - props: Dict = dict(on_oidc_request_token=request_token) + props: Dict = dict(request_token_callback=request_token) client = MongoClient(self.uri_single, authmechanismproperties=props) # Set a fail point for saslStart commands. @@ -443,7 +441,7 @@ def test_reauthenticate_succeeds(self): refresh_cb = self.create_refresh_cb() # Create a client with the callbacks. - props: Dict = dict(on_oidc_request_token=request_cb, on_oidc_refresh_token=refresh_cb) + props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb) client = MongoClient( self.uri_single, event_listeners=[listener], authmechanismproperties=props ) @@ -492,7 +490,7 @@ def test_reauthenticate_retries_and_succees_with_cache(self): refresh_cb = self.create_refresh_cb() # Create a client with the callbacks. - props: Dict = dict(on_oidc_request_token=request_cb, on_oidc_refresh_token=refresh_cb) + props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb) client = MongoClient( self.uri_single, event_listeners=[listener], authmechanismproperties=props ) @@ -522,7 +520,7 @@ def test_reauthenticate_fails_with_no_cache(self): refresh_cb = self.create_refresh_cb() # Create a client with the callbacks. - props: Dict = dict(on_oidc_request_token=request_cb, on_oidc_refresh_token=refresh_cb) + props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb) client = MongoClient( self.uri_single, event_listeners=[listener], authmechanismproperties=props ) From beb2b24e89a1fc93b7c64d2411a5da9ad3ee5cd9 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Fri, 31 Mar 2023 10:46:13 -0500 Subject: [PATCH 098/133] fix auth spec test --- test/test_auth_spec.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test/test_auth_spec.py b/test/test_auth_spec.py index 6dc2b5a253..0a0adc4490 100644 --- a/test/test_auth_spec.py +++ b/test/test_auth_spec.py @@ -46,10 +46,10 @@ def run_test(self): if credential: props = credential["mechanism_properties"] or {} if props.get("REQUEST_TOKEN_CALLBACK"): - props["on_oidc_request_token"] = lambda x, y, z: 1 + props["request_token_callback"] = lambda x, y, z: 1 del props["REQUEST_TOKEN_CALLBACK"] if props.get("REFRESH_TOKEN_CALLBACK"): - props["on_oidc_refresh_token"] = lambda a, b, c, d: 1 # type:ignore[misc] + props["refresh_token_callback"] = lambda a, b, c, d: 1 # type:ignore[misc] del props["REFRESH_TOKEN_CALLBACK"] client = MongoClient(uri, connect=False, authmechanismproperties=props) credentials = client.options.pool_options._credentials @@ -82,13 +82,13 @@ def run_test(self): ) elif "PROVIDER_NAME" in expected: self.assertEqual(actual.provider_name, expected["PROVIDER_NAME"]) - elif "on_oidc_request_token" in expected: + elif "request_token_callback" in expected: self.assertEqual( - actual.on_oidc_request_token, expected["on_oidc_request_token"] + actual.request_token_callback, expected["request_token_callback"] ) - elif "on_oidc_refresh_token" in expected: + elif "refresh_token_callback" in expected: self.assertEqual( - actual.on_oidc_refresh_token, expected["on_oidc_refresh_token"] + actual.refresh_token_callback, expected["refresh_token_callback"] ) else: self.fail("Unhandled property: %s" % (key,)) From 6d3ebe2e2bd509da4002a1dd35159f0247c1534c Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Mon, 3 Apr 2023 11:54:56 -0500 Subject: [PATCH 099/133] Merge principal name and timeout into client info object --- pymongo/auth_oidc.py | 7 +++-- test/auth_aws/test_auth_oidc.py | 48 ++++++++++++++++----------------- 2 files changed, 27 insertions(+), 28 deletions(-) diff --git a/pymongo/auth_oidc.py b/pymongo/auth_oidc.py index 5eca6f6421..53eb138243 100644 --- a/pymongo/auth_oidc.py +++ b/pymongo/auth_oidc.py @@ -98,14 +98,13 @@ def get_current_token(self, credentials, address, use_callbacks=True): ) if not current_valid_token and request_cb is not None: + client_info = dict(principal_name=principal_name, timeout_seconds=timeout) with cache_value.lock: if cache_value.token_result is None or refresh_cb is None: - cache_value.token_result = request_cb( - principal_name, cache_value.server_resp, timeout - ) + cache_value.token_result = request_cb(client_info, cache_value.server_resp) elif request_cb is not None: cache_value.token_result = refresh_cb( - principal_name, cache_value.server_resp, cache_value.token_result, timeout + client_info, cache_value.server_resp, cache_value.token_result ) cache_exp_utc = datetime.now(timezone.utc) + timedelta( minutes=self.cache_timeout_minutes diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py index 6cc584983d..10724c05ed 100644 --- a/test/auth_aws/test_auth_oidc.py +++ b/test/auth_aws/test_auth_oidc.py @@ -52,17 +52,17 @@ def create_request_cb(self, username="test_user1", expires_in_seconds=None): token_file = os.path.join(self.token_dir, username) - def request_token(principal, info, timeout): + def request_token(client_info, server_info): # Validate the principal. - if principal is not None: - self.assertIsInstance(principal, str) + if client_info["principal_name"] is not None: + self.assertIsInstance(client_info["principal_name"], str) # Validate the info. - self.assertIn("issuer", info) - self.assertIn("client_id", info) + self.assertIn("issuer", server_info) + self.assertIn("client_id", server_info) # Validate the timeout. - self.assertEqual(timeout, 60 * 5) + self.assertEqual(client_info["timeout_seconds"], 60 * 5) with open(token_file) as fid: token = fid.read() resp = dict(access_token=token) @@ -77,23 +77,23 @@ def create_refresh_cb(self, username="test_user1", expires_in_seconds=None): token_file = os.path.join(self.token_dir, username) - def refresh_token(principal, info, creds, timeout): + def refresh_token(client_info, server_info, creds): with open(token_file) as fid: token = fid.read() # Validate the principal. - if principal is not None: - self.assertIsInstance(principal, str) + if client_info["principal_name"] is not None: + self.assertIsInstance(client_info["principal_name"], str) # Validate the info. - self.assertIn("issuer", info) - self.assertIn("client_id", info) + self.assertIn("issuer", server_info) + self.assertIn("client_id", server_info) # Validate the creds self.assertIn("access_token", creds) # Validate the timeout. - self.assertEqual(timeout, 60 * 5) + self.assertEqual(client_info["timeout_seconds"], 60 * 5) resp = dict(access_token=token) if expires_in_seconds is not None: @@ -198,7 +198,7 @@ def test_valid_callbacks(self): client.close() def test_request_callback_returns_null(self): - def request_token_null(principal, info, timeout): + def request_token_null(client_info, server_info): return None props: Dict = dict(request_token_callback=request_token_null) @@ -210,7 +210,7 @@ def request_token_null(principal, info, timeout): def test_refresh_callback_returns_null(self): request_cb = self.create_request_cb(expires_in_seconds=60) - def refresh_token_null(principal, info, creds, timeout): + def refresh_token_null(client_info, server_info, creds): return None props: Dict = dict( @@ -226,7 +226,7 @@ def refresh_token_null(principal, info, creds, timeout): client.close() def test_request_callback_invalid_result(self): - def request_token_invalid(principal, info, timeout): + def request_token_invalid(client_info, server_info): return dict() props: Dict = dict(request_token_callback=request_token_invalid) @@ -235,8 +235,8 @@ def request_token_invalid(principal, info, timeout): client.test.test.find_one() client.close() - def request_cb_extra_value(principal, info, timeout): - result = self.create_request_cb()(principal, info, timeout) + def request_cb_extra_value(client_info, server_info): + result = self.create_request_cb()(client_info, server_info) result["foo"] = "bar" return result @@ -249,7 +249,7 @@ def request_cb_extra_value(principal, info, timeout): def test_refresh_callback_missing_data(self): request_cb = self.create_request_cb(expires_in_seconds=60) - def refresh_cb_no_token(principal, info, cred, timeout): + def refresh_cb_no_token(client_info, server_info, cred): return dict() props: Dict = dict( @@ -267,8 +267,8 @@ def refresh_cb_no_token(principal, info, cred, timeout): def test_refresh_callback_extra_data(self): request_cb = self.create_request_cb(expires_in_seconds=60) - def refresh_cb_extra_value(principal, info, cred, timeout): - result = self.create_refresh_cb()(principal, info, cred, timeout) + def refresh_cb_extra_value(client_info, server_info, cred): + result = self.create_refresh_cb()(client_info, server_info, cred) result["foo"] = "bar" return result @@ -343,8 +343,8 @@ def test_cache_key_includes_callback(self): client.close() # Create a new client with a different request callback. - def request_token_2(principal, info, timeout): - return request_cb(principal, info, timeout) + def request_token_2(client_info, server_info): + return request_cb(client_info, server_info) props["request_token_callback"] = request_token_2 client = MongoClient(self.uri_single, authMechanismProperties=props) @@ -358,7 +358,7 @@ def test_cache_clears_on_error(self): request_cb = self.create_request_cb() # Create a new client with a valid request callback that gives credentials that expire within 5 minutes and a refresh callback that gives invalid credentials. - def refresh_cb(principal, info, creds, timeout): + def refresh_cb(client_info, server_info, creds): return dict(access_token="bad") # Add a token to the cache that will expire soon. @@ -393,7 +393,7 @@ def test_speculative_auth_success(self): _oidc_cache.clear() token_file = os.path.join(self.token_dir, "test_user1") - def request_token(principal, info, timeout): + def request_token(client_info, server_info): with open(token_file) as fid: token = fid.read() return dict(access_token=token, expires_in_seconds=1000) From e46047d2ac4b26f167e73f93038041194ea82bd4 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Mon, 3 Apr 2023 12:00:40 -0500 Subject: [PATCH 100/133] fix validation --- pymongo/common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymongo/common.py b/pymongo/common.py index 7c1d70d6c1..5dd39aea5b 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -450,9 +450,9 @@ def validate_auth_mechanism_properties(option: str, value: Any) -> Dict[str, Uni elif inspect.isfunction(value): signature = inspect.signature(value) if key == "request_token_callback": - expected_params = 3 + expected_params = 2 elif key == "refresh_token_callback": - expected_params = 4 + expected_params = 3 else: raise ValueError(f"Unrecognized Auth mechanism function {key}") if len(signature.parameters) != expected_params: From 3454babfa27c574d5ef5c08550ad19aa021fbf7c Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Mon, 3 Apr 2023 12:06:49 -0500 Subject: [PATCH 101/133] fix auth spec test --- test/test_auth_spec.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_auth_spec.py b/test/test_auth_spec.py index 0a0adc4490..3548e85d20 100644 --- a/test/test_auth_spec.py +++ b/test/test_auth_spec.py @@ -46,10 +46,10 @@ def run_test(self): if credential: props = credential["mechanism_properties"] or {} if props.get("REQUEST_TOKEN_CALLBACK"): - props["request_token_callback"] = lambda x, y, z: 1 + props["request_token_callback"] = lambda x, y: 1 del props["REQUEST_TOKEN_CALLBACK"] if props.get("REFRESH_TOKEN_CALLBACK"): - props["refresh_token_callback"] = lambda a, b, c, d: 1 # type:ignore[misc] + props["refresh_token_callback"] = lambda a, b, c: 1 # type:ignore[misc] del props["REFRESH_TOKEN_CALLBACK"] client = MongoClient(uri, connect=False, authmechanismproperties=props) credentials = client.options.pool_options._credentials From 8e8912d518a1811c8b229e5c2024afff9df1a1c3 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Tue, 4 Apr 2023 16:57:21 -0500 Subject: [PATCH 102/133] strip sasl commands --- test/auth_aws/test_auth_oidc.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py index 10724c05ed..bb02453909 100644 --- a/test/auth_aws/test_auth_oidc.py +++ b/test/auth_aws/test_auth_oidc.py @@ -463,19 +463,24 @@ def test_reauthenticate_succeeds(self): # Perform a find operation. client.test.test.find_one() - started_events = [i.command_name for i in listener.started_events] - succeeded_events = [i.command_name for i in listener.succeeded_events] - failed_events = [i.command_name for i in listener.failed_events] + started_events = [ + i.command_name for i in listener.started_events if not i.command_name.startswith("sasl") + ] + succeeded_events = [ + i.command_name + for i in listener.succeeded_events + if not i.command_name.startswith("sasl") + ] + failed_events = [ + i.command_name for i in listener.failed_events if not i.command_name.startswith("sasl") + ] assert started_events == [ "find", - "saslStart", - "saslStart", - "saslContinue", "find", ], started_events - assert succeeded_events == ["saslStart", "saslContinue", "find"], succeeded_events - assert failed_events == ["find", "saslStart"], failed_events + assert succeeded_events == ["find"], succeeded_events + assert failed_events == ["find"], failed_events # Assert that the refresh callback has been called. self.assertEqual(self.refresh_called, 1) From 2133125522702421380c062f58eeeef5c2157353 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 5 Apr 2023 11:03:05 -0500 Subject: [PATCH 103/133] clean up handling of allowed hosts --- pymongo/auth.py | 8 +++++++- test/auth_aws/test_auth_oidc.py | 14 +++++++++++++- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/pymongo/auth.py b/pymongo/auth.py index c3edf7679a..8ba48e25fc 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -146,7 +146,13 @@ def _build_credentials_tuple(mech, source, user, passwd, extra, database): request_token_callback = properties.get("request_token_callback") refresh_token_callback = properties.get("refresh_token_callback", None) provider_name = properties.get("PROVIDER_NAME", "") - default_allowed = ["*.mongodb.net", "*.mongodb-dev.net", "*.mongodbgov.net", "localhost"] + default_allowed = [ + "*.mongodb.net", + "*.mongodb-dev.net", + "*.mongodbgov.net", + "localhost", + "127.0.0.1", + ] allowed_hosts = properties.get("allowed_hosts", default_allowed) if not request_token_callback and provider_name != "aws": raise ConfigurationError( diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py index bb02453909..80a1c7e468 100644 --- a/test/auth_aws/test_auth_oidc.py +++ b/test/auth_aws/test_auth_oidc.py @@ -154,7 +154,7 @@ def test_connect_callbacks_multiple_no_username(self): client.test.test.find_one() client.close() - def test_oidc_allowed_hosts_blocked(self): + def test_allowed_hosts_blocked(self): request_token = self.create_request_cb() props: Dict = dict(request_token_callback=request_token, allowed_hosts=[]) client = MongoClient(self.uri_single, authmechanismproperties=props) @@ -162,6 +162,12 @@ def test_oidc_allowed_hosts_blocked(self): client.test.test.find_one() client.close() + props: Dict = dict(request_token_callback=request_token, allowed_hosts=["localhost1"]) + client = MongoClient(self.uri_single, authmechanismproperties=props) + with self.assertRaises(PyMongoError): + client.test.test.find_one() + client.close() + def test_connect_aws_single_principal(self): props = dict(PROVIDER_NAME="aws") client = MongoClient(self.uri_single, authmechanismproperties=props) @@ -181,6 +187,12 @@ def test_connect_aws_multiple_principal_user2(self): client.test.test.find_one() client.close() + def test_connect_aws_allowed_hosts_ignored(self): + props = dict(PROVIDER_NAME="aws", allowed_hosts=[]) + client = MongoClient(self.uri_multiple, authmechanismproperties=props) + client.test.test.find_one() + client.close() + def test_valid_callbacks(self): request_cb = self.create_request_cb(expires_in_seconds=60) refresh_cb = self.create_refresh_cb() From e9024ec107e19fa4195275c84651bb9d41a24ce3 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 6 Apr 2023 14:10:36 -0500 Subject: [PATCH 104/133] add another allowed_host and update docs --- doc/changelog.rst | 2 +- doc/examples/authentication.rst | 21 +-------------------- pymongo/auth.py | 1 + 3 files changed, 3 insertions(+), 21 deletions(-) diff --git a/doc/changelog.rst b/doc/changelog.rst index 465a0f2ae4..4567713d6d 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -4,7 +4,7 @@ Changelog Changes in Version 4.4 ----------------------- -- Added support for :ref:`OIDC ` authentication with MongoDB Enterprise 7.0+. +- Added support for :ref:`OIDC ` authentication with MongoDB Enterprise 7.0+ in Public Preview. - Added support for passing a list containing (key, direction) pairs or keys to :meth:`~pymongo.collection.Collection.create_index`. - **BETA** Added support for range queries on client side field level encrypted collections. diff --git a/doc/examples/authentication.rst b/doc/examples/authentication.rst index 0f0b65a35b..6e89e93269 100644 --- a/doc/examples/authentication.rst +++ b/doc/examples/authentication.rst @@ -392,26 +392,7 @@ MONGODB-OIDC ------------ .. versionadded:: 4.4 -The MONGODB-OIDC authentication mechanism is available in MongoDB Enterprise 7.0+. - -AWS OIDC Support -~~~~~~~~~~~~~~~~ - -PyMongo supports automatic authentication when AWS OIDC credentials are -available, by installing pymongo with the -``aws`` extra:: - - $ python -m pip install 'pymongo[aws]' - -A sample URI would be: - - >>> from pymongo import MongoClient - >>> uri = "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:aws") - >>> client = MongoClient(uri) - -The driver will use the authentication token from the file given by the -``AWS_WEB_IDENTITY_TOKEN_FILE`` environment variable provided by AWS to -authenticate with the server. +The MONGODB-OIDC authentication mechanism is available in MongoDB Enterprise 7.0+ as a Public Preview. Callback-based OIDC Support ~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/pymongo/auth.py b/pymongo/auth.py index 8ba48e25fc..1afe91077e 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -152,6 +152,7 @@ def _build_credentials_tuple(mech, source, user, passwd, extra, database): "*.mongodbgov.net", "localhost", "127.0.0.1", + "::1", ] allowed_hosts = properties.get("allowed_hosts", default_allowed) if not request_token_callback and provider_name != "aws": From f522d0dbf5a9702323e2e25df61318e18a40d8af Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Mon, 17 Apr 2023 13:23:32 -0500 Subject: [PATCH 105/133] refactor and adjust allowed_hosts handling --- pymongo/auth_oidc.py | 256 ++++++++++++++++---------------- test/auth_aws/test_auth_oidc.py | 12 +- 2 files changed, 135 insertions(+), 133 deletions(-) diff --git a/pymongo/auth_oidc.py b/pymongo/auth_oidc.py index 53eb138243..22db35cac5 100644 --- a/pymongo/auth_oidc.py +++ b/pymongo/auth_oidc.py @@ -13,10 +13,9 @@ # limitations under the License. """"MONGODB-OIDC Authentication helpers.""" -import fnmatch import os import threading -from dataclasses import dataclass +from dataclasses import dataclass, field from datetime import datetime, timedelta, timezone from typing import Any, Callable, Dict, List, Optional @@ -24,6 +23,7 @@ from bson.binary import Binary from bson.son import SON from pymongo.errors import OperationFailure, PyMongoError +from pymongo.helpers import _REAUTHENTICATION_REQUIRED_CODE @dataclass @@ -36,36 +36,69 @@ class _OIDCProperties: """Mechanism properties for MONGODB-OIDC authentication.""" +TOKEN_BUFFER_MINUTES = 5 +CALLBACK_TIMEOUT_SECONDS = 5 * 60 +CACHE_TIMEOUT_MINUTES = 60 * 5 -@dataclass -class _OIDCCache: - token_result: Optional[Dict] - token_exp_utc: Optional[datetime] - cache_exp_utc: datetime - server_resp: Optional[Dict] - lock: threading.Lock - - -class _OIDCMechanism: - cache: Dict[str, _OIDCCache] = {} - token_buffer_minutes = 5 - callback_timeout_seconds = 5 * 60 - cache_timeout_minutes = 60 * 5 - - def get_cache_key(self, credentials, address): - # Handle authorization code credentials. - address = address - principal_name = credentials.username - properties = credentials.mechanism_properties - request_cb = properties.request_token_callback - refresh_cb = properties.refresh_token_callback - return f"{principal_name}{address[0]}{address[1]}{id(request_cb)}{id(refresh_cb)}" - def get_current_token(self, credentials, address, use_callbacks=True): - properties: _OIDCProperties = credentials.mechanism_properties - cache_key = self.get_cache_key(credentials, address) - cache_value = self.cache[cache_key] - principal_name = credentials.username +_CACHE: Dict[str, "_OIDCAuthenticator"] = {} + + +def _get_authenticator(credentials, address): + # Clear out old items in the cache. + now_utc = datetime.now(timezone.utc) + to_remove = [] + for key, value in _CACHE.items(): + if value.cache_exp_utc is not None and value.cache_exp_utc < now_utc: + to_remove.append(key) + for key in to_remove: + del _CACHE[key] + + # Extract values. + principal_name = credentials.username + properties = credentials.mechanism_properties + request_cb = properties.request_token_callback + refresh_cb = properties.refresh_token_callback + + # Validate that the address is allowed. + if not properties.provider_name: + found = False + allowed_hosts = properties.allowed_hosts + for patt in allowed_hosts: + if patt == address[0]: + found = True + elif patt.startswith("*") and address[0].endswith(patt[1:]): + found = True + if not found: + raise PyMongoError( + f"Refusing to connect to {address[0]}, which is not in authOIDCAllowedHosts: {allowed_hosts}" + ) + + # Get or create the cache item. + cache_key = f"{principal_name}{address[0]}{address[1]}{id(request_cb)}{id(refresh_cb)}" + _CACHE.setdefault(cache_key, _OIDCAuthenticator(username=principal_name, properties=properties)) + + return _CACHE[cache_key] + + +def _get_cache_exp(): + return datetime.now(timezone.utc) + timedelta(minutes=CACHE_TIMEOUT_MINUTES) + + +@dataclass +class _OIDCAuthenticator: + username: str + properties: _OIDCProperties + idp_info: Optional[Dict] = field(default=None) + idp_resp: Optional[Dict] = field(default=None) + reauth_time: Optional[datetime] = field(default=None) + token_exp_utc: Optional[datetime] = field(default=None) + cache_exp_utc: datetime = field(default_factory=_get_cache_exp) + lock: threading.Lock = field(default_factory=threading.Lock) + + def get_current_token(self, use_callbacks=True): + properties = self.properties + principal_name = self.username request_cb = properties.request_token_callback refresh_cb = properties.refresh_token_callback @@ -74,44 +107,28 @@ def get_current_token(self, credentials, address, use_callbacks=True): refresh_cb = None current_valid_token = False - if cache_value.token_exp_utc is not None: + if self.token_exp_utc is not None: now_utc = datetime.now(timezone.utc) - exp_utc = cache_value.token_exp_utc - buffer_seconds = self.token_buffer_minutes * 60 + exp_utc = self.token_exp_utc + buffer_seconds = TOKEN_BUFFER_MINUTES * 60 if (exp_utc - now_utc).total_seconds() >= buffer_seconds: current_valid_token = True - timeout = self.callback_timeout_seconds + timeout = CALLBACK_TIMEOUT_SECONDS if not use_callbacks and not current_valid_token: return None - # Ensure that the desired address is allowed. - found = False - allowed_hosts = credentials.mechanism_properties.allowed_hosts - for patt in allowed_hosts: - if fnmatch.fnmatch(address[0], patt): - found = True - if not found: - raise PyMongoError( - f"Refusing to connect to {address[0]}, which is not in authOIDCAllowedHosts: {allowed_hosts}" - ) - if not current_valid_token and request_cb is not None: client_info = dict(principal_name=principal_name, timeout_seconds=timeout) - with cache_value.lock: - if cache_value.token_result is None or refresh_cb is None: - cache_value.token_result = request_cb(client_info, cache_value.server_resp) - elif request_cb is not None: - cache_value.token_result = refresh_cb( - client_info, cache_value.server_resp, cache_value.token_result - ) - cache_exp_utc = datetime.now(timezone.utc) + timedelta( - minutes=self.cache_timeout_minutes - ) - cache_value.cache_exp_utc = cache_exp_utc - - token_result = cache_value.token_result + if self.idp_resp is None or refresh_cb is None: + self.idp_resp = request_cb(client_info, self.idp_info) + elif request_cb is not None: + self.idp_resp = refresh_cb(client_info, self.idp_info, self.idp_resp) + cache_exp_utc = datetime.now(timezone.utc) + timedelta(minutes=CACHE_TIMEOUT_MINUTES) + self.cache_exp_utc = cache_exp_utc + + token_result = self.idp_resp # Validate callback return value. if not isinstance(token_result, dict): @@ -126,33 +143,19 @@ def get_current_token(self, credentials, address, use_callbacks=True): raise ValueError(f'Unexpected field in callback result "{key}"') token = token_result["access_token"] + if "expires_in_seconds" in token_result: expires_in = int(token_result["expires_in_seconds"]) - buffer_seconds = self.token_buffer_minutes * 60 + buffer_seconds = TOKEN_BUFFER_MINUTES * 60 if expires_in >= buffer_seconds: now_utc = datetime.now(timezone.utc) exp_utc = now_utc + timedelta(seconds=expires_in) - cache_value.token_exp_utc = exp_utc + self.token_exp_utc = exp_utc return token - def invalidate_token(self, credentials, address): - cache_key = self.get_cache_key(credentials, address) - cache_value = self.cache.get(cache_key) - if cache_value: - cache_value.token_exp_utc = None - - def auth_start(self, credentials, address, use_callbacks=True): - properties: _OIDCProperties = credentials.mechanism_properties - - # Clear out old items in the cache. - now_utc = datetime.now(timezone.utc) - to_remove = [] - for key, value in self.cache.items(): - if value.cache_exp_utc < now_utc: - to_remove.append(key) - for key in to_remove: - del self.cache[key] + def auth_start_cmd(self, use_callbacks=True): + properties = self.properties # Handle aws provider credentials. if properties.provider_name == "aws": @@ -169,30 +172,17 @@ def auth_start(self, credentials, address, use_callbacks=True): ) return cmd - cache_key = self.get_cache_key(credentials, address) - cache_value = self.cache.get(cache_key) - principal_name = credentials.username + principal_name = self.username - if cache_value is not None: - cache_value.cache_exp_utc = datetime.now(timezone.utc) + timedelta( - minutes=self.cache_timeout_minutes + if self.idp_info is not None: + self.cache_exp_utc = datetime.now(timezone.utc) + timedelta( + minutes=CACHE_TIMEOUT_MINUTES ) - if cache_value is None: - lock = threading.Lock() - cache_exp_utc = datetime.now(timezone.utc) + timedelta( - minutes=self.cache_timeout_minutes - ) - cache_value = _OIDCCache( - lock=lock, - token_result=None, - server_resp=None, - token_exp_utc=None, - cache_exp_utc=cache_exp_utc, - ) - self.cache[cache_key] = cache_value + if self.idp_info is None: + self.cache_exp_utc = _get_cache_exp() - if cache_value.server_resp is None: + if self.idp_info is None: # Send the SASL start with the optional principal name. payload = dict() @@ -209,7 +199,7 @@ def auth_start(self, credentials, address, use_callbacks=True): ) return cmd - token = self.get_current_token(credentials, address, use_callbacks) + token = self.get_current_token(use_callbacks) if not token: return None bin_payload = Binary(bson.encode(dict(jwt=token))) @@ -221,32 +211,31 @@ def auth_start(self, credentials, address, use_callbacks=True): ] ) - def authenticate(self, credentials, sock_info, reauthenticate): - if reauthenticate: - self.invalidate_token(credentials, sock_info.address) - cache = self.cache + def clear(self): + self.idp_info = None + self.idp_resp = None + self.token_exp_utc = None + + def run_command(self, sock_info, cmd): + try: + return sock_info.command("$external", cmd) + except OperationFailure as exc: + self.clear() + if exc.code == _REAUTHENTICATION_REQUIRED_CODE: + if "jwt" in bson.decode(cmd["payload"]): + self.reauth_time = datetime.now(timezone.utc) + return self.authenticate(sock_info) + raise + + def authenticate(self, sock_info): ctx = sock_info.auth_ctx cmd = None - cache_key = self.get_cache_key(credentials, sock_info.address) - in_cache = cache_key in cache if ctx and ctx.speculate_succeeded(): resp = ctx.speculative_authenticate else: - cmd = self.auth_start(credentials, sock_info.address) - try: - resp = sock_info.command(credentials.source, cmd) - except Exception: - cache.pop(cache_key, None) - # Allow for one retry on reauthenticate when callbacks are in use - # and there was no cache. - if ( - reauthenticate - and not credentials.mechanism_properties.provider_name - and in_cache - ): - return _authenticate_oidc(credentials, sock_info, False) - raise + cmd = self.auth_start_cmd() + resp = self.run_command(sock_info, cmd) if resp["done"]: return @@ -261,10 +250,10 @@ def authenticate(self, credentials, sock_info, reauthenticate): server_resp[camel_to_snake(key)] = value if "issuer" in server_resp: - cache[cache_key].server_resp = server_resp + self.idp_info = server_resp conversation_id = resp["conversationId"] - token = self.get_current_token(credentials, sock_info.address) + token = self.get_current_token() bin_payload = Binary(bson.encode(dict(jwt=token))) cmd = SON( [ @@ -273,13 +262,11 @@ def authenticate(self, credentials, sock_info, reauthenticate): ("payload", bin_payload), ] ) - response = sock_info.command("$external", cmd) - if not response["done"]: - cache.pop(cache_key, None) + resp = self.run_command(sock_info, cmd) + if not resp["done"]: + self.clear() raise OperationFailure("SASL conversation failed to complete.") - - -_internal = _OIDCMechanism() + return resp class _OIDCContextMixin: @@ -287,7 +274,9 @@ class _OIDCContextMixin: address: Any def speculate_command(self): - cmd = _internal.auth_start(self.credentials, self.address, False) + authenticator = _get_authenticator(self.credentials, self.address) + with authenticator.lock: + cmd = authenticator.auth_start_cmd(False) if cmd is None: return cmd["db"] = self.credentials.source @@ -296,4 +285,15 @@ def speculate_command(self): def _authenticate_oidc(credentials, sock_info, reauthenticate): """Authenticate using MONGODB-OIDC.""" - return _internal.authenticate(credentials, sock_info, reauthenticate) + authenticator = _get_authenticator(credentials, sock_info.address) + # Prevent a race condition on reauthentication. Store the current time + # and compare to reauth time. + prev_time = authenticator.reauth_time + with authenticator.lock: + if reauthenticate: + new_time = authenticator.reauth_time + if prev_time and new_time and new_time <= prev_time: + authenticator.token_exp_utc = None + authenticator.reauth_time = datetime.now(timezone.utc) + + return authenticator.authenticate(sock_info) diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py index 80a1c7e468..1b4a820b61 100644 --- a/test/auth_aws/test_auth_oidc.py +++ b/test/auth_aws/test_auth_oidc.py @@ -26,11 +26,9 @@ from bson import SON from pymongo import MongoClient -from pymongo.auth_oidc import _internal +from pymongo.auth_oidc import _CACHE as _oidc_cache from pymongo.errors import OperationFailure, PyMongoError -_oidc_cache = _internal.cache - class TestAuthOIDC(unittest.TestCase): uri: str @@ -389,7 +387,8 @@ def refresh_cb(client_info, server_info, creds): client.close() # Ensure that the cache has been cleared. - assert len(_oidc_cache) == 0 + authenticator = list(_oidc_cache.values())[0] + assert authenticator.idp_info is None def test_cache_is_not_used_in_aws_automatic_workflow(self): # Create a new client using the AWS device workflow. @@ -398,7 +397,10 @@ def test_cache_is_not_used_in_aws_automatic_workflow(self): client = MongoClient(self.uri_single, authmechanismproperties=props) client.test.test.find_one() client.close() - assert len(_oidc_cache) == 0 + + # Ensure that the cache has been cleared. + authenticator = list(_oidc_cache.values())[0] + assert authenticator.idp_info is None def test_speculative_auth_success(self): # Clear the cache From b0c0919b52d49af45a58e63899f56efc41981362 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Mon, 17 Apr 2023 15:51:59 -0500 Subject: [PATCH 106/133] update callback parameters --- pymongo/auth_oidc.py | 5 ++--- pymongo/common.py | 4 ++-- test/auth_aws/test_auth_oidc.py | 40 ++++++++++++++++----------------- test/test_auth_spec.py | 4 ++-- 4 files changed, 26 insertions(+), 27 deletions(-) diff --git a/pymongo/auth_oidc.py b/pymongo/auth_oidc.py index 22db35cac5..6c7b1a3655 100644 --- a/pymongo/auth_oidc.py +++ b/pymongo/auth_oidc.py @@ -120,11 +120,10 @@ def get_current_token(self, use_callbacks=True): return None if not current_valid_token and request_cb is not None: - client_info = dict(principal_name=principal_name, timeout_seconds=timeout) if self.idp_resp is None or refresh_cb is None: - self.idp_resp = request_cb(client_info, self.idp_info) + self.idp_resp = request_cb(principal_name, self.idp_info, timeout) elif request_cb is not None: - self.idp_resp = refresh_cb(client_info, self.idp_info, self.idp_resp) + self.idp_resp = refresh_cb(principal_name, self.idp_info, self.idp_resp, timeout) cache_exp_utc = datetime.now(timezone.utc) + timedelta(minutes=CACHE_TIMEOUT_MINUTES) self.cache_exp_utc = cache_exp_utc diff --git a/pymongo/common.py b/pymongo/common.py index 5dd39aea5b..7c1d70d6c1 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -450,9 +450,9 @@ def validate_auth_mechanism_properties(option: str, value: Any) -> Dict[str, Uni elif inspect.isfunction(value): signature = inspect.signature(value) if key == "request_token_callback": - expected_params = 2 - elif key == "refresh_token_callback": expected_params = 3 + elif key == "refresh_token_callback": + expected_params = 4 else: raise ValueError(f"Unrecognized Auth mechanism function {key}") if len(signature.parameters) != expected_params: diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py index 1b4a820b61..ac4a3d4097 100644 --- a/test/auth_aws/test_auth_oidc.py +++ b/test/auth_aws/test_auth_oidc.py @@ -50,17 +50,17 @@ def create_request_cb(self, username="test_user1", expires_in_seconds=None): token_file = os.path.join(self.token_dir, username) - def request_token(client_info, server_info): + def request_token(principal_name, server_info, timeout_seconds): # Validate the principal. - if client_info["principal_name"] is not None: - self.assertIsInstance(client_info["principal_name"], str) + if principal_name is not None: + self.assertIsInstance(principal_name, str) # Validate the info. self.assertIn("issuer", server_info) self.assertIn("client_id", server_info) # Validate the timeout. - self.assertEqual(client_info["timeout_seconds"], 60 * 5) + self.assertEqual(timeout_seconds, 60 * 5) with open(token_file) as fid: token = fid.read() resp = dict(access_token=token) @@ -75,13 +75,13 @@ def create_refresh_cb(self, username="test_user1", expires_in_seconds=None): token_file = os.path.join(self.token_dir, username) - def refresh_token(client_info, server_info, creds): + def refresh_token(principal_name, server_info, creds, timeout_seconds): with open(token_file) as fid: token = fid.read() # Validate the principal. - if client_info["principal_name"] is not None: - self.assertIsInstance(client_info["principal_name"], str) + if principal_name is not None: + self.assertIsInstance(principal_name, str) # Validate the info. self.assertIn("issuer", server_info) @@ -91,7 +91,7 @@ def refresh_token(client_info, server_info, creds): self.assertIn("access_token", creds) # Validate the timeout. - self.assertEqual(client_info["timeout_seconds"], 60 * 5) + self.assertEqual(timeout_seconds, 60 * 5) resp = dict(access_token=token) if expires_in_seconds is not None: @@ -208,7 +208,7 @@ def test_valid_callbacks(self): client.close() def test_request_callback_returns_null(self): - def request_token_null(client_info, server_info): + def request_token_null(a, b, c): return None props: Dict = dict(request_token_callback=request_token_null) @@ -220,7 +220,7 @@ def request_token_null(client_info, server_info): def test_refresh_callback_returns_null(self): request_cb = self.create_request_cb(expires_in_seconds=60) - def refresh_token_null(client_info, server_info, creds): + def refresh_token_null(a, b, c, d): return None props: Dict = dict( @@ -236,7 +236,7 @@ def refresh_token_null(client_info, server_info, creds): client.close() def test_request_callback_invalid_result(self): - def request_token_invalid(client_info, server_info): + def request_token_invalid(a, b, c): return dict() props: Dict = dict(request_token_callback=request_token_invalid) @@ -245,8 +245,8 @@ def request_token_invalid(client_info, server_info): client.test.test.find_one() client.close() - def request_cb_extra_value(client_info, server_info): - result = self.create_request_cb()(client_info, server_info) + def request_cb_extra_value(principal_name, server_info, timeout_seconds): + result = self.create_request_cb()(principal_name, server_info, timeout_seconds) result["foo"] = "bar" return result @@ -259,7 +259,7 @@ def request_cb_extra_value(client_info, server_info): def test_refresh_callback_missing_data(self): request_cb = self.create_request_cb(expires_in_seconds=60) - def refresh_cb_no_token(client_info, server_info, cred): + def refresh_cb_no_token(a, b, c, d): return dict() props: Dict = dict( @@ -277,8 +277,8 @@ def refresh_cb_no_token(client_info, server_info, cred): def test_refresh_callback_extra_data(self): request_cb = self.create_request_cb(expires_in_seconds=60) - def refresh_cb_extra_value(client_info, server_info, cred): - result = self.create_refresh_cb()(client_info, server_info, cred) + def refresh_cb_extra_value(principal_name, server_info, cred, timeout_seconds): + result = self.create_refresh_cb()(principal_name, server_info, cred, timeout_seconds) result["foo"] = "bar" return result @@ -353,8 +353,8 @@ def test_cache_key_includes_callback(self): client.close() # Create a new client with a different request callback. - def request_token_2(client_info, server_info): - return request_cb(client_info, server_info) + def request_token_2(a, b, c): + return request_cb(a, b, c) props["request_token_callback"] = request_token_2 client = MongoClient(self.uri_single, authMechanismProperties=props) @@ -368,7 +368,7 @@ def test_cache_clears_on_error(self): request_cb = self.create_request_cb() # Create a new client with a valid request callback that gives credentials that expire within 5 minutes and a refresh callback that gives invalid credentials. - def refresh_cb(client_info, server_info, creds): + def refresh_cb(a, b, c, d): return dict(access_token="bad") # Add a token to the cache that will expire soon. @@ -407,7 +407,7 @@ def test_speculative_auth_success(self): _oidc_cache.clear() token_file = os.path.join(self.token_dir, "test_user1") - def request_token(client_info, server_info): + def request_token(a, b, c): with open(token_file) as fid: token = fid.read() return dict(access_token=token, expires_in_seconds=1000) diff --git a/test/test_auth_spec.py b/test/test_auth_spec.py index 3548e85d20..0a0adc4490 100644 --- a/test/test_auth_spec.py +++ b/test/test_auth_spec.py @@ -46,10 +46,10 @@ def run_test(self): if credential: props = credential["mechanism_properties"] or {} if props.get("REQUEST_TOKEN_CALLBACK"): - props["request_token_callback"] = lambda x, y: 1 + props["request_token_callback"] = lambda x, y, z: 1 del props["REQUEST_TOKEN_CALLBACK"] if props.get("REFRESH_TOKEN_CALLBACK"): - props["refresh_token_callback"] = lambda a, b, c: 1 # type:ignore[misc] + props["refresh_token_callback"] = lambda a, b, c, d: 1 # type:ignore[misc] del props["REFRESH_TOKEN_CALLBACK"] client = MongoClient(uri, connect=False, authmechanismproperties=props) credentials = client.options.pool_options._credentials From 5828220b1709e8fea91472af02bc8fb21033b3d9 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Mon, 17 Apr 2023 17:48:56 -0500 Subject: [PATCH 107/133] clear all info on reauth if no refresh callback --- pymongo/auth_oidc.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/pymongo/auth_oidc.py b/pymongo/auth_oidc.py index 6c7b1a3655..365b7132ab 100644 --- a/pymongo/auth_oidc.py +++ b/pymongo/auth_oidc.py @@ -226,6 +226,14 @@ def run_command(self, sock_info, cmd): return self.authenticate(sock_info) raise + def handle_reauth(self, prev_time): + new_time = self.reauth_time + if prev_time and new_time and new_time <= prev_time: + self.token_exp_utc = None + if not self.properties.refresh_token_callback: + self.clear() + self.reauth_time = datetime.now(timezone.utc) + def authenticate(self, sock_info): ctx = sock_info.auth_ctx cmd = None @@ -290,9 +298,5 @@ def _authenticate_oidc(credentials, sock_info, reauthenticate): prev_time = authenticator.reauth_time with authenticator.lock: if reauthenticate: - new_time = authenticator.reauth_time - if prev_time and new_time and new_time <= prev_time: - authenticator.token_exp_utc = None - authenticator.reauth_time = datetime.now(timezone.utc) - + authenticator.handle_reauth(prev_time) return authenticator.authenticate(sock_info) From a78ad30a917a3557b08bf89346867ad6a7a35f78 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Mon, 17 Apr 2023 18:00:55 -0500 Subject: [PATCH 108/133] add reauth guard --- pymongo/auth_oidc.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/pymongo/auth_oidc.py b/pymongo/auth_oidc.py index 365b7132ab..531661218f 100644 --- a/pymongo/auth_oidc.py +++ b/pymongo/auth_oidc.py @@ -92,6 +92,7 @@ class _OIDCAuthenticator: idp_info: Optional[Dict] = field(default=None) idp_resp: Optional[Dict] = field(default=None) reauth_time: Optional[datetime] = field(default=None) + idp_info_time: Optional[datetime] = field(default=None) token_exp_utc: Optional[datetime] = field(default=None) cache_exp_utc: datetime = field(default_factory=_get_cache_exp) lock: threading.Lock = field(default_factory=threading.Lock) @@ -222,7 +223,13 @@ def run_command(self, sock_info, cmd): self.clear() if exc.code == _REAUTHENTICATION_REQUIRED_CODE: if "jwt" in bson.decode(cmd["payload"]): - self.reauth_time = datetime.now(timezone.utc) + if ( + self.idp_info_time is not None + and self.reauth_time is not None + and self.idp_info_time > self.reauth_time + ): + raise + self.handle_reauth(self.reauth_time) return self.authenticate(sock_info) raise @@ -258,6 +265,7 @@ def authenticate(self, sock_info): if "issuer" in server_resp: self.idp_info = server_resp + self.idp_info_time = datetime.now(timezone.utc) conversation_id = resp["conversationId"] token = self.get_current_token() From ef991f93f764b5fd84153560fdd5a0b895e55f59 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Tue, 18 Apr 2023 18:13:28 -0500 Subject: [PATCH 109/133] wip better handling of reauth and locks --- pymongo/auth_oidc.py | 52 ++++++++++++++++++++------------- test/auth_aws/test_auth_oidc.py | 6 ++-- 2 files changed, 36 insertions(+), 22 deletions(-) diff --git a/pymongo/auth_oidc.py b/pymongo/auth_oidc.py index 531661218f..2d8a609993 100644 --- a/pymongo/auth_oidc.py +++ b/pymongo/auth_oidc.py @@ -121,12 +121,24 @@ def get_current_token(self, use_callbacks=True): return None if not current_valid_token and request_cb is not None: - if self.idp_resp is None or refresh_cb is None: - self.idp_resp = request_cb(principal_name, self.idp_info, timeout) - elif request_cb is not None: - self.idp_resp = refresh_cb(principal_name, self.idp_info, self.idp_resp, timeout) - cache_exp_utc = datetime.now(timezone.utc) + timedelta(minutes=CACHE_TIMEOUT_MINUTES) - self.cache_exp_utc = cache_exp_utc + prev_token = self.idp_resp and self.idp_resp["access_token"] + with self.lock: + # See if the token was changed while we were waiting for the + # lock. + new_token = self.idp_resp and self.idp_resp["access_token"] + if new_token != prev_token: + return new_token + + if self.idp_resp is None or refresh_cb is None: + self.idp_resp = request_cb(principal_name, self.idp_info, timeout) + elif request_cb is not None: + self.idp_resp = refresh_cb( + principal_name, self.idp_info, self.idp_resp, timeout + ) + cache_exp_utc = datetime.now(timezone.utc) + timedelta( + minutes=CACHE_TIMEOUT_MINUTES + ) + self.cache_exp_utc = cache_exp_utc token_result = self.idp_resp @@ -229,17 +241,20 @@ def run_command(self, sock_info, cmd): and self.idp_info_time > self.reauth_time ): raise - self.handle_reauth(self.reauth_time) + self.handle_reauth(sock_info) return self.authenticate(sock_info) raise - def handle_reauth(self, prev_time): - new_time = self.reauth_time - if prev_time and new_time and new_time <= prev_time: - self.token_exp_utc = None - if not self.properties.refresh_token_callback: - self.clear() + def handle_reauth(self, sock_info): + prev_token = getattr(sock_info, "oidc_access_token", None) + if prev_token and self.idp_resp and prev_token != self.idp_resp["access_token"]: + # No need to preemptively clear, we've already changed tokens. + return + self.reauth_time = datetime.now(timezone.utc) + self.token_exp_utc = None + if not self.properties.refresh_token_callback: + self.clear() def authenticate(self, sock_info): ctx = sock_info.auth_ctx @@ -269,6 +284,7 @@ def authenticate(self, sock_info): conversation_id = resp["conversationId"] token = self.get_current_token() + sock_info.oidc_access_token = token bin_payload = Binary(bson.encode(dict(jwt=token))) cmd = SON( [ @@ -301,10 +317,6 @@ def speculate_command(self): def _authenticate_oidc(credentials, sock_info, reauthenticate): """Authenticate using MONGODB-OIDC.""" authenticator = _get_authenticator(credentials, sock_info.address) - # Prevent a race condition on reauthentication. Store the current time - # and compare to reauth time. - prev_time = authenticator.reauth_time - with authenticator.lock: - if reauthenticate: - authenticator.handle_reauth(prev_time) - return authenticator.authenticate(sock_info) + if reauthenticate: + authenticator.handle_reauth(sock_info) + return authenticator.authenticate(sock_info) diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py index ac4a3d4097..2ba014eda0 100644 --- a/test/auth_aws/test_auth_oidc.py +++ b/test/auth_aws/test_auth_oidc.py @@ -160,8 +160,10 @@ def test_allowed_hosts_blocked(self): client.test.test.find_one() client.close() - props: Dict = dict(request_token_callback=request_token, allowed_hosts=["localhost1"]) - client = MongoClient(self.uri_single, authmechanismproperties=props) + props: Dict = dict(request_token_callback=request_token, allowed_hosts=["example.com"]) + client = MongoClient( + self.uri_single + "&ignored=example.com", authmechanismproperties=props, connect=False + ) with self.assertRaises(PyMongoError): client.test.test.find_one() client.close() From 4cf51104bfb6730651fba130206c41f62fbd833d Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 19 Apr 2023 11:12:28 -0500 Subject: [PATCH 110/133] add prose tests for reauth and lock guards --- pymongo/auth_oidc.py | 13 +++--- test/auth_aws/test_auth_oidc.py | 76 ++++++++++++++++++++++++++++++++- 2 files changed, 82 insertions(+), 7 deletions(-) diff --git a/pymongo/auth_oidc.py b/pymongo/auth_oidc.py index 2d8a609993..8fcf014ec4 100644 --- a/pymongo/auth_oidc.py +++ b/pymongo/auth_oidc.py @@ -93,6 +93,7 @@ class _OIDCAuthenticator: idp_resp: Optional[Dict] = field(default=None) reauth_time: Optional[datetime] = field(default=None) idp_info_time: Optional[datetime] = field(default=None) + token_acq_time: Optional[datetime] = field(default=None) token_exp_utc: Optional[datetime] = field(default=None) cache_exp_utc: datetime = field(default_factory=_get_cache_exp) lock: threading.Lock = field(default_factory=threading.Lock) @@ -139,6 +140,7 @@ def get_current_token(self, use_callbacks=True): minutes=CACHE_TIMEOUT_MINUTES ) self.cache_exp_utc = cache_exp_utc + self.token_acq_time = datetime.now(timezone.utc) token_result = self.idp_resp @@ -246,8 +248,8 @@ def run_command(self, sock_info, cmd): raise def handle_reauth(self, sock_info): - prev_token = getattr(sock_info, "oidc_access_token", None) - if prev_token and self.idp_resp and prev_token != self.idp_resp["access_token"]: + prev_time = getattr(sock_info, "oidc_token_time", None) + if prev_time and self.token_acq_time and prev_time != self.token_acq_time: # No need to preemptively clear, we've already changed tokens. return @@ -267,6 +269,8 @@ def authenticate(self, sock_info): resp = self.run_command(sock_info, cmd) if resp["done"]: + if self.token_acq_time is not None: + sock_info.oidc_token_time = self.token_acq_time return # Convert the server response to be more pythonic. @@ -284,7 +288,7 @@ def authenticate(self, sock_info): conversation_id = resp["conversationId"] token = self.get_current_token() - sock_info.oidc_access_token = token + sock_info.oidc_token_time = self.token_acq_time bin_payload = Binary(bson.encode(dict(jwt=token))) cmd = SON( [ @@ -306,8 +310,7 @@ class _OIDCContextMixin: def speculate_command(self): authenticator = _get_authenticator(self.credentials, self.address) - with authenticator.lock: - cmd = authenticator.auth_start_cmd(False) + cmd = authenticator.auth_start_cmd(False) if cmd is None: return cmd["db"] = self.credentials.source diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py index 2ba014eda0..d4c2ecefd0 100644 --- a/test/auth_aws/test_auth_oidc.py +++ b/test/auth_aws/test_auth_oidc.py @@ -16,6 +16,8 @@ import os import sys +import threading +import time import unittest from contextlib import contextmanager from typing import Dict @@ -26,7 +28,9 @@ from bson import SON from pymongo import MongoClient +from pymongo.auth import MongoCredential from pymongo.auth_oidc import _CACHE as _oidc_cache +from pymongo.auth_oidc import _get_authenticator, _OIDCProperties from pymongo.errors import OperationFailure, PyMongoError @@ -46,7 +50,7 @@ def setUp(self): _oidc_cache.clear() os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] = os.path.join(self.token_dir, "test_user1") - def create_request_cb(self, username="test_user1", expires_in_seconds=None): + def create_request_cb(self, username="test_user1", expires_in_seconds=None, sleep=0): token_file = os.path.join(self.token_dir, username) @@ -64,6 +68,9 @@ def request_token(principal_name, server_info, timeout_seconds): with open(token_file) as fid: token = fid.read() resp = dict(access_token=token) + + time.sleep(sleep) + if expires_in_seconds is not None: resp["expires_in_seconds"] = expires_in_seconds self.request_called += 1 @@ -209,6 +216,31 @@ def test_valid_callbacks(self): client.test.test.find_one() client.close() + def test_lock_avoids_extra_callbacks(self): + request_cb = self.create_request_cb(sleep=0.5) + refresh_cb = self.create_refresh_cb() + + props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb) + + def run_test(): + client = MongoClient(self.uri_single, authMechanismProperties=props) + client.test.test.find_one() + client.close() + + client = MongoClient(self.uri_single, authMechanismProperties=props) + client.test.test.find_one() + client.close() + + t1 = threading.Thread(target=run_test) + t2 = threading.Thread(target=run_test) + t1.start() + t2.start() + t1.join() + t2.join() + + assert self.request_called == 1 + assert self.refresh_called == 2 + def test_request_callback_returns_null(self): def request_token_null(a, b, c): return None @@ -330,7 +362,7 @@ def test_cache_with_no_refresh(self): client = MongoClient(self.uri_single, authMechanismProperties=props) # Ensure that a ``find`` operation adds credentials to the cache. - request_called = 0 + self.request_called = 0 client.test.test.find_one() client.close() assert self.request_called == 1 @@ -564,6 +596,46 @@ def test_reauthenticate_fails_with_no_cache(self): client.close() + def test_late_reauth_avoids_callback(self): + # Step 1: connect with both clients + request_cb = self.create_request_cb(expires_in_seconds=1e6) + refresh_cb = self.create_refresh_cb(expires_in_seconds=1e6) + + props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb) + client1 = MongoClient(self.uri_single, authMechanismProperties=props) + client1.test.test.find_one() + client2 = MongoClient(self.uri_single, authMechanismProperties=props) + client2.test.test.find_one() + + assert self.refresh_called == 0 + assert self.request_called == 1 + + # Step 2: cause a find 391 on the first client + with self.fail_point( + { + "mode": {"times": 1}, + "data": {"failCommands": ["find"], "errorCode": 391}, + } + ): + # Perform a find operation that succeeds. + client1.test.test.find_one() + + assert self.refresh_called == 1 + assert self.request_called == 1 + + # Step 3: cause a find 391 on the second client + with self.fail_point( + { + "mode": {"times": 1}, + "data": {"failCommands": ["find"], "errorCode": 391}, + } + ): + # Perform a find operation that succeeds. + client2.test.test.find_one() + + assert self.refresh_called == 1 + assert self.request_called == 1 + if __name__ == "__main__": unittest.main() From a354ce696243e8ef07524561ff0cbbc71a3d2cd4 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 20 Apr 2023 06:35:22 -0500 Subject: [PATCH 111/133] use gen_id instead of datetimes --- pymongo/auth_oidc.py | 27 +++++++++++---------------- test/auth_aws/test_auth_oidc.py | 3 +++ 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/pymongo/auth_oidc.py b/pymongo/auth_oidc.py index 8fcf014ec4..654502bafb 100644 --- a/pymongo/auth_oidc.py +++ b/pymongo/auth_oidc.py @@ -91,9 +91,9 @@ class _OIDCAuthenticator: properties: _OIDCProperties idp_info: Optional[Dict] = field(default=None) idp_resp: Optional[Dict] = field(default=None) - reauth_time: Optional[datetime] = field(default=None) - idp_info_time: Optional[datetime] = field(default=None) - token_acq_time: Optional[datetime] = field(default=None) + reauth_gen_id: int = field(default=0) + idp_info_gen_id: int = field(default=0) + token_gen_id: int = field(default=0) token_exp_utc: Optional[datetime] = field(default=None) cache_exp_utc: datetime = field(default_factory=_get_cache_exp) lock: threading.Lock = field(default_factory=threading.Lock) @@ -140,7 +140,7 @@ def get_current_token(self, use_callbacks=True): minutes=CACHE_TIMEOUT_MINUTES ) self.cache_exp_utc = cache_exp_utc - self.token_acq_time = datetime.now(timezone.utc) + self.token_gen_id += 1 token_result = self.idp_resp @@ -237,23 +237,19 @@ def run_command(self, sock_info, cmd): self.clear() if exc.code == _REAUTHENTICATION_REQUIRED_CODE: if "jwt" in bson.decode(cmd["payload"]): - if ( - self.idp_info_time is not None - and self.reauth_time is not None - and self.idp_info_time > self.reauth_time - ): + if self.idp_info_gen_id > self.reauth_gen_id: raise self.handle_reauth(sock_info) return self.authenticate(sock_info) raise def handle_reauth(self, sock_info): - prev_time = getattr(sock_info, "oidc_token_time", None) - if prev_time and self.token_acq_time and prev_time != self.token_acq_time: + prev_id = getattr(sock_info, "oidc_token_gen_id", None) + if prev_id != self.token_gen_id: # No need to preemptively clear, we've already changed tokens. return - self.reauth_time = datetime.now(timezone.utc) + self.reauth_gen_id = self.idp_info_gen_id self.token_exp_utc = None if not self.properties.refresh_token_callback: self.clear() @@ -269,8 +265,7 @@ def authenticate(self, sock_info): resp = self.run_command(sock_info, cmd) if resp["done"]: - if self.token_acq_time is not None: - sock_info.oidc_token_time = self.token_acq_time + sock_info.oidc_token_gen_id = self.token_gen_id return # Convert the server response to be more pythonic. @@ -284,11 +279,11 @@ def authenticate(self, sock_info): if "issuer" in server_resp: self.idp_info = server_resp - self.idp_info_time = datetime.now(timezone.utc) + self.idp_info_gen_id += 1 conversation_id = resp["conversationId"] token = self.get_current_token() - sock_info.oidc_token_time = self.token_acq_time + sock_info.oidc_token_gen_id = self.token_gen_id bin_payload = Binary(bson.encode(dict(jwt=token))) cmd = SON( [ diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py index d4c2ecefd0..0b2da0caed 100644 --- a/test/auth_aws/test_auth_oidc.py +++ b/test/auth_aws/test_auth_oidc.py @@ -636,6 +636,9 @@ def test_late_reauth_avoids_callback(self): assert self.refresh_called == 1 assert self.request_called == 1 + client1.close() + client2.close() + if __name__ == "__main__": unittest.main() From 66f2276bf034806e925e7318292808ce33966741 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Tue, 25 Apr 2023 16:14:31 -0500 Subject: [PATCH 112/133] update callback parameters --- pymongo/auth_oidc.py | 17 +++++++++++----- pymongo/common.py | 4 ++-- test/auth_aws/test_auth_oidc.py | 35 ++++++++++++++++++--------------- test/test_auth_spec.py | 4 ++-- 4 files changed, 35 insertions(+), 25 deletions(-) diff --git a/pymongo/auth_oidc.py b/pymongo/auth_oidc.py index 654502bafb..6e72a63da0 100644 --- a/pymongo/auth_oidc.py +++ b/pymongo/auth_oidc.py @@ -39,7 +39,7 @@ class _OIDCProperties: TOKEN_BUFFER_MINUTES = 5 CALLBACK_TIMEOUT_SECONDS = 5 * 60 CACHE_TIMEOUT_MINUTES = 60 * 5 - +CALLBACK_VERSION = 0 _CACHE: Dict[str, "_OIDCAuthenticator"] = {} @@ -130,12 +130,19 @@ def get_current_token(self, use_callbacks=True): if new_token != prev_token: return new_token + refresh_token = self.idp_resp and self.idp_resp.get("refresh_token") + refresh_token = refresh_token or "" + context = dict( + timeout_seconds=timeout, + principal_name=principal_name, + version=CALLBACK_VERSION, + refresh_token=refresh_token, + ) + if self.idp_resp is None or refresh_cb is None: - self.idp_resp = request_cb(principal_name, self.idp_info, timeout) + self.idp_resp = request_cb(self.idp_info, context) elif request_cb is not None: - self.idp_resp = refresh_cb( - principal_name, self.idp_info, self.idp_resp, timeout - ) + self.idp_resp = refresh_cb(self.idp_info, context) cache_exp_utc = datetime.now(timezone.utc) + timedelta( minutes=CACHE_TIMEOUT_MINUTES ) diff --git a/pymongo/common.py b/pymongo/common.py index 7c1d70d6c1..c665d5cdc1 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -450,9 +450,9 @@ def validate_auth_mechanism_properties(option: str, value: Any) -> Dict[str, Uni elif inspect.isfunction(value): signature = inspect.signature(value) if key == "request_token_callback": - expected_params = 3 + expected_params = 2 elif key == "refresh_token_callback": - expected_params = 4 + expected_params = 2 else: raise ValueError(f"Unrecognized Auth mechanism function {key}") if len(signature.parameters) != expected_params: diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py index 0b2da0caed..3255615e2b 100644 --- a/test/auth_aws/test_auth_oidc.py +++ b/test/auth_aws/test_auth_oidc.py @@ -54,8 +54,9 @@ def create_request_cb(self, username="test_user1", expires_in_seconds=None, slee token_file = os.path.join(self.token_dir, username) - def request_token(principal_name, server_info, timeout_seconds): + def request_token(server_info, context): # Validate the principal. + principal_name = context.get("principal_name") if principal_name is not None: self.assertIsInstance(principal_name, str) @@ -64,6 +65,7 @@ def request_token(principal_name, server_info, timeout_seconds): self.assertIn("client_id", server_info) # Validate the timeout. + timeout_seconds = context["timeout_seconds"] self.assertEqual(timeout_seconds, 60 * 5) with open(token_file) as fid: token = fid.read() @@ -82,11 +84,12 @@ def create_refresh_cb(self, username="test_user1", expires_in_seconds=None): token_file = os.path.join(self.token_dir, username) - def refresh_token(principal_name, server_info, creds, timeout_seconds): + def refresh_token(server_info, context): with open(token_file) as fid: token = fid.read() # Validate the principal. + principal_name = context.get("principal_name") if principal_name is not None: self.assertIsInstance(principal_name, str) @@ -95,10 +98,10 @@ def refresh_token(principal_name, server_info, creds, timeout_seconds): self.assertIn("client_id", server_info) # Validate the creds - self.assertIn("access_token", creds) + self.assertIsNotNone(context["refresh_token"]) # Validate the timeout. - self.assertEqual(timeout_seconds, 60 * 5) + self.assertEqual(context["timeout_seconds"], 60 * 5) resp = dict(access_token=token) if expires_in_seconds is not None: @@ -242,7 +245,7 @@ def run_test(): assert self.refresh_called == 2 def test_request_callback_returns_null(self): - def request_token_null(a, b, c): + def request_token_null(a, b): return None props: Dict = dict(request_token_callback=request_token_null) @@ -254,7 +257,7 @@ def request_token_null(a, b, c): def test_refresh_callback_returns_null(self): request_cb = self.create_request_cb(expires_in_seconds=60) - def refresh_token_null(a, b, c, d): + def refresh_token_null(a, b): return None props: Dict = dict( @@ -270,7 +273,7 @@ def refresh_token_null(a, b, c, d): client.close() def test_request_callback_invalid_result(self): - def request_token_invalid(a, b, c): + def request_token_invalid(a, b): return dict() props: Dict = dict(request_token_callback=request_token_invalid) @@ -279,8 +282,8 @@ def request_token_invalid(a, b, c): client.test.test.find_one() client.close() - def request_cb_extra_value(principal_name, server_info, timeout_seconds): - result = self.create_request_cb()(principal_name, server_info, timeout_seconds) + def request_cb_extra_value(server_info, context): + result = self.create_request_cb()(server_info, context) result["foo"] = "bar" return result @@ -293,7 +296,7 @@ def request_cb_extra_value(principal_name, server_info, timeout_seconds): def test_refresh_callback_missing_data(self): request_cb = self.create_request_cb(expires_in_seconds=60) - def refresh_cb_no_token(a, b, c, d): + def refresh_cb_no_token(a, b): return dict() props: Dict = dict( @@ -311,8 +314,8 @@ def refresh_cb_no_token(a, b, c, d): def test_refresh_callback_extra_data(self): request_cb = self.create_request_cb(expires_in_seconds=60) - def refresh_cb_extra_value(principal_name, server_info, cred, timeout_seconds): - result = self.create_refresh_cb()(principal_name, server_info, cred, timeout_seconds) + def refresh_cb_extra_value(server_info, context): + result = self.create_refresh_cb()(server_info, context) result["foo"] = "bar" return result @@ -387,8 +390,8 @@ def test_cache_key_includes_callback(self): client.close() # Create a new client with a different request callback. - def request_token_2(a, b, c): - return request_cb(a, b, c) + def request_token_2(a, b): + return request_cb(a, b) props["request_token_callback"] = request_token_2 client = MongoClient(self.uri_single, authMechanismProperties=props) @@ -402,7 +405,7 @@ def test_cache_clears_on_error(self): request_cb = self.create_request_cb() # Create a new client with a valid request callback that gives credentials that expire within 5 minutes and a refresh callback that gives invalid credentials. - def refresh_cb(a, b, c, d): + def refresh_cb(a, b): return dict(access_token="bad") # Add a token to the cache that will expire soon. @@ -441,7 +444,7 @@ def test_speculative_auth_success(self): _oidc_cache.clear() token_file = os.path.join(self.token_dir, "test_user1") - def request_token(a, b, c): + def request_token(a, b): with open(token_file) as fid: token = fid.read() return dict(access_token=token, expires_in_seconds=1000) diff --git a/test/test_auth_spec.py b/test/test_auth_spec.py index 0a0adc4490..ffed059e67 100644 --- a/test/test_auth_spec.py +++ b/test/test_auth_spec.py @@ -46,10 +46,10 @@ def run_test(self): if credential: props = credential["mechanism_properties"] or {} if props.get("REQUEST_TOKEN_CALLBACK"): - props["request_token_callback"] = lambda x, y, z: 1 + props["request_token_callback"] = lambda x, y: 1 del props["REQUEST_TOKEN_CALLBACK"] if props.get("REFRESH_TOKEN_CALLBACK"): - props["refresh_token_callback"] = lambda a, b, c, d: 1 # type:ignore[misc] + props["refresh_token_callback"] = lambda a, b: 1 # type:ignore[misc] del props["REFRESH_TOKEN_CALLBACK"] client = MongoClient(uri, connect=False, authmechanismproperties=props) credentials = client.options.pool_options._credentials From 710add4859cd3fbf1403ffbed27f88e955085121 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 26 Apr 2023 10:15:42 -0500 Subject: [PATCH 113/133] remove principal name from context --- pymongo/auth_oidc.py | 1 - test/auth_aws/test_auth_oidc.py | 10 ---------- 2 files changed, 11 deletions(-) diff --git a/pymongo/auth_oidc.py b/pymongo/auth_oidc.py index 6e72a63da0..643a935452 100644 --- a/pymongo/auth_oidc.py +++ b/pymongo/auth_oidc.py @@ -134,7 +134,6 @@ def get_current_token(self, use_callbacks=True): refresh_token = refresh_token or "" context = dict( timeout_seconds=timeout, - principal_name=principal_name, version=CALLBACK_VERSION, refresh_token=refresh_token, ) diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py index 3255615e2b..c4816a4154 100644 --- a/test/auth_aws/test_auth_oidc.py +++ b/test/auth_aws/test_auth_oidc.py @@ -55,11 +55,6 @@ def create_request_cb(self, username="test_user1", expires_in_seconds=None, slee token_file = os.path.join(self.token_dir, username) def request_token(server_info, context): - # Validate the principal. - principal_name = context.get("principal_name") - if principal_name is not None: - self.assertIsInstance(principal_name, str) - # Validate the info. self.assertIn("issuer", server_info) self.assertIn("client_id", server_info) @@ -88,11 +83,6 @@ def refresh_token(server_info, context): with open(token_file) as fid: token = fid.read() - # Validate the principal. - principal_name = context.get("principal_name") - if principal_name is not None: - self.assertIsInstance(principal_name, str) - # Validate the info. self.assertIn("issuer", server_info) self.assertIn("client_id", server_info) From b98aa5ce3008ce2353455ac5d041f626d13aed4d Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Fri, 28 Apr 2023 09:40:49 -0500 Subject: [PATCH 114/133] update for spec changes --- pymongo/auth_oidc.py | 6 +++--- test/auth_aws/test_auth_oidc.py | 7 +++---- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/pymongo/auth_oidc.py b/pymongo/auth_oidc.py index 643a935452..eb10b37c29 100644 --- a/pymongo/auth_oidc.py +++ b/pymongo/auth_oidc.py @@ -22,7 +22,7 @@ import bson from bson.binary import Binary from bson.son import SON -from pymongo.errors import OperationFailure, PyMongoError +from pymongo.errors import ConfigurationError, OperationFailure from pymongo.helpers import _REAUTHENTICATION_REQUIRED_CODE @@ -67,10 +67,10 @@ def _get_authenticator(credentials, address): for patt in allowed_hosts: if patt == address[0]: found = True - elif patt.startswith("*") and address[0].endswith(patt[1:]): + elif patt.startswith("*.") and address[0].endswith(patt[1:]): found = True if not found: - raise PyMongoError( + raise ConfigurationError( f"Refusing to connect to {address[0]}, which is not in authOIDCAllowedHosts: {allowed_hosts}" ) diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py index c4816a4154..503691ee37 100644 --- a/test/auth_aws/test_auth_oidc.py +++ b/test/auth_aws/test_auth_oidc.py @@ -30,8 +30,7 @@ from pymongo import MongoClient from pymongo.auth import MongoCredential from pymongo.auth_oidc import _CACHE as _oidc_cache -from pymongo.auth_oidc import _get_authenticator, _OIDCProperties -from pymongo.errors import OperationFailure, PyMongoError +from pymongo.errors import ConfigurationError, OperationFailure class TestAuthOIDC(unittest.TestCase): @@ -156,7 +155,7 @@ def test_allowed_hosts_blocked(self): request_token = self.create_request_cb() props: Dict = dict(request_token_callback=request_token, allowed_hosts=[]) client = MongoClient(self.uri_single, authmechanismproperties=props) - with self.assertRaises(PyMongoError): + with self.assertRaises(ConfigurationError): client.test.test.find_one() client.close() @@ -164,7 +163,7 @@ def test_allowed_hosts_blocked(self): client = MongoClient( self.uri_single + "&ignored=example.com", authmechanismproperties=props, connect=False ) - with self.assertRaises(PyMongoError): + with self.assertRaises(ConfigurationError): client.test.test.find_one() client.close() From eb3ced00407592c71c00350ed5ec817af2b5d719 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Fri, 28 Apr 2023 16:09:25 -0500 Subject: [PATCH 115/133] use unittest asserts --- test/auth_aws/test_auth_oidc.py | 51 +++++++++++++++++---------------- 1 file changed, 27 insertions(+), 24 deletions(-) diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py index 503691ee37..ae9578bdc5 100644 --- a/test/auth_aws/test_auth_oidc.py +++ b/test/auth_aws/test_auth_oidc.py @@ -230,8 +230,8 @@ def run_test(): t1.join() t2.join() - assert self.request_called == 1 - assert self.refresh_called == 2 + self.assertEqual(self.request_called, 1) + self.assertEqual(self.refresh_called, 2) def test_request_callback_returns_null(self): def request_token_null(a, b): @@ -334,7 +334,7 @@ def test_cache_with_refresh(self): client.test.test.find_one() client.close() - assert len(_oidc_cache) == 1 + self.assertEqual(len(_oidc_cache), 1) # Create a new client with the same request callback and a refresh callback. # Ensure that a ``find`` operation results in a call to the refresh callback. @@ -342,8 +342,8 @@ def test_cache_with_refresh(self): client.test.test.find_one() client.close() - assert self.refresh_called == 1 - assert len(_oidc_cache) == 1 + self.assertEqual(self.refresh_called, 1) + self.assertEqual(len(_oidc_cache), 1) def test_cache_with_no_refresh(self): # Create a new client with a request callback callback. @@ -357,16 +357,16 @@ def test_cache_with_no_refresh(self): self.request_called = 0 client.test.test.find_one() client.close() - assert self.request_called == 1 - assert len(_oidc_cache) == 1 + self.assertEqual(self.request_called, 1) + self.assertEqual(len(_oidc_cache), 1) # Create a new client with the same request callback. # Ensure that a ``find`` operation results in a call to the request callback. client = MongoClient(self.uri_single, authMechanismProperties=props) client.test.test.find_one() client.close() - assert self.request_called == 2 - assert len(_oidc_cache) == 1 + self.assertEqual(self.request_called, 2) + self.assertEqual(len(_oidc_cache), 1) def test_cache_key_includes_callback(self): request_cb = self.create_request_cb() @@ -388,7 +388,7 @@ def request_token_2(a, b): # Ensure that a ``find`` operation adds a new entry to the cache. client.test.test.find_one() client.close() - assert len(_oidc_cache) == 2 + self.assertEqual(len(_oidc_cache), 2) def test_cache_clears_on_error(self): request_cb = self.create_request_cb() @@ -414,7 +414,7 @@ def refresh_cb(a, b): # Ensure that the cache has been cleared. authenticator = list(_oidc_cache.values())[0] - assert authenticator.idp_info is None + self.assertIsNone(authenticator.idp_info) def test_cache_is_not_used_in_aws_automatic_workflow(self): # Create a new client using the AWS device workflow. @@ -426,7 +426,7 @@ def test_cache_is_not_used_in_aws_automatic_workflow(self): # Ensure that the cache has been cleared. authenticator = list(_oidc_cache.values())[0] - assert authenticator.idp_info is None + self.assertIsNone(authenticator.idp_info) def test_speculative_auth_success(self): # Clear the cache @@ -515,12 +515,15 @@ def test_reauthenticate_succeeds(self): i.command_name for i in listener.failed_events if not i.command_name.startswith("sasl") ] - assert started_events == [ - "find", - "find", - ], started_events - assert succeeded_events == ["find"], succeeded_events - assert failed_events == ["find"], failed_events + self.assertEqual( + started_events, + [ + "find", + "find", + ], + ) + self.assertEqual(succeeded_events, ["find"]) + self.assertEqual(failed_events, ["find"]) # Assert that the refresh callback has been called. self.assertEqual(self.refresh_called, 1) @@ -599,8 +602,8 @@ def test_late_reauth_avoids_callback(self): client2 = MongoClient(self.uri_single, authMechanismProperties=props) client2.test.test.find_one() - assert self.refresh_called == 0 - assert self.request_called == 1 + self.assertEqual(self.refresh_called, 0) + self.assertEqual(self.request_called, 1) # Step 2: cause a find 391 on the first client with self.fail_point( @@ -612,8 +615,8 @@ def test_late_reauth_avoids_callback(self): # Perform a find operation that succeeds. client1.test.test.find_one() - assert self.refresh_called == 1 - assert self.request_called == 1 + self.assertEqual(self.refresh_called, 1) + self.assertEqual(self.request_called, 1) # Step 3: cause a find 391 on the second client with self.fail_point( @@ -625,8 +628,8 @@ def test_late_reauth_avoids_callback(self): # Perform a find operation that succeeds. client2.test.test.find_one() - assert self.refresh_called == 1 - assert self.request_called == 1 + self.assertEqual(self.refresh_called, 1) + self.assertEqual(self.request_called, 1) client1.close() client2.close() From c7d33b2057eb4b92e974a0ee35b6aee139bb4eea Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Tue, 2 May 2023 18:58:53 -0500 Subject: [PATCH 116/133] address review --- .evergreen/run-mongodb-oidc-test.sh | 3 +- doc/changelog.rst | 1 - doc/examples/authentication.rst | 108 ---------------------------- 3 files changed, 1 insertion(+), 111 deletions(-) diff --git a/.evergreen/run-mongodb-oidc-test.sh b/.evergreen/run-mongodb-oidc-test.sh index 97262ac76e..46bb779578 100755 --- a/.evergreen/run-mongodb-oidc-test.sh +++ b/.evergreen/run-mongodb-oidc-test.sh @@ -50,8 +50,7 @@ set -x # Workaround macOS python 3.9 incompatibility with system virtualenv. if [ "$(uname -s)" = "Darwin" ]; then - # TODO: change back to 3.9 before merging. - VIRTUALENV="/Library/Frameworks/Python.framework/Versions/3.10/bin/python3 -m virtualenv" + VIRTUALENV="/Library/Frameworks/Python.framework/Versions/3.9/bin/python3 -m virtualenv" else VIRTUALENV=$(command -v virtualenv) fi diff --git a/doc/changelog.rst b/doc/changelog.rst index 80f2d38a22..3c0419f401 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -4,7 +4,6 @@ Changelog Changes in Version 4.4 ----------------------- -- Added support for :ref:`OIDC ` authentication with MongoDB Enterprise 7.0+ in Public Preview. - Added support for passing a list containing (key, direction) pairs or keys to :meth:`~pymongo.collection.Collection.create_index`. - pymongocrypt 1.5.0 or later is now required for client side field level diff --git a/doc/examples/authentication.rst b/doc/examples/authentication.rst index 742ce4f6a7..5bd3282146 100644 --- a/doc/examples/authentication.rst +++ b/doc/examples/authentication.rst @@ -384,111 +384,3 @@ would be:: .. _Assume Role: https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRole.html .. _EC2 instance: https://docs.aws.amazon.com/IAM/latest/UserGuide/id_roles_use_switch-role-ec2.html .. _environment variables: https://docs.aws.amazon.com/lambda/latest/dg/configuration-envvars.html#configuration-envvars-runtime - - -.. _oidc_sasl: - -MONGODB-OIDC ------------- -.. versionadded:: 4.4 - -The MONGODB-OIDC authentication mechanism is available in MongoDB Enterprise 7.0+ as a Public Preview. - -Callback-based OIDC Support -~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -PyMongo supports user-provided callbacks for OIDC, which are are given to the -``MongoClient``. The ``request_token_callback`` is intended to accept -information about the Identity Provider, and return credentials that are used -to authenticate with the server, usually through a browser interaction with -the user. The callback must be of the form:: - - def request_callback(ProviderInfo, timeout_seconds) -> TokenResult: - ... - return dict(access_token=...) - -Where ``ProviderInfo`` is a dictionary of the following form:: - - issuer: - description: >- - URL which describes the Authorization Server. This identifier should be - the iss of provided access tokens, and be viable for RFC8414 - metadata discovery and RFC9207 identification. - type: string - client_id: - description: "Unique client ID for this OIDC client" - type: string - client_ecret: - description: "Secret used when communicating with IDP" - type: string - optional: true - request_scopes: - description: "Additional scopes to request from IDP" - type: array - optional: true - -And ``TokenResult`` is a dictionary of the following form:: - - access_token: - description: "The OIDC access token" - type: string - expires_in_seconds: - description: "The expiration time in seconds from the current time" - type: int - optional: true - refresh_token: - description: "The OIDC refresh token" - type: str - optional: true - -And ``timeout_seconds`` will always be 300 (5 minutes). An example -client would be:: - - >>> from pymongo import MongoClient - >>> uri = "mongodb://localhost/?authMechanism=MONGODB-OIDC") - >>> client = MongoClient(uri, request_token_callback=my_callback) - -If the identity provider supports refresh, a refresh callback can also -be provided. If a refresh callback is provided, it will be called -if ``expires_in_seconds`` was given in the request response and is -within 5 minutes, or the server raises a ``ReAuthenticationRequired`` -error during an operation. The refresh callback must of the form:: - - - def request_callback(ProviderInfo, TokenResult, timeout_seconds) -> TokenResult: - ... - return dict(access_token=...) - -Where ``ProviderInfo``, ``timeout_seconds`` and the return value are of -the same form as the request callback, and the ``TokenResult`` parameter -is the result of the request callback, which will contain the ``refresh_token`` -if it was provided. An example using both callbacks would be:: - - >>> from pymongo import MongoClient - >>> uri = "mongodb://localhost/?authMechanism=MONGODB-OIDC") - >>> client = MongoClient(uri, request_token_callback=my_request_callback, - ... refresh_token_callback=my_refresh_callback) - -Note: when multiple identity providers -are configured on the server, a ``username`` must be provided, which is the -Principal Name used on the provider. For example:: - - >>> from pymongo import MongoClient - >>> uri = "mongodb://my_username@localhost/?authMechanism=MONGODB-OIDC") - >>> client = MongoClient(uri, authMechanismProperties=request_token_callback=my_request_callback, - ... refresh_token_callback=my_refresh_callback) - -Allowed Hosts -~~~~~~~~~~~~~ - -By default, only Atlas and localhost MongoDB server hostnames are allowed -to use MONGODB-OIDC. If a different set of host(s) are desired, use -the ``allowed_hosts`` option: - - >>> from pymongo import MongoClient - >>> uri = "mongodb://my_username@localhost/?authMechanism=MONGODB-OIDC") - >>> my_host = "example.com" - >>> client = MongoClient(uri, authMechanismProperties=request_token_callback=my_request_callback, - ... refresh_token_callback=my_refresh_callback, allowed_hosts=[my_host]) - -The hostnames support Unix shell-style wildcards. From d23ef6bfd6b3ee82d5ae29c9238b544b0d3ceb4f Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Tue, 2 May 2023 19:04:15 -0500 Subject: [PATCH 117/133] address review --- pymongo/auth.py | 17 +++++++++++------ pymongo/auth_oidc.py | 15 +-------------- pymongo/pool.py | 6 ++++++ 3 files changed, 18 insertions(+), 20 deletions(-) diff --git a/pymongo/auth.py b/pymongo/auth.py index 1afe91077e..c1746d4a27 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -27,7 +27,7 @@ from bson.binary import Binary from bson.son import SON from pymongo.auth_aws import _authenticate_aws -from pymongo.auth_oidc import _authenticate_oidc, _OIDCContextMixin, _OIDCProperties +from pymongo.auth_oidc import _authenticate_oidc, _get_authenticator, _OIDCProperties from pymongo.errors import ConfigurationError, OperationFailure from pymongo.saslprep import saslprep @@ -566,8 +566,14 @@ def speculate_command(self): return cmd -class _OIDCContext(_OIDCContextMixin, _AuthContext): - pass +class _OIDCContext(_AuthContext): + def speculate_command(self): + authenticator = _get_authenticator(self.credentials, self.address) + cmd = authenticator.auth_start_cmd(False) + if cmd is None: + return + cmd["db"] = self.credentials.source + return cmd _SPECULATIVE_AUTH_MAP: Mapping[str, Callable] = { @@ -583,9 +589,8 @@ def authenticate(credentials, sock_info, reauthenticate=False): """Authenticate sock_info.""" mechanism = credentials.mechanism auth_func = _AUTH_MAP[mechanism] - if reauthenticate and sock_info.performed_handshake: - # Existing auth_ctx is stale, remove it. - sock_info.auth_ctx = None + if reauthenticate: + sock_info.handle_reauthenticate() if mechanism == "MONGODB-OIDC": _authenticate_oidc(credentials, sock_info, reauthenticate) else: diff --git a/pymongo/auth_oidc.py b/pymongo/auth_oidc.py index eb10b37c29..ca45a5ed0a 100644 --- a/pymongo/auth_oidc.py +++ b/pymongo/auth_oidc.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -""""MONGODB-OIDC Authentication helpers.""" +"""MONGODB-OIDC Authentication helpers.""" import os import threading from dataclasses import dataclass, field @@ -305,19 +305,6 @@ def authenticate(self, sock_info): return resp -class _OIDCContextMixin: - credentials: Any - address: Any - - def speculate_command(self): - authenticator = _get_authenticator(self.credentials, self.address) - cmd = authenticator.auth_start_cmd(False) - if cmd is None: - return - cmd["db"] = self.credentials.source - return cmd - - def _authenticate_oidc(credentials, sock_info, reauthenticate): """Authenticate using MONGODB-OIDC.""" authenticator = _get_authenticator(credentials, sock_info.address) diff --git a/pymongo/pool.py b/pymongo/pool.py index d9aa6e7a38..984f8923ed 100644 --- a/pymongo/pool.py +++ b/pymongo/pool.py @@ -927,6 +927,12 @@ def idle_time_seconds(self): """Seconds since this socket was last checked into its pool.""" return time.monotonic() - self.last_checkin_time + def handle_reauthenticate(self): + """Handle a reauthentication.""" + if self.performed_handshake: + # Existing auth_ctx is stale, remove it. + self.auth_ctx = None + def _raise_connection_failure(self, error): # Catch *all* exceptions from socket methods and close the socket. In # regular Python, socket operations only raise socket.error, even if From b431cdc5a91def0f7d64b4a165f632acdffff331 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Tue, 2 May 2023 20:08:52 -0500 Subject: [PATCH 118/133] address review --- pymongo/common.py | 6 ------ test/utils.py | 7 ++++++- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/pymongo/common.py b/pymongo/common.py index 05adf69981..5f60891990 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -171,12 +171,6 @@ def raise_config_error(key: str, dummy: Any) -> NoReturn: } -def camel_to_snake(camel): - # Regex to convert CamelCase to snake_case. - snake = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel) - return re.sub("([a-z0-9])([A-Z])", r"\1_\2", snake).lower() - - def validate_boolean(option: str, value: Any) -> bool: """Validates that 'value' is True or False.""" if isinstance(value, bool): diff --git a/test/utils.py b/test/utils.py index 88130e1d6d..9eb3cce9b8 100644 --- a/test/utils.py +++ b/test/utils.py @@ -36,7 +36,6 @@ from bson.son import SON from pymongo import MongoClient, monitoring, operations, read_preferences from pymongo.collection import ReturnDocument -from pymongo.common import camel_to_snake # noqa from pymongo.cursor import CursorType from pymongo.errors import ConfigurationError, OperationFailure from pymongo.hello import HelloCompat @@ -672,6 +671,12 @@ def camel_to_upper_camel(camel): return camel[0].upper() + camel[1:] +def camel_to_snake(camel): + # Regex to convert CamelCase to snake_case. + snake = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel) + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", snake).lower() + + def camel_to_snake_args(arguments): for arg_name in list(arguments): c2s = camel_to_snake(arg_name) From a089dc0f5530896f2152a39257ca33ef62eaf4fd Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 3 May 2023 10:19:25 -0500 Subject: [PATCH 119/133] address review --- pymongo/auth_oidc.py | 10 +--------- test/auth_aws/test_auth_oidc.py | 4 ++-- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/pymongo/auth_oidc.py b/pymongo/auth_oidc.py index ca45a5ed0a..e5bf9a7df1 100644 --- a/pymongo/auth_oidc.py +++ b/pymongo/auth_oidc.py @@ -274,15 +274,7 @@ def authenticate(self, sock_info): sock_info.oidc_token_gen_id = self.token_gen_id return - # Convert the server response to be more pythonic. - # Avoid circular import - from pymongo.common import camel_to_snake - - orig_server_resp: Dict = bson.decode(resp["payload"]) - server_resp = dict() - for key, value in orig_server_resp.items(): - server_resp[camel_to_snake(key)] = value - + server_resp: Dict = bson.decode(resp["payload"]) if "issuer" in server_resp: self.idp_info = server_resp self.idp_info_gen_id += 1 diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py index ae9578bdc5..85b1c7315d 100644 --- a/test/auth_aws/test_auth_oidc.py +++ b/test/auth_aws/test_auth_oidc.py @@ -56,7 +56,7 @@ def create_request_cb(self, username="test_user1", expires_in_seconds=None, slee def request_token(server_info, context): # Validate the info. self.assertIn("issuer", server_info) - self.assertIn("client_id", server_info) + self.assertIn("clientId", server_info) # Validate the timeout. timeout_seconds = context["timeout_seconds"] @@ -84,7 +84,7 @@ def refresh_token(server_info, context): # Validate the info. self.assertIn("issuer", server_info) - self.assertIn("client_id", server_info) + self.assertIn("clientId", server_info) # Validate the creds self.assertIsNotNone(context["refresh_token"]) From 7c56888e1adca4244ae475826d23c5dddea91adb Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 3 May 2023 15:15:11 -0500 Subject: [PATCH 120/133] try more bytes --- pymongo/compression_support.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymongo/compression_support.py b/pymongo/compression_support.py index c9632a43d3..f21d717771 100644 --- a/pymongo/compression_support.py +++ b/pymongo/compression_support.py @@ -109,7 +109,7 @@ class SnappyContext(object): @staticmethod def compress(data): - return snappy.compress(data) + return bytes(snappy.compress(data)) class ZlibContext(object): @@ -138,7 +138,7 @@ def decompress(data, compressor_id): # https://github.com/andrix/python-snappy/issues/65 # This only matters when data is a memoryview since # id(bytes(data)) == id(data) when data is a bytes. - return snappy.uncompress(bytes(data)) + return bytes(snappy.uncompress(bytes(data))) elif compressor_id == ZlibContext.compressor_id: return zlib.decompress(data) elif compressor_id == ZstdContext.compressor_id: From 7bb678f2eab12f9a5bf1d6658142da9f20a2d7ab Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 3 May 2023 15:17:26 -0500 Subject: [PATCH 121/133] Revert "try more bytes" This reverts commit 7c56888e1adca4244ae475826d23c5dddea91adb. --- pymongo/compression_support.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymongo/compression_support.py b/pymongo/compression_support.py index f21d717771..c9632a43d3 100644 --- a/pymongo/compression_support.py +++ b/pymongo/compression_support.py @@ -109,7 +109,7 @@ class SnappyContext(object): @staticmethod def compress(data): - return bytes(snappy.compress(data)) + return snappy.compress(data) class ZlibContext(object): @@ -138,7 +138,7 @@ def decompress(data, compressor_id): # https://github.com/andrix/python-snappy/issues/65 # This only matters when data is a memoryview since # id(bytes(data)) == id(data) when data is a bytes. - return bytes(snappy.uncompress(bytes(data))) + return snappy.uncompress(bytes(data)) elif compressor_id == ZlibContext.compressor_id: return zlib.decompress(data) elif compressor_id == ZstdContext.compressor_id: From bd53fdc59636449189ca4844df42a35bc5b1e3fb Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 3 May 2023 18:14:31 -0500 Subject: [PATCH 122/133] wip better handling of reauth --- pymongo/auth.py | 2 - pymongo/bulk.py | 1 - pymongo/message.py | 9 ++++- pymongo/mongo_client.py | 18 +-------- pymongo/pool.py | 71 ++++++++++++++++++--------------- pymongo/server.py | 13 +++++- test/auth_aws/test_auth_oidc.py | 59 ++++++++++++++++++++++++++- 7 files changed, 119 insertions(+), 54 deletions(-) diff --git a/pymongo/auth.py b/pymongo/auth.py index c1746d4a27..cf32262d70 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -589,8 +589,6 @@ def authenticate(credentials, sock_info, reauthenticate=False): """Authenticate sock_info.""" mechanism = credentials.mechanism auth_func = _AUTH_MAP[mechanism] - if reauthenticate: - sock_info.handle_reauthenticate() if mechanism == "MONGODB-OIDC": _authenticate_oidc(credentials, sock_info, reauthenticate) else: diff --git a/pymongo/bulk.py b/pymongo/bulk.py index b21b576aa5..715ef7724d 100644 --- a/pymongo/bulk.py +++ b/pymongo/bulk.py @@ -305,7 +305,6 @@ def _execute_command( run.op_type, self.collection.codec_options, ) - while run.idx_offset < len(run.ops): # If this is the last possible operation, use the # final write concern. diff --git a/pymongo/message.py b/pymongo/message.py index 9fa64a875a..b4755f9a40 100644 --- a/pymongo/message.py +++ b/pymongo/message.py @@ -54,6 +54,7 @@ ProtocolError, ) from pymongo.hello import HelloCompat +from pymongo.helpers import _REAUTHENTICATION_REQUIRED_CODE from pymongo.read_preferences import ReadPreference from pymongo.write_concern import WriteConcern @@ -839,7 +840,13 @@ def _batch_command(self, cmd, docs): def execute(self, cmd, docs, client): request_id, msg, to_send = self._batch_command(cmd, docs) - result = self.write_command(cmd, request_id, msg, to_send) + try: + result = self.write_command(cmd, request_id, msg, to_send) + except OperationFailure as exc: + if exc.code == _REAUTHENTICATION_REQUIRED_CODE: + self.sock_info.authenticate(True) + result = self.write_command(cmd, request_id, msg, to_send) + raise client._process_response(result, self.session) return result, to_send diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index 6d9ef27583..ca60affdf5 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -1397,14 +1397,7 @@ def is_retrying(): assert last_error is not None raise last_error retryable = False - # Handle re-authentication. - try: - return func(session, sock_info, retryable) - except OperationFailure as exc: - if exc.code == helpers._REAUTHENTICATION_REQUIRED_CODE: - sock_info.authenticate(reauthenticate=True) - return func(session, sock_info, retryable) - raise + return func(session, sock_info, retryable) except ServerSelectionTimeoutError: if is_retrying(): # The application may think the write was never attempted @@ -1468,14 +1461,7 @@ def _retryable_read(self, func, read_pref, session, address=None, retryable=True # not support retryable reads, raise the last error. assert last_error is not None raise last_error - # Handle re-authentication. - try: - return func(session, server, sock_info, read_pref) - except OperationFailure as exc: - if exc.code == helpers._REAUTHENTICATION_REQUIRED_CODE: - sock_info.authenticate(reauthenticate=True) - return func(session, server, sock_info, read_pref) - raise + return func(session, server, sock_info, read_pref) except ServerSelectionTimeoutError: if retrying: # The application may think the write was never attempted diff --git a/pymongo/pool.py b/pymongo/pool.py index 984f8923ed..2a5e0063d8 100644 --- a/pymongo/pool.py +++ b/pymongo/pool.py @@ -763,32 +763,40 @@ def command( unacknowledged = write_concern and not write_concern.acknowledged if self.op_msg_enabled: self._raise_if_not_writable(unacknowledged) + args = ( + self, + dbname, + spec, + self.is_mongos, + read_preference, + codec_options, + session, + client, + check, + allowable_errors, + self.address, + listeners, + self.max_bson_size, + read_concern, + ) + kwargs = dict( + parse_write_concern_error=parse_write_concern_error, + collation=collation, + compression_ctx=self.compression_context, + use_op_msg=self.op_msg_enabled, + unacknowledged=unacknowledged, + user_fields=user_fields, + exhaust_allowed=exhaust_allowed, + write_concern=write_concern, + ) try: - return command( - self, - dbname, - spec, - self.is_mongos, - read_preference, - codec_options, - session, - client, - check, - allowable_errors, - self.address, - listeners, - self.max_bson_size, - read_concern, - parse_write_concern_error=parse_write_concern_error, - collation=collation, - compression_ctx=self.compression_context, - use_op_msg=self.op_msg_enabled, - unacknowledged=unacknowledged, - user_fields=user_fields, - exhaust_allowed=exhaust_allowed, - write_concern=write_concern, - ) - except (OperationFailure, NotPrimaryError): + return command(*args, **kwargs) + except OperationFailure as exc: + if exc.code == helpers._REAUTHENTICATION_REQUIRED_CODE: + self.authenticate(True) + return command(*args, **kwargs) + raise + except NotPrimaryError: raise # Catch socket.error, KeyboardInterrupt, etc. and close ourselves. except BaseException as error: @@ -864,7 +872,12 @@ def authenticate(self, reauthenticate=False): """ # CMAP spec says to publish the ready event only after authenticating # the connection. - if not self.ready or reauthenticate: + if reauthenticate: + self.ready = False + if self.performed_handshake: + # Existing auth_ctx is stale, remove it. + self.auth_ctx = None + if not self.ready: creds = self.opts._credentials if creds: auth.authenticate(creds, self, reauthenticate=reauthenticate) @@ -927,12 +940,6 @@ def idle_time_seconds(self): """Seconds since this socket was last checked into its pool.""" return time.monotonic() - self.last_checkin_time - def handle_reauthenticate(self): - """Handle a reauthentication.""" - if self.performed_handshake: - # Existing auth_ctx is stale, remove it. - self.auth_ctx = None - def _raise_connection_failure(self, error): # Catch *all* exceptions from socket methods and close the socket. In # regular Python, socket operations only raise socket.error, even if diff --git a/pymongo/server.py b/pymongo/server.py index f26f473c32..d3672c8c35 100644 --- a/pymongo/server.py +++ b/pymongo/server.py @@ -18,7 +18,7 @@ from bson import _decode_all_selective from pymongo.errors import NotPrimaryError, OperationFailure -from pymongo.helpers import _check_command_response +from pymongo.helpers import _REAUTHENTICATION_REQUIRED_CODE, _check_command_response from pymongo.message import _convert_exception, _OpMsg from pymongo.response import PinnedResponse, Response @@ -87,6 +87,17 @@ def run_operation(self, sock_info, operation, read_preference, listeners, unpack - `listeners`: Instance of _EventListeners or None. - `unpack_res`: A callable that decodes the wire protocol response. """ + try: + return self._run_operation(sock_info, operation, read_preference, listeners, unpack_res) + except OperationFailure as exc: + if exc.code == _REAUTHENTICATION_REQUIRED_CODE: + sock_info.authenticate(True) + return self._run_operation( + sock_info, operation, read_preference, listeners, unpack_res + ) + raise + + def _run_operation(self, sock_info, operation, read_preference, listeners, unpack_res): duration = None publish = listeners.enabled_for_commands if publish: diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py index 85b1c7315d..a9eb04bcc4 100644 --- a/test/auth_aws/test_auth_oidc.py +++ b/test/auth_aws/test_auth_oidc.py @@ -31,6 +31,7 @@ from pymongo.auth import MongoCredential from pymongo.auth_oidc import _CACHE as _oidc_cache from pymongo.errors import ConfigurationError, OperationFailure +from pymongo.operations import InsertOne class TestAuthOIDC(unittest.TestCase): @@ -529,7 +530,63 @@ def test_reauthenticate_succeeds(self): self.assertEqual(self.refresh_called, 1) client.close() - def test_reauthenticate_retries_and_succees_with_cache(self): + def test_reauthenticate_succeeds_bulk_write(self): + request_cb = self.create_request_cb() + refresh_cb = self.create_refresh_cb() + + # Create a client with the callbacks. + props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb) + client = MongoClient(self.uri_single, authmechanismproperties=props) + + # Perform a find operation. + client.test.test.find_one() + + # Assert that the refresh callback has not been called. + self.assertEqual(self.refresh_called, 0) + + with self.fail_point( + { + "mode": {"times": 2}, + "data": {"failCommands": ["insert", "saslStart"], "errorCode": 391}, + } + ): + # Perform a bulk write operation. + client.test.test.bulk_write([InsertOne({})]) + + # Assert that the refresh callback has been called. + self.assertEqual(self.refresh_called, 1) + client.close() + + def test_reauthenticate_succeeds_cursor(self): + request_cb = self.create_request_cb() + refresh_cb = self.create_refresh_cb() + + # Create a client with the callbacks. + props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb) + client = MongoClient(self.uri_single, authmechanismproperties=props) + + # Perform an insert operation. + client.test.test.insert_one({"a": 1}) + + # Assert that the refresh callback has not been called. + self.assertEqual(self.refresh_called, 0) + + with self.fail_point( + { + "mode": {"times": 2}, + "data": {"failCommands": ["find", "saslStart"], "errorCode": 391}, + } + ): + # Perform a find operation. + cursor = client.test.test.find({"a": 1}) + + self.assertGreaterEqual(len(list(cursor)), 1) + + # Assert that the refresh callback has been called. + self.assertEqual(self.refresh_called, 1) + client.close() + + def test_reauthenticate_retries_and_succeeds_with_cache(self): listener = EventListener() # Create request and refresh callbacks that return valid credentials From b13427e161710159ea5e7cbf7aaf148246e7da04 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 4 May 2023 09:38:01 -0500 Subject: [PATCH 123/133] Revert "wip better handling of reauth" This reverts commit bd53fdc59636449189ca4844df42a35bc5b1e3fb. --- pymongo/auth.py | 2 + pymongo/bulk.py | 1 + pymongo/message.py | 9 +---- pymongo/mongo_client.py | 18 ++++++++- pymongo/pool.py | 71 +++++++++++++++------------------ pymongo/server.py | 13 +----- test/auth_aws/test_auth_oidc.py | 59 +-------------------------- 7 files changed, 54 insertions(+), 119 deletions(-) diff --git a/pymongo/auth.py b/pymongo/auth.py index cf32262d70..c1746d4a27 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -589,6 +589,8 @@ def authenticate(credentials, sock_info, reauthenticate=False): """Authenticate sock_info.""" mechanism = credentials.mechanism auth_func = _AUTH_MAP[mechanism] + if reauthenticate: + sock_info.handle_reauthenticate() if mechanism == "MONGODB-OIDC": _authenticate_oidc(credentials, sock_info, reauthenticate) else: diff --git a/pymongo/bulk.py b/pymongo/bulk.py index 715ef7724d..b21b576aa5 100644 --- a/pymongo/bulk.py +++ b/pymongo/bulk.py @@ -305,6 +305,7 @@ def _execute_command( run.op_type, self.collection.codec_options, ) + while run.idx_offset < len(run.ops): # If this is the last possible operation, use the # final write concern. diff --git a/pymongo/message.py b/pymongo/message.py index b4755f9a40..9fa64a875a 100644 --- a/pymongo/message.py +++ b/pymongo/message.py @@ -54,7 +54,6 @@ ProtocolError, ) from pymongo.hello import HelloCompat -from pymongo.helpers import _REAUTHENTICATION_REQUIRED_CODE from pymongo.read_preferences import ReadPreference from pymongo.write_concern import WriteConcern @@ -840,13 +839,7 @@ def _batch_command(self, cmd, docs): def execute(self, cmd, docs, client): request_id, msg, to_send = self._batch_command(cmd, docs) - try: - result = self.write_command(cmd, request_id, msg, to_send) - except OperationFailure as exc: - if exc.code == _REAUTHENTICATION_REQUIRED_CODE: - self.sock_info.authenticate(True) - result = self.write_command(cmd, request_id, msg, to_send) - raise + result = self.write_command(cmd, request_id, msg, to_send) client._process_response(result, self.session) return result, to_send diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index ca60affdf5..6d9ef27583 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -1397,7 +1397,14 @@ def is_retrying(): assert last_error is not None raise last_error retryable = False - return func(session, sock_info, retryable) + # Handle re-authentication. + try: + return func(session, sock_info, retryable) + except OperationFailure as exc: + if exc.code == helpers._REAUTHENTICATION_REQUIRED_CODE: + sock_info.authenticate(reauthenticate=True) + return func(session, sock_info, retryable) + raise except ServerSelectionTimeoutError: if is_retrying(): # The application may think the write was never attempted @@ -1461,7 +1468,14 @@ def _retryable_read(self, func, read_pref, session, address=None, retryable=True # not support retryable reads, raise the last error. assert last_error is not None raise last_error - return func(session, server, sock_info, read_pref) + # Handle re-authentication. + try: + return func(session, server, sock_info, read_pref) + except OperationFailure as exc: + if exc.code == helpers._REAUTHENTICATION_REQUIRED_CODE: + sock_info.authenticate(reauthenticate=True) + return func(session, server, sock_info, read_pref) + raise except ServerSelectionTimeoutError: if retrying: # The application may think the write was never attempted diff --git a/pymongo/pool.py b/pymongo/pool.py index 2a5e0063d8..984f8923ed 100644 --- a/pymongo/pool.py +++ b/pymongo/pool.py @@ -763,40 +763,32 @@ def command( unacknowledged = write_concern and not write_concern.acknowledged if self.op_msg_enabled: self._raise_if_not_writable(unacknowledged) - args = ( - self, - dbname, - spec, - self.is_mongos, - read_preference, - codec_options, - session, - client, - check, - allowable_errors, - self.address, - listeners, - self.max_bson_size, - read_concern, - ) - kwargs = dict( - parse_write_concern_error=parse_write_concern_error, - collation=collation, - compression_ctx=self.compression_context, - use_op_msg=self.op_msg_enabled, - unacknowledged=unacknowledged, - user_fields=user_fields, - exhaust_allowed=exhaust_allowed, - write_concern=write_concern, - ) try: - return command(*args, **kwargs) - except OperationFailure as exc: - if exc.code == helpers._REAUTHENTICATION_REQUIRED_CODE: - self.authenticate(True) - return command(*args, **kwargs) - raise - except NotPrimaryError: + return command( + self, + dbname, + spec, + self.is_mongos, + read_preference, + codec_options, + session, + client, + check, + allowable_errors, + self.address, + listeners, + self.max_bson_size, + read_concern, + parse_write_concern_error=parse_write_concern_error, + collation=collation, + compression_ctx=self.compression_context, + use_op_msg=self.op_msg_enabled, + unacknowledged=unacknowledged, + user_fields=user_fields, + exhaust_allowed=exhaust_allowed, + write_concern=write_concern, + ) + except (OperationFailure, NotPrimaryError): raise # Catch socket.error, KeyboardInterrupt, etc. and close ourselves. except BaseException as error: @@ -872,12 +864,7 @@ def authenticate(self, reauthenticate=False): """ # CMAP spec says to publish the ready event only after authenticating # the connection. - if reauthenticate: - self.ready = False - if self.performed_handshake: - # Existing auth_ctx is stale, remove it. - self.auth_ctx = None - if not self.ready: + if not self.ready or reauthenticate: creds = self.opts._credentials if creds: auth.authenticate(creds, self, reauthenticate=reauthenticate) @@ -940,6 +927,12 @@ def idle_time_seconds(self): """Seconds since this socket was last checked into its pool.""" return time.monotonic() - self.last_checkin_time + def handle_reauthenticate(self): + """Handle a reauthentication.""" + if self.performed_handshake: + # Existing auth_ctx is stale, remove it. + self.auth_ctx = None + def _raise_connection_failure(self, error): # Catch *all* exceptions from socket methods and close the socket. In # regular Python, socket operations only raise socket.error, even if diff --git a/pymongo/server.py b/pymongo/server.py index d3672c8c35..f26f473c32 100644 --- a/pymongo/server.py +++ b/pymongo/server.py @@ -18,7 +18,7 @@ from bson import _decode_all_selective from pymongo.errors import NotPrimaryError, OperationFailure -from pymongo.helpers import _REAUTHENTICATION_REQUIRED_CODE, _check_command_response +from pymongo.helpers import _check_command_response from pymongo.message import _convert_exception, _OpMsg from pymongo.response import PinnedResponse, Response @@ -87,17 +87,6 @@ def run_operation(self, sock_info, operation, read_preference, listeners, unpack - `listeners`: Instance of _EventListeners or None. - `unpack_res`: A callable that decodes the wire protocol response. """ - try: - return self._run_operation(sock_info, operation, read_preference, listeners, unpack_res) - except OperationFailure as exc: - if exc.code == _REAUTHENTICATION_REQUIRED_CODE: - sock_info.authenticate(True) - return self._run_operation( - sock_info, operation, read_preference, listeners, unpack_res - ) - raise - - def _run_operation(self, sock_info, operation, read_preference, listeners, unpack_res): duration = None publish = listeners.enabled_for_commands if publish: diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py index a9eb04bcc4..85b1c7315d 100644 --- a/test/auth_aws/test_auth_oidc.py +++ b/test/auth_aws/test_auth_oidc.py @@ -31,7 +31,6 @@ from pymongo.auth import MongoCredential from pymongo.auth_oidc import _CACHE as _oidc_cache from pymongo.errors import ConfigurationError, OperationFailure -from pymongo.operations import InsertOne class TestAuthOIDC(unittest.TestCase): @@ -530,63 +529,7 @@ def test_reauthenticate_succeeds(self): self.assertEqual(self.refresh_called, 1) client.close() - def test_reauthenticate_succeeds_bulk_write(self): - request_cb = self.create_request_cb() - refresh_cb = self.create_refresh_cb() - - # Create a client with the callbacks. - props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb) - client = MongoClient(self.uri_single, authmechanismproperties=props) - - # Perform a find operation. - client.test.test.find_one() - - # Assert that the refresh callback has not been called. - self.assertEqual(self.refresh_called, 0) - - with self.fail_point( - { - "mode": {"times": 2}, - "data": {"failCommands": ["insert", "saslStart"], "errorCode": 391}, - } - ): - # Perform a bulk write operation. - client.test.test.bulk_write([InsertOne({})]) - - # Assert that the refresh callback has been called. - self.assertEqual(self.refresh_called, 1) - client.close() - - def test_reauthenticate_succeeds_cursor(self): - request_cb = self.create_request_cb() - refresh_cb = self.create_refresh_cb() - - # Create a client with the callbacks. - props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb) - client = MongoClient(self.uri_single, authmechanismproperties=props) - - # Perform an insert operation. - client.test.test.insert_one({"a": 1}) - - # Assert that the refresh callback has not been called. - self.assertEqual(self.refresh_called, 0) - - with self.fail_point( - { - "mode": {"times": 2}, - "data": {"failCommands": ["find", "saslStart"], "errorCode": 391}, - } - ): - # Perform a find operation. - cursor = client.test.test.find({"a": 1}) - - self.assertGreaterEqual(len(list(cursor)), 1) - - # Assert that the refresh callback has been called. - self.assertEqual(self.refresh_called, 1) - client.close() - - def test_reauthenticate_retries_and_succeeds_with_cache(self): + def test_reauthenticate_retries_and_succees_with_cache(self): listener = EventListener() # Create request and refresh callbacks that return valid credentials From 8d24dd99c3adfb045527b9b7d8becd87b13bdefa Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 4 May 2023 19:50:34 -0500 Subject: [PATCH 124/133] clean up reauth handling --- pymongo/auth.py | 2 - pymongo/auth_oidc.py | 30 +++---- pymongo/helpers.py | 32 +++++++ pymongo/message.py | 2 + pymongo/mongo_client.py | 18 +--- pymongo/pool.py | 17 ++-- pymongo/server.py | 3 +- test/auth_aws/test_auth_oidc.py | 154 +++++++++++++++++++++++++++++++- 8 files changed, 209 insertions(+), 49 deletions(-) diff --git a/pymongo/auth.py b/pymongo/auth.py index c1746d4a27..cf32262d70 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -589,8 +589,6 @@ def authenticate(credentials, sock_info, reauthenticate=False): """Authenticate sock_info.""" mechanism = credentials.mechanism auth_func = _AUTH_MAP[mechanism] - if reauthenticate: - sock_info.handle_reauthenticate() if mechanism == "MONGODB-OIDC": _authenticate_oidc(credentials, sock_info, reauthenticate) else: diff --git a/pymongo/auth_oidc.py b/pymongo/auth_oidc.py index e5bf9a7df1..64b77435d0 100644 --- a/pymongo/auth_oidc.py +++ b/pymongo/auth_oidc.py @@ -100,7 +100,6 @@ class _OIDCAuthenticator: def get_current_token(self, use_callbacks=True): properties = self.properties - principal_name = self.username request_cb = properties.request_token_callback refresh_cb = properties.refresh_token_callback @@ -238,29 +237,26 @@ def clear(self): def run_command(self, sock_info, cmd): try: - return sock_info.command("$external", cmd) + return sock_info.command("$external", cmd, no_reauth=True) except OperationFailure as exc: self.clear() if exc.code == _REAUTHENTICATION_REQUIRED_CODE: if "jwt" in bson.decode(cmd["payload"]): if self.idp_info_gen_id > self.reauth_gen_id: raise - self.handle_reauth(sock_info) - return self.authenticate(sock_info) + return self.authenticate(sock_info, reauthenticate=True) raise - def handle_reauth(self, sock_info): - prev_id = getattr(sock_info, "oidc_token_gen_id", None) - if prev_id != self.token_gen_id: - # No need to preemptively clear, we've already changed tokens. - return - - self.reauth_gen_id = self.idp_info_gen_id - self.token_exp_utc = None - if not self.properties.refresh_token_callback: - self.clear() + def authenticate(self, sock_info, reauthenticate=False): + if reauthenticate: + prev_id = getattr(sock_info, "oidc_token_gen_id", None) + # Check if we've already changed tokens. + if prev_id == self.token_gen_id: + self.reauth_gen_id = self.idp_info_gen_id + self.token_exp_utc = None + if not self.properties.refresh_token_callback: + self.clear() - def authenticate(self, sock_info): ctx = sock_info.auth_ctx cmd = None @@ -300,6 +296,4 @@ def authenticate(self, sock_info): def _authenticate_oidc(credentials, sock_info, reauthenticate): """Authenticate using MONGODB-OIDC.""" authenticator = _get_authenticator(credentials, sock_info.address) - if reauthenticate: - authenticator.handle_reauth(sock_info) - return authenticator.authenticate(sock_info) + return authenticator.authenticate(sock_info, reauthenticate=reauthenticate) diff --git a/pymongo/helpers.py b/pymongo/helpers.py index 28978dc036..1a753c66f4 100644 --- a/pymongo/helpers.py +++ b/pymongo/helpers.py @@ -270,3 +270,35 @@ def _handle_exception(): pass finally: del einfo + + +def _handle_reauth(func): + def inner(*args, **kwargs): + no_reauth = kwargs.pop("no_reauth", False) + from pymongo.pool import SocketInfo + + try: + return func(*args, **kwargs) + except OperationFailure as exc: + if no_reauth: + raise + if exc.code == _REAUTHENTICATION_REQUIRED_CODE: + # Look for an argument that either is a SocketInfo + # or has a socket_info attribute, so we can trigger + # a reauth. + sock_info = None + for arg in args: + if isinstance(arg, SocketInfo): + sock_info = arg + break + if hasattr(arg, "sock_info"): + sock_info = arg.sock_info + break + if sock_info: + sock_info.authenticate(reauthenticate=True) + else: + raise + return func(*args, **kwargs) + raise + + return inner diff --git a/pymongo/message.py b/pymongo/message.py index 9fa64a875a..127ea8c012 100644 --- a/pymongo/message.py +++ b/pymongo/message.py @@ -54,6 +54,7 @@ ProtocolError, ) from pymongo.hello import HelloCompat +from pymongo.helpers import _handle_reauth from pymongo.read_preferences import ReadPreference from pymongo.write_concern import WriteConcern @@ -909,6 +910,7 @@ def unack_write(self, cmd, request_id, msg, max_doc_size, docs): self.start_time = datetime.datetime.now() return result + @_handle_reauth def write_command(self, cmd, request_id, msg, docs): """A proxy for SocketInfo.write_command that handles event publishing.""" if self.publish: diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index 6d9ef27583..ca60affdf5 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -1397,14 +1397,7 @@ def is_retrying(): assert last_error is not None raise last_error retryable = False - # Handle re-authentication. - try: - return func(session, sock_info, retryable) - except OperationFailure as exc: - if exc.code == helpers._REAUTHENTICATION_REQUIRED_CODE: - sock_info.authenticate(reauthenticate=True) - return func(session, sock_info, retryable) - raise + return func(session, sock_info, retryable) except ServerSelectionTimeoutError: if is_retrying(): # The application may think the write was never attempted @@ -1468,14 +1461,7 @@ def _retryable_read(self, func, read_pref, session, address=None, retryable=True # not support retryable reads, raise the last error. assert last_error is not None raise last_error - # Handle re-authentication. - try: - return func(session, server, sock_info, read_pref) - except OperationFailure as exc: - if exc.code == helpers._REAUTHENTICATION_REQUIRED_CODE: - sock_info.authenticate(reauthenticate=True) - return func(session, server, sock_info, read_pref) - raise + return func(session, server, sock_info, read_pref) except ServerSelectionTimeoutError: if retrying: # The application may think the write was never attempted diff --git a/pymongo/pool.py b/pymongo/pool.py index 984f8923ed..095d1692fe 100644 --- a/pymongo/pool.py +++ b/pymongo/pool.py @@ -56,6 +56,7 @@ _CertificateError, ) from pymongo.hello import Hello, HelloCompat +from pymongo.helpers import _handle_reauth from pymongo.lock import _create_lock from pymongo.monitoring import ConnectionCheckOutFailedReason, ConnectionClosedReason from pymongo.network import command, receive_message @@ -704,6 +705,7 @@ def _next_reply(self): helpers._check_command_response(response_doc, self.max_wire_version) return response_doc + @_handle_reauth def command( self, dbname, @@ -788,7 +790,7 @@ def command( exhaust_allowed=exhaust_allowed, write_concern=write_concern, ) - except (OperationFailure, NotPrimaryError): + except (OperationFailure, NotPrimaryError) as exc: raise # Catch socket.error, KeyboardInterrupt, etc. and close ourselves. except BaseException as error: @@ -864,7 +866,12 @@ def authenticate(self, reauthenticate=False): """ # CMAP spec says to publish the ready event only after authenticating # the connection. - if not self.ready or reauthenticate: + if reauthenticate: + if self.performed_handshake: + # Existing auth_ctx is stale, remove it. + self.auth_ctx = None + self.ready = False + if not self.ready: creds = self.opts._credentials if creds: auth.authenticate(creds, self, reauthenticate=reauthenticate) @@ -927,12 +934,6 @@ def idle_time_seconds(self): """Seconds since this socket was last checked into its pool.""" return time.monotonic() - self.last_checkin_time - def handle_reauthenticate(self): - """Handle a reauthentication.""" - if self.performed_handshake: - # Existing auth_ctx is stale, remove it. - self.auth_ctx = None - def _raise_connection_failure(self, error): # Catch *all* exceptions from socket methods and close the socket. In # regular Python, socket operations only raise socket.error, even if diff --git a/pymongo/server.py b/pymongo/server.py index f26f473c32..16c905abb7 100644 --- a/pymongo/server.py +++ b/pymongo/server.py @@ -18,7 +18,7 @@ from bson import _decode_all_selective from pymongo.errors import NotPrimaryError, OperationFailure -from pymongo.helpers import _check_command_response +from pymongo.helpers import _check_command_response, _handle_reauth from pymongo.message import _convert_exception, _OpMsg from pymongo.response import PinnedResponse, Response @@ -73,6 +73,7 @@ def request_check(self): """Check the server's state soon.""" self._monitor.request_check() + @_handle_reauth def run_operation(self, sock_info, operation, read_preference, listeners, unpack_res): """Run a _Query or _GetMore operation and return a Response object. diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py index 85b1c7315d..d9ff3b6b23 100644 --- a/test/auth_aws/test_auth_oidc.py +++ b/test/auth_aws/test_auth_oidc.py @@ -28,9 +28,10 @@ from bson import SON from pymongo import MongoClient -from pymongo.auth import MongoCredential from pymongo.auth_oidc import _CACHE as _oidc_cache +from pymongo.cursor import CursorType from pymongo.errors import ConfigurationError, OperationFailure +from pymongo.operations import InsertOne class TestAuthOIDC(unittest.TestCase): @@ -496,8 +497,8 @@ def test_reauthenticate_succeeds(self): with self.fail_point( { - "mode": {"times": 2}, - "data": {"failCommands": ["find", "saslStart"], "errorCode": 391}, + "mode": {"times": 1}, + "data": {"failCommands": ["find"], "errorCode": 391}, } ): # Perform a find operation. @@ -529,7 +530,152 @@ def test_reauthenticate_succeeds(self): self.assertEqual(self.refresh_called, 1) client.close() - def test_reauthenticate_retries_and_succees_with_cache(self): + def test_reauthenticate_succeeds_bulk_write(self): + request_cb = self.create_request_cb() + refresh_cb = self.create_refresh_cb() + + # Create a client with the callbacks. + props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb) + client = MongoClient(self.uri_single, authmechanismproperties=props) + + # Perform a find operation. + client.test.test.find_one() + + # Assert that the refresh callback has not been called. + self.assertEqual(self.refresh_called, 0) + + with self.fail_point( + { + "mode": {"times": 1}, + "data": {"failCommands": ["insert"], "errorCode": 391}, + } + ): + # Perform a bulk write operation. + client.test.test.bulk_write([InsertOne({})]) + + # Assert that the refresh callback has been called. + self.assertEqual(self.refresh_called, 1) + client.close() + + def test_reauthenticate_succeeds_bulk_read(self): + request_cb = self.create_request_cb() + refresh_cb = self.create_refresh_cb() + + # Create a client with the callbacks. + props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb) + client = MongoClient(self.uri_single, authmechanismproperties=props) + + # Perform a find operation. + client.test.test.find_one() + + # Perform a bulk write operation. + client.test.test.bulk_write([InsertOne({})]) + + # Assert that the refresh callback has not been called. + self.assertEqual(self.refresh_called, 0) + + with self.fail_point( + { + "mode": {"times": 1}, + "data": {"failCommands": ["find"], "errorCode": 391}, + } + ): + # Perform a bulk read operation. + cursor = client.test.test.find_raw_batches({}) + list(cursor) + + # Assert that the refresh callback has been called. + self.assertEqual(self.refresh_called, 1) + client.close() + + def test_reauthenticate_succeeds_cursor(self): + request_cb = self.create_request_cb() + refresh_cb = self.create_refresh_cb() + + # Create a client with the callbacks. + props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb) + client = MongoClient(self.uri_single, authmechanismproperties=props) + + # Perform an insert operation. + client.test.test.insert_one({"a": 1}) + + # Assert that the refresh callback has not been called. + self.assertEqual(self.refresh_called, 0) + + with self.fail_point( + { + "mode": {"times": 1}, + "data": {"failCommands": ["find"], "errorCode": 391}, + } + ): + # Perform a find operation. + cursor = client.test.test.find({"a": 1}) + self.assertGreaterEqual(len(list(cursor)), 1) + + # Assert that the refresh callback has been called. + self.assertEqual(self.refresh_called, 1) + client.close() + + def test_reauthenticate_succeeds_get_more(self): + request_cb = self.create_request_cb() + refresh_cb = self.create_refresh_cb() + + # Create a client with the callbacks. + props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb) + client = MongoClient(self.uri_single, authmechanismproperties=props) + + # Perform an insert operation. + client.test.test.insert_many([{"a": 1}, {"a": 1}]) + + # Assert that the refresh callback has not been called. + self.assertEqual(self.refresh_called, 0) + + with self.fail_point( + { + "mode": {"times": 1}, + "data": {"failCommands": ["find"], "errorCode": 391}, + } + ): + # Perform a find operation. + cursor = client.test.test.find({"a": 1}, batch_size=1, cursor_type=CursorType.EXHAUST) + self.assertGreaterEqual(len(list(cursor)), 1) + + # Assert that the refresh callback has been called. + self.assertEqual(self.refresh_called, 1) + client.close() + + def test_reauthenticate_succeeds_command(self): + request_cb = self.create_request_cb() + refresh_cb = self.create_refresh_cb() + + # Create a client with the callbacks. + props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb) + + print("start of test") + client = MongoClient(self.uri_single, authmechanismproperties=props) + + # Perform an insert operation. + client.test.test.insert_one({"a": 1}) + + # Assert that the refresh callback has not been called. + self.assertEqual(self.refresh_called, 0) + + with self.fail_point( + { + "mode": {"times": 1}, + "data": {"failCommands": ["count"], "errorCode": 391}, + } + ): + # Perform a count operation. + cursor = client.test.command(dict(count="test")) + + self.assertGreaterEqual(len(list(cursor)), 1) + + # Assert that the refresh callback has been called. + self.assertEqual(self.refresh_called, 1) + client.close() + + def test_reauthenticate_retries_and_succeeds_with_cache(self): listener = EventListener() # Create request and refresh callbacks that return valid credentials From c73be945ae030c1c523ebe47c1f97c178befce72 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 4 May 2023 19:53:09 -0500 Subject: [PATCH 125/133] lint --- pymongo/auth_oidc.py | 2 +- pymongo/common.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pymongo/auth_oidc.py b/pymongo/auth_oidc.py index 64b77435d0..0be7a49f97 100644 --- a/pymongo/auth_oidc.py +++ b/pymongo/auth_oidc.py @@ -241,7 +241,7 @@ def run_command(self, sock_info, cmd): except OperationFailure as exc: self.clear() if exc.code == _REAUTHENTICATION_REQUIRED_CODE: - if "jwt" in bson.decode(cmd["payload"]): + if "jwt" in bson.decode(cmd["payload"]): # type:ignore[attr-defined] if self.idp_info_gen_id > self.reauth_gen_id: raise return self.authenticate(sock_info, reauthenticate=True) diff --git a/pymongo/common.py b/pymongo/common.py index 79bf0e4304..63850fb565 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -433,7 +433,7 @@ def validate_auth_mechanism_properties(option: str, value: Any) -> Dict[str, Uni props: Dict[str, Any] = {} if not isinstance(value, str): if not isinstance(value, dict): - raise ValueError("Auth mechansim properties must be given as a string or a dictionary") + raise ValueError("Auth mechanism properties must be given as a string or a dictionary") for key, value in value.items(): if isinstance(value, str): props[key] = value @@ -455,7 +455,7 @@ def validate_auth_mechanism_properties(option: str, value: Any) -> Dict[str, Uni props[key] = value else: raise ValueError( - "Auth mechanism properity values must be strings or callback functions" + "Auth mechanism property values must be strings or callback functions" ) return props From 1924f3eacb2b8e828e7a81bba27a54115ac6ddda Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 4 May 2023 20:12:08 -0500 Subject: [PATCH 126/133] lint --- pymongo/auth_oidc.py | 2 +- pymongo/common.py | 1 - pymongo/pool.py | 2 +- test/test_auth_spec.py | 2 +- 4 files changed, 3 insertions(+), 4 deletions(-) diff --git a/pymongo/auth_oidc.py b/pymongo/auth_oidc.py index 0be7a49f97..530b1bb068 100644 --- a/pymongo/auth_oidc.py +++ b/pymongo/auth_oidc.py @@ -17,7 +17,7 @@ import threading from dataclasses import dataclass, field from datetime import datetime, timedelta, timezone -from typing import Any, Callable, Dict, List, Optional +from typing import Callable, Dict, List, Optional import bson from bson.binary import Binary diff --git a/pymongo/common.py b/pymongo/common.py index 63850fb565..f5e4d6369c 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -17,7 +17,6 @@ import datetime import inspect -import re import warnings from collections import OrderedDict, abc from typing import ( diff --git a/pymongo/pool.py b/pymongo/pool.py index 095d1692fe..89ba6459eb 100644 --- a/pymongo/pool.py +++ b/pymongo/pool.py @@ -790,7 +790,7 @@ def command( exhaust_allowed=exhaust_allowed, write_concern=write_concern, ) - except (OperationFailure, NotPrimaryError) as exc: + except (OperationFailure, NotPrimaryError): raise # Catch socket.error, KeyboardInterrupt, etc. and close ourselves. except BaseException as error: diff --git a/test/test_auth_spec.py b/test/test_auth_spec.py index ffed059e67..78f4d21929 100644 --- a/test/test_auth_spec.py +++ b/test/test_auth_spec.py @@ -49,7 +49,7 @@ def run_test(self): props["request_token_callback"] = lambda x, y: 1 del props["REQUEST_TOKEN_CALLBACK"] if props.get("REFRESH_TOKEN_CALLBACK"): - props["refresh_token_callback"] = lambda a, b: 1 # type:ignore[misc] + props["refresh_token_callback"] = lambda a, b: 1 del props["REFRESH_TOKEN_CALLBACK"] client = MongoClient(uri, connect=False, authmechanismproperties=props) credentials = client.options.pool_options._credentials From 770e4f608b501afb4134034aa9e2928fd515eb16 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Mon, 8 May 2023 11:55:53 -0500 Subject: [PATCH 127/133] always use for source --- .evergreen/run-mongodb-oidc-test.sh | 2 +- pymongo/auth.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/.evergreen/run-mongodb-oidc-test.sh b/.evergreen/run-mongodb-oidc-test.sh index 46bb779578..748a168b15 100755 --- a/.evergreen/run-mongodb-oidc-test.sh +++ b/.evergreen/run-mongodb-oidc-test.sh @@ -23,7 +23,7 @@ shopt -s expand_aliases # needed for `urlencode` alias MONGODB_URI=${MONGODB_URI:-"mongodb://localhost"} MONGODB_URI_SINGLE="${MONGODB_URI}/?authMechanism=MONGODB-OIDC" -MONGODB_URI_MULTIPLE="${MONGODB_URI}:27018/?authMechanism=MONGODB-OIDC&directConnection=true" +MONGODB_URI_MULTIPLE="mongodb+srv://cluster0.oztdp.mongodb-dev.net/?authMechanism=MONGODB-OIDC" if [ -z "${OIDC_TOKEN_DIR}" ]; then echo "Must specify OIDC_TOKEN_DIR" diff --git a/pymongo/auth.py b/pymongo/auth.py index cf32262d70..4bc31ee97b 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -140,8 +140,6 @@ def _build_credentials_tuple(mech, source, user, passwd, extra, database): # user can be None for temporary link-local EC2 credentials. return MongoCredential(mech, "$external", user, passwd, aws_props, None) elif mech == "MONGODB-OIDC": - if source is not None and source != "$external": - raise ValueError("authentication source must be $external or None for MONGODB-OIDC") properties = extra.get("authmechanismproperties", {}) request_token_callback = properties.get("request_token_callback") refresh_token_callback = properties.get("refresh_token_callback", None) From 828e22a159303e9441f95872300115c9f21cf297 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Mon, 8 May 2023 11:56:17 -0500 Subject: [PATCH 128/133] undo change to uri --- .evergreen/run-mongodb-oidc-test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.evergreen/run-mongodb-oidc-test.sh b/.evergreen/run-mongodb-oidc-test.sh index 748a168b15..46bb779578 100755 --- a/.evergreen/run-mongodb-oidc-test.sh +++ b/.evergreen/run-mongodb-oidc-test.sh @@ -23,7 +23,7 @@ shopt -s expand_aliases # needed for `urlencode` alias MONGODB_URI=${MONGODB_URI:-"mongodb://localhost"} MONGODB_URI_SINGLE="${MONGODB_URI}/?authMechanism=MONGODB-OIDC" -MONGODB_URI_MULTIPLE="mongodb+srv://cluster0.oztdp.mongodb-dev.net/?authMechanism=MONGODB-OIDC" +MONGODB_URI_MULTIPLE="${MONGODB_URI}:27018/?authMechanism=MONGODB-OIDC&directConnection=true" if [ -z "${OIDC_TOKEN_DIR}" ]; then echo "Must specify OIDC_TOKEN_DIR" From c719618fbe7779e8ded1ef55c68c01cd58073d12 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Mon, 8 May 2023 14:02:26 -0500 Subject: [PATCH 129/133] Revert "undo change to uri" This reverts commit 828e22a159303e9441f95872300115c9f21cf297. --- .evergreen/run-mongodb-oidc-test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.evergreen/run-mongodb-oidc-test.sh b/.evergreen/run-mongodb-oidc-test.sh index 46bb779578..748a168b15 100755 --- a/.evergreen/run-mongodb-oidc-test.sh +++ b/.evergreen/run-mongodb-oidc-test.sh @@ -23,7 +23,7 @@ shopt -s expand_aliases # needed for `urlencode` alias MONGODB_URI=${MONGODB_URI:-"mongodb://localhost"} MONGODB_URI_SINGLE="${MONGODB_URI}/?authMechanism=MONGODB-OIDC" -MONGODB_URI_MULTIPLE="${MONGODB_URI}:27018/?authMechanism=MONGODB-OIDC&directConnection=true" +MONGODB_URI_MULTIPLE="mongodb+srv://cluster0.oztdp.mongodb-dev.net/?authMechanism=MONGODB-OIDC" if [ -z "${OIDC_TOKEN_DIR}" ]; then echo "Must specify OIDC_TOKEN_DIR" From 1739968512bafe645f8aff9ec6e447647037c3c9 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 10 May 2023 16:45:36 -0500 Subject: [PATCH 130/133] address review --- .../unified/reauthenticate_with_retry.json | 204 +++++++++--------- .../unified/reauthenticate_without_retry.json | 204 +++++++++--------- test/auth_aws/test_auth_oidc.py | 35 ++- test/utils.py | 8 +- 4 files changed, 230 insertions(+), 221 deletions(-) diff --git a/test/auth/unified/reauthenticate_with_retry.json b/test/auth/unified/reauthenticate_with_retry.json index 036e56a432..ef110562ed 100644 --- a/test/auth/unified/reauthenticate_with_retry.json +++ b/test/auth/unified/reauthenticate_with_retry.json @@ -1,21 +1,21 @@ { - "description":"reauthenticate_with_retry", - "schemaVersion":"1.12", - "runOnRequirements":[ + "description": "reauthenticate_with_retry", + "schemaVersion": "1.12", + "runOnRequirements": [ { - "minServerVersion":"6.3", - "auth":true + "minServerVersion": "6.3", + "auth": true } ], - "createEntities":[ + "createEntities": [ { - "client":{ - "id":"client0", - "uriOptions":{ - "retryReads":true, - "retryWrites":true + "client": { + "id": "client0", + "uriOptions": { + "retryReads": true, + "retryWrites": true }, - "observeEvents":[ + "observeEvents": [ "commandStartedEvent", "commandSucceededEvent", "commandFailedEvent" @@ -23,97 +23,87 @@ } }, { - "database":{ - "id":"database0", - "client":"client0", - "databaseName":"db" + "database": { + "id": "database0", + "client": "client0", + "databaseName": "db" } }, { - "collection":{ - "id":"collection0", - "database":"database0", - "collectionName":"collName" + "collection": { + "id": "collection0", + "database": "database0", + "collectionName": "collName" } } ], - "initialData":[ + "initialData": [ { - "collectionName":"collName", - "databaseName":"db", - "documents":[ - - ] + "collectionName": "collName", + "databaseName": "db", + "documents": [] } ], - "tests":[ + "tests": [ { - "description":"Read command should reauthenticate when receive ReauthenticationRequired error code and retryReads=true", - "operations":[ + "description": "Read command should reauthenticate when receive ReauthenticationRequired error code and retryReads=true", + "operations": [ { - "name":"failPoint", - "object":"testRunner", - "arguments":{ - "client":"client0", - "failPoint":{ - "configureFailPoint":"failCommand", - "mode":{ - "times":1 + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "client0", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 }, - "data":{ - "failCommands":[ + "data": { + "failCommands": [ "find" ], - "errorCode":391 + "errorCode": 391 } } } }, { - "name":"find", - "arguments":{ - "filter":{ - - } + "name": "find", + "arguments": { + "filter": {} }, - "object":"collection0", - "expectResult":[ - - ] + "object": "collection0", + "expectResult": [] } ], - "expectEvents":[ + "expectEvents": [ { - "client":"client0", - "events":[ + "client": "client0", + "events": [ { - "commandStartedEvent":{ - "command":{ - "find":"collName", - "filter":{ - - } + "commandStartedEvent": { + "command": { + "find": "collName", + "filter": {} } } }, { - "commandFailedEvent":{ - "commandName":"find" + "commandFailedEvent": { + "commandName": "find" } }, { - "commandStartedEvent":{ - "command":{ - "find":"collName", - "filter":{ - - } + "commandStartedEvent": { + "command": { + "find": "collName", + "filter": {} } } }, { - "commandSucceededEvent":{ - "commandName":"find" + "commandSucceededEvent": { + "commandName": "find" } } ] @@ -121,76 +111,76 @@ ] }, { - "description":"Write command should reauthenticate when receive ReauthenticationRequired error code and retryWrites=true", - "operations":[ + "description": "Write command should reauthenticate when receive ReauthenticationRequired error code and retryWrites=true", + "operations": [ { - "name":"failPoint", - "object":"testRunner", - "arguments":{ - "client":"client0", - "failPoint":{ - "configureFailPoint":"failCommand", - "mode":{ - "times":1 + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "client0", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 }, - "data":{ - "failCommands":[ + "data": { + "failCommands": [ "insert" ], - "errorCode":391 + "errorCode": 391 } } } }, { - "name":"insertOne", - "object":"collection0", - "arguments":{ - "document":{ - "_id":1, - "x":1 + "name": "insertOne", + "object": "collection0", + "arguments": { + "document": { + "_id": 1, + "x": 1 } } } ], - "expectEvents":[ + "expectEvents": [ { - "client":"client0", - "events":[ + "client": "client0", + "events": [ { - "commandStartedEvent":{ - "command":{ - "insert":"collName", - "documents":[ + "commandStartedEvent": { + "command": { + "insert": "collName", + "documents": [ { - "_id":1, - "x":1 + "_id": 1, + "x": 1 } ] } } }, { - "commandFailedEvent":{ - "commandName":"insert" + "commandFailedEvent": { + "commandName": "insert" } }, { - "commandStartedEvent":{ - "command":{ - "insert":"collName", - "documents":[ + "commandStartedEvent": { + "command": { + "insert": "collName", + "documents": [ { - "_id":1, - "x":1 + "_id": 1, + "x": 1 } ] } } }, { - "commandSucceededEvent":{ - "commandName":"insert" + "commandSucceededEvent": { + "commandName": "insert" } } ] @@ -198,4 +188,4 @@ ] } ] -} \ No newline at end of file +} diff --git a/test/auth/unified/reauthenticate_without_retry.json b/test/auth/unified/reauthenticate_without_retry.json index 0c0b47d77a..819542274f 100644 --- a/test/auth/unified/reauthenticate_without_retry.json +++ b/test/auth/unified/reauthenticate_without_retry.json @@ -1,21 +1,21 @@ { - "description":"reauthenticate_without_retry", - "schemaVersion":"1.12", - "runOnRequirements":[ + "description": "reauthenticate_without_retry", + "schemaVersion": "1.13", + "runOnRequirements": [ { - "minServerVersion":"6.3", - "auth":true + "minServerVersion": "6.3", + "auth": true } ], - "createEntities":[ + "createEntities": [ { - "client":{ - "id":"client0", - "uriOptions":{ - "retryReads":false, - "retryWrites":false + "client": { + "id": "client0", + "uriOptions": { + "retryReads": false, + "retryWrites": false }, - "observeEvents":[ + "observeEvents": [ "commandStartedEvent", "commandSucceededEvent", "commandFailedEvent" @@ -23,97 +23,87 @@ } }, { - "database":{ - "id":"database0", - "client":"client0", - "databaseName":"db" + "database": { + "id": "database0", + "client": "client0", + "databaseName": "db" } }, { - "collection":{ - "id":"collection0", - "database":"database0", - "collectionName":"collName" + "collection": { + "id": "collection0", + "database": "database0", + "collectionName": "collName" } } ], - "initialData":[ + "initialData": [ { - "collectionName":"collName", - "databaseName":"db", - "documents":[ - - ] + "collectionName": "collName", + "databaseName": "db", + "documents": [] } ], - "tests":[ + "tests": [ { - "description":"Read command should reauthenticate when receive ReauthenticationRequired error code and retryReads=false", - "operations":[ + "description": "Read command should reauthenticate when receive ReauthenticationRequired error code and retryReads=false", + "operations": [ { - "name":"failPoint", - "object":"testRunner", - "arguments":{ - "client":"client0", - "failPoint":{ - "configureFailPoint":"failCommand", - "mode":{ - "times":1 + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "client0", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 }, - "data":{ - "failCommands":[ + "data": { + "failCommands": [ "find" ], - "errorCode":391 + "errorCode": 391 } } } }, { - "name":"find", - "arguments":{ - "filter":{ - - } + "name": "find", + "arguments": { + "filter": {} }, - "object":"collection0", - "expectResult":[ - - ] + "object": "collection0", + "expectResult": [] } ], - "expectEvents":[ + "expectEvents": [ { - "client":"client0", - "events":[ + "client": "client0", + "events": [ { - "commandStartedEvent":{ - "command":{ - "find":"collName", - "filter":{ - - } + "commandStartedEvent": { + "command": { + "find": "collName", + "filter": {} } } }, { - "commandFailedEvent":{ - "commandName":"find" + "commandFailedEvent": { + "commandName": "find" } }, { - "commandStartedEvent":{ - "command":{ - "find":"collName", - "filter":{ - - } + "commandStartedEvent": { + "command": { + "find": "collName", + "filter": {} } } }, { - "commandSucceededEvent":{ - "commandName":"find" + "commandSucceededEvent": { + "commandName": "find" } } ] @@ -121,76 +111,76 @@ ] }, { - "description":"Write command should reauthenticate when receive ReauthenticationRequired error code and retryWrites=false", - "operations":[ + "description": "Write command should reauthenticate when receive ReauthenticationRequired error code and retryWrites=false", + "operations": [ { - "name":"failPoint", - "object":"testRunner", - "arguments":{ - "client":"client0", - "failPoint":{ - "configureFailPoint":"failCommand", - "mode":{ - "times":1 + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "client0", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 }, - "data":{ - "failCommands":[ + "data": { + "failCommands": [ "insert" ], - "errorCode":391 + "errorCode": 391 } } } }, { - "name":"insertOne", - "object":"collection0", - "arguments":{ - "document":{ - "_id":1, - "x":1 + "name": "insertOne", + "object": "collection0", + "arguments": { + "document": { + "_id": 1, + "x": 1 } } } ], - "expectEvents":[ + "expectEvents": [ { - "client":"client0", - "events":[ + "client": "client0", + "events": [ { - "commandStartedEvent":{ - "command":{ - "insert":"collName", - "documents":[ + "commandStartedEvent": { + "command": { + "insert": "collName", + "documents": [ { - "_id":1, - "x":1 + "_id": 1, + "x": 1 } ] } } }, { - "commandFailedEvent":{ - "commandName":"insert" + "commandFailedEvent": { + "commandName": "insert" } }, { - "commandStartedEvent":{ - "command":{ - "insert":"collName", - "documents":[ + "commandStartedEvent": { + "command": { + "insert": "collName", + "documents": [ { - "_id":1, - "x":1 + "_id": 1, + "x": 1 } ] } } }, { - "commandSucceededEvent":{ - "commandName":"insert" + "commandSucceededEvent": { + "commandName": "insert" } } ] @@ -198,4 +188,4 @@ ] } ] -} \ No newline at end of file +} diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py index d9ff3b6b23..f81b3af0e7 100644 --- a/test/auth_aws/test_auth_oidc.py +++ b/test/auth_aws/test_auth_oidc.py @@ -1,4 +1,4 @@ -# Copyright 2020-present MongoDB, Inc. +# Copyright 2023-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,7 +24,7 @@ sys.path[0:0] = [""] -from test.utils import EventListener +from test.utils import EventListener, client_context from bson import SON from pymongo import MongoClient @@ -633,7 +633,36 @@ def test_reauthenticate_succeeds_get_more(self): with self.fail_point( { "mode": {"times": 1}, - "data": {"failCommands": ["find"], "errorCode": 391}, + "data": {"failCommands": ["getMore"], "errorCode": 391}, + } + ): + # Perform a find operation. + cursor = client.test.test.find({"a": 1}, batch_size=1) + self.assertGreaterEqual(len(list(cursor)), 1) + + # Assert that the refresh callback has been called. + self.assertEqual(self.refresh_called, 1) + client.close() + + @client_context.require_no_mongos + def test_reauthenticate_succeeds_get_more_exhaust(self): + request_cb = self.create_request_cb() + refresh_cb = self.create_refresh_cb() + + # Create a client with the callbacks. + props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb) + client = MongoClient(self.uri_single, authmechanismproperties=props) + + # Perform an insert operation. + client.test.test.insert_many([{"a": 1}, {"a": 1}]) + + # Assert that the refresh callback has not been called. + self.assertEqual(self.refresh_called, 0) + + with self.fail_point( + { + "mode": {"times": 1}, + "data": {"failCommands": ["getMore"], "errorCode": 391}, } ): # Perform a find operation. diff --git a/test/utils.py b/test/utils.py index aa7b360817..b39375925c 100644 --- a/test/utils.py +++ b/test/utils.py @@ -663,16 +663,16 @@ def get_command_line(client): return command_line -def camel_to_upper_camel(camel): - return camel[0].upper() + camel[1:] - - def camel_to_snake(camel): # Regex to convert CamelCase to snake_case. snake = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel) return re.sub("([a-z0-9])([A-Z])", r"\1_\2", snake).lower() +def camel_to_upper_camel(camel): + return camel[0].upper() + camel[1:] + + def camel_to_snake_args(arguments): for arg_name in list(arguments): c2s = camel_to_snake(arg_name) From cbadc70a71132f694ee4ea801a37c5de6f44bc38 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 10 May 2023 16:54:25 -0500 Subject: [PATCH 131/133] fix uri --- .evergreen/run-mongodb-oidc-test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.evergreen/run-mongodb-oidc-test.sh b/.evergreen/run-mongodb-oidc-test.sh index 748a168b15..46bb779578 100755 --- a/.evergreen/run-mongodb-oidc-test.sh +++ b/.evergreen/run-mongodb-oidc-test.sh @@ -23,7 +23,7 @@ shopt -s expand_aliases # needed for `urlencode` alias MONGODB_URI=${MONGODB_URI:-"mongodb://localhost"} MONGODB_URI_SINGLE="${MONGODB_URI}/?authMechanism=MONGODB-OIDC" -MONGODB_URI_MULTIPLE="mongodb+srv://cluster0.oztdp.mongodb-dev.net/?authMechanism=MONGODB-OIDC" +MONGODB_URI_MULTIPLE="${MONGODB_URI}:27018/?authMechanism=MONGODB-OIDC&directConnection=true" if [ -z "${OIDC_TOKEN_DIR}" ]; then echo "Must specify OIDC_TOKEN_DIR" From 109a2040d5332b8ca12e9230fd9648fd2dcbc918 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 11 May 2023 07:10:28 -0500 Subject: [PATCH 132/133] fix handling of no_mongos --- test/auth_aws/test_auth_oidc.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py index f81b3af0e7..470e4581c2 100644 --- a/test/auth_aws/test_auth_oidc.py +++ b/test/auth_aws/test_auth_oidc.py @@ -24,13 +24,14 @@ sys.path[0:0] = [""] -from test.utils import EventListener, client_context +from test.utils import EventListener from bson import SON from pymongo import MongoClient from pymongo.auth_oidc import _CACHE as _oidc_cache from pymongo.cursor import CursorType from pymongo.errors import ConfigurationError, OperationFailure +from pymongo.hello import HelloCompat from pymongo.operations import InsertOne @@ -644,8 +645,14 @@ def test_reauthenticate_succeeds_get_more(self): self.assertEqual(self.refresh_called, 1) client.close() - @client_context.require_no_mongos def test_reauthenticate_succeeds_get_more_exhaust(self): + # Ensure no mongos + props = dict(PROVIDER_NAME="aws") + client = MongoClient(self.uri_single, authmechanismproperties=props) + hello = client.admin.command(HelloCompat.LEGACY_CMD) + if hello.get("msg") != "isdbgrid": + raise unittest.SkipTest("Must not be a mongos") + request_cb = self.create_request_cb() refresh_cb = self.create_refresh_cb() From 7baf5a7b6319832f6738124f4f0bb91eb0d3bd0e Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 11 May 2023 10:04:47 -0500 Subject: [PATCH 133/133] update schema version --- test/auth/unified/reauthenticate_without_retry.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/auth/unified/reauthenticate_without_retry.json b/test/auth/unified/reauthenticate_without_retry.json index 819542274f..6fded47634 100644 --- a/test/auth/unified/reauthenticate_without_retry.json +++ b/test/auth/unified/reauthenticate_without_retry.json @@ -1,6 +1,6 @@ { "description": "reauthenticate_without_retry", - "schemaVersion": "1.13", + "schemaVersion": "1.12", "runOnRequirements": [ { "minServerVersion": "6.3",