diff --git a/.evergreen/config.yml b/.evergreen/config.yml index 8398244071..3f06fc1a03 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -749,6 +749,68 @@ functions: fi PYTHON_BINARY=${PYTHON_BINARY} ASSERT_NO_URI_CREDS=true .evergreen/run-mongodb-aws-test.sh + "bootstrap oidc": + - command: ec2.assume_role + params: + role_arn: ${aws_test_secrets_role} + - command: shell.exec + type: test + params: + working_dir: "src" + shell: bash + script: | + ${PREPARE_SHELL} + if [ "${skip_EC2_auth_test}" = "true" ]; then + echo "This platform does not support the oidc auth test, skipping..." + exit 0 + fi + + cd ${DRIVERS_TOOLS}/.evergreen/auth_oidc + export AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID} + export AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY} + export AWS_SESSION_TOKEN=${AWS_SESSION_TOKEN} + export OIDC_TOKEN_DIR=/tmp/tokens + + . ./activate-authoidcvenv.sh + python oidc_write_orchestration.py + python oidc_get_tokens.py + + "run oidc auth test with aws credentials": + - command: shell.exec + type: test + params: + working_dir: "src" + shell: bash + script: | + ${PREPARE_SHELL} + if [ "${skip_EC2_auth_test}" = "true" ]; then + echo "This platform does not support the oidc auth test, skipping..." + exit 0 + fi + cd ${DRIVERS_TOOLS}/.evergreen/auth_oidc + mongosh setup_oidc.js + - command: shell.exec + type: test + params: + working_dir: "src" + silent: true + script: | + # DO NOT ECHO WITH XTRACE (which PREPARE_SHELL does) + cat <<'EOF' > "${PROJECT_DIRECTORY}/prepare_mongodb_oidc.sh" + export OIDC_TOKEN_DIR=/tmp/tokens + EOF + - command: shell.exec + type: test + params: + working_dir: "src" + script: | + ${PREPARE_SHELL} + if [ "${skip_web_identity_auth_test}" = "true" ]; then + echo "This platform does not support the oidc auth test, skipping..." + exit 0 + fi + PYTHON_BINARY=${PYTHON_BINARY} ASSERT_NO_URI_CREDS=true .evergreen/run-mongodb-oidc-test.sh + "run aws auth test with aws credentials as environment variables": - command: shell.exec type: test @@ -2034,6 +2096,19 @@ tasks: - func: "run aws auth test with aws web identity credentials" - func: "run aws ECS auth test" + - name: "oidc-auth-test-latest" + commands: + - func: "bootstrap oidc" + - func: "bootstrap mongo-orchestration" + vars: + AUTH: "auth" + ORCHESTRATION_FILE: "auth-oidc.json" + TOPOLOGY: "replica_set" + VERSION: "latest" + - func: "run oidc auth test with aws credentials" + vars: + AWS_WEB_IDENTITY_TOKEN_FILE: /tmp/tokens/test1 + - name: load-balancer-test commands: - func: "bootstrap mongo-orchestration" @@ -3103,6 +3178,14 @@ buildvariants: # macOS MongoDB servers do not staple OCSP responses and only support RSA. - name: ".ocsp-rsa !.ocsp-staple" +- matrix_name: "oidc-auth-test" + matrix_spec: + platform: [ ubuntu-20.04 ] + python-version: ["3.9"] + display_name: "MONGODB-OIDC Auth ${platform} ${python-version}" + tasks: + - name: "oidc-auth-test-latest" + - matrix_name: "aws-auth-test" matrix_spec: platform: [ubuntu-20.04] diff --git a/.evergreen/resync-specs.sh b/.evergreen/resync-specs.sh index 489ff28b3a..817a2d96bc 100755 --- a/.evergreen/resync-specs.sh +++ b/.evergreen/resync-specs.sh @@ -70,6 +70,9 @@ for spec in "$@" do # Match the spec dir name, the python test dir name, and/or common abbreviations. case "$spec" in + auth) + cpjson auth/tests/ auth + ;; atlas-data-lake-testing|data_lake) cpjson atlas-data-lake-testing/tests/ data_lake ;; diff --git a/.evergreen/run-mongodb-oidc-test.sh b/.evergreen/run-mongodb-oidc-test.sh new file mode 100755 index 0000000000..46bb779578 --- /dev/null +++ b/.evergreen/run-mongodb-oidc-test.sh @@ -0,0 +1,85 @@ +#!/bin/bash + +set -o xtrace +set -o errexit # Exit the script with error if any of the commands fail + +############################################ +# Main Program # +############################################ + +# Supported/used environment variables: +# MONGODB_URI Set the URI, including an optional username/password to use +# to connect to the server via MONGODB-OIDC authentication +# mechanism. +# PYTHON_BINARY The Python version to use. + +echo "Running MONGODB-OIDC authentication tests" +# ensure no secrets are printed in log files +set +x + +# load the script +shopt -s expand_aliases # needed for `urlencode` alias +[ -s "${PROJECT_DIRECTORY}/prepare_mongodb_oidc.sh" ] && source "${PROJECT_DIRECTORY}/prepare_mongodb_oidc.sh" + +MONGODB_URI=${MONGODB_URI:-"mongodb://localhost"} +MONGODB_URI_SINGLE="${MONGODB_URI}/?authMechanism=MONGODB-OIDC" +MONGODB_URI_MULTIPLE="${MONGODB_URI}:27018/?authMechanism=MONGODB-OIDC&directConnection=true" + +if [ -z "${OIDC_TOKEN_DIR}" ]; then + echo "Must specify OIDC_TOKEN_DIR" + exit 1 +fi + +export MONGODB_URI_SINGLE="$MONGODB_URI_SINGLE" +export MONGODB_URI_MULTIPLE="$MONGODB_URI_MULTIPLE" +export MONGODB_URI="$MONGODB_URI" + +echo $MONGODB_URI_SINGLE +echo $MONGODB_URI_MULTIPLE +echo $MONGODB_URI + +if [ "$ASSERT_NO_URI_CREDS" = "true" ]; then + if echo "$MONGODB_URI" | grep -q "@"; then + echo "MONGODB_URI unexpectedly contains user credentials!"; + exit 1 + fi +fi + +# show test output +set -x + +# Workaround macOS python 3.9 incompatibility with system virtualenv. +if [ "$(uname -s)" = "Darwin" ]; then + VIRTUALENV="/Library/Frameworks/Python.framework/Versions/3.9/bin/python3 -m virtualenv" +else + VIRTUALENV=$(command -v virtualenv) +fi + +authtest () { + if [ "Windows_NT" = "$OS" ]; then + PYTHON=$(cygpath -m $PYTHON) + fi + + echo "Running MONGODB-OIDC authentication tests with $PYTHON" + $PYTHON --version + + $VIRTUALENV -p $PYTHON --never-download venvoidc + if [ "Windows_NT" = "$OS" ]; then + . venvoidc/Scripts/activate + else + . venvoidc/bin/activate + fi + python -m pip install -U pip setuptools + python -m pip install '.[aws]' + python test/auth_aws/test_auth_oidc.py -v + deactivate + rm -rf venvoidc +} + +PYTHON=${PYTHON_BINARY:-} +if [ -z "$PYTHON" ]; then + echo "Cannot test without specifying PYTHON_BINARY" + exit 1 +fi + +authtest diff --git a/pymongo/auth.py b/pymongo/auth.py index 3d259335b0..4bc31ee97b 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -27,6 +27,7 @@ from bson.binary import Binary from bson.son import SON from pymongo.auth_aws import _authenticate_aws +from pymongo.auth_oidc import _authenticate_oidc, _get_authenticator, _OIDCProperties from pymongo.errors import ConfigurationError, OperationFailure from pymongo.saslprep import saslprep @@ -48,6 +49,7 @@ [ "GSSAPI", "MONGODB-CR", + "MONGODB-OIDC", "MONGODB-X509", "MONGODB-AWS", "PLAIN", @@ -101,7 +103,7 @@ def __hash__(self): def _build_credentials_tuple(mech, source, user, passwd, extra, database): """Build and return a mechanism specific credentials tuple.""" - if mech not in ("MONGODB-X509", "MONGODB-AWS") and user is None: + if mech not in ("MONGODB-X509", "MONGODB-AWS", "MONGODB-OIDC") and user is None: raise ConfigurationError("%s requires a username." % (mech,)) if mech == "GSSAPI": if source is not None and source != "$external": @@ -137,6 +139,32 @@ def _build_credentials_tuple(mech, source, user, passwd, extra, database): aws_props = _AWSProperties(aws_session_token=aws_session_token) # user can be None for temporary link-local EC2 credentials. return MongoCredential(mech, "$external", user, passwd, aws_props, None) + elif mech == "MONGODB-OIDC": + properties = extra.get("authmechanismproperties", {}) + request_token_callback = properties.get("request_token_callback") + refresh_token_callback = properties.get("refresh_token_callback", None) + provider_name = properties.get("PROVIDER_NAME", "") + default_allowed = [ + "*.mongodb.net", + "*.mongodb-dev.net", + "*.mongodbgov.net", + "localhost", + "127.0.0.1", + "::1", + ] + allowed_hosts = properties.get("allowed_hosts", default_allowed) + if not request_token_callback and provider_name != "aws": + raise ConfigurationError( + "authentication with MONGODB-OIDC requires providing an request_token_callback or a provider_name of 'aws'" + ) + oidc_props = _OIDCProperties( + request_token_callback=request_token_callback, + refresh_token_callback=refresh_token_callback, + provider_name=provider_name, + allowed_hosts=allowed_hosts, + ) + return MongoCredential(mech, "$external", user, passwd, oidc_props, None) + elif mech == "PLAIN": source_database = source or database or "$external" return MongoCredential(mech, source_database, user, passwd, None, None) @@ -439,7 +467,7 @@ def _authenticate_x509(credentials, sock_info): # MONGODB-X509 is done after the speculative auth step. return - cmd = _X509Context(credentials).speculate_command() + cmd = _X509Context(credentials, sock_info.address).speculate_command() sock_info.command("$external", cmd) @@ -482,6 +510,7 @@ def _authenticate_default(credentials, sock_info): "MONGODB-CR": _authenticate_mongo_cr, "MONGODB-X509": _authenticate_x509, "MONGODB-AWS": _authenticate_aws, + "MONGODB-OIDC": _authenticate_oidc, "PLAIN": _authenticate_plain, "SCRAM-SHA-1": functools.partial(_authenticate_scram, mechanism="SCRAM-SHA-1"), "SCRAM-SHA-256": functools.partial(_authenticate_scram, mechanism="SCRAM-SHA-256"), @@ -490,15 +519,16 @@ def _authenticate_default(credentials, sock_info): class _AuthContext(object): - def __init__(self, credentials): + def __init__(self, credentials, address): self.credentials = credentials self.speculative_authenticate = None + self.address = address @staticmethod - def from_credentials(creds): + def from_credentials(creds, address): spec_cls = _SPECULATIVE_AUTH_MAP.get(creds.mechanism) if spec_cls: - return spec_cls(creds) + return spec_cls(creds, address) return None def speculate_command(self): @@ -512,8 +542,8 @@ def speculate_succeeded(self): class _ScramContext(_AuthContext): - def __init__(self, credentials, mechanism): - super(_ScramContext, self).__init__(credentials) + def __init__(self, credentials, address, mechanism): + super(_ScramContext, self).__init__(credentials, address) self.scram_data = None self.mechanism = mechanism @@ -534,16 +564,30 @@ def speculate_command(self): return cmd +class _OIDCContext(_AuthContext): + def speculate_command(self): + authenticator = _get_authenticator(self.credentials, self.address) + cmd = authenticator.auth_start_cmd(False) + if cmd is None: + return + cmd["db"] = self.credentials.source + return cmd + + _SPECULATIVE_AUTH_MAP: Mapping[str, Callable] = { "MONGODB-X509": _X509Context, "SCRAM-SHA-1": functools.partial(_ScramContext, mechanism="SCRAM-SHA-1"), "SCRAM-SHA-256": functools.partial(_ScramContext, mechanism="SCRAM-SHA-256"), + "MONGODB-OIDC": _OIDCContext, "DEFAULT": functools.partial(_ScramContext, mechanism="SCRAM-SHA-256"), } -def authenticate(credentials, sock_info): +def authenticate(credentials, sock_info, reauthenticate=False): """Authenticate sock_info.""" mechanism = credentials.mechanism auth_func = _AUTH_MAP[mechanism] - auth_func(credentials, sock_info) + if mechanism == "MONGODB-OIDC": + _authenticate_oidc(credentials, sock_info, reauthenticate) + else: + auth_func(credentials, sock_info) diff --git a/pymongo/auth_oidc.py b/pymongo/auth_oidc.py new file mode 100644 index 0000000000..530b1bb068 --- /dev/null +++ b/pymongo/auth_oidc.py @@ -0,0 +1,299 @@ +# Copyright 2023-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MONGODB-OIDC Authentication helpers.""" +import os +import threading +from dataclasses import dataclass, field +from datetime import datetime, timedelta, timezone +from typing import Callable, Dict, List, Optional + +import bson +from bson.binary import Binary +from bson.son import SON +from pymongo.errors import ConfigurationError, OperationFailure +from pymongo.helpers import _REAUTHENTICATION_REQUIRED_CODE + + +@dataclass +class _OIDCProperties: + request_token_callback: Optional[Callable[..., Dict]] + refresh_token_callback: Optional[Callable[..., Dict]] + provider_name: Optional[str] + allowed_hosts: List[str] + + +"""Mechanism properties for MONGODB-OIDC authentication.""" + +TOKEN_BUFFER_MINUTES = 5 +CALLBACK_TIMEOUT_SECONDS = 5 * 60 +CACHE_TIMEOUT_MINUTES = 60 * 5 +CALLBACK_VERSION = 0 + +_CACHE: Dict[str, "_OIDCAuthenticator"] = {} + + +def _get_authenticator(credentials, address): + # Clear out old items in the cache. + now_utc = datetime.now(timezone.utc) + to_remove = [] + for key, value in _CACHE.items(): + if value.cache_exp_utc is not None and value.cache_exp_utc < now_utc: + to_remove.append(key) + for key in to_remove: + del _CACHE[key] + + # Extract values. + principal_name = credentials.username + properties = credentials.mechanism_properties + request_cb = properties.request_token_callback + refresh_cb = properties.refresh_token_callback + + # Validate that the address is allowed. + if not properties.provider_name: + found = False + allowed_hosts = properties.allowed_hosts + for patt in allowed_hosts: + if patt == address[0]: + found = True + elif patt.startswith("*.") and address[0].endswith(patt[1:]): + found = True + if not found: + raise ConfigurationError( + f"Refusing to connect to {address[0]}, which is not in authOIDCAllowedHosts: {allowed_hosts}" + ) + + # Get or create the cache item. + cache_key = f"{principal_name}{address[0]}{address[1]}{id(request_cb)}{id(refresh_cb)}" + _CACHE.setdefault(cache_key, _OIDCAuthenticator(username=principal_name, properties=properties)) + + return _CACHE[cache_key] + + +def _get_cache_exp(): + return datetime.now(timezone.utc) + timedelta(minutes=CACHE_TIMEOUT_MINUTES) + + +@dataclass +class _OIDCAuthenticator: + username: str + properties: _OIDCProperties + idp_info: Optional[Dict] = field(default=None) + idp_resp: Optional[Dict] = field(default=None) + reauth_gen_id: int = field(default=0) + idp_info_gen_id: int = field(default=0) + token_gen_id: int = field(default=0) + token_exp_utc: Optional[datetime] = field(default=None) + cache_exp_utc: datetime = field(default_factory=_get_cache_exp) + lock: threading.Lock = field(default_factory=threading.Lock) + + def get_current_token(self, use_callbacks=True): + properties = self.properties + + request_cb = properties.request_token_callback + refresh_cb = properties.refresh_token_callback + if not use_callbacks: + request_cb = None + refresh_cb = None + + current_valid_token = False + if self.token_exp_utc is not None: + now_utc = datetime.now(timezone.utc) + exp_utc = self.token_exp_utc + buffer_seconds = TOKEN_BUFFER_MINUTES * 60 + if (exp_utc - now_utc).total_seconds() >= buffer_seconds: + current_valid_token = True + + timeout = CALLBACK_TIMEOUT_SECONDS + + if not use_callbacks and not current_valid_token: + return None + + if not current_valid_token and request_cb is not None: + prev_token = self.idp_resp and self.idp_resp["access_token"] + with self.lock: + # See if the token was changed while we were waiting for the + # lock. + new_token = self.idp_resp and self.idp_resp["access_token"] + if new_token != prev_token: + return new_token + + refresh_token = self.idp_resp and self.idp_resp.get("refresh_token") + refresh_token = refresh_token or "" + context = dict( + timeout_seconds=timeout, + version=CALLBACK_VERSION, + refresh_token=refresh_token, + ) + + if self.idp_resp is None or refresh_cb is None: + self.idp_resp = request_cb(self.idp_info, context) + elif request_cb is not None: + self.idp_resp = refresh_cb(self.idp_info, context) + cache_exp_utc = datetime.now(timezone.utc) + timedelta( + minutes=CACHE_TIMEOUT_MINUTES + ) + self.cache_exp_utc = cache_exp_utc + self.token_gen_id += 1 + + token_result = self.idp_resp + + # Validate callback return value. + if not isinstance(token_result, dict): + raise ValueError("OIDC callback returned invalid result") + + if "access_token" not in token_result: + raise ValueError("OIDC callback did not return an access_token") + + expected = ["access_token", "expires_in_seconds", "refesh_token"] + for key in token_result: + if key not in expected: + raise ValueError(f'Unexpected field in callback result "{key}"') + + token = token_result["access_token"] + + if "expires_in_seconds" in token_result: + expires_in = int(token_result["expires_in_seconds"]) + buffer_seconds = TOKEN_BUFFER_MINUTES * 60 + if expires_in >= buffer_seconds: + now_utc = datetime.now(timezone.utc) + exp_utc = now_utc + timedelta(seconds=expires_in) + self.token_exp_utc = exp_utc + + return token + + def auth_start_cmd(self, use_callbacks=True): + properties = self.properties + + # Handle aws provider credentials. + if properties.provider_name == "aws": + aws_identity_file = os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] + with open(aws_identity_file) as fid: + token = fid.read().strip() + payload = dict(jwt=token) + cmd = SON( + [ + ("saslStart", 1), + ("mechanism", "MONGODB-OIDC"), + ("payload", Binary(bson.encode(payload))), + ] + ) + return cmd + + principal_name = self.username + + if self.idp_info is not None: + self.cache_exp_utc = datetime.now(timezone.utc) + timedelta( + minutes=CACHE_TIMEOUT_MINUTES + ) + + if self.idp_info is None: + self.cache_exp_utc = _get_cache_exp() + + if self.idp_info is None: + # Send the SASL start with the optional principal name. + payload = dict() + + if principal_name: + payload["n"] = principal_name + + cmd = SON( + [ + ("saslStart", 1), + ("mechanism", "MONGODB-OIDC"), + ("payload", Binary(bson.encode(payload))), + ("autoAuthorize", 1), + ] + ) + return cmd + + token = self.get_current_token(use_callbacks) + if not token: + return None + bin_payload = Binary(bson.encode(dict(jwt=token))) + return SON( + [ + ("saslStart", 1), + ("mechanism", "MONGODB-OIDC"), + ("payload", bin_payload), + ] + ) + + def clear(self): + self.idp_info = None + self.idp_resp = None + self.token_exp_utc = None + + def run_command(self, sock_info, cmd): + try: + return sock_info.command("$external", cmd, no_reauth=True) + except OperationFailure as exc: + self.clear() + if exc.code == _REAUTHENTICATION_REQUIRED_CODE: + if "jwt" in bson.decode(cmd["payload"]): # type:ignore[attr-defined] + if self.idp_info_gen_id > self.reauth_gen_id: + raise + return self.authenticate(sock_info, reauthenticate=True) + raise + + def authenticate(self, sock_info, reauthenticate=False): + if reauthenticate: + prev_id = getattr(sock_info, "oidc_token_gen_id", None) + # Check if we've already changed tokens. + if prev_id == self.token_gen_id: + self.reauth_gen_id = self.idp_info_gen_id + self.token_exp_utc = None + if not self.properties.refresh_token_callback: + self.clear() + + ctx = sock_info.auth_ctx + cmd = None + + if ctx and ctx.speculate_succeeded(): + resp = ctx.speculative_authenticate + else: + cmd = self.auth_start_cmd() + resp = self.run_command(sock_info, cmd) + + if resp["done"]: + sock_info.oidc_token_gen_id = self.token_gen_id + return + + server_resp: Dict = bson.decode(resp["payload"]) + if "issuer" in server_resp: + self.idp_info = server_resp + self.idp_info_gen_id += 1 + + conversation_id = resp["conversationId"] + token = self.get_current_token() + sock_info.oidc_token_gen_id = self.token_gen_id + bin_payload = Binary(bson.encode(dict(jwt=token))) + cmd = SON( + [ + ("saslContinue", 1), + ("conversationId", conversation_id), + ("payload", bin_payload), + ] + ) + resp = self.run_command(sock_info, cmd) + if not resp["done"]: + self.clear() + raise OperationFailure("SASL conversation failed to complete.") + return resp + + +def _authenticate_oidc(credentials, sock_info, reauthenticate): + """Authenticate using MONGODB-OIDC.""" + authenticator = _get_authenticator(credentials, sock_info.address) + return authenticator.authenticate(sock_info, reauthenticate=reauthenticate) diff --git a/pymongo/common.py b/pymongo/common.py index 4b8aeb020c..4e39c8e514 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -16,6 +16,7 @@ """Functions and classes common to multiple pymongo modules.""" import datetime +import inspect import warnings from collections import OrderedDict, abc from typing import ( @@ -416,14 +417,48 @@ def validate_read_preference_tags(name: str, value: Any) -> List[Dict[str, str]] _MECHANISM_PROPS = frozenset( - ["SERVICE_NAME", "CANONICALIZE_HOST_NAME", "SERVICE_REALM", "AWS_SESSION_TOKEN"] + [ + "SERVICE_NAME", + "CANONICALIZE_HOST_NAME", + "SERVICE_REALM", + "AWS_SESSION_TOKEN", + "PROVIDER_NAME", + ] ) def validate_auth_mechanism_properties(option: str, value: Any) -> Dict[str, Union[bool, str]]: """Validate authMechanismProperties.""" - value = validate_string(option, value) props: Dict[str, Any] = {} + if not isinstance(value, str): + if not isinstance(value, dict): + raise ValueError("Auth mechanism properties must be given as a string or a dictionary") + for key, value in value.items(): + if isinstance(value, str): + props[key] = value + elif isinstance(value, bool): + props[key] = str(value).lower() + elif key in ["allowed_hosts"] and isinstance(value, list): + props[key] = value + elif inspect.isfunction(value): + signature = inspect.signature(value) + if key == "request_token_callback": + expected_params = 2 + elif key == "refresh_token_callback": + expected_params = 2 + else: + raise ValueError(f"Unrecognized Auth mechanism function {key}") + if len(signature.parameters) != expected_params: + msg = f"{key} must accept {expected_params} parameters" + raise ValueError(msg) + props[key] = value + else: + raise ValueError( + "Auth mechanism property values must be strings or callback functions" + ) + return props + + value = validate_string(option, value) for opt in value.split(","): try: key, val = opt.split(":") @@ -715,6 +750,7 @@ def validate_datetime_conversion(option: Any, value: Any) -> Optional[DatetimeCo "password": validate_string_or_none, "server_selector": validate_is_callable_or_none, "auto_encryption_opts": validate_auto_encryption_opts_or_none, + "authoidcallowedhosts": validate_list, } # Dictionary where keys are any URI option name, and values are the diff --git a/pymongo/helpers.py b/pymongo/helpers.py index 31325c8af2..1a753c66f4 100644 --- a/pymongo/helpers.py +++ b/pymongo/helpers.py @@ -68,6 +68,9 @@ ] ) +# Server code raised when re-authentication is required +_REAUTHENTICATION_REQUIRED_CODE = 391 + def _gen_index_name(keys): """Generate an index name from the set of fields it is over.""" @@ -267,3 +270,35 @@ def _handle_exception(): pass finally: del einfo + + +def _handle_reauth(func): + def inner(*args, **kwargs): + no_reauth = kwargs.pop("no_reauth", False) + from pymongo.pool import SocketInfo + + try: + return func(*args, **kwargs) + except OperationFailure as exc: + if no_reauth: + raise + if exc.code == _REAUTHENTICATION_REQUIRED_CODE: + # Look for an argument that either is a SocketInfo + # or has a socket_info attribute, so we can trigger + # a reauth. + sock_info = None + for arg in args: + if isinstance(arg, SocketInfo): + sock_info = arg + break + if hasattr(arg, "sock_info"): + sock_info = arg.sock_info + break + if sock_info: + sock_info.authenticate(reauthenticate=True) + else: + raise + return func(*args, **kwargs) + raise + + return inner diff --git a/pymongo/message.py b/pymongo/message.py index f7a173ca8a..3510d210a5 100644 --- a/pymongo/message.py +++ b/pymongo/message.py @@ -54,6 +54,7 @@ ProtocolError, ) from pymongo.hello import HelloCompat +from pymongo.helpers import _handle_reauth from pymongo.read_preferences import ReadPreference from pymongo.write_concern import WriteConcern @@ -909,6 +910,7 @@ def unack_write(self, cmd, request_id, msg, max_doc_size, docs): self.start_time = datetime.datetime.now() return result + @_handle_reauth def write_command(self, cmd, request_id, msg, docs): """A proxy for SocketInfo.write_command that handles event publishing.""" if self.publish: diff --git a/pymongo/pool.py b/pymongo/pool.py index 42e6a642a4..6ba1554231 100644 --- a/pymongo/pool.py +++ b/pymongo/pool.py @@ -57,6 +57,7 @@ _CertificateError, ) from pymongo.hello import Hello, HelloCompat +from pymongo.helpers import _handle_reauth from pymongo.lock import _create_lock from pymongo.monitoring import ConnectionCheckOutFailedReason, ConnectionClosedReason from pymongo.network import command, receive_message @@ -756,7 +757,7 @@ def _hello(self, cluster_time, topology_version, heartbeat_frequency): if creds: if creds.mechanism == "DEFAULT" and creds.username: cmd["saslSupportedMechs"] = creds.source + "." + creds.username - auth_ctx = auth._AuthContext.from_credentials(creds) + auth_ctx = auth._AuthContext.from_credentials(creds, self.address) if auth_ctx: cmd["speculativeAuthenticate"] = auth_ctx.speculate_command() else: @@ -813,6 +814,7 @@ def _next_reply(self): helpers._check_command_response(response_doc, self.max_wire_version) return response_doc + @_handle_reauth def command( self, dbname, @@ -966,17 +968,22 @@ def write_command(self, request_id, msg, codec_options): helpers._check_command_response(result, self.max_wire_version) return result - def authenticate(self): + def authenticate(self, reauthenticate=False): """Authenticate to the server if needed. Can raise ConnectionFailure or OperationFailure. """ # CMAP spec says to publish the ready event only after authenticating # the connection. + if reauthenticate: + if self.performed_handshake: + # Existing auth_ctx is stale, remove it. + self.auth_ctx = None + self.ready = False if not self.ready: creds = self.opts._credentials if creds: - auth.authenticate(creds, self) + auth.authenticate(creds, self, reauthenticate=reauthenticate) self.ready = True if self.enabled_for_cmap: self.listeners.publish_connection_ready(self.address, self.id) diff --git a/pymongo/server.py b/pymongo/server.py index f26f473c32..16c905abb7 100644 --- a/pymongo/server.py +++ b/pymongo/server.py @@ -18,7 +18,7 @@ from bson import _decode_all_selective from pymongo.errors import NotPrimaryError, OperationFailure -from pymongo.helpers import _check_command_response +from pymongo.helpers import _check_command_response, _handle_reauth from pymongo.message import _convert_exception, _OpMsg from pymongo.response import PinnedResponse, Response @@ -73,6 +73,7 @@ def request_check(self): """Check the server's state soon.""" self._monitor.request_check() + @_handle_reauth def run_operation(self, sock_info, operation, read_preference, listeners, unpack_res): """Run a _Query or _GetMore operation and return a Response object. diff --git a/test/auth/connection-string.json b/test/auth/legacy/connection-string.json similarity index 76% rename from test/auth/connection-string.json rename to test/auth/legacy/connection-string.json index 2a37ae8df4..ca979010af 100644 --- a/test/auth/connection-string.json +++ b/test/auth/legacy/connection-string.json @@ -444,6 +444,133 @@ "AWS_SESSION_TOKEN": "token!@#$%^&*()_+" } } + }, + { + "description": "should recognise the mechanism and request callback (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC", + "callback": ["oidcRequest"], + "valid": true, + "credential": { + "username": null, + "password": null, + "source": "$external", + "mechanism": "MONGODB-OIDC", + "mechanism_properties": { + "REQUEST_TOKEN_CALLBACK": true + } + } + }, + { + "description": "should recognise the mechanism when auth source is explicitly specified and with request callback (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authSource=$external", + "callback": ["oidcRequest"], + "valid": true, + "credential": { + "username": null, + "password": null, + "source": "$external", + "mechanism": "MONGODB-OIDC", + "mechanism_properties": { + "REQUEST_TOKEN_CALLBACK": true + } + } + }, + { + "description": "should recognise the mechanism with request and refresh callback (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC", + "callback": ["oidcRequest", "oidcRefresh"], + "valid": true, + "credential": { + "username": null, + "password": null, + "source": "$external", + "mechanism": "MONGODB-OIDC", + "mechanism_properties": { + "REQUEST_TOKEN_CALLBACK": true, + "REFRESH_TOKEN_CALLBACK": true + } + } + }, + { + "description": "should recognise the mechanism and username with request callback (MONGODB-OIDC)", + "uri": "mongodb://principalName@localhost/?authMechanism=MONGODB-OIDC", + "callback": ["oidcRequest"], + "valid": true, + "credential": { + "username": "principalName", + "password": null, + "source": "$external", + "mechanism": "MONGODB-OIDC", + "mechanism_properties": { + "REQUEST_TOKEN_CALLBACK": true + } + } + }, + { + "description": "should recognise the mechanism with aws device (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:aws", + "valid": true, + "credential": { + "username": null, + "password": null, + "source": "$external", + "mechanism": "MONGODB-OIDC", + "mechanism_properties": { + "PROVIDER_NAME": "aws" + } + } + }, + { + "description": "should recognise the mechanism when auth source is explicitly specified and with aws device (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authSource=$external&authMechanismProperties=PROVIDER_NAME:aws", + "valid": true, + "credential": { + "username": null, + "password": null, + "source": "$external", + "mechanism": "MONGODB-OIDC", + "mechanism_properties": { + "PROVIDER_NAME": "aws" + } + } + }, + { + "description": "should throw an exception if username and password are specified (MONGODB-OIDC)", + "uri": "mongodb://user:pass@localhost/?authMechanism=MONGODB-OIDC", + "callback": ["oidcRequest"], + "valid": false, + "credential": null + }, + { + "description": "should throw an exception if username and deviceName are specified (MONGODB-OIDC)", + "uri": "mongodb://principalName@localhost/?authMechanism=MONGODB-OIDC&PROVIDER_NAME:gcp", + "valid": false, + "credential": null + }, + { + "description": "should throw an exception if specified deviceName is not supported (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:unexisted", + "valid": false, + "credential": null + }, + { + "description": "should throw an exception if neither deviceName nor callbacks specified (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC", + "valid": false, + "credential": null + }, + { + "description": "should throw an exception when only refresh callback is specified (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC", + "callback": ["oidcRefresh"], + "valid": false, + "credential": null + }, + { + "description": "should throw an exception when unsupported auth property is specified (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=UnsupportedProperty:unexisted", + "valid": false, + "credential": null } ] -} +} \ No newline at end of file diff --git a/test/auth/unified/reauthenticate_with_retry.json b/test/auth/unified/reauthenticate_with_retry.json new file mode 100644 index 0000000000..ef110562ed --- /dev/null +++ b/test/auth/unified/reauthenticate_with_retry.json @@ -0,0 +1,191 @@ +{ + "description": "reauthenticate_with_retry", + "schemaVersion": "1.12", + "runOnRequirements": [ + { + "minServerVersion": "6.3", + "auth": true + } + ], + "createEntities": [ + { + "client": { + "id": "client0", + "uriOptions": { + "retryReads": true, + "retryWrites": true + }, + "observeEvents": [ + "commandStartedEvent", + "commandSucceededEvent", + "commandFailedEvent" + ] + } + }, + { + "database": { + "id": "database0", + "client": "client0", + "databaseName": "db" + } + }, + { + "collection": { + "id": "collection0", + "database": "database0", + "collectionName": "collName" + } + } + ], + "initialData": [ + { + "collectionName": "collName", + "databaseName": "db", + "documents": [] + } + ], + "tests": [ + { + "description": "Read command should reauthenticate when receive ReauthenticationRequired error code and retryReads=true", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "client0", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "find" + ], + "errorCode": 391 + } + } + } + }, + { + "name": "find", + "arguments": { + "filter": {} + }, + "object": "collection0", + "expectResult": [] + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "find": "collName", + "filter": {} + } + } + }, + { + "commandFailedEvent": { + "commandName": "find" + } + }, + { + "commandStartedEvent": { + "command": { + "find": "collName", + "filter": {} + } + } + }, + { + "commandSucceededEvent": { + "commandName": "find" + } + } + ] + } + ] + }, + { + "description": "Write command should reauthenticate when receive ReauthenticationRequired error code and retryWrites=true", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "client0", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "insert" + ], + "errorCode": 391 + } + } + } + }, + { + "name": "insertOne", + "object": "collection0", + "arguments": { + "document": { + "_id": 1, + "x": 1 + } + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "insert": "collName", + "documents": [ + { + "_id": 1, + "x": 1 + } + ] + } + } + }, + { + "commandFailedEvent": { + "commandName": "insert" + } + }, + { + "commandStartedEvent": { + "command": { + "insert": "collName", + "documents": [ + { + "_id": 1, + "x": 1 + } + ] + } + } + }, + { + "commandSucceededEvent": { + "commandName": "insert" + } + } + ] + } + ] + } + ] +} diff --git a/test/auth/unified/reauthenticate_without_retry.json b/test/auth/unified/reauthenticate_without_retry.json new file mode 100644 index 0000000000..6fded47634 --- /dev/null +++ b/test/auth/unified/reauthenticate_without_retry.json @@ -0,0 +1,191 @@ +{ + "description": "reauthenticate_without_retry", + "schemaVersion": "1.12", + "runOnRequirements": [ + { + "minServerVersion": "6.3", + "auth": true + } + ], + "createEntities": [ + { + "client": { + "id": "client0", + "uriOptions": { + "retryReads": false, + "retryWrites": false + }, + "observeEvents": [ + "commandStartedEvent", + "commandSucceededEvent", + "commandFailedEvent" + ] + } + }, + { + "database": { + "id": "database0", + "client": "client0", + "databaseName": "db" + } + }, + { + "collection": { + "id": "collection0", + "database": "database0", + "collectionName": "collName" + } + } + ], + "initialData": [ + { + "collectionName": "collName", + "databaseName": "db", + "documents": [] + } + ], + "tests": [ + { + "description": "Read command should reauthenticate when receive ReauthenticationRequired error code and retryReads=false", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "client0", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "find" + ], + "errorCode": 391 + } + } + } + }, + { + "name": "find", + "arguments": { + "filter": {} + }, + "object": "collection0", + "expectResult": [] + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "find": "collName", + "filter": {} + } + } + }, + { + "commandFailedEvent": { + "commandName": "find" + } + }, + { + "commandStartedEvent": { + "command": { + "find": "collName", + "filter": {} + } + } + }, + { + "commandSucceededEvent": { + "commandName": "find" + } + } + ] + } + ] + }, + { + "description": "Write command should reauthenticate when receive ReauthenticationRequired error code and retryWrites=false", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "client0", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "insert" + ], + "errorCode": 391 + } + } + } + }, + { + "name": "insertOne", + "object": "collection0", + "arguments": { + "document": { + "_id": 1, + "x": 1 + } + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "insert": "collName", + "documents": [ + { + "_id": 1, + "x": 1 + } + ] + } + } + }, + { + "commandFailedEvent": { + "commandName": "insert" + } + }, + { + "commandStartedEvent": { + "command": { + "insert": "collName", + "documents": [ + { + "_id": 1, + "x": 1 + } + ] + } + } + }, + { + "commandSucceededEvent": { + "commandName": "insert" + } + } + ] + } + ] + } + ] +} diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py new file mode 100644 index 0000000000..470e4581c2 --- /dev/null +++ b/test/auth_aws/test_auth_oidc.py @@ -0,0 +1,821 @@ +# Copyright 2023-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test MONGODB-OIDC Authentication.""" + +import os +import sys +import threading +import time +import unittest +from contextlib import contextmanager +from typing import Dict + +sys.path[0:0] = [""] + +from test.utils import EventListener + +from bson import SON +from pymongo import MongoClient +from pymongo.auth_oidc import _CACHE as _oidc_cache +from pymongo.cursor import CursorType +from pymongo.errors import ConfigurationError, OperationFailure +from pymongo.hello import HelloCompat +from pymongo.operations import InsertOne + + +class TestAuthOIDC(unittest.TestCase): + uri: str + + @classmethod + def setUpClass(cls): + cls.uri_single = os.environ["MONGODB_URI_SINGLE"] + cls.uri_multiple = os.environ["MONGODB_URI_MULTIPLE"] + cls.uri_admin = os.environ["MONGODB_URI"] + cls.token_dir = os.environ["OIDC_TOKEN_DIR"] + + def setUp(self): + self.request_called = 0 + self.refresh_called = 0 + _oidc_cache.clear() + os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] = os.path.join(self.token_dir, "test_user1") + + def create_request_cb(self, username="test_user1", expires_in_seconds=None, sleep=0): + + token_file = os.path.join(self.token_dir, username) + + def request_token(server_info, context): + # Validate the info. + self.assertIn("issuer", server_info) + self.assertIn("clientId", server_info) + + # Validate the timeout. + timeout_seconds = context["timeout_seconds"] + self.assertEqual(timeout_seconds, 60 * 5) + with open(token_file) as fid: + token = fid.read() + resp = dict(access_token=token) + + time.sleep(sleep) + + if expires_in_seconds is not None: + resp["expires_in_seconds"] = expires_in_seconds + self.request_called += 1 + return resp + + return request_token + + def create_refresh_cb(self, username="test_user1", expires_in_seconds=None): + + token_file = os.path.join(self.token_dir, username) + + def refresh_token(server_info, context): + with open(token_file) as fid: + token = fid.read() + + # Validate the info. + self.assertIn("issuer", server_info) + self.assertIn("clientId", server_info) + + # Validate the creds + self.assertIsNotNone(context["refresh_token"]) + + # Validate the timeout. + self.assertEqual(context["timeout_seconds"], 60 * 5) + + resp = dict(access_token=token) + if expires_in_seconds is not None: + resp["expires_in_seconds"] = expires_in_seconds + self.refresh_called += 1 + return resp + + return refresh_token + + @contextmanager + def fail_point(self, command_args): + cmd_on = SON([("configureFailPoint", "failCommand")]) + cmd_on.update(command_args) + client = MongoClient(self.uri_admin) + client.admin.command(cmd_on) + try: + yield + finally: + client.admin.command("configureFailPoint", cmd_on["configureFailPoint"], mode="off") + + def test_connect_callbacks_single_implicit_username(self): + request_token = self.create_request_cb() + props: Dict = dict(request_token_callback=request_token) + client = MongoClient(self.uri_single, authmechanismproperties=props) + client.test.test.find_one() + client.close() + + def test_connect_callbacks_single_explicit_username(self): + request_token = self.create_request_cb() + props: Dict = dict(request_token_callback=request_token) + client = MongoClient(self.uri_single, username="test_user1", authmechanismproperties=props) + client.test.test.find_one() + client.close() + + def test_connect_callbacks_multiple_principal_user1(self): + request_token = self.create_request_cb() + props: Dict = dict(request_token_callback=request_token) + client = MongoClient( + self.uri_multiple, username="test_user1", authmechanismproperties=props + ) + client.test.test.find_one() + client.close() + + def test_connect_callbacks_multiple_principal_user2(self): + request_token = self.create_request_cb("test_user2") + props: Dict = dict(request_token_callback=request_token) + client = MongoClient( + self.uri_multiple, username="test_user2", authmechanismproperties=props + ) + client.test.test.find_one() + client.close() + + def test_connect_callbacks_multiple_no_username(self): + request_token = self.create_request_cb() + props: Dict = dict(request_token_callback=request_token) + client = MongoClient(self.uri_multiple, authmechanismproperties=props) + with self.assertRaises(OperationFailure): + client.test.test.find_one() + client.close() + + def test_allowed_hosts_blocked(self): + request_token = self.create_request_cb() + props: Dict = dict(request_token_callback=request_token, allowed_hosts=[]) + client = MongoClient(self.uri_single, authmechanismproperties=props) + with self.assertRaises(ConfigurationError): + client.test.test.find_one() + client.close() + + props: Dict = dict(request_token_callback=request_token, allowed_hosts=["example.com"]) + client = MongoClient( + self.uri_single + "&ignored=example.com", authmechanismproperties=props, connect=False + ) + with self.assertRaises(ConfigurationError): + client.test.test.find_one() + client.close() + + def test_connect_aws_single_principal(self): + props = dict(PROVIDER_NAME="aws") + client = MongoClient(self.uri_single, authmechanismproperties=props) + client.test.test.find_one() + client.close() + + def test_connect_aws_multiple_principal_user1(self): + props = dict(PROVIDER_NAME="aws") + client = MongoClient(self.uri_multiple, authmechanismproperties=props) + client.test.test.find_one() + client.close() + + def test_connect_aws_multiple_principal_user2(self): + os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] = os.path.join(self.token_dir, "test_user2") + props = dict(PROVIDER_NAME="aws") + client = MongoClient(self.uri_multiple, authmechanismproperties=props) + client.test.test.find_one() + client.close() + + def test_connect_aws_allowed_hosts_ignored(self): + props = dict(PROVIDER_NAME="aws", allowed_hosts=[]) + client = MongoClient(self.uri_multiple, authmechanismproperties=props) + client.test.test.find_one() + client.close() + + def test_valid_callbacks(self): + request_cb = self.create_request_cb(expires_in_seconds=60) + refresh_cb = self.create_refresh_cb() + + props: Dict = dict( + request_token_callback=request_cb, + refresh_token_callback=refresh_cb, + ) + client = MongoClient(self.uri_single, authmechanismproperties=props) + client.test.test.find_one() + client.close() + + client = MongoClient(self.uri_single, authmechanismproperties=props) + client.test.test.find_one() + client.close() + + def test_lock_avoids_extra_callbacks(self): + request_cb = self.create_request_cb(sleep=0.5) + refresh_cb = self.create_refresh_cb() + + props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb) + + def run_test(): + client = MongoClient(self.uri_single, authMechanismProperties=props) + client.test.test.find_one() + client.close() + + client = MongoClient(self.uri_single, authMechanismProperties=props) + client.test.test.find_one() + client.close() + + t1 = threading.Thread(target=run_test) + t2 = threading.Thread(target=run_test) + t1.start() + t2.start() + t1.join() + t2.join() + + self.assertEqual(self.request_called, 1) + self.assertEqual(self.refresh_called, 2) + + def test_request_callback_returns_null(self): + def request_token_null(a, b): + return None + + props: Dict = dict(request_token_callback=request_token_null) + client = MongoClient(self.uri_single, authMechanismProperties=props) + with self.assertRaises(ValueError): + client.test.test.find_one() + client.close() + + def test_refresh_callback_returns_null(self): + request_cb = self.create_request_cb(expires_in_seconds=60) + + def refresh_token_null(a, b): + return None + + props: Dict = dict( + request_token_callback=request_cb, refresh_token_callback=refresh_token_null + ) + client = MongoClient(self.uri_single, authMechanismProperties=props) + client.test.test.find_one() + client.close() + + client = MongoClient(self.uri_single, authMechanismProperties=props) + with self.assertRaises(ValueError): + client.test.test.find_one() + client.close() + + def test_request_callback_invalid_result(self): + def request_token_invalid(a, b): + return dict() + + props: Dict = dict(request_token_callback=request_token_invalid) + client = MongoClient(self.uri_single, authMechanismProperties=props) + with self.assertRaises(ValueError): + client.test.test.find_one() + client.close() + + def request_cb_extra_value(server_info, context): + result = self.create_request_cb()(server_info, context) + result["foo"] = "bar" + return result + + props: Dict = dict(request_token_callback=request_cb_extra_value) + client = MongoClient(self.uri_single, authMechanismProperties=props) + with self.assertRaises(ValueError): + client.test.test.find_one() + client.close() + + def test_refresh_callback_missing_data(self): + request_cb = self.create_request_cb(expires_in_seconds=60) + + def refresh_cb_no_token(a, b): + return dict() + + props: Dict = dict( + request_token_callback=request_cb, refresh_token_callback=refresh_cb_no_token + ) + client = MongoClient(self.uri_single, authMechanismProperties=props) + client.test.test.find_one() + client.close() + + client = MongoClient(self.uri_single, authMechanismProperties=props) + with self.assertRaises(ValueError): + client.test.test.find_one() + client.close() + + def test_refresh_callback_extra_data(self): + request_cb = self.create_request_cb(expires_in_seconds=60) + + def refresh_cb_extra_value(server_info, context): + result = self.create_refresh_cb()(server_info, context) + result["foo"] = "bar" + return result + + props: Dict = dict( + request_token_callback=request_cb, refresh_token_callback=refresh_cb_extra_value + ) + client = MongoClient(self.uri_single, authMechanismProperties=props) + client.test.test.find_one() + client.close() + + client = MongoClient(self.uri_single, authMechanismProperties=props) + with self.assertRaises(ValueError): + client.test.test.find_one() + client.close() + + def test_cache_with_refresh(self): + # Create a new client with a request callback and a refresh callback. Both callbacks will read the contents of the ``AWS_WEB_IDENTITY_TOKEN_FILE`` location to obtain a valid access token. + + # Give a callback response with a valid accessToken and an expiresInSeconds that is within one minute. + request_cb = self.create_request_cb(expires_in_seconds=60) + refresh_cb = self.create_refresh_cb() + + props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb) + + # Ensure that a ``find`` operation adds credentials to the cache. + client = MongoClient(self.uri_single, authMechanismProperties=props) + client.test.test.find_one() + client.close() + + self.assertEqual(len(_oidc_cache), 1) + + # Create a new client with the same request callback and a refresh callback. + # Ensure that a ``find`` operation results in a call to the refresh callback. + client = MongoClient(self.uri_single, authMechanismProperties=props) + client.test.test.find_one() + client.close() + + self.assertEqual(self.refresh_called, 1) + self.assertEqual(len(_oidc_cache), 1) + + def test_cache_with_no_refresh(self): + # Create a new client with a request callback callback. + # Give a callback response with a valid accessToken and an expiresInSeconds that is within one minute. + request_cb = self.create_request_cb() + + props = dict(request_token_callback=request_cb) + client = MongoClient(self.uri_single, authMechanismProperties=props) + + # Ensure that a ``find`` operation adds credentials to the cache. + self.request_called = 0 + client.test.test.find_one() + client.close() + self.assertEqual(self.request_called, 1) + self.assertEqual(len(_oidc_cache), 1) + + # Create a new client with the same request callback. + # Ensure that a ``find`` operation results in a call to the request callback. + client = MongoClient(self.uri_single, authMechanismProperties=props) + client.test.test.find_one() + client.close() + self.assertEqual(self.request_called, 2) + self.assertEqual(len(_oidc_cache), 1) + + def test_cache_key_includes_callback(self): + request_cb = self.create_request_cb() + + props: Dict = dict(request_token_callback=request_cb) + + # Ensure that a ``find`` operation adds a new entry to the cache. + client = MongoClient(self.uri_single, authMechanismProperties=props) + client.test.test.find_one() + client.close() + + # Create a new client with a different request callback. + def request_token_2(a, b): + return request_cb(a, b) + + props["request_token_callback"] = request_token_2 + client = MongoClient(self.uri_single, authMechanismProperties=props) + + # Ensure that a ``find`` operation adds a new entry to the cache. + client.test.test.find_one() + client.close() + self.assertEqual(len(_oidc_cache), 2) + + def test_cache_clears_on_error(self): + request_cb = self.create_request_cb() + + # Create a new client with a valid request callback that gives credentials that expire within 5 minutes and a refresh callback that gives invalid credentials. + def refresh_cb(a, b): + return dict(access_token="bad") + + # Add a token to the cache that will expire soon. + props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb) + client = MongoClient(self.uri_single, authMechanismProperties=props) + client.test.test.find_one() + client.close() + + # Create a new client with the same callbacks. + client = MongoClient(self.uri_single, authMechanismProperties=props) + + # Ensure that another ``find`` operation results in an error. + with self.assertRaises(OperationFailure): + client.test.test.find_one() + + client.close() + + # Ensure that the cache has been cleared. + authenticator = list(_oidc_cache.values())[0] + self.assertIsNone(authenticator.idp_info) + + def test_cache_is_not_used_in_aws_automatic_workflow(self): + # Create a new client using the AWS device workflow. + # Ensure that a ``find`` operation does not add credentials to the cache. + props = dict(PROVIDER_NAME="aws") + client = MongoClient(self.uri_single, authmechanismproperties=props) + client.test.test.find_one() + client.close() + + # Ensure that the cache has been cleared. + authenticator = list(_oidc_cache.values())[0] + self.assertIsNone(authenticator.idp_info) + + def test_speculative_auth_success(self): + # Clear the cache + _oidc_cache.clear() + token_file = os.path.join(self.token_dir, "test_user1") + + def request_token(a, b): + with open(token_file) as fid: + token = fid.read() + return dict(access_token=token, expires_in_seconds=1000) + + # Create a client with a request callback that returns a valid token + # that will not expire soon. + props: Dict = dict(request_token_callback=request_token) + client = MongoClient(self.uri_single, authmechanismproperties=props) + + # Set a fail point for saslStart commands. + with self.fail_point( + { + "mode": {"times": 2}, + "data": {"failCommands": ["saslStart"], "errorCode": 18}, + } + ): + # Perform a find operation. + client.test.test.find_one() + + # Close the client. + client.close() + + # Create a new client. + client = MongoClient(self.uri_single, authmechanismproperties=props) + + # Set a fail point for saslStart commands. + with self.fail_point( + { + "mode": {"times": 2}, + "data": {"failCommands": ["saslStart"], "errorCode": 18}, + } + ): + # Perform a find operation. + client.test.test.find_one() + + # Close the client. + client.close() + + def test_reauthenticate_succeeds(self): + listener = EventListener() + + # Create request and refresh callbacks that return valid credentials + # that will not expire soon. + request_cb = self.create_request_cb() + refresh_cb = self.create_refresh_cb() + + # Create a client with the callbacks. + props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb) + client = MongoClient( + self.uri_single, event_listeners=[listener], authmechanismproperties=props + ) + + # Perform a find operation. + client.test.test.find_one() + + # Assert that the refresh callback has not been called. + self.assertEqual(self.refresh_called, 0) + + listener.reset() + + with self.fail_point( + { + "mode": {"times": 1}, + "data": {"failCommands": ["find"], "errorCode": 391}, + } + ): + # Perform a find operation. + client.test.test.find_one() + + started_events = [ + i.command_name for i in listener.started_events if not i.command_name.startswith("sasl") + ] + succeeded_events = [ + i.command_name + for i in listener.succeeded_events + if not i.command_name.startswith("sasl") + ] + failed_events = [ + i.command_name for i in listener.failed_events if not i.command_name.startswith("sasl") + ] + + self.assertEqual( + started_events, + [ + "find", + "find", + ], + ) + self.assertEqual(succeeded_events, ["find"]) + self.assertEqual(failed_events, ["find"]) + + # Assert that the refresh callback has been called. + self.assertEqual(self.refresh_called, 1) + client.close() + + def test_reauthenticate_succeeds_bulk_write(self): + request_cb = self.create_request_cb() + refresh_cb = self.create_refresh_cb() + + # Create a client with the callbacks. + props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb) + client = MongoClient(self.uri_single, authmechanismproperties=props) + + # Perform a find operation. + client.test.test.find_one() + + # Assert that the refresh callback has not been called. + self.assertEqual(self.refresh_called, 0) + + with self.fail_point( + { + "mode": {"times": 1}, + "data": {"failCommands": ["insert"], "errorCode": 391}, + } + ): + # Perform a bulk write operation. + client.test.test.bulk_write([InsertOne({})]) + + # Assert that the refresh callback has been called. + self.assertEqual(self.refresh_called, 1) + client.close() + + def test_reauthenticate_succeeds_bulk_read(self): + request_cb = self.create_request_cb() + refresh_cb = self.create_refresh_cb() + + # Create a client with the callbacks. + props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb) + client = MongoClient(self.uri_single, authmechanismproperties=props) + + # Perform a find operation. + client.test.test.find_one() + + # Perform a bulk write operation. + client.test.test.bulk_write([InsertOne({})]) + + # Assert that the refresh callback has not been called. + self.assertEqual(self.refresh_called, 0) + + with self.fail_point( + { + "mode": {"times": 1}, + "data": {"failCommands": ["find"], "errorCode": 391}, + } + ): + # Perform a bulk read operation. + cursor = client.test.test.find_raw_batches({}) + list(cursor) + + # Assert that the refresh callback has been called. + self.assertEqual(self.refresh_called, 1) + client.close() + + def test_reauthenticate_succeeds_cursor(self): + request_cb = self.create_request_cb() + refresh_cb = self.create_refresh_cb() + + # Create a client with the callbacks. + props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb) + client = MongoClient(self.uri_single, authmechanismproperties=props) + + # Perform an insert operation. + client.test.test.insert_one({"a": 1}) + + # Assert that the refresh callback has not been called. + self.assertEqual(self.refresh_called, 0) + + with self.fail_point( + { + "mode": {"times": 1}, + "data": {"failCommands": ["find"], "errorCode": 391}, + } + ): + # Perform a find operation. + cursor = client.test.test.find({"a": 1}) + self.assertGreaterEqual(len(list(cursor)), 1) + + # Assert that the refresh callback has been called. + self.assertEqual(self.refresh_called, 1) + client.close() + + def test_reauthenticate_succeeds_get_more(self): + request_cb = self.create_request_cb() + refresh_cb = self.create_refresh_cb() + + # Create a client with the callbacks. + props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb) + client = MongoClient(self.uri_single, authmechanismproperties=props) + + # Perform an insert operation. + client.test.test.insert_many([{"a": 1}, {"a": 1}]) + + # Assert that the refresh callback has not been called. + self.assertEqual(self.refresh_called, 0) + + with self.fail_point( + { + "mode": {"times": 1}, + "data": {"failCommands": ["getMore"], "errorCode": 391}, + } + ): + # Perform a find operation. + cursor = client.test.test.find({"a": 1}, batch_size=1) + self.assertGreaterEqual(len(list(cursor)), 1) + + # Assert that the refresh callback has been called. + self.assertEqual(self.refresh_called, 1) + client.close() + + def test_reauthenticate_succeeds_get_more_exhaust(self): + # Ensure no mongos + props = dict(PROVIDER_NAME="aws") + client = MongoClient(self.uri_single, authmechanismproperties=props) + hello = client.admin.command(HelloCompat.LEGACY_CMD) + if hello.get("msg") != "isdbgrid": + raise unittest.SkipTest("Must not be a mongos") + + request_cb = self.create_request_cb() + refresh_cb = self.create_refresh_cb() + + # Create a client with the callbacks. + props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb) + client = MongoClient(self.uri_single, authmechanismproperties=props) + + # Perform an insert operation. + client.test.test.insert_many([{"a": 1}, {"a": 1}]) + + # Assert that the refresh callback has not been called. + self.assertEqual(self.refresh_called, 0) + + with self.fail_point( + { + "mode": {"times": 1}, + "data": {"failCommands": ["getMore"], "errorCode": 391}, + } + ): + # Perform a find operation. + cursor = client.test.test.find({"a": 1}, batch_size=1, cursor_type=CursorType.EXHAUST) + self.assertGreaterEqual(len(list(cursor)), 1) + + # Assert that the refresh callback has been called. + self.assertEqual(self.refresh_called, 1) + client.close() + + def test_reauthenticate_succeeds_command(self): + request_cb = self.create_request_cb() + refresh_cb = self.create_refresh_cb() + + # Create a client with the callbacks. + props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb) + + print("start of test") + client = MongoClient(self.uri_single, authmechanismproperties=props) + + # Perform an insert operation. + client.test.test.insert_one({"a": 1}) + + # Assert that the refresh callback has not been called. + self.assertEqual(self.refresh_called, 0) + + with self.fail_point( + { + "mode": {"times": 1}, + "data": {"failCommands": ["count"], "errorCode": 391}, + } + ): + # Perform a count operation. + cursor = client.test.command(dict(count="test")) + + self.assertGreaterEqual(len(list(cursor)), 1) + + # Assert that the refresh callback has been called. + self.assertEqual(self.refresh_called, 1) + client.close() + + def test_reauthenticate_retries_and_succeeds_with_cache(self): + listener = EventListener() + + # Create request and refresh callbacks that return valid credentials + # that will not expire soon. + request_cb = self.create_request_cb() + refresh_cb = self.create_refresh_cb() + + # Create a client with the callbacks. + props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb) + client = MongoClient( + self.uri_single, event_listeners=[listener], authmechanismproperties=props + ) + + # Perform a find operation. + client.test.test.find_one() + + # Set a fail point for ``saslStart`` commands of the form + with self.fail_point( + { + "mode": {"times": 2}, + "data": {"failCommands": ["find", "saslStart"], "errorCode": 391}, + } + ): + # Perform a find operation that succeeds. + client.test.test.find_one() + + # Close the client. + client.close() + + def test_reauthenticate_fails_with_no_cache(self): + listener = EventListener() + + # Create request and refresh callbacks that return valid credentials + # that will not expire soon. + request_cb = self.create_request_cb() + refresh_cb = self.create_refresh_cb() + + # Create a client with the callbacks. + props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb) + client = MongoClient( + self.uri_single, event_listeners=[listener], authmechanismproperties=props + ) + + # Perform a find operation. + client.test.test.find_one() + + # Clear the cache. + _oidc_cache.clear() + + with self.fail_point( + { + "mode": {"times": 2}, + "data": {"failCommands": ["find", "saslStart"], "errorCode": 391}, + } + ): + # Perform a find operation that fails. + with self.assertRaises(OperationFailure): + client.test.test.find_one() + + client.close() + + def test_late_reauth_avoids_callback(self): + # Step 1: connect with both clients + request_cb = self.create_request_cb(expires_in_seconds=1e6) + refresh_cb = self.create_refresh_cb(expires_in_seconds=1e6) + + props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb) + client1 = MongoClient(self.uri_single, authMechanismProperties=props) + client1.test.test.find_one() + client2 = MongoClient(self.uri_single, authMechanismProperties=props) + client2.test.test.find_one() + + self.assertEqual(self.refresh_called, 0) + self.assertEqual(self.request_called, 1) + + # Step 2: cause a find 391 on the first client + with self.fail_point( + { + "mode": {"times": 1}, + "data": {"failCommands": ["find"], "errorCode": 391}, + } + ): + # Perform a find operation that succeeds. + client1.test.test.find_one() + + self.assertEqual(self.refresh_called, 1) + self.assertEqual(self.request_called, 1) + + # Step 3: cause a find 391 on the second client + with self.fail_point( + { + "mode": {"times": 1}, + "data": {"failCommands": ["find"], "errorCode": 391}, + } + ): + # Perform a find operation that succeeds. + client2.test.test.find_one() + + self.assertEqual(self.refresh_called, 1) + self.assertEqual(self.request_called, 1) + + client1.close() + client2.close() + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_auth_spec.py b/test/test_auth_spec.py index 9f2fa374ac..78f4d21929 100644 --- a/test/test_auth_spec.py +++ b/test/test_auth_spec.py @@ -22,6 +22,7 @@ sys.path[0:0] = [""] from test import unittest +from test.unified_format import generate_test_classes from pymongo import MongoClient @@ -41,7 +42,16 @@ def run_test(self): if not valid: self.assertRaises(Exception, MongoClient, uri, connect=False) else: - client = MongoClient(uri, connect=False) + props = {} + if credential: + props = credential["mechanism_properties"] or {} + if props.get("REQUEST_TOKEN_CALLBACK"): + props["request_token_callback"] = lambda x, y: 1 + del props["REQUEST_TOKEN_CALLBACK"] + if props.get("REFRESH_TOKEN_CALLBACK"): + props["refresh_token_callback"] = lambda a, b: 1 + del props["REFRESH_TOKEN_CALLBACK"] + client = MongoClient(uri, connect=False, authmechanismproperties=props) credentials = client.options.pool_options._credentials if credential is None: self.assertIsNone(credentials) @@ -70,6 +80,16 @@ def run_test(self): self.assertEqual( actual.aws_session_token, expected["AWS_SESSION_TOKEN"] ) + elif "PROVIDER_NAME" in expected: + self.assertEqual(actual.provider_name, expected["PROVIDER_NAME"]) + elif "request_token_callback" in expected: + self.assertEqual( + actual.request_token_callback, expected["request_token_callback"] + ) + elif "refresh_token_callback" in expected: + self.assertEqual( + actual.refresh_token_callback, expected["refresh_token_callback"] + ) else: self.fail("Unhandled property: %s" % (key,)) else: @@ -82,7 +102,7 @@ def run_test(self): def create_tests(): - for filename in glob.glob(os.path.join(_TEST_PATH, "*.json")): + for filename in glob.glob(os.path.join(_TEST_PATH, "legacy", "*.json")): test_suffix, _ = os.path.splitext(os.path.basename(filename)) with open(filename) as auth_tests: test_cases = json.load(auth_tests)["tests"] @@ -97,5 +117,12 @@ def create_tests(): create_tests() +globals().update( + generate_test_classes( + os.path.join(_TEST_PATH, "unified"), + module=__name__, + ) +) + if __name__ == "__main__": unittest.main()