diff --git a/mcpgateway/admin.py b/mcpgateway/admin.py index c1a23625b..92d0faa88 100644 --- a/mcpgateway/admin.py +++ b/mcpgateway/admin.py @@ -1018,7 +1018,8 @@ async def admin_list_gateways( ... updated_at=datetime.now(timezone.utc), ... is_active=True, ... auth_type=None, auth_username=None, auth_password=None, auth_token=None, - ... auth_header_key=None, auth_header_value=None + ... auth_header_key=None, auth_header_value=None, + ... slug="test-gateway" ... ) >>> >>> # Mock the gateway_service.list_gateways method @@ -1039,7 +1040,8 @@ async def admin_list_gateways( ... description="Another test", transport="HTTP", created_at=datetime.now(timezone.utc), ... updated_at=datetime.now(timezone.utc), enabled=False, ... auth_type=None, auth_username=None, auth_password=None, auth_token=None, - ... auth_header_key=None, auth_header_value=None + ... auth_header_key=None, auth_header_value=None, + ... slug="test-gateway" ... ) >>> gateway_service.list_gateways = AsyncMock(return_value=[ ... mock_gateway, # Return the GatewayRead objects, not pre-dumped dicts @@ -2168,7 +2170,8 @@ async def admin_get_gateway(gateway_id: str, db: Session = Depends(get_db), user ... description="Gateway for getting", transport="HTTP", ... created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ... enabled=True, auth_type=None, auth_username=None, auth_password=None, - ... auth_token=None, auth_header_key=None, auth_header_value=None + ... auth_token=None, auth_header_key=None, auth_header_value=None, + ... slug="test-gateway" ... ) >>> >>> # Mock the gateway_service.get_gateway method diff --git a/mcpgateway/config.py b/mcpgateway/config.py index cf8cfb0ba..261e5e0a5 100644 --- a/mcpgateway/config.py +++ b/mcpgateway/config.py @@ -535,6 +535,9 @@ def validate_database(self) -> None: # Rate limiting validation_max_requests_per_minute: int = 60 + # Masking value for all sensitive data + masked_auth_value: str = "*****" + def extract_using_jq(data, jq_filter=""): """ diff --git a/mcpgateway/schemas.py b/mcpgateway/schemas.py index 75a60d1f9..3f755b03a 100644 --- a/mcpgateway/schemas.py +++ b/mcpgateway/schemas.py @@ -1882,6 +1882,7 @@ def _process_auth_fields(values: Dict[str, Any]) -> Optional[Dict[str, Any]]: Raises: ValueError: If auth type is invalid """ + auth_type = values.data.get("auth_type") if auth_type == "basic": @@ -2032,6 +2033,11 @@ def _populate_auth(cls, values: Self) -> Dict[str, Any]: """ auth_type = values.auth_type auth_value_encoded = values.auth_value + + # Skip validation logic if masked value + if auth_value_encoded == settings.masked_auth_value: + return values + auth_value = decode_auth(auth_value_encoded) if auth_type == "basic": auth = auth_value.get("Authorization") @@ -2056,6 +2062,40 @@ def _populate_auth(cls, values: Self) -> Dict[str, Any]: return values + def masked(self) -> "GatewayRead": + """ + Return a masked version of the model instance with sensitive authentication fields hidden. + + This method creates a dictionary representation of the model data and replaces sensitive fields + such as `auth_value`, `auth_password`, `auth_token`, and `auth_header_value` with a masked + placeholder value defined in `settings.masked_auth_value`. Masking is only applied if the fields + are present and not already masked. + + Args: + None + + Returns: + GatewayRead: A new instance of the GatewayRead model with sensitive authentication-related fields + masked to prevent exposure of sensitive information. + + Notes: + - The `auth_value` field is only masked if it exists and its value is different from the masking + placeholder. + - Other sensitive fields (`auth_password`, `auth_token`, `auth_header_value`) are masked if present. + - Fields not related to authentication remain unchanged. + """ + masked_data = self.model_dump() + + # Only mask if auth_value is present and not already masked + if masked_data.get("auth_value") and masked_data["auth_value"] != settings.masked_auth_value: + masked_data["auth_value"] = settings.masked_auth_value + + masked_data["auth_password"] = settings.masked_auth_value if masked_data.get("auth_password") else None + masked_data["auth_token"] = settings.masked_auth_value if masked_data.get("auth_token") else None + masked_data["auth_header_value"] = settings.masked_auth_value if masked_data.get("auth_header_value") else None + + return GatewayRead.model_validate(masked_data) + class FederatedTool(BaseModelWithConfigDict): """Schema for tools provided by federated gateways. diff --git a/mcpgateway/services/gateway_service.py b/mcpgateway/services/gateway_service.py index bfa2519b7..ccfe6ac01 100644 --- a/mcpgateway/services/gateway_service.py +++ b/mcpgateway/services/gateway_service.py @@ -386,7 +386,7 @@ async def register_gateway(self, db: Session, gateway: GatewayCreate) -> Gateway # Notify subscribers await self._notify_gateway_added(db_gateway) - return GatewayRead.model_validate(db_gateway) + return GatewayRead.model_validate(db_gateway).masked() except* GatewayConnectionError as ge: if TYPE_CHECKING: ge: ExceptionGroup[GatewayConnectionError] @@ -436,7 +436,9 @@ async def list_gateways(self, db: Session, include_inactive: bool = False) -> Li >>> db = MagicMock() >>> gateway_obj = MagicMock() >>> db.execute.return_value.scalars.return_value.all.return_value = [gateway_obj] - >>> GatewayRead.model_validate = MagicMock(return_value='gateway_read') + >>> mocked_gateway_read = MagicMock() + >>> mocked_gateway_read.masked.return_value = 'gateway_read' + >>> GatewayRead.model_validate = MagicMock(return_value=mocked_gateway_read) >>> import asyncio >>> result = asyncio.run(service.list_gateways(db)) >>> result == ['gateway_read'] @@ -459,7 +461,7 @@ async def list_gateways(self, db: Session, include_inactive: bool = False) -> Li query = query.where(DbGateway.enabled) gateways = db.execute(query).scalars().all() - return [GatewayRead.model_validate(g) for g in gateways] + return [GatewayRead.model_validate(g).masked() for g in gateways] async def update_gateway(self, db: Session, gateway_id: str, gateway_update: GatewayUpdate, include_inactive: bool = True) -> GatewayRead: """Update a gateway. @@ -510,9 +512,20 @@ async def update_gateway(self, db: Session, gateway_id: str, gateway_update: Gat if getattr(gateway, "auth_type", None) is not None: gateway.auth_type = gateway_update.auth_type + # If auth_type is empty, update the auth_value too + if gateway_update.auth_type == "": + gateway.auth_value = "" + # if auth_type is not None and only then check auth_value - if getattr(gateway, "auth_value", {}) != {}: - gateway.auth_value = gateway_update.auth_value + if getattr(gateway, "auth_value", "") != "": + token = gateway_update.auth_token + password = gateway_update.auth_password + header_value = gateway_update.auth_header_value + + if settings.masked_auth_value not in (token, password, header_value): + # Check if values differ from existing ones + if gateway.auth_value != gateway_update.auth_value: + gateway.auth_value = gateway_update.auth_value # Try to reinitialize connection if URL changed if gateway_update.url is not None: @@ -557,7 +570,7 @@ async def update_gateway(self, db: Session, gateway_id: str, gateway_update: Gat await self._notify_gateway_updated(gateway) logger.info(f"Updated gateway: {gateway.name}") - return GatewayRead.model_validate(gateway) + return GatewayRead.model_validate(gateway).masked() except Exception as e: db.rollback() @@ -578,7 +591,6 @@ async def get_gateway(self, db: Session, gateway_id: str, include_inactive: bool GatewayNotFoundError: If the gateway is not found Examples: - >>> from mcpgateway.services.gateway_service import GatewayService >>> from unittest.mock import MagicMock >>> from mcpgateway.schemas import GatewayRead >>> service = GatewayService() @@ -586,7 +598,9 @@ async def get_gateway(self, db: Session, gateway_id: str, include_inactive: bool >>> gateway_mock = MagicMock() >>> gateway_mock.enabled = True >>> db.get.return_value = gateway_mock - >>> GatewayRead.model_validate = MagicMock(return_value='gateway_read') + >>> mocked_gateway_read = MagicMock() + >>> mocked_gateway_read.masked.return_value = 'gateway_read' + >>> GatewayRead.model_validate = MagicMock(return_value=mocked_gateway_read) >>> import asyncio >>> result = asyncio.run(service.get_gateway(db, 'gateway_id')) >>> result == 'gateway_read' @@ -620,7 +634,7 @@ async def get_gateway(self, db: Session, gateway_id: str, include_inactive: bool raise GatewayNotFoundError(f"Gateway not found: {gateway_id}") if gateway.enabled or include_inactive: - return GatewayRead.model_validate(gateway) + return GatewayRead.model_validate(gateway).masked() raise GatewayNotFoundError(f"Gateway not found: {gateway_id}") @@ -708,7 +722,7 @@ async def toggle_gateway_status(self, db: Session, gateway_id: str, activate: bo logger.info(f"Gateway status: {gateway.name} - {'enabled' if activate else 'disabled'} and {'accessible' if reachable else 'inaccessible'}") - return GatewayRead.model_validate(gateway) + return GatewayRead.model_validate(gateway).masked() except Exception as e: db.rollback() diff --git a/tests/unit/mcpgateway/services/test_gateway_service.py b/tests/unit/mcpgateway/services/test_gateway_service.py index fa20e50b9..0ba91b22f 100644 --- a/tests/unit/mcpgateway/services/test_gateway_service.py +++ b/tests/unit/mcpgateway/services/test_gateway_service.py @@ -17,7 +17,7 @@ # Standard from datetime import datetime, timezone -from unittest.mock import AsyncMock, MagicMock, Mock +from unittest.mock import AsyncMock, MagicMock, Mock, patch # Third-Party import pytest @@ -146,7 +146,7 @@ class TestGatewayService: # ──────────────────────────────────────────────────────────────────── @pytest.mark.asyncio - async def test_register_gateway(self, gateway_service, test_db): + async def test_register_gateway(self, gateway_service, test_db, monkeypatch): """Successful gateway registration populates DB and returns data.""" # DB: no gateway with that name; no existing tools found test_db.execute = Mock( @@ -172,6 +172,18 @@ async def test_register_gateway(self, gateway_service, test_db): ) gateway_service._notify_gateway_added = AsyncMock() + # Patch GatewayRead.model_validate to return a mock with .masked() + mock_model = Mock() + mock_model.masked.return_value = mock_model + mock_model.name = "test_gateway" + mock_model.url = "http://example.com/gateway" + mock_model.description = "A test gateway" + + monkeypatch.setattr( + "mcpgateway.services.gateway_service.GatewayRead.model_validate", + lambda x: mock_model, + ) + gateway_create = GatewayCreate( name="test_gateway", url="http://example.com/gateway", @@ -236,10 +248,18 @@ async def test_register_gateway_connection_error(self, gateway_service, test_db) # ──────────────────────────────────────────────────────────────────── @pytest.mark.asyncio - async def test_list_gateways(self, gateway_service, mock_gateway, test_db): + async def test_list_gateways(self, gateway_service, mock_gateway, test_db, monkeypatch): """Listing gateways returns the active ones.""" + test_db.execute = Mock(return_value=_make_execute_result(scalars_list=[mock_gateway])) + mock_model = Mock() + mock_model.masked.return_value = mock_model + mock_model.name = "test_gateway" + + # Patch using full path string to GatewayRead.model_validate + monkeypatch.setattr("mcpgateway.services.gateway_service.GatewayRead.model_validate", lambda x: mock_model) + result = await gateway_service.list_gateways(test_db) test_db.execute.assert_called_once() @@ -249,6 +269,7 @@ async def test_list_gateways(self, gateway_service, mock_gateway, test_db): @pytest.mark.asyncio async def test_get_gateway(self, gateway_service, mock_gateway, test_db): """Gateway is fetched and returned by ID.""" + mock_gateway.masked = Mock(return_value=mock_gateway) test_db.get = Mock(return_value=mock_gateway) result = await gateway_service.get_gateway(test_db, 1) test_db.get.assert_called_once_with(DbGateway, 1) @@ -266,14 +287,24 @@ async def test_get_gateway_not_found(self, gateway_service, test_db): async def test_get_gateway_inactive(self, gateway_service, mock_gateway, test_db): """Inactive gateway is not returned unless explicitly asked for.""" mock_gateway.enabled = False + mock_gateway.id = 1 test_db.get = Mock(return_value=mock_gateway) - result = await gateway_service.get_gateway(test_db, 1, include_inactive=True) - assert result.id == 1 - assert result.enabled == False - test_db.get.reset_mock() - test_db.get = Mock(return_value=mock_gateway) - with pytest.raises(GatewayNotFoundError): - result = await gateway_service.get_gateway(test_db, 1, include_inactive=False) + + # Create a mock for GatewayRead with a masked method + mock_gateway_read = Mock() + mock_gateway_read.id = 1 + mock_gateway_read.enabled = False + mock_gateway_read.masked = Mock(return_value=mock_gateway_read) + + with patch("mcpgateway.services.gateway_service.GatewayRead.model_validate", return_value=mock_gateway_read): + result = await gateway_service.get_gateway(test_db, 1, include_inactive=True) + assert result.id == 1 + assert result.enabled == False + + # Now test the inactive = False path + test_db.get = Mock(return_value=mock_gateway) + with pytest.raises(GatewayNotFoundError): + await gateway_service.get_gateway(test_db, 1, include_inactive=False) # ──────────────────────────────────────────────────────────────────── # UPDATE @@ -288,22 +319,36 @@ async def test_update_gateway(self, gateway_service, mock_gateway, test_db): test_db.commit = Mock() test_db.refresh = Mock() + # Simulate successful gateway initialization gateway_service._initialize_gateway = AsyncMock( return_value=( - {"prompts": {"subscribe": True}, "resources": {"subscribe": True}, "tools": {"subscribe": True}}, + { + "prompts": {"subscribe": True}, + "resources": {"subscribe": True}, + "tools": {"subscribe": True}, + }, [], ) ) gateway_service._notify_gateway_updated = AsyncMock() + # Create the update payload gateway_update = GatewayUpdate( name="updated_gateway", url="http://example.com/updated", description="Updated description", ) - result = await gateway_service.update_gateway(test_db, 1, gateway_update) + # Create mock return for GatewayRead.model_validate().masked() + mock_gateway_read = MagicMock() + mock_gateway_read.name = "updated_gateway" + mock_gateway_read.masked.return_value = mock_gateway_read # Ensure .masked() returns the same object + + # Patch the model_validate call in the service + with patch("mcpgateway.services.gateway_service.GatewayRead.model_validate", return_value=mock_gateway_read): + result = await gateway_service.update_gateway(test_db, 1, gateway_update) + # Assertions test_db.commit.assert_called_once() test_db.refresh.assert_called_once() gateway_service._initialize_gateway.assert_called_once() @@ -354,6 +399,7 @@ async def test_toggle_gateway_status(self, gateway_service, mock_gateway, test_d query_proxy.filter.return_value = filter_proxy test_db.query = Mock(return_value=query_proxy) + # Setup gateway service mocks gateway_service._notify_gateway_activated = AsyncMock() gateway_service._notify_gateway_deactivated = AsyncMock() gateway_service._initialize_gateway = AsyncMock(return_value=({"prompts": {}}, [])) @@ -362,12 +408,17 @@ async def test_toggle_gateway_status(self, gateway_service, mock_gateway, test_d tool_service_stub.toggle_tool_status = AsyncMock() gateway_service.tool_service = tool_service_stub - result = await gateway_service.toggle_gateway_status(test_db, 1, activate=False) + # Patch model_validate to return a mock with .masked() + mock_gateway_read = MagicMock() + mock_gateway_read.masked.return_value = mock_gateway_read + + with patch("mcpgateway.services.gateway_service.GatewayRead.model_validate", return_value=mock_gateway_read): + result = await gateway_service.toggle_gateway_status(test_db, 1, activate=False) assert mock_gateway.enabled is False gateway_service._notify_gateway_deactivated.assert_called_once() assert tool_service_stub.toggle_tool_status.called - assert result.enabled is False + assert result == mock_gateway_read # ──────────────────────────────────────────────────────────────────── # DELETE diff --git a/tests/unit/mcpgateway/test_admin.py b/tests/unit/mcpgateway/test_admin.py index 6bf00f968..21e84e3c3 100644 --- a/tests/unit/mcpgateway/test_admin.py +++ b/tests/unit/mcpgateway/test_admin.py @@ -746,7 +746,8 @@ async def test_admin_list_gateways_with_auth_info(self, mock_list_gateways, mock "transport": "HTTP", "enabled": True, "auth_type": "bearer", - "auth_token": "hidden", # Should be masked + "auth_token": "Bearer hidden", # Should be masked + "auth_value": "Some value", } mock_list_gateways.return_value = [mock_gateway] @@ -764,6 +765,8 @@ async def test_admin_get_gateway_all_transports(self, mock_get_gateway, mock_db) mock_gateway.model_dump.return_value = { "id": f"gateway-{transport}", "transport": transport, + "name": f"Gateway {transport}", # Add this field + "url": f"https://gateway-{transport}.com", # Add this field } mock_get_gateway.return_value = mock_gateway