diff --git a/src/aap_eda/api/event_stream_authentication.py b/src/aap_eda/api/event_stream_authentication.py index 1c511e1b8..8c84f8e59 100644 --- a/src/aap_eda/api/event_stream_authentication.py +++ b/src/aap_eda/api/event_stream_authentication.py @@ -32,6 +32,7 @@ from rest_framework.exceptions import AuthenticationFailed from aap_eda.core.enums import SignatureEncodingType +from aap_eda.core.utils.credentials import validate_x509_subject_match logger = logging.getLogger(__name__) DEFAULT_TIMEOUT = 30 @@ -110,11 +111,29 @@ class MTLSAuthentication(EventStreamAuthentication): def authenticate(self, _body=None): """Handle mTLS authentication.""" - if self.subject and self.subject != self.value: - message = f"Subject Name mismatch : {self.value}" + if self.subject and not self.validate_subject( + self.subject, self.value + ): + message = f"Subject: {self.value} does not match {self.subject}" logger.warning(message) raise AuthenticationFailed(message) + def validate_subject(self, expected: str, actual: str) -> bool: + """Validate that actual subject matches expected subject pattern. + + Uses shared X.509 standard-compliant DN parsing for attribute-level + matching. Supports wildcards and is order-independent per X.509 + standards. + + Args: + expected: Official subject pattern (may contain * wildcards) + actual: Input subject from user + + Returns: + bool: True if actual matches expected pattern, False otherwise + """ + return validate_x509_subject_match(expected, actual) + @dataclass class BasicAuthentication(EventStreamAuthentication): diff --git a/src/aap_eda/api/exceptions.py b/src/aap_eda/api/exceptions.py index 82f4fbccb..1cdf03550 100644 --- a/src/aap_eda/api/exceptions.py +++ b/src/aap_eda/api/exceptions.py @@ -35,6 +35,8 @@ "Conflict", "Unprocessable", "PermissionDenied", + "GatewayAPIError", + "MissingCredentialsError", "api_fallback_handler", ) @@ -125,3 +127,20 @@ class ExternalSMSError(APIException): default_detail = ( "External SMS Error: not able to fetch secrets from external SMS" ) + + +class GatewayAPIError(APIException): + status_code = status.HTTP_502_BAD_GATEWAY + default_code = "gateway_api_error" + default_detail = _( + "Gateway API Error: Unable to communicate with the Gateway service" + ) + + +class MissingCredentialsError(APIException): + status_code = status.HTTP_400_BAD_REQUEST + default_code = "missing_credentials" + default_detail = _( + "Missing Credentials: Required credentials are not available " + "for Gateway operations" + ) diff --git a/src/aap_eda/api/views/eda_credential.py b/src/aap_eda/api/views/eda_credential.py index 6fd8f7df6..7052473c6 100644 --- a/src/aap_eda/api/views/eda_credential.py +++ b/src/aap_eda/api/views/eda_credential.py @@ -204,6 +204,7 @@ def partial_update(self, request, pk): setattr(eda_credential, key, value) with transaction.atomic(): + eda_credential._request = request eda_credential.save() check_related_permissions( request.user, diff --git a/src/aap_eda/api/views/event_stream.py b/src/aap_eda/api/views/event_stream.py index 993f6488b..d43ce92ec 100644 --- a/src/aap_eda/api/views/event_stream.py +++ b/src/aap_eda/api/views/event_stream.py @@ -30,8 +30,13 @@ from aap_eda.api import exceptions as api_exc, filters, serializers from aap_eda.core import models -from aap_eda.core.enums import ResourceType +from aap_eda.core.enums import EventStreamAuthType, ResourceType +from aap_eda.core.exceptions import ( + GatewayAPIError as CoreGatewayAPIError, + MissingCredentials as CoreMissingCredentials, +) from aap_eda.core.utils import logging_utils +from aap_eda.services.sync_certs import SyncCertificates logger = logging.getLogger(__name__) @@ -99,7 +104,13 @@ def retrieve(self, request, *args, **kwargs): responses={ status.HTTP_204_NO_CONTENT: OpenApiResponse( None, description="Delete successful." - ) + ), + status.HTTP_400_BAD_REQUEST: OpenApiResponse( + description="Missing credentials for certificate deletion." + ), + status.HTTP_502_BAD_GATEWAY: OpenApiResponse( + description="Gateway API error during certificate deletion." + ), }, ) def destroy(self, request, *args, **kwargs): @@ -110,6 +121,7 @@ def destroy(self, request, *args, **kwargs): f"Event stream '{event_stream.name}' is being referenced by " f"{ref_count} activation(s) and cannot be deleted" ) + self._sync_certificates(event_stream, "destroy") self.perform_destroy(event_stream) logger.info( @@ -160,7 +172,10 @@ def list(self, request, *args, **kwargs): description="Return the new event stream.", ), status.HTTP_400_BAD_REQUEST: OpenApiResponse( - description="Invalid data to create event stream." + description="Invalid data or missing credentials." + ), + status.HTTP_502_BAD_GATEWAY: OpenApiResponse( + description="Gateway API error during certificate sync." ), }, ) @@ -182,6 +197,7 @@ def create(self, request, *args, **kwargs): RoleDefinition.objects.give_creator_permissions( request.user, serializer.instance ) + self._sync_certificates(response, "create") logger.info( logging_utils.generate_simple_audit_log( @@ -206,12 +222,17 @@ def create(self, request, *args, **kwargs): description="Update successful, return the new event stream.", ), status.HTTP_400_BAD_REQUEST: OpenApiResponse( - description="Unable to update event stream." + description="Update failed or missing credentials." + ), + status.HTTP_502_BAD_GATEWAY: OpenApiResponse( + description="Gateway API error during certificate sync." ), }, ) def partial_update(self, request, *args, **kwargs): event_stream = self.get_object() + new_eda_credential_id = request.data.get("eda_credential_id") + old_data = model_to_dict(event_stream) context = {"request": request} serializer = serializers.EventStreamInSerializer( @@ -233,6 +254,13 @@ def partial_update(self, request, *args, **kwargs): setattr(event_stream, key, value) with transaction.atomic(): + # Check if we need to destroy old certificates before saving + if ( + new_eda_credential_id + and event_stream.eda_credential.id != new_eda_credential_id + ): + self._sync_certificates(event_stream, "destroy") + event_stream.save() check_related_permissions( request.user, @@ -241,6 +269,9 @@ def partial_update(self, request, *args, **kwargs): model_to_dict(event_stream), ) + if new_eda_credential_id: + self._sync_certificates(event_stream, "update") + logger.info( logging_utils.generate_simple_audit_log( "Update", @@ -307,3 +338,35 @@ def activations(self, request, id): ) ) return self.get_paginated_response(serializer.data) + + def _sync_certificates( + self, + event_stream: models.EventStream, + action: str, + ) -> None: + if ( + event_stream.eda_credential.credential_type.kind + == EventStreamAuthType.MTLS + ): + try: + obj = SyncCertificates(event_stream.eda_credential.id) + if action == "destroy": + obj.delete(event_stream.id) + else: + obj.update() + except CoreGatewayAPIError as ex: + logger.error("Could not %s certificates: %s", action, str(ex)) + raise api_exc.GatewayAPIError( + detail=f"Gateway API error during certificate {action}: " + f"{str(ex)}" + ) + except CoreMissingCredentials as ex: + logger.error( + "Missing credentials for certificate %s: %s", + action, + str(ex), + ) + raise api_exc.MissingCredentialsError( + detail=f"Missing credentials for certificate {action}: " + f"{str(ex)}" + ) diff --git a/src/aap_eda/core/exceptions.py b/src/aap_eda/core/exceptions.py index 4599f4c4e..0306cd004 100644 --- a/src/aap_eda/core/exceptions.py +++ b/src/aap_eda/core/exceptions.py @@ -55,3 +55,11 @@ class CredentialPluginError(Exception): class UnknownPluginTypeError(Exception): pass + + +class GatewayAPIError(Exception): + pass + + +class MissingCredentials(Exception): + pass diff --git a/src/aap_eda/core/management/commands/create_initial_data.py b/src/aap_eda/core/management/commands/create_initial_data.py index fe3ca8c97..8bbd8df76 100644 --- a/src/aap_eda/core/management/commands/create_initial_data.py +++ b/src/aap_eda/core/management/commands/create_initial_data.py @@ -48,7 +48,7 @@ AUTH_TYPE_LABEL = "Event Stream Authentication Type" SIGNATURE_ENCODING_LABEL = "Signature Encoding" HTTP_HEADER_LABEL = "HTTP Header Key" -DEPRECATED_CREDENTIAL_KINDS = ["mtls"] +DEPRECATED_CREDENTIAL_KINDS = [] LABEL_PATH_TO_AUTH = "Path to Auth" LABEL_CLIENT_CERTIFICATE = "Client Certificate" LABEL_CLIENT_SECRET = "Client Secret" @@ -1762,6 +1762,55 @@ "required": ["app_or_client_id", "install_id", "private_rsa_key"], } +EVENT_STREAM_MTLS_INPUTS = { + "fields": [ + { + "id": "auth_type", + "label": AUTH_TYPE_LABEL, + "type": "string", + "default": "mtls", + "hidden": True, + }, + { + "id": "certificate", + "label": "Certificate", + "type": "string", + "multiline": True, + "format": "pem_certificate", + "help_text": ( + "The Certificate collection in PEM format. You can have " + "multiple certificates in this field separated by " + "-----BEGIN CERTIFICATE----- " + "and ending in -----END CERTIFICATE-----" + "If a certificate is provided it will be transferred " + "to the Gateway, otherwise its assumed that the Gateway " + "already has the CA certificates in place to validate " + "the incoming client certificate." + ), + }, + { + "id": "subject", + "label": "Certificate Subject", + "type": "string", + "help_text": ( + "The Subject from Certificate compliant with RFC 2253." + "This is optional and can be used to check the subject " + "defined in the certificate. It can contains regular " + "expression to match indivisual attributes in the subject " + "name. E.g., CN=[agent1,agent2].example.com,ST=[NJ|NY]" + ), + }, + { + "id": "http_header_key", + "label": HTTP_HEADER_LABEL, + "type": "string", + "default": "Subject", + "hidden": True, + }, + ], + "required": ["auth_type", "http_header_key"], +} + CREDENTIAL_TYPES = [ { "name": enums.DefaultCredentialType.SOURCE_CONTROL, @@ -2045,6 +2094,21 @@ "injectors": {}, "managed": True, }, + { + "name": enums.EventStreamCredentialType.MTLS, + "namespace": "event_stream", + "kind": "mtls", + "inputs": EVENT_STREAM_MTLS_INPUTS, + "injectors": {}, + "managed": True, + "description": ( + "Credential for Event Streams that use mutual TLS. " + "If CA Certificates are defined in the UI it will " + "be transferred to the Gateway proxy for validation " + "of incoming requests. We can optionally validate the " + "Subject defined in the inbound Certificate." + ), + }, ] diff --git a/src/aap_eda/core/utils/credentials.py b/src/aap_eda/core/utils/credentials.py index 6a6f0a8b7..757968641 100644 --- a/src/aap_eda/core/utils/credentials.py +++ b/src/aap_eda/core/utils/credentials.py @@ -15,11 +15,16 @@ import re import tempfile import typing +from datetime import datetime, timezone +from typing import Any, Optional import gnupg import jinja2 import validators import yaml +from cryptography import x509 +from cryptography.exceptions import UnsupportedAlgorithm +from cryptography.x509.oid import NameOID from django.core.exceptions import ValidationError from django.forms import model_to_dict from django.utils.translation import gettext_lazy as _ @@ -212,6 +217,16 @@ def validate_inputs( if bool(result): errors[display_field] = result + # Special validation for mTLS certificate and subject + if credential_type.name == enums.EventStreamCredentialType.MTLS: + # Also validate subject format if provided + if field == "subject" and user_input: + subject_errors = _validate_certificate_subject_format( + user_input + ) + if bool(subject_errors): + errors["inputs.subject"] = subject_errors + if data.get("type") == "boolean": if user_input and not isinstance(user_input, bool): errors[display_field] = ["Must be a boolean"] @@ -457,6 +472,9 @@ def _validate_format( elif data_type == "ssh_private_key": return _validate_ssh_key(schema, data, inputs) + elif data_type == "pem_certificate": + return _validate_pem_certificate(data) + return errors @@ -557,7 +575,7 @@ def _validate_file_template_key(key: str, key_names: list[str]) -> None: ) -def check_reserved_keys_in_extra_vars(data: dict[str, any]) -> None: +def check_reserved_keys_in_extra_vars(data: dict[str, Any]) -> None: for key in data.keys(): if key in RESERVED_EXTRA_VAR_KEYS: raise ValidationError( @@ -572,6 +590,40 @@ def check_reserved_keys_in_extra_vars(data: dict[str, any]) -> None: ) +def _validate_pem_certificate(cert_data: str) -> list[str]: + """Validate PEM certificate format and content. + + Args: + cert_data: Certificate data in PEM format + + Returns: + List of error messages, empty if valid + """ + errors = [] + + if not cert_data or not cert_data.strip(): + errors.append("Certificate data cannot be empty") + return errors + + certs = [] + try: + # Try to load as single certificate first + certs = x509.load_pem_x509_certificates(cert_data.encode("utf-8")) + except (ValueError, UnsupportedAlgorithm) as e: + errors.append(f"Invalid PEM certificate format: {str(e)}") + + # Validate certificate is not expired + for cert in certs: + now = datetime.now(timezone.utc) + + if now > cert.not_valid_after_utc: + errors.append( + f"Certificate has expired: {cert.not_valid_after_utc}" + ) + + return errors + + def build_copy_post_data( eda_credential: "models.EdaCredential", new_cred_name: str ) -> dict: @@ -600,7 +652,7 @@ def build_copy_post_data( return post_data -def _add_file_template_keys(context: dict, files: dict): +def _add_file_template_keys(context: dict, files: dict) -> None: for key in files.keys(): parts = key.split(".") # case key == "template" @@ -644,3 +696,242 @@ def get_resolved_secrets(obj: "models.EdaCredential") -> dict: for key, value in external_secrets.items(): result[key] = value return result + + +# X.509 Subject Matching Functions + + +def _validate_certificate_subject_format(subject_string: str) -> list[str]: + """Validate certificate subject format using X.509 standard parsing. + + Args: + subject_string: Subject DN string (example: "CN=example.com,O=Test") + + Returns: + List of error messages, empty if valid + """ + if not subject_string or not subject_string.strip(): + return [] # Empty subject is allowed + + errors = [] + + try: + # Normalize whitespace using shared logic + normalized_subject = _normalize_x509_dn_whitespace(subject_string) + + # Parse using X.509 standard library + parsed_name = x509.Name.from_rfc4514_string(normalized_subject) + + # Additional validations using parsed X.509 Name + errors.extend(_validate_x509_name_constraints(parsed_name)) + + except ValueError as e: + errors.append(f"Invalid X.509 DN format: {str(e) or subject_string}") + except Exception as e: + errors.append(f"Failed to parse subject DN: {str(e)}") + + return errors + + +def _validate_x509_name_constraints(parsed_name: Any) -> list[str]: + """Validate X.509 Name object constraints. + + Args: + parsed_name: Parsed X.509 Name object + + Returns: + List of error messages, empty if valid + """ + errors = [] + + # Check for duplicate attributes (except DC which can be repeated) + attr_counts = {} + for attr in parsed_name: + oid_name = ( + attr.oid._name if hasattr(attr.oid, "_name") else str(attr.oid) + ) + + # Allow multiple DC (Domain Component) attributes + if attr.oid == NameOID.DOMAIN_COMPONENT: + continue + + attr_counts[oid_name] = attr_counts.get(oid_name, 0) + 1 + + duplicates = [name for name, count in attr_counts.items() if count > 1] + if duplicates: + errors.append(f"Duplicate attributes not allowed: {duplicates}") + + # Validate country code length + for attr in parsed_name: + if attr.oid == NameOID.COUNTRY_NAME and len(attr.value) != 2: + errors.append(f"Country code must be 2 characters: '{attr.value}'") + + return errors + + +def validate_x509_subject_match(expected: str, actual: str) -> bool: + """Validate that actual X.509 subject matches expected subject pattern. + + Uses X.509 standard-compliant DN parsing for attribute-level matching. + Supports regex patterns and is order-independent per X.509 standards. + + Args: + expected: Expected subject pattern (may contain regex patterns) + actual: Actual subject from certificate + + Returns: + bool: True if actual matches expected pattern, False otherwise + """ + if not expected or not actual: + return False + + # Normalize whitespace in DN strings + expected = _normalize_x509_dn_whitespace(expected) + actual = _normalize_x509_dn_whitespace(actual) + + # Parse actual DN into X.509 Name object + try: + actual_name = x509.Name.from_rfc4514_string(actual) + except ValueError as e: + LOGGER.error(f"Invalid actual DN format: '{actual}': {e}") + return False + + # Parse expected DN manually to handle regex patterns + expected_attrs = _parse_dn_with_regex_patterns(expected) + if not expected_attrs: + LOGGER.error(f"Failed to parse expected DN: '{expected}'") + return False + + # Check each expected attribute against actual certificate + for attr_name, pattern_value in expected_attrs: + # Convert attribute name to OID + attr_oid = _get_oid_from_name(attr_name) + if not attr_oid: + LOGGER.error(f"Unknown attribute name: '{attr_name}'") + return False + + actual_attrs = actual_name.get_attributes_for_oid(attr_oid) + if not actual_attrs: + LOGGER.error( + f"Required attribute {attr_name} not found in actual subject" + ) + return False + + # Check if any actual attribute matches the expected pattern + if not _match_regex_pattern_against_attrs(pattern_value, actual_attrs): + return False + + return True + + +def _normalize_x509_dn_whitespace(dn: str) -> str: + """Normalize whitespace in DN string components. + + Args: + dn: Distinguished Name string + + Returns: + str: DN with normalized whitespace + """ + if not dn: + return dn + + # Split by comma and normalize each component + components = [] + for component in dn.split(","): + component = component.strip() + if "=" in component: + attr, value = component.split("=", 1) + attr = attr.strip() + value = value.strip() + components.append(f"{attr}={value}") + else: + components.append(component) + + return ",".join(components) + + +def _parse_dn_with_regex_patterns(dn: str) -> list[tuple[str, str]]: + """Parse DN string that may contain regex patterns in attribute values. + + Args: + dn: Distinguished Name string with potential regex patterns + + Returns: + List of (attribute_name, pattern_value) tuples + """ + if not dn: + return [] + + attrs = [] + # Split by comma, but be careful of commas inside quoted values + components = dn.split(",") + + for component in components: + component = component.strip() + if "=" not in component: + continue + + attr_name, value = component.split("=", 1) + attr_name = attr_name.strip().upper() + value = value.strip() + + attrs.append((attr_name, value)) + + return attrs + + +def _get_oid_from_name(attr_name: str) -> Optional[Any]: + """Convert attribute name to OID. + + Args: + attr_name: Attribute name (e.g., 'CN', 'O', 'OU') + + Returns: + OID object or None if unknown + """ + name_to_oid = { + "CN": NameOID.COMMON_NAME, + "COMMONNAME": NameOID.COMMON_NAME, + "O": NameOID.ORGANIZATION_NAME, + "ORGANIZATIONNAME": NameOID.ORGANIZATION_NAME, + "OU": NameOID.ORGANIZATIONAL_UNIT_NAME, + "ORGANIZATIONALUNITNAME": NameOID.ORGANIZATIONAL_UNIT_NAME, + "C": NameOID.COUNTRY_NAME, + "COUNTRYNAME": NameOID.COUNTRY_NAME, + "L": NameOID.LOCALITY_NAME, + "LOCALITYNAME": NameOID.LOCALITY_NAME, + "ST": NameOID.STATE_OR_PROVINCE_NAME, + "STATEORPROVINCENAME": NameOID.STATE_OR_PROVINCE_NAME, + "STREET": NameOID.STREET_ADDRESS, + "STREETADDRESS": NameOID.STREET_ADDRESS, + "DC": NameOID.DOMAIN_COMPONENT, + "DOMAINCOMPONENT": NameOID.DOMAIN_COMPONENT, + "EMAIL": NameOID.EMAIL_ADDRESS, + "EMAILADDRESS": NameOID.EMAIL_ADDRESS, + } + + return name_to_oid.get(attr_name.upper()) + + +def _match_regex_pattern_against_attrs( + pattern: str, actual_attrs: list[Any] +) -> bool: + """Match regex pattern against list of actual attributes. + + Args: + pattern: Regex pattern string + actual_attrs: List of actual X.509 attributes + + Returns: + bool: True if pattern matches any actual attribute + """ + try: + regex_pattern = f"^{pattern}$" + return any( + re.match(regex_pattern, attr.value, re.IGNORECASE) + for attr in actual_attrs + ) + except re.error as e: + LOGGER.error(f"Invalid regex pattern '{pattern}': {e}") + return False diff --git a/src/aap_eda/services/sync_certs.py b/src/aap_eda/services/sync_certs.py new file mode 100644 index 000000000..3f0ed1bc8 --- /dev/null +++ b/src/aap_eda/services/sync_certs.py @@ -0,0 +1,306 @@ +# Copyright 2025 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Synchronize Certificates with Gateway.""" +import hashlib +import logging +from typing import Any, Dict, Optional +from urllib.parse import urljoin + +import requests +import yaml +from ansible_base.resource_registry import resource_server +from django.conf import settings +from django.db.models.signals import post_save +from django.dispatch import receiver +from rest_framework import status + +from aap_eda.core import enums, models +from aap_eda.core.exceptions import GatewayAPIError, MissingCredentials + +LOGGER = logging.getLogger(__name__) +SLUG = "api/gateway/v1/ca_certificates/" +DEFAULT_TIMEOUT = 30 +SERVICE_TOKEN_HEADER = "X-ANSIBLE-SERVICE-AUTH" + + +class SyncCertificates: + """This class synchronizes the certificates with Gateway.""" + + def __init__(self, eda_credential_id: int) -> None: + self.eda_credential_id: int = eda_credential_id + self.gateway_url: str = settings.RESOURCE_SERVER["URL"] + self.gateway_ssl_verify: bool = settings.RESOURCE_SERVER.get( + "VALIDATE_HTTPS", True + ) + + self.eda_credential: models.EdaCredential = ( + models.EdaCredential.objects.get(id=self.eda_credential_id) + ) + + def update(self) -> None: + """Handle creating and updating the certificate in Gateway.""" + inputs = self._get_credential_inputs() + existing_object = self._fetch_from_gateway() + + # Handle certificate deletion case + if self._should_delete_certificate(existing_object, inputs): + return self.delete(None) + + # Handle no certificate case + if not self._has_certificate(inputs): + return + + # Handle certificate update case + if self._certificate_needs_update(existing_object, inputs): + self._update_certificate_in_gateway(existing_object, inputs) + else: + LOGGER.debug("No changes detected") + + def _get_credential_inputs(self) -> Dict[str, Any]: + """Get and parse credential inputs.""" + return yaml.safe_load(self.eda_credential.inputs.get_secret_value()) + + def _should_delete_certificate( + self, existing_object: Dict[str, Any], inputs: Dict[str, Any] + ) -> bool: + """Check if certificate should be deleted from Gateway.""" + return existing_object and not inputs.get("certificate") + + def _has_certificate(self, inputs: Dict[str, Any]) -> bool: + """Check if inputs contain a certificate.""" + return bool(inputs.get("certificate")) + + def _certificate_needs_update( + self, existing_object: Dict[str, Any], inputs: Dict[str, Any] + ) -> bool: + """Check if certificate needs to be updated in Gateway.""" + certificate = inputs["certificate"] + current_sha256 = hashlib.sha256( + certificate.encode("utf-8") + ).hexdigest() + existing_sha256 = existing_object.get("sha256", "") + return existing_sha256 != current_sha256 + + def _update_certificate_in_gateway( + self, existing_object: Dict[str, Any], inputs: Dict[str, Any] + ) -> None: + """Update or create certificate in Gateway.""" + certificate = inputs["certificate"] + sha256 = hashlib.sha256(certificate.encode("utf-8")).hexdigest() + + data = { + "name": self.eda_credential.name, + "pem_data": certificate, + "sha256": sha256, + "related_id_reference": self._get_remote_id(), + } + + if existing_object: + response = self._patch_certificate(existing_object["id"], data) + else: + response = self._post_certificate(data) + + self._handle_certificate_response(response) + + def _patch_certificate( + self, cert_id: str, data: Dict[str, str] + ) -> requests.Response: + """Update existing certificate in Gateway.""" + slug = f"{SLUG}/{cert_id}/" + url = urljoin(self.gateway_url, slug) + return self._make_request("patch", url, data) + + def _post_certificate(self, data: Dict[str, str]) -> requests.Response: + """Create new certificate in Gateway.""" + url = urljoin(self.gateway_url, SLUG) + return self._make_request("post", url, data) + + def _make_request( + self, method: str, url: str, data: Dict[str, str] + ) -> requests.Response: + """Make HTTP request with proper error handling.""" + headers = self._prep_headers() + + try: + response = getattr(requests, method)( + url, + json=data, + headers=headers, + verify=self.gateway_ssl_verify, + timeout=DEFAULT_TIMEOUT, + ) + return response + except requests.exceptions.ConnectionError as ex: + LOGGER.error( + "Connection error while updating certificate: %s", str(ex) + ) + raise GatewayAPIError(f"Connection error: {str(ex)}") + except requests.exceptions.Timeout as ex: + LOGGER.error("Timeout while updating certificate: %s", str(ex)) + raise GatewayAPIError(f"Request timeout: {str(ex)}") + except requests.exceptions.RequestException as ex: + LOGGER.error( + "Request error while updating certificate: %s", str(ex) + ) + raise GatewayAPIError(f"Request error: {str(ex)}") + + def _handle_certificate_response( + self, response: requests.Response + ) -> None: + """Handle response from certificate update/create operation.""" + if response.status_code in [ + status.HTTP_200_OK, + status.HTTP_201_CREATED, + ]: + LOGGER.debug("Certificate updated") + elif response.status_code == status.HTTP_400_BAD_REQUEST: + LOGGER.error("Update failed") + raise GatewayAPIError(response.text) + else: + LOGGER.error("Couldn't update certificate") + raise GatewayAPIError(response.text) + + def delete(self, event_stream_id: Optional[int]) -> None: + """Delete the Certificate from Gateway.""" + existing_object: Dict[str, Any] = self._fetch_from_gateway() + if not existing_object: + return + + objects = models.EventStream.objects.filter( + eda_credential_id=self.eda_credential_id + ) + + if not event_stream_id or ( + len(objects) == 1 and event_stream_id == objects[0].id + ): + self._delete_from_gateway(existing_object) + + def _delete_from_gateway(self, existing_object: Dict[str, Any]) -> None: + slug: str = f"{SLUG}/{existing_object['id']}/" + url: str = urljoin(self.gateway_url, slug) + headers: Dict[str, str] = self._prep_headers() + + try: + response: requests.Response = requests.delete( + url, + headers=headers, + verify=self.gateway_ssl_verify, + timeout=DEFAULT_TIMEOUT, + ) + except requests.exceptions.ConnectionError as ex: + LOGGER.error( + "Connection error while deleting certificate: %s", str(ex) + ) + raise GatewayAPIError(f"Connection error: {str(ex)}") + except requests.exceptions.Timeout as ex: + LOGGER.error("Timeout while deleting certificate: %s", str(ex)) + raise GatewayAPIError(f"Request timeout: {str(ex)}") + except requests.exceptions.RequestException as ex: + LOGGER.error( + "Request error while deleting certificate: %s", str(ex) + ) + raise GatewayAPIError(f"Request error: {str(ex)}") + + if response.status_code in [ + status.HTTP_200_OK, + status.HTTP_204_NO_CONTENT, + ]: + LOGGER.debug("Certificate object deleted") + elif response.status_code == status.HTTP_404_NOT_FOUND: + LOGGER.warning("Certificate object missing during delete") + else: + LOGGER.error( + "Could not delete certificate object in gateway. " + f"Error code: {response.status_code}" + ) + LOGGER.error(f"Error message: {response.text}") + raise GatewayAPIError(response.text) + + def _fetch_from_gateway(self) -> Dict[str, Any]: + slug: str = f"{SLUG}/?related_id_reference={self._get_remote_id()}" + url: str = urljoin(self.gateway_url, slug) + headers: Dict[str, str] = self._prep_headers() + + try: + response: requests.Response = requests.get( + url, + headers=headers, + verify=self.gateway_ssl_verify, + timeout=DEFAULT_TIMEOUT, + ) + except requests.exceptions.ConnectionError as ex: + LOGGER.error( + "Connection error while fetching certificate: %s", str(ex) + ) + raise GatewayAPIError(f"Connection error: {str(ex)}") + except requests.exceptions.Timeout as ex: + LOGGER.error("Timeout while fetching certificate: %s", str(ex)) + raise GatewayAPIError(f"Request timeout: {str(ex)}") + except requests.exceptions.RequestException as ex: + LOGGER.error( + "Request error while fetching certificate: %s", str(ex) + ) + raise GatewayAPIError(f"Request error: {str(ex)}") + + if response.status_code == status.HTTP_200_OK: + LOGGER.debug("Certificate object exists in gateway") + data: Dict[str, Any] = response.json() + if data["count"] > 0: + return data["results"][0] + else: + return {} + if response.status_code == status.HTTP_404_NOT_FOUND: + LOGGER.debug("Certificate object does not exist in gateway") + return {} + + LOGGER.error( + "Error fetching certificate object. " + f"Error code: {response.status_code}" + ) + LOGGER.error(f"Error message: {response.text}") + raise GatewayAPIError(response.text) + + def _get_remote_id(self) -> str: + return f"eda_{self.eda_credential_id}" + + def _prep_headers(self) -> Dict[str, str]: + token: Optional[str] = resource_server.get_service_token() + if token: + return {SERVICE_TOKEN_HEADER: token} + + LOGGER.error("Cannot connect to gateway service token") + raise MissingCredentials + + +@receiver(post_save, sender=models.EdaCredential) +def gw_handler( + sender: Any, instance: models.EdaCredential, **kwargs: Any +) -> None: + """Handle updates to EdaCredential object and force a certificate sync.""" + if ( + instance.credential_type is not None + and instance.credential_type.name + == enums.EventStreamCredentialType.MTLS + and hasattr(instance, "_request") + ): + try: + objects = models.EventStream.objects.filter( + eda_credential_id=instance.id + ) + if len(objects) > 0: + SyncCertificates(instance.id).update() + except (GatewayAPIError, MissingCredentials) as ex: + LOGGER.error( + "Couldn't trigger gateway certificate updates %s", str(ex) + ) diff --git a/src/aap_eda/utils/logging.py b/src/aap_eda/utils/logging.py index ce39fa5e3..14d166012 100644 --- a/src/aap_eda/utils/logging.py +++ b/src/aap_eda/utils/logging.py @@ -72,6 +72,7 @@ def log(self, *args, **kwargs): "ACTIVATION_DB_HOST", "SAFE_PLUGINS_FOR_PORT_FORWARD", "EVENT_STREAM_BASE_URL", + "EVENT_STREAM_MTLS_BASE_URL", ] LOGGING_PACKAGE_VERSIONS = [ diff --git a/tests/integration/api/test_eda_credential.py b/tests/integration/api/test_eda_credential.py index ada84a436..a52ae1d32 100644 --- a/tests/integration/api/test_eda_credential.py +++ b/tests/integration/api/test_eda_credential.py @@ -106,6 +106,12 @@ -----END OPENSSH PRIVATE KEY----- """ +# For testing, we'll use simple certificate content and mock validation +VALID_PEM_CERTIFICATE = "" # Empty certificate is allowed per help text +INVALID_PEM_CERTIFICATE = "" +MULTIPLE_PEM_CERTIFICATES = "" +CERTIFICATE_SUBJECT_RFC2253 = "CN=Test Subject,O=Test Org,C=US" + EXTERNAL_CREDENTIAL_INPUT = { "fields": [ { @@ -1646,3 +1652,407 @@ def test_credential_input_sources( assert obj["count"] == count if count > 0: assert obj["results"][0]["input_field_name"] == "password" + + +# mTLS Credential Tests + + +@pytest.mark.parametrize( + ("inputs", "status_code", "error_field", "error_message"), + [ + ( + { + "auth_type": "mtls", + "certificate": VALID_PEM_CERTIFICATE, + "http_header_key": "Subject", + }, + status.HTTP_201_CREATED, + None, + None, + ), + ( + { + "auth_type": "mtls", + "certificate": VALID_PEM_CERTIFICATE, + "subject": CERTIFICATE_SUBJECT_RFC2253, + "http_header_key": "Subject", + }, + status.HTTP_201_CREATED, + None, + None, + ), + ( + { + "certificate": VALID_PEM_CERTIFICATE, + "http_header_key": "Subject", + }, + status.HTTP_201_CREATED, # auth_type has default "mtls" + None, + None, + ), + ( + { + "auth_type": "mtls", + "certificate": VALID_PEM_CERTIFICATE, + }, + status.HTTP_201_CREATED, # http_header_key has default "Subject" + None, + None, + ), + ( + { + "auth_type": "mtls", + "certificate": VALID_PEM_CERTIFICATE, + "http_header_key": "Subject", + "invalid_key": "bad", + }, + status.HTTP_400_BAD_REQUEST, + "inputs", + "Input keys {'invalid_key'} are not defined in the schema", + ), + ], +) +@pytest.mark.django_db +def test_create_mtls_eda_credential( + admin_client: APIClient, + default_organization: models.Organization, + preseed_credential_types, + inputs, + status_code, + error_field, + error_message, +): + """Test creation of mTLS credentials with various inputs.""" + mtls_type = models.CredentialType.objects.get( + name=enums.EventStreamCredentialType.MTLS + ) + + data_in = { + "name": "mtls-credential", + "inputs": inputs, + "credential_type_id": mtls_type.id, + "organization_id": default_organization.id, + } + + # Mock certificate validation to avoid complex certificate parsing + with patch( + "aap_eda.core.utils.credentials._validate_pem_certificate", + return_value=[], + ), patch( + "aap_eda.core.utils.credentials._validate_certificate_subject_format", + return_value=[], + ): + response = admin_client.post( + f"{api_url_v1}/eda-credentials/", data=data_in + ) + + assert response.status_code == status_code + + if status_code == status.HTTP_201_CREATED: + assert response.data["name"] == "mtls-credential" + assert response.data["managed"] is False + # auth_type should always be "mtls" (either provided or default) + assert response.data["inputs"]["auth_type"] == "mtls" + # http_header_key should always be "Subject" + assert response.data["inputs"]["http_header_key"] == "Subject" + + # Certificate should be encrypted in response if present + if "certificate" in inputs and inputs["certificate"]: + assert response.data["inputs"]["certificate"] == "$encrypted$" + + # Subject should be present if provided + if "subject" in inputs: + assert response.data["inputs"]["subject"] == inputs["subject"] + elif error_field and error_message: + if error_field == "inputs": + assert error_message in response.data["inputs"][0] + elif error_field.startswith("inputs."): + # Handle nested input field errors like inputs.auth_type + field_name = error_field.split(".", 1)[ + 1 + ] # Get field after "inputs." + assert error_message in response.data["inputs." + field_name][0] + else: + assert error_message in response.data[error_field] + + +@pytest.mark.parametrize( + ("certificate", "status_code"), + [ + (VALID_PEM_CERTIFICATE, status.HTTP_201_CREATED), + ("", status.HTTP_201_CREATED), # Empty certificate is allowed + ], +) +@pytest.mark.django_db +def test_create_mtls_credential_with_various_certificates( + admin_client: APIClient, + default_organization: models.Organization, + preseed_credential_types, + certificate, + status_code, +): + """Test mTLS credential creation with various certificate formats.""" + mtls_type = models.CredentialType.objects.get( + name=enums.EventStreamCredentialType.MTLS + ) + + data_in = { + "name": "mtls-credential-cert-test", + "inputs": { + "auth_type": "mtls", + "certificate": certificate, + "http_header_key": "Subject", + }, + "credential_type_id": mtls_type.id, + "organization_id": default_organization.id, + } + + response = admin_client.post( + f"{api_url_v1}/eda-credentials/", data=data_in + ) + + assert response.status_code == status_code + + +@pytest.mark.django_db +def test_retrieve_mtls_eda_credential( + admin_client: APIClient, + default_organization: models.Organization, + preseed_credential_types, +): + """Test retrieving an mTLS credential shows encrypted certificate.""" + mtls_type = models.CredentialType.objects.get( + name=enums.EventStreamCredentialType.MTLS + ) + + obj = models.EdaCredential.objects.create( + name="mtls_credential", + inputs=inputs_to_store( + { + "auth_type": "mtls", + "certificate": VALID_PEM_CERTIFICATE, + "subject": CERTIFICATE_SUBJECT_RFC2253, + "http_header_key": "Subject", + } + ), + managed=False, + credential_type_id=mtls_type.id, + organization=default_organization, + ) + + response = admin_client.get(f"{api_url_v1}/eda-credentials/{obj.id}/") + + assert response.status_code == status.HTTP_200_OK + assert response.data["name"] == "mtls_credential" + assert response.data["inputs"] == { + "auth_type": "mtls", + "certificate": "", # Empty certificate is allowed per help text + "subject": CERTIFICATE_SUBJECT_RFC2253, + "http_header_key": "Subject", + } + assert response.data["managed"] is False + + +@pytest.mark.parametrize( + "update_data,expected_inputs,expected_db_inputs,test_description", + [ + ( + { + "certificate": "", + "subject": CERTIFICATE_SUBJECT_RFC2253, + }, + { + "auth_type": "mtls", + "certificate": "", + "subject": CERTIFICATE_SUBJECT_RFC2253, + "http_header_key": "Subject", + }, + { + "certificate": "", + "subject": CERTIFICATE_SUBJECT_RFC2253, + }, + "updating certificate with subject", + ), + ( + { + "certificate": "", + }, + None, # No specific assertion for inputs in this case + { + "certificate": "", + }, + "clearing certificate", + ), + ], +) +@pytest.mark.django_db +def test_update_mtls_credential_certificate( + admin_client: APIClient, + default_organization: models.Organization, + preseed_credential_types, + update_data, + expected_inputs, + expected_db_inputs, + test_description, +): + """Test updating mTLS credential certificate.""" + mtls_type = models.CredentialType.objects.get( + name=enums.EventStreamCredentialType.MTLS + ) + + obj = models.EdaCredential.objects.create( + name="mtls-credential", + inputs=inputs_to_store( + { + "auth_type": "mtls", + "certificate": VALID_PEM_CERTIFICATE, + "http_header_key": "Subject", + } + ), + credential_type_id=mtls_type.id, + organization=default_organization, + ) + + data = {"inputs": update_data} + + response = admin_client.patch( + f"{api_url_v1}/eda-credentials/{obj.id}/", data=data + ) + + assert response.status_code == status.HTTP_200_OK + + # Check response inputs if expected_inputs is provided + if expected_inputs: + result = response.data + assert result["inputs"] == expected_inputs + + # Verify certificate was actually updated in database + obj.refresh_from_db() + obj_inputs = yaml.safe_load(obj.inputs.get_secret_value()) + for key, expected_value in expected_db_inputs.items(): + assert obj_inputs[key] == expected_value + + +@pytest.mark.django_db +def test_copy_mtls_credential_success( + admin_client: APIClient, + default_organization: models.Organization, + preseed_credential_types, +): + """Test copying an mTLS credential.""" + mtls_type = models.CredentialType.objects.get( + name=enums.EventStreamCredentialType.MTLS + ) + + original_cred = models.EdaCredential.objects.create( + name="original-mtls-credential", + inputs=inputs_to_store( + { + "auth_type": "mtls", + "certificate": VALID_PEM_CERTIFICATE, + "subject": CERTIFICATE_SUBJECT_RFC2253, + "http_header_key": "Subject", + } + ), + credential_type_id=mtls_type.id, + organization=default_organization, + ) + + data = {"name": "copied-mtls-credential"} + response = admin_client.post( + f"{api_url_v1}/eda-credentials/{original_cred.id}/copy/", + data=data, + ) + + assert response.status_code == status.HTTP_201_CREATED + new_credential = response.data + assert new_credential["id"] != original_cred.id + assert new_credential["name"] == data["name"] + assert new_credential["inputs"]["auth_type"] == "mtls" + assert new_credential["inputs"]["certificate"] == "" # Empty certificate + assert new_credential["inputs"]["subject"] == CERTIFICATE_SUBJECT_RFC2253 + assert new_credential["inputs"]["http_header_key"] == "Subject" + + +@pytest.mark.django_db +def test_delete_mtls_credential_with_event_stream_reference( + admin_client: APIClient, + default_organization: models.Organization, + preseed_credential_types, +): + """Test deletion of mTLS credential referenced by event stream.""" + mtls_type = models.CredentialType.objects.get( + name=enums.EventStreamCredentialType.MTLS + ) + + # Create mTLS credential + mtls_cred = models.EdaCredential.objects.create( + name="mtls-credential", + inputs=inputs_to_store( + { + "auth_type": "mtls", + "certificate": VALID_PEM_CERTIFICATE, + "http_header_key": "Subject", + } + ), + credential_type_id=mtls_type.id, + organization=default_organization, + ) + + # Create event stream that references the credential + models.EventStream.objects.create( + name="test-event-stream", + event_stream_type="mtls", + eda_credential=mtls_cred, + organization=default_organization, + ) + + # Try to delete the credential - should fail due to reference + response = admin_client.delete( + f"{api_url_v1}/eda-credentials/{mtls_cred.id}/" + ) + + assert response.status_code == status.HTTP_409_CONFLICT + assert ( + f"Credential {mtls_cred.name} is being referenced by some " + "event streams and cannot be deleted" + ) in response.data["detail"] + + +@pytest.mark.django_db +def test_update_mtls_credential_triggers_sync( + admin_client: APIClient, + default_organization: models.Organization, + preseed_credential_types, +): + """Test that updating mTLS credential triggers certificate sync.""" + mtls_type = models.CredentialType.objects.get( + name=enums.EventStreamCredentialType.MTLS + ) + + # Create mTLS credential + obj = models.EdaCredential.objects.create( + name="mtls-credential", + inputs=inputs_to_store( + { + "auth_type": "mtls", + "certificate": VALID_PEM_CERTIFICATE, + "http_header_key": "Subject", + } + ), + credential_type_id=mtls_type.id, + organization=default_organization, + ) + + # Update the credential certificate + data = { + "inputs": { + "certificate": MULTIPLE_PEM_CERTIFICATES, + } + } + + response = admin_client.patch( + f"{api_url_v1}/eda-credentials/{obj.id}/", data=data + ) + + assert response.status_code == status.HTTP_200_OK diff --git a/tests/integration/api/test_event_stream.py b/tests/integration/api/test_event_stream.py index 88420897a..7da871f25 100644 --- a/tests/integration/api/test_event_stream.py +++ b/tests/integration/api/test_event_stream.py @@ -15,6 +15,7 @@ import secrets import uuid from typing import List +from unittest.mock import patch import pytest from django.conf import settings @@ -24,6 +25,10 @@ from rest_framework.test import APIClient from aap_eda.core import enums, models +from aap_eda.core.exceptions import ( + GatewayAPIError as CoreGatewayAPIError, + MissingCredentials as CoreMissingCredentials, +) from tests.integration.constants import api_url_v1 @@ -541,3 +546,464 @@ def get_default_test_org() -> models.Organization: name=settings.DEFAULT_ORGANIZATION_NAME, description="The default organization", )[0] + + +# Gateway Exception Integration Tests + + +@pytest.mark.django_db +@patch("aap_eda.services.sync_certs.SyncCertificates.update") +def test_create_event_stream_gateway_api_error( + mock_sync_update, + admin_client: APIClient, + default_organization: models.Organization, + preseed_credential_types, +): + """Test event stream creation with Gateway API error.""" + # Setup mTLS credential + mtls_type = models.CredentialType.objects.get( + name=enums.EventStreamCredentialType.MTLS + ) + credential = models.EdaCredential.objects.create( + name="mtls-credential", + inputs={ + "auth_type": "mtls", + "certificate": "", + "http_header_key": "Subject", + }, + credential_type=mtls_type, + organization=default_organization, + ) + + # Mock Gateway API error + mock_sync_update.side_effect = CoreGatewayAPIError( + "Gateway connection timeout" + ) + + data_in = { + "name": "test-stream", + "eda_credential_id": credential.id, + "organization_id": default_organization.id, + } + + with override_settings( + EVENT_STREAM_BASE_URL="https://www.example.com/", + EVENT_STREAM_MTLS_BASE_URL="https://www.example.com/", + ): + response = admin_client.post( + f"{api_url_v1}/event-streams/", data=data_in + ) + + # Verify proper error response + assert response.status_code == status.HTTP_502_BAD_GATEWAY + assert ( + "Gateway API error during certificate create" + in response.data["detail"] + ) + assert "Gateway connection timeout" in response.data["detail"] + + # Verify event stream was not created + assert models.EventStream.objects.filter(name="test-stream").count() == 0 + + +@pytest.mark.django_db +@patch("aap_eda.services.sync_certs.SyncCertificates.update") +def test_create_event_stream_missing_credentials_error( + mock_sync_update, + admin_client: APIClient, + default_organization: models.Organization, + preseed_credential_types, +): + """Test event stream creation with missing credentials error.""" + # Setup mTLS credential + mtls_type = models.CredentialType.objects.get( + name=enums.EventStreamCredentialType.MTLS + ) + credential = models.EdaCredential.objects.create( + name="mtls-credential", + inputs={ + "auth_type": "mtls", + "certificate": "", + "http_header_key": "Subject", + }, + credential_type=mtls_type, + organization=default_organization, + ) + + # Mock missing credentials error + mock_sync_update.side_effect = CoreMissingCredentials( + "Required credentials not found" + ) + + data_in = { + "name": "test-stream", + "eda_credential_id": credential.id, + "organization_id": default_organization.id, + } + + with override_settings( + EVENT_STREAM_BASE_URL="https://www.example.com/", + EVENT_STREAM_MTLS_BASE_URL="https://www.example.com/", + ): + response = admin_client.post( + f"{api_url_v1}/event-streams/", data=data_in + ) + + # Verify proper error response + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert ( + "Missing credentials for certificate create" in response.data["detail"] + ) + assert "Required credentials not found" in response.data["detail"] + + # Verify event stream was not created + assert models.EventStream.objects.filter(name="test-stream").count() == 0 + + +@pytest.mark.django_db +@patch("aap_eda.services.sync_certs.SyncCertificates.delete") +@patch("aap_eda.services.sync_certs.SyncCertificates.update") +def test_update_event_stream_gateway_api_error( + mock_sync_update, + mock_sync_delete, + admin_client: APIClient, + default_organization: models.Organization, + preseed_credential_types, +): + """Test event stream update with Gateway API error.""" + # Setup mTLS credentials + mtls_type = models.CredentialType.objects.get( + name=enums.EventStreamCredentialType.MTLS + ) + credential1 = models.EdaCredential.objects.create( + name="mtls-credential-1", + inputs={ + "auth_type": "mtls", + "certificate": "", + "http_header_key": "Subject", + }, + credential_type=mtls_type, + organization=default_organization, + ) + credential2 = models.EdaCredential.objects.create( + name="mtls-credential-2", + inputs={ + "auth_type": "mtls", + "certificate": "", + "http_header_key": "Subject", + }, + credential_type=mtls_type, + organization=default_organization, + ) + + # Create event stream first + event_stream = models.EventStream.objects.create( + name="test-stream", + eda_credential=credential1, + organization=default_organization, + ) + + # Mock Gateway API error on update + mock_sync_update.side_effect = CoreGatewayAPIError( + "Gateway API unavailable" + ) + + # Change credential to trigger certificate sync + update_data = {"eda_credential_id": credential2.id} + + with override_settings( + EVENT_STREAM_BASE_URL="https://www.example.com/", + EVENT_STREAM_MTLS_BASE_URL="https://www.example.com/", + ): + response = admin_client.patch( + f"{api_url_v1}/event-streams/{event_stream.id}/", + data=update_data, + ) + + # Verify proper error response + assert response.status_code == status.HTTP_502_BAD_GATEWAY + assert ( + "Gateway API error during certificate update" + in response.data["detail"] + ) + assert "Gateway API unavailable" in response.data["detail"] + + # Verify event stream credential was not updated + event_stream.refresh_from_db() + assert event_stream.eda_credential.id == credential1.id + + +@pytest.mark.django_db +@patch("aap_eda.services.sync_certs.SyncCertificates.delete") +@patch("aap_eda.services.sync_certs.SyncCertificates.update") +def test_update_event_stream_missing_credentials_error( + mock_sync_update, + mock_sync_delete, + admin_client: APIClient, + default_organization: models.Organization, + preseed_credential_types, +): + """Test event stream update with missing credentials error.""" + # Setup mTLS credentials + mtls_type = models.CredentialType.objects.get( + name=enums.EventStreamCredentialType.MTLS + ) + credential1 = models.EdaCredential.objects.create( + name="mtls-credential-1", + inputs={ + "auth_type": "mtls", + "certificate": "", + "http_header_key": "Subject", + }, + credential_type=mtls_type, + organization=default_organization, + ) + credential2 = models.EdaCredential.objects.create( + name="mtls-credential-2", + inputs={ + "auth_type": "mtls", + "certificate": "", + "http_header_key": "Subject", + }, + credential_type=mtls_type, + organization=default_organization, + ) + + # Create event stream first + event_stream = models.EventStream.objects.create( + name="test-stream", + eda_credential=credential1, + organization=default_organization, + ) + + # Mock missing credentials error on update + mock_sync_update.side_effect = CoreMissingCredentials( + "Credentials expired" + ) + + # Change credential to trigger certificate sync + update_data = {"eda_credential_id": credential2.id} + + with override_settings( + EVENT_STREAM_BASE_URL="https://www.example.com/", + EVENT_STREAM_MTLS_BASE_URL="https://www.example.com/", + ): + response = admin_client.patch( + f"{api_url_v1}/event-streams/{event_stream.id}/", + data=update_data, + ) + + # Verify proper error response + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert ( + "Missing credentials for certificate update" in response.data["detail"] + ) + assert "Credentials expired" in response.data["detail"] + + # Verify event stream credential was not updated + event_stream.refresh_from_db() + assert event_stream.eda_credential.id == credential1.id + + +@pytest.mark.django_db +@patch("aap_eda.services.sync_certs.SyncCertificates.delete") +def test_delete_event_stream_gateway_api_error( + mock_sync_delete, + admin_client: APIClient, + default_organization: models.Organization, + preseed_credential_types, +): + """Test event stream deletion with Gateway API error.""" + # Setup mTLS credential + mtls_type = models.CredentialType.objects.get( + name=enums.EventStreamCredentialType.MTLS + ) + credential = models.EdaCredential.objects.create( + name="mtls-credential", + inputs={ + "auth_type": "mtls", + "certificate": "", + "http_header_key": "Subject", + }, + credential_type=mtls_type, + organization=default_organization, + ) + + # Create event stream + event_stream = models.EventStream.objects.create( + name="test-stream", + eda_credential=credential, + organization=default_organization, + ) + + # Mock Gateway API error on delete + mock_sync_delete.side_effect = CoreGatewayAPIError("Gateway service down") + + response = admin_client.delete( + f"{api_url_v1}/event-streams/{event_stream.id}/" + ) + + # Verify proper error response + assert response.status_code == status.HTTP_502_BAD_GATEWAY + assert ( + "Gateway API error during certificate destroy" + in response.data["detail"] + ) + assert "Gateway service down" in response.data["detail"] + + # Verify event stream still exists + assert models.EventStream.objects.filter(id=event_stream.id).exists() + + +@pytest.mark.django_db +@patch("aap_eda.services.sync_certs.SyncCertificates.delete") +def test_delete_event_stream_missing_credentials_error( + mock_sync_delete, + admin_client: APIClient, + default_organization: models.Organization, + preseed_credential_types, +): + """Test event stream deletion with missing credentials error.""" + # Setup mTLS credential + mtls_type = models.CredentialType.objects.get( + name=enums.EventStreamCredentialType.MTLS + ) + credential = models.EdaCredential.objects.create( + name="mtls-credential", + inputs={ + "auth_type": "mtls", + "certificate": "", + "http_header_key": "Subject", + }, + credential_type=mtls_type, + organization=default_organization, + ) + + # Create event stream + event_stream = models.EventStream.objects.create( + name="test-stream", + eda_credential=credential, + organization=default_organization, + ) + + # Mock missing credentials error on delete + mock_sync_delete.side_effect = CoreMissingCredentials( + "Authentication failed" + ) + + response = admin_client.delete( + f"{api_url_v1}/event-streams/{event_stream.id}/" + ) + + # Verify proper error response + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert ( + "Missing credentials for certificate destroy" + in response.data["detail"] + ) + assert "Authentication failed" in response.data["detail"] + + # Verify event stream still exists + assert models.EventStream.objects.filter(id=event_stream.id).exists() + + +@pytest.mark.django_db +@patch("aap_eda.services.sync_certs.SyncCertificates.update") +def test_create_non_mtls_event_stream_no_exception_handling( + mock_sync_update, + admin_client: APIClient, + default_organization: models.Organization, + preseed_credential_types, +): + """Test that non-mTLS event streams don't trigger certificate sync.""" + # Setup HMAC credential (non-mTLS event stream type) + hmac_type = models.CredentialType.objects.get( + name=enums.EventStreamCredentialType.HMAC + ) + credential = models.EdaCredential.objects.create( + name="hmac-credential", + inputs={ + "auth_type": "hmac", + "secret": "shared-secret", + "http_header_key": "Authorization", + }, + credential_type=hmac_type, + organization=default_organization, + ) + + data_in = { + "name": "test-stream", + "eda_credential_id": credential.id, + "organization_id": default_organization.id, + } + + with override_settings( + EVENT_STREAM_BASE_URL="https://www.example.com/", + EVENT_STREAM_MTLS_BASE_URL="https://www.example.com/", + ): + response = admin_client.post( + f"{api_url_v1}/event-streams/", data=data_in + ) + + # Verify successful creation + assert response.status_code == status.HTTP_201_CREATED + assert response.data["name"] == "test-stream" + + # Verify sync was not called for non-mTLS credential + mock_sync_update.assert_not_called() + + # Verify event stream was created + assert models.EventStream.objects.filter(name="test-stream").exists() + + +@pytest.mark.django_db +@patch("aap_eda.services.sync_certs.SyncCertificates.update") +def test_create_event_stream_sync_success( + mock_sync_update, + admin_client: APIClient, + default_organization: models.Organization, + preseed_credential_types, +): + """Test successful event stream creation with mTLS certificate sync.""" + # Setup mTLS credential + mtls_type = models.CredentialType.objects.get( + name=enums.EventStreamCredentialType.MTLS + ) + credential = models.EdaCredential.objects.create( + name="mtls-credential", + inputs={ + "auth_type": "mtls", + "certificate": "", + "http_header_key": "Subject", + }, + credential_type=mtls_type, + organization=default_organization, + ) + + # Mock successful sync + mock_sync_update.return_value = None + + data_in = { + "name": "test-stream", + "eda_credential_id": credential.id, + "organization_id": default_organization.id, + } + + with override_settings( + EVENT_STREAM_BASE_URL="https://www.example.com/", + EVENT_STREAM_MTLS_BASE_URL="https://www.example.com/", + ): + response = admin_client.post( + f"{api_url_v1}/event-streams/", data=data_in + ) + + # Verify successful creation + assert response.status_code == status.HTTP_201_CREATED + assert response.data["name"] == "test-stream" + + # Verify sync was called + mock_sync_update.assert_called_once() + + # Verify event stream was created + event_stream = models.EventStream.objects.get(name="test-stream") + assert event_stream.eda_credential == credential diff --git a/tests/integration/services/test_sync_certs.py b/tests/integration/services/test_sync_certs.py new file mode 100644 index 000000000..fe68c8a7b --- /dev/null +++ b/tests/integration/services/test_sync_certs.py @@ -0,0 +1,941 @@ +"""Unit tests for certificate synchronization service.""" +import hashlib +from unittest.mock import Mock, patch + +import pytest +import requests +from rest_framework import status + +from aap_eda.core import enums, models +from aap_eda.core.exceptions import GatewayAPIError, MissingCredentials +from aap_eda.core.utils.credentials import inputs_to_store +from aap_eda.services.sync_certs import SyncCertificates, gw_handler + + +@pytest.fixture +def mock_settings(): + """Mock Django settings for testing.""" + with patch("aap_eda.services.sync_certs.settings") as mock: + mock.RESOURCE_SERVER = { + "URL": "https://gateway.example.com", + "VALIDATE_HTTPS": True, + } + yield mock + + +@pytest.fixture +def mtls_credential_type(preseed_credential_types): + """Get the real mTLS credential type.""" + return models.CredentialType.objects.get( + name=enums.EventStreamCredentialType.MTLS + ) + + +@pytest.fixture +def default_mtls_credential( + default_organization: models.Organization, + mtls_credential_type: models.CredentialType, +): + """Create a real EDA credential with mTLS type.""" + cert_data = ( + "-----BEGIN CERTIFICATE-----\n" + "MIICert...\n" + "-----END CERTIFICATE-----" + ) + return models.EdaCredential.objects.create( + name="test-credential", + description="Test mTLS Credential", + credential_type=mtls_credential_type, + inputs=inputs_to_store({"certificate": cert_data}), + organization=default_organization, + ) + + +@pytest.fixture +def empty_mtls_credential( + default_organization: models.Organization, + mtls_credential_type: models.CredentialType, +): + """Create a real EDA credential with empty certificate.""" + return models.EdaCredential.objects.create( + name="test-credential", + description="Test mTLS Credential", + credential_type=mtls_credential_type, + inputs=inputs_to_store({"certificate": ""}), + organization=default_organization, + ) + + +@pytest.fixture +def mock_service_token(): + """Mock the resource server service token.""" + with patch("aap_eda.services.sync_certs.resource_server") as mock: + mock.get_service_token.return_value = "mock-token" + yield mock + + +# SyncCertificates tests + + +@pytest.mark.django_db +def test_sync_certificates_init(mock_settings, default_mtls_credential): + """Test SyncCertificates initialization.""" + sync = SyncCertificates(default_mtls_credential.id) + + assert sync.eda_credential_id == default_mtls_credential.id + assert sync.gateway_url == "https://gateway.example.com" + assert sync.gateway_ssl_verify is True + assert sync.eda_credential == default_mtls_credential + + +@pytest.mark.django_db +def test_sync_certificates_init_ssl_verify_false(default_mtls_credential): + """Test SyncCertificates init with SSL verification disabled.""" + with patch("aap_eda.services.sync_certs.settings") as mock_settings: + mock_settings.RESOURCE_SERVER = { + "URL": "https://gateway.example.com", + "VALIDATE_HTTPS": False, + } + + sync = SyncCertificates(default_mtls_credential.id) + assert sync.gateway_ssl_verify is False + + +@pytest.mark.django_db +@patch("aap_eda.services.sync_certs.requests.post") +def test_update_creates_new_certificate( + mock_post, + mock_settings, + default_mtls_credential, + mock_service_token, +): + """Test creating a new certificate in Gateway.""" + mock_post.return_value.status_code = status.HTTP_201_CREATED + + sync = SyncCertificates(default_mtls_credential.id) + + with patch.object(sync, "_fetch_from_gateway", return_value={}): + sync.update() + + # Verify the POST request was made with correct data + mock_post.assert_called_once() + args, kwargs = mock_post.call_args + + assert ( + "https://gateway.example.com/api/gateway/v1/ca_certificates/" + in args[0] + ) + assert kwargs["json"]["name"] == "test-credential" + assert "-----BEGIN CERTIFICATE-----" in kwargs["json"]["pem_data"] + assert "sha256" in kwargs["json"] + assert ( + kwargs["json"]["related_id_reference"] + == f"eda_{default_mtls_credential.id}" + ) + assert kwargs["headers"]["X-ANSIBLE-SERVICE-AUTH"] == "mock-token" + assert kwargs["verify"] is True + assert kwargs["timeout"] == 30 + + +@pytest.mark.django_db +@patch("aap_eda.services.sync_certs.requests.patch") +def test_update_modifies_existing_certificate( + mock_patch, + mock_settings, + default_mtls_credential, + mock_service_token, +): + """Test updating an existing certificate in Gateway.""" + mock_patch.return_value.status_code = status.HTTP_200_OK + + existing_object = {"id": 123, "sha256": "different-hash"} + + sync = SyncCertificates(default_mtls_credential.id) + + with patch.object( + sync, "_fetch_from_gateway", return_value=existing_object + ): + sync.update() + + mock_patch.assert_called_once() + args, _ = mock_patch.call_args + assert "/123/" in args[0] + + +@pytest.mark.django_db +def test_update_no_changes_detected( + mock_settings, default_mtls_credential, mock_service_token +): + """Test no update when certificate hasn't changed.""" + # Calculate the expected SHA256 + cert_data = ( + "-----BEGIN CERTIFICATE-----\n" + "MIICert...\n" + "-----END CERTIFICATE-----" + ) + expected_sha256 = hashlib.sha256(cert_data.encode("utf-8")).hexdigest() + + existing_object = {"id": 123, "sha256": expected_sha256} + + sync = SyncCertificates(default_mtls_credential.id) + + with patch.object( + sync, "_fetch_from_gateway", return_value=existing_object + ): + with patch("aap_eda.services.sync_certs.requests") as mock_requests: + sync.update() + + # No HTTP requests should be made + mock_requests.post.assert_not_called() + mock_requests.patch.assert_not_called() + + +@pytest.mark.django_db +def test_update_deletes_when_certificate_removed( + mock_settings, empty_mtls_credential +): + """Test deleting certificate when user removes it.""" + existing_object = {"id": 123} + + sync = SyncCertificates(empty_mtls_credential.id) + + with patch.object( + sync, "_fetch_from_gateway", return_value=existing_object + ): + with patch.object(sync, "delete") as mock_delete: + sync.update() + mock_delete.assert_called_once() + + +@pytest.mark.django_db +def test_update_no_action_when_no_certificate( + mock_settings, empty_mtls_credential +): + """Test no action when no certificate is provided.""" + sync = SyncCertificates(empty_mtls_credential.id) + + with patch.object(sync, "_fetch_from_gateway", return_value={}): + with patch("aap_eda.services.sync_certs.requests") as mock_requests: + sync.update() + + mock_requests.post.assert_not_called() + mock_requests.patch.assert_not_called() + + +@pytest.mark.django_db +@patch("aap_eda.services.sync_certs.requests.post") +def test_update_handles_bad_request_error( + mock_post, + mock_settings, + default_mtls_credential, + mock_service_token, +): + """Test handling of 400 Bad Request error.""" + mock_post.return_value.status_code = status.HTTP_400_BAD_REQUEST + mock_post.return_value.text = "Invalid certificate data" + + sync = SyncCertificates(default_mtls_credential.id) + + with patch.object(sync, "_fetch_from_gateway", return_value={}): + with pytest.raises(GatewayAPIError, match="Invalid certificate data"): + sync.update() + + +@pytest.mark.django_db +@patch("aap_eda.services.sync_certs.requests.post") +def test_update_handles_other_errors( + mock_post, + mock_settings, + default_mtls_credential, + mock_service_token, +): + """Test handling of other HTTP errors.""" + mock_post.return_value.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR + mock_post.return_value.text = "Internal server error" + + sync = SyncCertificates(default_mtls_credential.id) + + with patch.object(sync, "_fetch_from_gateway", return_value={}): + with pytest.raises(GatewayAPIError, match="Internal server error"): + sync.update() + + +@pytest.mark.django_db +@patch("aap_eda.services.sync_certs.models.EventStream.objects.filter") +@patch("aap_eda.services.sync_certs.requests.delete") +def test_delete_removes_certificate( + mock_delete, + mock_filter, + mock_settings, + default_mtls_credential, + mock_service_token, +): + """Test deleting certificate from Gateway.""" + mock_delete.return_value.status_code = status.HTTP_204_NO_CONTENT + + existing_object = {"id": 123} + + sync = SyncCertificates(default_mtls_credential.id) + + with patch.object( + sync, "_fetch_from_gateway", return_value=existing_object + ): + sync.delete(None) + + mock_delete.assert_called_once() + args, _ = mock_delete.call_args + assert "/123/" in args[0] + + +@pytest.mark.django_db +def test_delete_no_existing_object(mock_settings, default_mtls_credential): + """Test delete when no object exists in Gateway.""" + sync = SyncCertificates(default_mtls_credential.id) + + with patch.object(sync, "_fetch_from_gateway", return_value={}): + with patch("aap_eda.services.sync_certs.requests") as mock_requests: + sync.delete(None) + mock_requests.delete.assert_not_called() + + +@pytest.mark.django_db +@patch("aap_eda.services.sync_certs.models.EventStream.objects.filter") +def test_delete_with_event_stream_id_single_match( + mock_filter, mock_settings, default_mtls_credential +): + """Test delete when event_stream_id matches single object.""" + # Mock single EventStream object + event_stream = Mock() + event_stream.id = 456 + mock_filter.return_value = [event_stream] + + existing_object = {"id": 123} + + sync = SyncCertificates(default_mtls_credential.id) + + with patch.object( + sync, "_fetch_from_gateway", return_value=existing_object + ): + with patch.object(sync, "_delete_from_gateway") as mock_delete_gateway: + sync.delete(456) # Same ID as the event stream + mock_delete_gateway.assert_called_once_with(existing_object) + + +@pytest.mark.django_db +@patch("aap_eda.services.sync_certs.models.EventStream.objects.filter") +def test_delete_with_event_stream_id_no_match( + mock_filter, mock_settings, default_mtls_credential +): + """Test delete when event_stream_id doesn't match.""" + # Mock single EventStream object with different ID + event_stream = Mock() + event_stream.id = 456 + mock_filter.return_value = [event_stream] + + existing_object = {"id": 123} + + sync = SyncCertificates(default_mtls_credential.id) + + with patch.object( + sync, "_fetch_from_gateway", return_value=existing_object + ): + with patch.object(sync, "_delete_from_gateway") as mock_delete_gateway: + sync.delete(999) # Different ID + mock_delete_gateway.assert_not_called() + + +@pytest.mark.django_db +@patch("aap_eda.services.sync_certs.requests.delete") +def test_delete_from_gateway_success( + mock_delete, + mock_settings, + default_mtls_credential, + mock_service_token, +): + """Test successful deletion from Gateway.""" + mock_delete.return_value.status_code = status.HTTP_204_NO_CONTENT + + sync = SyncCertificates(default_mtls_credential.id) + existing_object = {"id": 123} + + sync._delete_from_gateway(existing_object) + + mock_delete.assert_called_once() + + +@pytest.mark.django_db +@patch("aap_eda.services.sync_certs.requests.delete") +def test_delete_from_gateway_not_found( + mock_delete, + mock_settings, + default_mtls_credential, + mock_service_token, +): + """Test deletion when object not found (warning case).""" + mock_delete.return_value.status_code = status.HTTP_404_NOT_FOUND + + sync = SyncCertificates(default_mtls_credential.id) + existing_object = {"id": 123} + + # Should not raise exception, just log warning + sync._delete_from_gateway(existing_object) + + +@pytest.mark.django_db +@patch("aap_eda.services.sync_certs.requests.delete") +def test_delete_from_gateway_error( + mock_delete, + mock_settings, + default_mtls_credential, + mock_service_token, +): + """Test deletion error handling.""" + mock_delete.return_value.status_code = ( + status.HTTP_500_INTERNAL_SERVER_ERROR + ) + mock_delete.return_value.text = "Server error" + + sync = SyncCertificates(default_mtls_credential.id) + existing_object = {"id": 123} + + with pytest.raises(GatewayAPIError, match="Server error"): + sync._delete_from_gateway(existing_object) + + +@pytest.mark.django_db +@patch("aap_eda.services.sync_certs.requests.get") +def test_fetch_from_gateway_exists( + mock_get_request, + mock_settings, + default_mtls_credential, + mock_service_token, +): + """Test fetching existing certificate from Gateway.""" + response_data = { + "count": 1, + "results": [{"id": 123, "sha256": "abc123"}], + } + mock_get_request.return_value.status_code = status.HTTP_200_OK + mock_get_request.return_value.json.return_value = response_data + + sync = SyncCertificates(default_mtls_credential.id) + result = sync._fetch_from_gateway() + + assert result == {"id": 123, "sha256": "abc123"} + mock_get_request.assert_called_once() + + +@pytest.mark.django_db +@patch("aap_eda.services.sync_certs.requests.get") +def test_fetch_from_gateway_not_found( + mock_get_request, + mock_settings, + default_mtls_credential, + mock_service_token, +): + """Test fetching when certificate doesn't exist.""" + mock_get_request.return_value.status_code = status.HTTP_404_NOT_FOUND + + sync = SyncCertificates(default_mtls_credential.id) + result = sync._fetch_from_gateway() + + assert result == {} + + +@pytest.mark.django_db +@patch("aap_eda.services.sync_certs.requests.get") +def test_fetch_from_gateway_empty_results( + mock_get_request, + mock_settings, + default_mtls_credential, + mock_service_token, +): + """Test fetching when no results returned.""" + response_data = {"count": 0, "results": []} + mock_get_request.return_value.status_code = status.HTTP_200_OK + mock_get_request.return_value.json.return_value = response_data + + sync = SyncCertificates(default_mtls_credential.id) + result = sync._fetch_from_gateway() + + assert result == {} + + +@pytest.mark.django_db +@patch("aap_eda.services.sync_certs.requests.get") +def test_fetch_from_gateway_error( + mock_get_request, + mock_settings, + default_mtls_credential, + mock_service_token, +): + """Test error handling in fetch_from_gateway.""" + mock_get_request.return_value.status_code = ( + status.HTTP_500_INTERNAL_SERVER_ERROR + ) + mock_get_request.return_value.text = "Server error" + + sync = SyncCertificates(default_mtls_credential.id) + + with pytest.raises(GatewayAPIError, match="Server error"): + sync._fetch_from_gateway() + + +@pytest.mark.django_db +def test_get_remote_id(mock_settings, default_mtls_credential): + """Test remote ID generation.""" + sync = SyncCertificates(default_mtls_credential.id) + remote_id = sync._get_remote_id() + + assert remote_id == f"eda_{default_mtls_credential.id}" + + +@pytest.mark.django_db +def test_prep_headers_with_token( + mock_settings, default_mtls_credential, mock_service_token +): + """Test header preparation with valid token.""" + sync = SyncCertificates(default_mtls_credential.id) + headers = sync._prep_headers() + + assert headers == {"X-ANSIBLE-SERVICE-AUTH": "mock-token"} + + +@pytest.mark.django_db +def test_prep_headers_no_token(mock_settings, default_mtls_credential): + """Test header preparation when no token available.""" + with patch("aap_eda.services.sync_certs.resource_server") as mock_rs: + mock_rs.get_service_token.return_value = None + + sync = SyncCertificates(default_mtls_credential.id) + + with pytest.raises(MissingCredentials): + sync._prep_headers() + + +# Signal handler tests + + +@pytest.mark.django_db +@patch("aap_eda.services.sync_certs.models.EventStream.objects.filter") +@patch("aap_eda.services.sync_certs.SyncCertificates") +def test_gw_handler_triggers_sync( + mock_sync_class, mock_filter, mtls_credential_type +): + """Test signal handler triggers certificate sync.""" + # Create mock instance + instance = Mock() + instance.id = 1 + instance.credential_type = mtls_credential_type + instance._request = Mock() # Simulate having _request attribute + + # Mock EventStream objects + event_stream = Mock() + mock_filter.return_value = [event_stream] + + # Mock SyncCertificates instance + mock_sync_instance = Mock() + mock_sync_class.return_value = mock_sync_instance + + # Call the signal handler + gw_handler(models.EdaCredential, instance) + + # Verify sync was triggered + mock_sync_class.assert_called_once_with(1) + mock_sync_instance.update.assert_called_once() + + +@pytest.mark.django_db +@patch("aap_eda.services.sync_certs.models.EventStream.objects.filter") +@patch("aap_eda.services.sync_certs.SyncCertificates") +def test_gw_handler_no_event_streams( + mock_sync_class, mock_filter, mtls_credential_type +): + """Test signal handler when no EventStreams exist.""" + instance = Mock() + instance.id = 1 + instance.credential_type = mtls_credential_type + instance._request = Mock() + + # No EventStreams + mock_filter.return_value = [] + + gw_handler(models.EdaCredential, instance) + + # Should not trigger sync + mock_sync_class.assert_not_called() + + +@patch("aap_eda.services.sync_certs.models.EventStream.objects.filter") +@patch("aap_eda.services.sync_certs.SyncCertificates") +def test_gw_handler_wrong_credential_type(mock_sync_class, mock_filter): + """Test signal handler with wrong credential type.""" + instance = Mock() + instance.id = 1 + + # Wrong credential type + wrong_type = Mock() + wrong_type.name = "wrong_type" + instance.credential_type = wrong_type + instance._request = Mock() + + gw_handler(models.EdaCredential, instance) + + # Should not trigger sync + mock_sync_class.assert_not_called() + + +@pytest.mark.django_db +@patch("aap_eda.services.sync_certs.models.EventStream.objects.filter") +@patch("aap_eda.services.sync_certs.SyncCertificates") +def test_gw_handler_no_request_attribute( + mock_sync_class, mock_filter, mtls_credential_type +): + """Test signal handler when _request attribute is missing.""" + instance = Mock() + instance.id = 1 + instance.credential_type = mtls_credential_type + # No _request attribute + if hasattr(instance, "_request"): + delattr(instance, "_request") + + gw_handler(models.EdaCredential, instance) + + # Should not trigger sync + mock_sync_class.assert_not_called() + + +def test_gw_handler_no_credential_type(): + """Test signal handler when credential_type is None.""" + instance = Mock() + instance.id = 1 + instance.credential_type = None + + with patch( + "aap_eda.services.sync_certs.SyncCertificates" + ) as mock_sync_class: + gw_handler(models.EdaCredential, instance) + mock_sync_class.assert_not_called() + + +@pytest.mark.django_db +@patch("aap_eda.services.sync_certs.models.EventStream.objects.filter") +@patch("aap_eda.services.sync_certs.SyncCertificates") +def test_gw_handler_handles_gateway_api_error( + mock_sync_class, mock_filter, mtls_credential_type +): + """Test signal handler handles GatewayAPIError gracefully.""" + instance = Mock() + instance.id = 1 + instance.credential_type = mtls_credential_type + instance._request = Mock() + + # Mock EventStream objects + event_stream = Mock() + mock_filter.return_value = [event_stream] + + # Mock SyncCertificates to raise GatewayAPIError + mock_sync_instance = Mock() + mock_sync_instance.update.side_effect = GatewayAPIError("API Error") + mock_sync_class.return_value = mock_sync_instance + + # Should not raise exception + gw_handler(models.EdaCredential, instance) + + +@pytest.mark.django_db +@patch("aap_eda.services.sync_certs.models.EventStream.objects.filter") +@patch("aap_eda.services.sync_certs.SyncCertificates") +def test_gw_handler_handles_missing_credentials( + mock_sync_class, mock_filter, mtls_credential_type +): + """Test signal handler handles MissingCredentials gracefully.""" + instance = Mock() + instance.id = 1 + instance.credential_type = mtls_credential_type + instance._request = Mock() + + # Mock EventStream objects + event_stream = Mock() + mock_filter.return_value = [event_stream] + + # Mock SyncCertificates to raise MissingCredentials + mock_sync_instance = Mock() + mock_sync_instance.update.side_effect = MissingCredentials( + "No credentials" + ) + mock_sync_class.return_value = mock_sync_instance + + # Should not raise exception + gw_handler(models.EdaCredential, instance) + + +# Parameterized tests + + +@pytest.mark.parametrize( + "status_code,should_raise", + [ + (status.HTTP_200_OK, False), + (status.HTTP_201_CREATED, False), + (status.HTTP_400_BAD_REQUEST, True), + (status.HTTP_500_INTERNAL_SERVER_ERROR, True), + ], +) +@pytest.mark.django_db +@patch("aap_eda.services.sync_certs.requests.post") +def test_update_response_codes( + mock_post, + status_code, + should_raise, + mock_settings, + default_mtls_credential, + mock_service_token, +): + """Test various HTTP response codes in update method.""" + mock_post.return_value.status_code = status_code + mock_post.return_value.text = "Error message" + + sync = SyncCertificates(default_mtls_credential.id) + + with patch.object(sync, "_fetch_from_gateway", return_value={}): + if should_raise: + with pytest.raises(GatewayAPIError): + sync.update() + else: + sync.update() # Should not raise + + +@pytest.mark.parametrize( + "status_code,should_raise", + [ + (status.HTTP_200_OK, False), + (status.HTTP_204_NO_CONTENT, False), + (status.HTTP_404_NOT_FOUND, False), # Warning case + (status.HTTP_500_INTERNAL_SERVER_ERROR, True), + ], +) +@pytest.mark.django_db +@patch("aap_eda.services.sync_certs.requests.delete") +def test_delete_response_codes( + mock_delete, + status_code, + should_raise, + mock_settings, + default_mtls_credential, + mock_service_token, +): + """Test various HTTP response codes in delete method.""" + mock_delete.return_value.status_code = status_code + mock_delete.return_value.text = "Error message" + + sync = SyncCertificates(default_mtls_credential.id) + existing_object = {"id": 123} + + if should_raise: + with pytest.raises(GatewayAPIError): + sync._delete_from_gateway(existing_object) + else: + sync._delete_from_gateway(existing_object) # Should not raise + + +@pytest.mark.parametrize( + "status_code,should_raise", + [ + (status.HTTP_200_OK, False), + (status.HTTP_404_NOT_FOUND, False), + (status.HTTP_500_INTERNAL_SERVER_ERROR, True), + ], +) +@pytest.mark.django_db +@patch("aap_eda.services.sync_certs.requests.get") +def test_fetch_response_codes( + mock_get_request, + status_code, + should_raise, + mock_settings, + default_mtls_credential, + mock_service_token, +): + """Test various HTTP response codes in fetch method.""" + mock_get_request.return_value.status_code = status_code + mock_get_request.return_value.text = "Error message" + mock_get_request.return_value.json.return_value = { + "count": 0, + "results": [], + } + + sync = SyncCertificates(default_mtls_credential.id) + + if should_raise: + with pytest.raises(GatewayAPIError): + sync._fetch_from_gateway() + else: + result = sync._fetch_from_gateway() + assert result == {} + + +# Exception handling tests for new network error scenarios + + +@pytest.mark.parametrize( + "request_method,exception_class,error_message,expected_prefix," + "existing_object", + [ + # POST requests (new certificates) + ( + "post", + requests.exceptions.ConnectionError, + "Connection refused", + "Connection error", + {}, + ), + ( + "post", + requests.exceptions.Timeout, + "Request timed out", + "Request timeout", + {}, + ), + ( + "post", + requests.exceptions.RequestException, + "Generic request error", + "Request error", + {}, + ), + # PATCH requests (updating existing certificates) + ( + "patch", + requests.exceptions.ConnectionError, + "Network unreachable", + "Connection error", + {"id": 123, "sha256": "different-hash"}, + ), + ( + "patch", + requests.exceptions.Timeout, + "Read timeout occurred", + "Request timeout", + {"id": 123, "sha256": "different-hash"}, + ), + ( + "patch", + requests.exceptions.RequestException, + "SSL certificate error", + "Request error", + {"id": 123, "sha256": "different-hash"}, + ), + ], +) +@pytest.mark.django_db +def test_update_handles_network_exceptions( + request_method, + exception_class, + error_message, + expected_prefix, + existing_object, + mock_settings, + default_mtls_credential, + mock_service_token, +): + """Test handling of network exceptions during update() method.""" + sync = SyncCertificates(default_mtls_credential.id) + + with patch.object( + sync, "_fetch_from_gateway", return_value=existing_object + ): + with patch( + f"aap_eda.services.sync_certs.requests.{request_method}" + ) as mock_request: + mock_request.side_effect = exception_class(error_message) + + with pytest.raises( + GatewayAPIError, match=f"{expected_prefix}: {error_message}" + ): + sync.update() + + +@pytest.mark.parametrize( + "exception_class,error_message,expected_prefix", + [ + ( + requests.exceptions.ConnectionError, + "Connection refused", + "Connection error", + ), + ( + requests.exceptions.Timeout, + "Delete timed out", + "Request timeout", + ), + ( + requests.exceptions.RequestException, + "HTTP adapter error", + "Request error", + ), + ], +) +@pytest.mark.django_db +@patch("aap_eda.services.sync_certs.requests.delete") +def test_delete_from_gateway_handles_network_exceptions( + mock_delete, + exception_class, + error_message, + expected_prefix, + mock_settings, + default_mtls_credential, + mock_service_token, +): + """Test handling of network exceptions in _delete_from_gateway().""" + mock_delete.side_effect = exception_class(error_message) + + sync = SyncCertificates(default_mtls_credential.id) + existing_object = {"id": 123} + + with pytest.raises( + GatewayAPIError, match=f"{expected_prefix}: {error_message}" + ): + sync._delete_from_gateway(existing_object) + + +@pytest.mark.parametrize( + "exception_class,error_message,expected_prefix", + [ + ( + requests.exceptions.ConnectionError, + "Connection refused", + "Connection error", + ), + ( + requests.exceptions.Timeout, + "Fetch timed out", + "Request timeout", + ), + ( + requests.exceptions.RequestException, + "DNS resolution failed", + "Request error", + ), + ], +) +@pytest.mark.django_db +@patch("aap_eda.services.sync_certs.requests.get") +def test_fetch_from_gateway_handles_network_exceptions( + mock_get_request, + exception_class, + error_message, + expected_prefix, + mock_settings, + default_mtls_credential, + mock_service_token, +): + """Test handling of network exceptions in _fetch_from_gateway().""" + mock_get_request.side_effect = exception_class(error_message) + + sync = SyncCertificates(default_mtls_credential.id) + + with pytest.raises( + GatewayAPIError, match=f"{expected_prefix}: {error_message}" + ): + sync._fetch_from_gateway() diff --git a/tests/unit/test_certificate_validation.py b/tests/unit/test_certificate_validation.py new file mode 100644 index 000000000..39d0a7d7d --- /dev/null +++ b/tests/unit/test_certificate_validation.py @@ -0,0 +1,632 @@ +"""Unit tests for certificate validation functionality.""" +from datetime import datetime, timedelta, timezone + +import pytest +from cryptography import x509 +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.x509.oid import NameOID + +from aap_eda.core.utils.credentials import ( + _normalize_x509_dn_whitespace, + _validate_certificate_subject_format, + _validate_pem_certificate, + validate_x509_subject_match, +) + + +@pytest.mark.parametrize( + "cert_data,expected_error", + [ + ("", "Certificate data cannot be empty"), + ("not a certificate", "Invalid PEM certificate format"), + ( + "-----BEGIN CERTIFICATE-----\ninvalid\n-----END CERTIFICATE-----", + "Invalid PEM certificate format", + ), + ], +) +def test_invalid_certificate_validation(cert_data, expected_error): + """Test validation of invalid certificate data.""" + errors = _validate_pem_certificate(cert_data) + assert len(errors) >= 1 + assert ( + expected_error in errors[0] + or "No valid certificates found" in errors[0] + ) + + +def test_valid_certificate_validation(): + """Test validation of valid certificate.""" + # Generate a test certificate + private_key = rsa.generate_private_key( + public_exponent=65537, + key_size=2048, + ) + + subject = issuer = x509.Name( + [ + x509.NameAttribute(NameOID.COUNTRY_NAME, "US"), + x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "Testing"), + x509.NameAttribute(NameOID.LOCALITY_NAME, "Local"), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Test Org"), + x509.NameAttribute(NameOID.COMMON_NAME, "test.example.com"), + ] + ) + + cert = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(private_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.now(timezone.utc)) + .not_valid_after(datetime.now(timezone.utc) + timedelta(days=365)) + .sign(private_key, hashes.SHA256()) + ) + + # Convert to PEM + cert_pem = cert.public_bytes(serialization.Encoding.PEM).decode("utf-8") + + # Should validate successfully + errors = _validate_pem_certificate(cert_pem) + assert len(errors) == 0 + + +def test_expired_certificate_validation(): + """Test validation of expired certificate.""" + # Generate an expired test certificate + private_key = rsa.generate_private_key( + public_exponent=65537, + key_size=2048, + ) + + subject = issuer = x509.Name( + [ + x509.NameAttribute(NameOID.COMMON_NAME, "expired.example.com"), + ] + ) + + # Create certificate that expired yesterday + yesterday = datetime.now(timezone.utc) - timedelta(days=1) + two_days_ago = yesterday - timedelta(days=1) + + cert = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(private_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(two_days_ago) + .not_valid_after(yesterday) + .sign(private_key, hashes.SHA256()) + ) + + cert_pem = cert.public_bytes(serialization.Encoding.PEM).decode("utf-8") + + errors = _validate_pem_certificate(cert_pem) + assert len(errors) == 1 + assert "Certificate has expired" in errors[0] + + +@pytest.mark.parametrize( + "subject", + [ + "", # Empty is allowed + "CN=example.com", + ( + "CN=server.example.com,OU=IT Department,O=Company Inc," + "L=City,ST=State,C=US" + ), + "CN=test,C=US", + "DC=com,DC=example,CN=server", + " CN = example.com , O = Test Org ", # Whitespace handling + ], +) +def test_subject_format_validation_valid(subject): + """Test validation of valid subject formats.""" + errors = _validate_certificate_subject_format(subject) + assert len(errors) == 0 + + +@pytest.mark.parametrize( + "subject,expected_error", + [ + ("INVALID=value", "Invalid X.509 DN format"), + ("CN test", "Invalid X.509 DN format"), + ("=value", "Invalid X.509 DN format"), + ("CN=Acme,Inc", "Invalid X.509 DN format"), + ("CN=test1,CN=test2", "Duplicate attributes not allowed"), + ("CN=test,C=USA", "Country name must be a 2 character country code"), + ], +) +def test_subject_format_validation_invalid(subject, expected_error): + """Test validation of invalid subject formats.""" + errors = _validate_certificate_subject_format(subject) + assert len(errors) >= 1 + assert expected_error in errors[0] + + +def test_subject_format_validation_mixed_case_attributes(): + """Test subject format validation with mixed case attributes.""" + # Mixed case should now be handled by X.509 standard parsing + errors = _validate_certificate_subject_format("CN=test,O=org") + assert len(errors) == 0 # Should be valid with X.509 parsing + + +def test_subject_format_validation_exception_handling(): + """Test subject format validation exception handling.""" + # This will test the except block in _validate_certificate_subject_format + # by providing a case that could cause an exception during parsing + errors = _validate_certificate_subject_format("CN=test,=") + assert len(errors) >= 1 + # Should contain parsing error + + +# New tests for X.509 subject matching functionality + + +def test_normalize_x509_dn_whitespace(): + """Test _normalize_x509_dn_whitespace function.""" + # Test basic whitespace normalization + result = _normalize_x509_dn_whitespace( + " CN = example.com , O = Test Org " + ) + assert result == "CN=example.com,O=Test Org" + + # Test empty string + result = _normalize_x509_dn_whitespace("") + assert result == "" + + # Test component without equals + result = _normalize_x509_dn_whitespace("CN=test, invalid ") + assert result == "CN=test,invalid" + + # Test multiple components + result = _normalize_x509_dn_whitespace("CN=test,O=org,C=US") + assert result == "CN=test,O=org,C=US" + + +def test_validate_x509_subject_match_exact(): + """Test validate_x509_subject_match with exact matching.""" + # Exact match + assert ( + validate_x509_subject_match("CN=example.com", "CN=example.com") is True + ) + + # Case insensitive for most attributes + assert ( + validate_x509_subject_match("CN=Example.com", "CN=example.com") is True + ) + + # Order independence + assert ( + validate_x509_subject_match("CN=test,O=org", "O=org,CN=test") is True + ) + + # Whitespace normalization + assert ( + validate_x509_subject_match("CN=test, O=org", "CN=test,O=org") is True + ) + + # Missing attribute in actual + assert validate_x509_subject_match("CN=test,O=org", "CN=test") is False + + # Different values + assert validate_x509_subject_match("CN=test1", "CN=test2") is False + + # Empty/None inputs + assert validate_x509_subject_match("", "CN=test") is False + assert validate_x509_subject_match("CN=test", "") is False + assert validate_x509_subject_match("", "") is False + + +def test_validate_x509_subject_match_wildcards(): + """Test validate_x509_subject_match with regex wildcard patterns.""" + # Basic wildcard (using .* regex pattern) + assert ( + validate_x509_subject_match( + "CN=.*\\.example\\.com", "CN=server.example.com" + ) + is True + ) + assert ( + validate_x509_subject_match( + "CN=.*\\.example\\.com", "CN=api.example.com" + ) + is True + ) + assert ( + validate_x509_subject_match( + "CN=.*\\.example\\.com", "CN=other.domain.com" + ) + is False + ) + + # Multiple wildcards + assert ( + validate_x509_subject_match("CN=.*\\..*", "CN=server.example.com") + is True + ) + assert ( + validate_x509_subject_match( + "CN=.*\\.example\\.com,O=R.*", + "O=R&D Development,CN=server.example.com", + ) + is True + ) + + # Wildcard with other attributes + assert ( + validate_x509_subject_match( + "CN=.*\\.example\\.com,O=Test", "CN=server.example.com,O=Test" + ) + is True + ) + assert ( + validate_x509_subject_match( + "CN=(agent1|agent2)\\.example\\.com", "CN=agent1.example.com" + ) + is True + ) + + # Wildcard mismatch + assert ( + validate_x509_subject_match( + "CN=.*\\.example\\.com,O=Test", "CN=server.example.com,O=Other" + ) + is False + ) + assert ( + validate_x509_subject_match( + "CN=.*\\.example\\.com,O=Test", "CN=server.agent.com,O=Test" + ) + is False + ) + + +def test_validate_x509_subject_match_country_codes(): + """Test validate_x509_subject_match with country code handling.""" + # Country codes should be case insensitive + assert validate_x509_subject_match("C=us", "C=US") is True + assert validate_x509_subject_match("C=US", "C=us") is True + + # With other attributes + assert validate_x509_subject_match("CN=test,C=us", "CN=test,C=US") is True + + +def test_validate_x509_subject_match_domain_components(): + """Test validate_x509_subject_match with domain components.""" + # Multiple DC attributes should work + assert ( + validate_x509_subject_match( + "DC=com,DC=example,CN=server", "DC=com,DC=example,CN=server" + ) + is True + ) + + # Order should not matter for DC + assert ( + validate_x509_subject_match( + "CN=server,DC=example,DC=com", "DC=com,DC=example,CN=server" + ) + is True + ) + + +def test_validate_x509_subject_match_invalid_dn(): + """Test validate_x509_subject_match with invalid DN strings.""" + # Invalid DN format should return False + assert validate_x509_subject_match("INVALID=test", "CN=test") is False + assert validate_x509_subject_match("CN=test", "INVALID=test") is False + + # Malformed DN + assert validate_x509_subject_match("CN test", "CN=test") is False + + +def test_validate_x509_subject_match_edge_cases(): + """Test validate_x509_subject_match edge cases.""" + # Test with special characters that need escaping in regex + assert ( + validate_x509_subject_match( + "CN=test.example.com", "CN=test.example.com" + ) + is True + ) + + # Test with valid DN special characters + assert ( + validate_x509_subject_match( + "CN=test-server.example.com", "CN=test-server.example.com" + ) + is True + ) + + # Test very long DN strings + long_dn = ( + "CN=very-long-hostname-that-exceeds-normal-length.example.com," + "OU=Very Long Organizational Unit Name," + "O=Very Long Organization Name Inc," + "L=Very Long City Name,ST=Very Long State Name,C=US" + ) + assert validate_x509_subject_match(long_dn, long_dn) is True + + +# New tests for regex pattern matching functionality + + +def test_validate_x509_subject_match_regex_patterns(): + """Test validate_x509_subject_match with regex patterns.""" + # Character class patterns + assert ( + validate_x509_subject_match( + "CN=agent[1-9]\\.example\\.com", "CN=agent5.example.com" + ) + is True + ) + assert ( + validate_x509_subject_match( + "CN=agent[1-9]\\.example\\.com", "CN=agent0.example.com" + ) + is False + ) + assert ( + validate_x509_subject_match( + "CN=server[a-z]\\.domain\\.com", "CN=serverx.domain.com" + ) + is True + ) + + # Alternation patterns + assert ( + validate_x509_subject_match( + "CN=(agent1|agent2)\\.example\\.com", "CN=agent1.example.com" + ) + is True + ) + assert ( + validate_x509_subject_match( + "CN=(agent1|agent2)\\.example\\.com", "CN=agent2.example.com" + ) + is True + ) + assert ( + validate_x509_subject_match( + "CN=(agent1|agent2)\\.example\\.com", "CN=agent3.example.com" + ) + is False + ) + + # Plus quantifier patterns + assert ( + validate_x509_subject_match( + "CN=server-[0-9]+\\.domain\\.com", "CN=server-123.domain.com" + ) + is True + ) + assert ( + validate_x509_subject_match( + "CN=server-[0-9]+\\.domain\\.com", "CN=server-1.domain.com" + ) + is True + ) + assert ( + validate_x509_subject_match( + "CN=server-[0-9]+\\.domain\\.com", "CN=server-.domain.com" + ) + is False + ) + + # Dot wildcard patterns + assert ( + validate_x509_subject_match( + "CN=.*\\.example\\.com", "CN=anything.example.com" + ) + is True + ) + assert ( + validate_x509_subject_match( + "CN=.*\\.example\\.com", "CN=multi.part.example.com" + ) + is True + ) + assert ( + validate_x509_subject_match( + "CN=.*\\.example\\.com", "CN=test.other.com" + ) + is False + ) + + +def test_validate_x509_subject_match_regex_complex_patterns(): + """Test validate_x509_subject_match with complex regex patterns.""" + # Mixed patterns with multiple attributes + assert ( + validate_x509_subject_match( + "CN=app[1-3]\\.example\\.com,O=(Dev|Test|Prod)", + "CN=app2.example.com,O=Dev", + ) + is True + ) + assert ( + validate_x509_subject_match( + "CN=app[1-3]\\.example\\.com,O=(Dev|Test|Prod)", + "CN=app2.example.com,O=Staging", + ) + is False + ) + + # Optional groups + assert ( + validate_x509_subject_match( + "CN=server(-[0-9]+)?\\.domain\\.com", "CN=server.domain.com" + ) + is True + ) + assert ( + validate_x509_subject_match( + "CN=server(-[0-9]+)?\\.domain\\.com", "CN=server-123.domain.com" + ) + is True + ) + + # Word boundaries + assert ( + validate_x509_subject_match( + "CN=\\btest\\b\\.example\\.com", "CN=test.example.com" + ) + is True + ) + assert ( + validate_x509_subject_match( + "CN=\\btest\\b\\.example\\.com", "CN=testing.example.com" + ) + is False + ) + + # Case sensitivity with regex + assert ( + validate_x509_subject_match( + "CN=SERVER[0-9]\\.DOMAIN\\.COM", "CN=server5.domain.com" + ) + is True + ) + + +def test_validate_x509_subject_match_regex_special_characters(): + """Test validate_x509_subject_match with special regex characters.""" + # Escaped special characters + assert ( + validate_x509_subject_match( + "CN=test\\$server\\.example\\.com", "CN=test$server.example.com" + ) + is True + ) + assert ( + validate_x509_subject_match( + "CN=server-api\\.example\\.com", "CN=server-api.example.com" + ) + is True + ) + assert ( + validate_x509_subject_match( + "CN=app\\(v2\\)\\.example\\.com", "CN=app(v2).example.com" + ) + is True + ) + + # Brackets in actual certificate subject (literal matching) + assert ( + validate_x509_subject_match( + "CN=\\[agent1\\|agent2\\]\\.example\\.com", + "CN=[agent1|agent2].example.com", + ) + is True + ) + + # Character classes with special chars + assert ( + validate_x509_subject_match( + "CN=test[._-]server\\.com", "CN=test.server.com" + ) + is True + ) + assert ( + validate_x509_subject_match( + "CN=test[._-]server\\.com", "CN=test_server.com" + ) + is True + ) + assert ( + validate_x509_subject_match( + "CN=test[._-]server\\.com", "CN=test-server.com" + ) + is True + ) + + +def test_validate_x509_subject_match_regex_anchoring(): + """Test that regex patterns are properly anchored.""" + # Should not match substring + assert ( + validate_x509_subject_match("CN=server", "CN=test-server.example.com") + is False + ) + + # Should match full string + assert ( + validate_x509_subject_match( + "CN=.*server.*", "CN=test-server.example.com" + ) + is True + ) + + # Ensure anchoring works with complex patterns + assert ( + validate_x509_subject_match("CN=server[0-9]", "CN=myserver5.com") + is False + ) + + +@pytest.mark.parametrize( + "pattern,subject,expected", + [ + # Basic regex patterns + ("CN=server[0-9]\\.com", "CN=server5.com", True), + ("CN=server[0-9]\\.com", "CN=serverX.com", False), + ("CN=(dev|prod)\\.example\\.com", "CN=dev.example.com", True), + ("CN=(dev|prod)\\.example\\.com", "CN=test.example.com", False), + # Quantifiers + ("CN=app[0-9]+\\.domain\\.com", "CN=app123.domain.com", True), + ("CN=app[0-9]+\\.domain\\.com", "CN=app.domain.com", False), + ("CN=test[a-z]?\\.com", "CN=test.com", True), + ("CN=test[a-z]?\\.com", "CN=testa.com", True), + # Complex patterns + ("CN=.*-v[0-9]+\\.example\\.com", "CN=api-v2.example.com", True), + ("CN=.*-v[0-9]+\\.example\\.com", "CN=api-v.example.com", False), + ("CN=\\w+\\.internal\\.com", "CN=service123.internal.com", True), + ("CN=\\w+\\.internal\\.com", "CN=service-123.internal.com", False), + # Multi-attribute patterns + ("CN=app[1-3]\\.com,O=Test.*", "CN=app2.com,O=Test Org", True), + ("CN=app[1-3]\\.com,O=Test.*", "CN=app2.com,O=Prod Org", False), + # Case insensitivity + ("CN=SERVER[0-9]\\.COM", "CN=server5.com", True), + ("CN=server[0-9]\\.com", "CN=SERVER5.COM", True), + # Special characters + ("CN=test\\$[0-9]\\.com", "CN=test$5.com", True), + ("CN=test-api\\.com", "CN=test-api.com", True), + ("CN=\\[prod\\]\\.com", "CN=[prod].com", True), + ], +) +def test_validate_x509_subject_match_regex_parametrized( + pattern, subject, expected +): + """Parametrized test for various regex patterns.""" + result = validate_x509_subject_match(pattern, subject) + assert result is expected + + +def test_validate_x509_subject_match_performance_edge_cases(): + """Test performance and edge cases with regex patterns.""" + # Very long patterns + long_pattern = "CN=" + "a" * 1000 + "[0-9]\\.com" + long_subject = "CN=" + "a" * 1000 + "5.com" + assert validate_x509_subject_match(long_pattern, long_subject) is True + + # Pattern with many alternations + many_alt_pattern = ( + "CN=(" + "|".join([f"server{i}" for i in range(100)]) + ")\\.com" + ) + assert ( + validate_x509_subject_match(many_alt_pattern, "CN=server50.com") + is True + ) + assert ( + validate_x509_subject_match(many_alt_pattern, "CN=server999.com") + is False + ) + + # Empty regex groups + assert ( + validate_x509_subject_match("CN=test()\\.com", "CN=test.com") is True + ) diff --git a/tests/unit/test_mtls_authentication.py b/tests/unit/test_mtls_authentication.py new file mode 100644 index 000000000..538812fe5 --- /dev/null +++ b/tests/unit/test_mtls_authentication.py @@ -0,0 +1,240 @@ +"""Unit tests for mTLS authentication functionality.""" + +import pytest +from rest_framework.exceptions import AuthenticationFailed + +from aap_eda.api.event_stream_authentication import MTLSAuthentication + + +def test_mtls_authentication_valid_exact_match(): + """Test mTLS authentication with exact subject match.""" + auth = MTLSAuthentication( + subject="CN=server.example.com,O=Test Org,C=US", + value="CN=server.example.com,O=Test Org,C=US", + ) + + # Should not raise exception + auth.authenticate() + + +def test_mtls_authentication_valid_case_insensitive(): + """Test mTLS authentication with case insensitive match.""" + auth = MTLSAuthentication( + subject="CN=Server.Example.Com,O=Test Org,C=us", + value="CN=server.example.com,O=Test Org,C=US", + ) + + # Should not raise exception + auth.authenticate() + + +def test_mtls_authentication_valid_order_independent(): + """Test mTLS authentication with different attribute order.""" + auth = MTLSAuthentication( + subject="CN=server.example.com,O=Test Org,C=US", + value="O=Test Org,CN=server.example.com,C=US", + ) + + # Should not raise exception + auth.authenticate() + + +def test_mtls_authentication_valid_whitespace_normalization(): + """Test mTLS authentication with whitespace differences.""" + auth = MTLSAuthentication( + subject="CN=server.example.com, O=Test Org, C=US", + value="CN=server.example.com,O=Test Org,C=US", + ) + + # Should not raise exception + auth.authenticate() + + +def test_mtls_authentication_valid_wildcard_match(): + """Test mTLS authentication with wildcard pattern.""" + auth = MTLSAuthentication( + subject="CN=.*\\.example\\.com,O=Test Org", + value="CN=server.example.com,O=Test Org", + ) + + # Should not raise exception + auth.authenticate() + + +def test_mtls_authentication_valid_multiple_wildcards(): + """Test mTLS authentication with multiple wildcard patterns.""" + auth = MTLSAuthentication( + subject="CN=.*\\..*\\.com,O=.*", value="CN=api.server.com,O=TestOrg" + ) + + # Should not raise exception + auth.authenticate() + + +def test_mtls_authentication_invalid_subject_mismatch(): + """Test mTLS authentication with subject mismatch.""" + auth = MTLSAuthentication( + subject="CN=server.example.com,O=Test Org", + value="CN=different.example.com,O=Test Org", + ) + + with pytest.raises(AuthenticationFailed) as exc_info: + auth.authenticate() + + assert "does not match" in str(exc_info.value) + assert "CN=different.example.com,O=Test Org" in str(exc_info.value) + + +def test_mtls_authentication_invalid_missing_attribute(): + """Test mTLS authentication with missing required attribute.""" + auth = MTLSAuthentication( + subject="CN=server.example.com,O=Test Org", + value="CN=server.example.com", + ) + + with pytest.raises(AuthenticationFailed) as exc_info: + auth.authenticate() + + assert "does not match" in str(exc_info.value) + + +def test_mtls_authentication_invalid_wildcard_no_match(): + """Test mTLS authentication with wildcard that doesn't match.""" + auth = MTLSAuthentication( + subject="CN=.*\\.example\\.com", value="CN=server.different.com" + ) + + with pytest.raises(AuthenticationFailed) as exc_info: + auth.authenticate() + + assert "does not match" in str(exc_info.value) + + +def test_mtls_authentication_invalid_dn_format(): + """Test mTLS authentication with invalid DN format.""" + auth = MTLSAuthentication( + subject="INVALID=test", value="CN=server.example.com" + ) + + with pytest.raises(AuthenticationFailed) as exc_info: + auth.authenticate() + + assert "does not match" in str(exc_info.value) + + +def test_mtls_authentication_empty_subject_allowed(): + """Test mTLS authentication with empty subject (should pass).""" + auth = MTLSAuthentication(subject="", value="CN=server.example.com") + + # Should not raise exception when subject is empty + auth.authenticate() + + +def test_mtls_authentication_none_subject_allowed(): + """Test mTLS authentication with None subject (should pass).""" + auth = MTLSAuthentication(subject=None, value="CN=server.example.com") + + # Should not raise exception when subject is None + auth.authenticate() + + +def test_validate_subject_method_direct(): + """Test validate_subject method directly.""" + auth = MTLSAuthentication(subject="CN=test", value="CN=test") + + # Test valid cases + assert auth.validate_subject("CN=test", "CN=test") is True + assert auth.validate_subject("CN=Test", "CN=test") is True + assert ( + auth.validate_subject("CN=.*\\.example\\.com", "CN=api.example.com") + is True + ) + + # Test invalid cases + assert auth.validate_subject("CN=test1", "CN=test2") is False + assert auth.validate_subject("CN=test,O=org", "CN=test") is False + assert auth.validate_subject("", "CN=test") is False + + +def test_mtls_authentication_domain_components(): + """Test mTLS authentication with domain components.""" + auth = MTLSAuthentication( + subject="DC=com,DC=example,CN=server", + value="DC=com,DC=example,CN=server", + ) + + # Should not raise exception + auth.authenticate() + + +def test_mtls_authentication_domain_components_order_independence(): + """Test mTLS authentication with domain components in different order.""" + auth = MTLSAuthentication( + subject="CN=server,DC=example,DC=com", + value="DC=com,DC=example,CN=server", + ) + + # Should not raise exception + auth.authenticate() + + +def test_mtls_authentication_country_code_case_insensitive(): + """Test mTLS authentication with country code case differences.""" + auth = MTLSAuthentication(subject="CN=server,C=us", value="CN=server,C=US") + + # Should not raise exception + auth.authenticate() + + +def test_mtls_authentication_organization_case_insensitive(): + """Test mTLS authentication with organization case differences.""" + auth = MTLSAuthentication( + subject="CN=server,O=Test Organization", + value="CN=server,O=test organization", + ) + + # Should not raise exception (organization names are case-insensitive) + auth.authenticate() + + +def test_mtls_authentication_complex_dn_with_wildcards(): + """Test mTLS authentication with complex DN and wildcards.""" + subject = ( + "CN=.*\\.example\\.com,OU=IT Department,O=Company Inc," + "L=.*,ST=California,C=US" + ) + value = ( + "CN=api.example.com,OU=IT Department,O=Company Inc," + "L=San Francisco,ST=California,C=US" + ) + auth = MTLSAuthentication(subject=subject, value=value) + + # Should not raise exception + auth.authenticate() + + +def test_mtls_authentication_edge_case_special_characters(): + """Test mTLS authentication with valid special characters in DN.""" + auth = MTLSAuthentication( + subject="CN=test-server.example.com,O=Test Co", + value="CN=test-server.example.com,O=Test Co", + ) + + # Should not raise exception + auth.authenticate() + + +def test_mtls_authentication_performance_long_dn(): + """Test mTLS authentication with very long DN strings.""" + long_subject = ( + "CN=very-long-hostname-that-exceeds-normal-length.example.com," + "OU=Very Long Organizational Unit Name That Contains Many Words," + "O=Very Long Organization Name Incorporated With Extra Words," + "L=Very Long City Name With Multiple Words," + "ST=Very Long State Or Province Name,C=US" + ) + + auth = MTLSAuthentication(subject=long_subject, value=long_subject) + + # Should not raise exception and should complete quickly + auth.authenticate()