Skip to content
Merged
9 changes: 6 additions & 3 deletions mcpgateway/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,7 +1018,8 @@ async def admin_list_gateways(
... updated_at=datetime.now(timezone.utc),
... is_active=True,
... auth_type=None, auth_username=None, auth_password=None, auth_token=None,
... auth_header_key=None, auth_header_value=None
... auth_header_key=None, auth_header_value=None,
... slug="test-gateway"
... )
>>>
>>> # Mock the gateway_service.list_gateways method
Expand All @@ -1039,7 +1040,8 @@ async def admin_list_gateways(
... description="Another test", transport="HTTP", created_at=datetime.now(timezone.utc),
... updated_at=datetime.now(timezone.utc), enabled=False,
... auth_type=None, auth_username=None, auth_password=None, auth_token=None,
... auth_header_key=None, auth_header_value=None
... auth_header_key=None, auth_header_value=None,
... slug="test-gateway"
... )
>>> gateway_service.list_gateways = AsyncMock(return_value=[
... mock_gateway, # Return the GatewayRead objects, not pre-dumped dicts
Expand Down Expand Up @@ -2168,7 +2170,8 @@ async def admin_get_gateway(gateway_id: str, db: Session = Depends(get_db), user
... description="Gateway for getting", transport="HTTP",
... created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc),
... enabled=True, auth_type=None, auth_username=None, auth_password=None,
... auth_token=None, auth_header_key=None, auth_header_value=None
... auth_token=None, auth_header_key=None, auth_header_value=None,
... slug="test-gateway"
... )
>>>
>>> # Mock the gateway_service.get_gateway method
Expand Down
3 changes: 3 additions & 0 deletions mcpgateway/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,9 @@ def validate_database(self) -> None:
# Rate limiting
validation_max_requests_per_minute: int = 60

# Masking value for all sensitive data
masked_auth_value: str = "*****"


def extract_using_jq(data, jq_filter=""):
"""
Expand Down
40 changes: 40 additions & 0 deletions mcpgateway/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -1882,6 +1882,7 @@ def _process_auth_fields(values: Dict[str, Any]) -> Optional[Dict[str, Any]]:
Raises:
ValueError: If auth type is invalid
"""

auth_type = values.data.get("auth_type")

if auth_type == "basic":
Expand Down Expand Up @@ -2032,6 +2033,11 @@ def _populate_auth(cls, values: Self) -> Dict[str, Any]:
"""
auth_type = values.auth_type
auth_value_encoded = values.auth_value

# Skip validation logic if masked value
if auth_value_encoded == settings.masked_auth_value:
return values

auth_value = decode_auth(auth_value_encoded)
if auth_type == "basic":
auth = auth_value.get("Authorization")
Expand All @@ -2056,6 +2062,40 @@ def _populate_auth(cls, values: Self) -> Dict[str, Any]:

return values

def masked(self) -> "GatewayRead":
"""
Return a masked version of the model instance with sensitive authentication fields hidden.

This method creates a dictionary representation of the model data and replaces sensitive fields
such as `auth_value`, `auth_password`, `auth_token`, and `auth_header_value` with a masked
placeholder value defined in `settings.masked_auth_value`. Masking is only applied if the fields
are present and not already masked.

Args:
None

Returns:
GatewayRead: A new instance of the GatewayRead model with sensitive authentication-related fields
masked to prevent exposure of sensitive information.

Notes:
- The `auth_value` field is only masked if it exists and its value is different from the masking
placeholder.
- Other sensitive fields (`auth_password`, `auth_token`, `auth_header_value`) are masked if present.
- Fields not related to authentication remain unchanged.
"""
masked_data = self.model_dump()

# Only mask if auth_value is present and not already masked
if masked_data.get("auth_value") and masked_data["auth_value"] != settings.masked_auth_value:
masked_data["auth_value"] = settings.masked_auth_value

masked_data["auth_password"] = settings.masked_auth_value if masked_data.get("auth_password") else None
masked_data["auth_token"] = settings.masked_auth_value if masked_data.get("auth_token") else None
masked_data["auth_header_value"] = settings.masked_auth_value if masked_data.get("auth_header_value") else None

return GatewayRead.model_validate(masked_data)


class FederatedTool(BaseModelWithConfigDict):
"""Schema for tools provided by federated gateways.
Expand Down
34 changes: 24 additions & 10 deletions mcpgateway/services/gateway_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ async def register_gateway(self, db: Session, gateway: GatewayCreate) -> Gateway
# Notify subscribers
await self._notify_gateway_added(db_gateway)

return GatewayRead.model_validate(db_gateway)
return GatewayRead.model_validate(db_gateway).masked()
except* GatewayConnectionError as ge:
if TYPE_CHECKING:
ge: ExceptionGroup[GatewayConnectionError]
Expand Down Expand Up @@ -436,7 +436,9 @@ async def list_gateways(self, db: Session, include_inactive: bool = False) -> Li
>>> db = MagicMock()
>>> gateway_obj = MagicMock()
>>> db.execute.return_value.scalars.return_value.all.return_value = [gateway_obj]
>>> GatewayRead.model_validate = MagicMock(return_value='gateway_read')
>>> mocked_gateway_read = MagicMock()
>>> mocked_gateway_read.masked.return_value = 'gateway_read'
>>> GatewayRead.model_validate = MagicMock(return_value=mocked_gateway_read)
>>> import asyncio
>>> result = asyncio.run(service.list_gateways(db))
>>> result == ['gateway_read']
Expand All @@ -459,7 +461,7 @@ async def list_gateways(self, db: Session, include_inactive: bool = False) -> Li
query = query.where(DbGateway.enabled)

gateways = db.execute(query).scalars().all()
return [GatewayRead.model_validate(g) for g in gateways]
return [GatewayRead.model_validate(g).masked() for g in gateways]

async def update_gateway(self, db: Session, gateway_id: str, gateway_update: GatewayUpdate, include_inactive: bool = True) -> GatewayRead:
"""Update a gateway.
Expand Down Expand Up @@ -510,9 +512,20 @@ async def update_gateway(self, db: Session, gateway_id: str, gateway_update: Gat
if getattr(gateway, "auth_type", None) is not None:
gateway.auth_type = gateway_update.auth_type

# If auth_type is empty, update the auth_value too
if gateway_update.auth_type == "":
gateway.auth_value = ""

# if auth_type is not None and only then check auth_value
if getattr(gateway, "auth_value", {}) != {}:
gateway.auth_value = gateway_update.auth_value
if getattr(gateway, "auth_value", "") != "":
token = gateway_update.auth_token
password = gateway_update.auth_password
header_value = gateway_update.auth_header_value

if settings.masked_auth_value not in (token, password, header_value):
# Check if values differ from existing ones
if gateway.auth_value != gateway_update.auth_value:
gateway.auth_value = gateway_update.auth_value

# Try to reinitialize connection if URL changed
if gateway_update.url is not None:
Expand Down Expand Up @@ -557,7 +570,7 @@ async def update_gateway(self, db: Session, gateway_id: str, gateway_update: Gat
await self._notify_gateway_updated(gateway)

logger.info(f"Updated gateway: {gateway.name}")
return GatewayRead.model_validate(gateway)
return GatewayRead.model_validate(gateway).masked()

except Exception as e:
db.rollback()
Expand All @@ -578,15 +591,16 @@ async def get_gateway(self, db: Session, gateway_id: str, include_inactive: bool
GatewayNotFoundError: If the gateway is not found

Examples:
>>> from mcpgateway.services.gateway_service import GatewayService
>>> from unittest.mock import MagicMock
>>> from mcpgateway.schemas import GatewayRead
>>> service = GatewayService()
>>> db = MagicMock()
>>> gateway_mock = MagicMock()
>>> gateway_mock.enabled = True
>>> db.get.return_value = gateway_mock
>>> GatewayRead.model_validate = MagicMock(return_value='gateway_read')
>>> mocked_gateway_read = MagicMock()
>>> mocked_gateway_read.masked.return_value = 'gateway_read'
>>> GatewayRead.model_validate = MagicMock(return_value=mocked_gateway_read)
>>> import asyncio
>>> result = asyncio.run(service.get_gateway(db, 'gateway_id'))
>>> result == 'gateway_read'
Expand Down Expand Up @@ -620,7 +634,7 @@ async def get_gateway(self, db: Session, gateway_id: str, include_inactive: bool
raise GatewayNotFoundError(f"Gateway not found: {gateway_id}")

if gateway.enabled or include_inactive:
return GatewayRead.model_validate(gateway)
return GatewayRead.model_validate(gateway).masked()

raise GatewayNotFoundError(f"Gateway not found: {gateway_id}")

Expand Down Expand Up @@ -708,7 +722,7 @@ async def toggle_gateway_status(self, db: Session, gateway_id: str, activate: bo

logger.info(f"Gateway status: {gateway.name} - {'enabled' if activate else 'disabled'} and {'accessible' if reachable else 'inaccessible'}")

return GatewayRead.model_validate(gateway)
return GatewayRead.model_validate(gateway).masked()

except Exception as e:
db.rollback()
Expand Down
79 changes: 65 additions & 14 deletions tests/unit/mcpgateway/services/test_gateway_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

# Standard
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, Mock
from unittest.mock import AsyncMock, MagicMock, Mock, patch

# Third-Party
import pytest
Expand Down Expand Up @@ -146,7 +146,7 @@ class TestGatewayService:
# ────────────────────────────────────────────────────────────────────

@pytest.mark.asyncio
async def test_register_gateway(self, gateway_service, test_db):
async def test_register_gateway(self, gateway_service, test_db, monkeypatch):
"""Successful gateway registration populates DB and returns data."""
# DB: no gateway with that name; no existing tools found
test_db.execute = Mock(
Expand All @@ -172,6 +172,18 @@ async def test_register_gateway(self, gateway_service, test_db):
)
gateway_service._notify_gateway_added = AsyncMock()

# Patch GatewayRead.model_validate to return a mock with .masked()
mock_model = Mock()
mock_model.masked.return_value = mock_model
mock_model.name = "test_gateway"
mock_model.url = "http://example.com/gateway"
mock_model.description = "A test gateway"

monkeypatch.setattr(
"mcpgateway.services.gateway_service.GatewayRead.model_validate",
lambda x: mock_model,
)

gateway_create = GatewayCreate(
name="test_gateway",
url="http://example.com/gateway",
Expand Down Expand Up @@ -236,10 +248,18 @@ async def test_register_gateway_connection_error(self, gateway_service, test_db)
# ────────────────────────────────────────────────────────────────────

@pytest.mark.asyncio
async def test_list_gateways(self, gateway_service, mock_gateway, test_db):
async def test_list_gateways(self, gateway_service, mock_gateway, test_db, monkeypatch):
"""Listing gateways returns the active ones."""

test_db.execute = Mock(return_value=_make_execute_result(scalars_list=[mock_gateway]))

mock_model = Mock()
mock_model.masked.return_value = mock_model
mock_model.name = "test_gateway"

# Patch using full path string to GatewayRead.model_validate
monkeypatch.setattr("mcpgateway.services.gateway_service.GatewayRead.model_validate", lambda x: mock_model)

result = await gateway_service.list_gateways(test_db)

test_db.execute.assert_called_once()
Expand All @@ -249,6 +269,7 @@ async def test_list_gateways(self, gateway_service, mock_gateway, test_db):
@pytest.mark.asyncio
async def test_get_gateway(self, gateway_service, mock_gateway, test_db):
"""Gateway is fetched and returned by ID."""
mock_gateway.masked = Mock(return_value=mock_gateway)
test_db.get = Mock(return_value=mock_gateway)
result = await gateway_service.get_gateway(test_db, 1)
test_db.get.assert_called_once_with(DbGateway, 1)
Expand All @@ -266,14 +287,24 @@ async def test_get_gateway_not_found(self, gateway_service, test_db):
async def test_get_gateway_inactive(self, gateway_service, mock_gateway, test_db):
"""Inactive gateway is not returned unless explicitly asked for."""
mock_gateway.enabled = False
mock_gateway.id = 1
test_db.get = Mock(return_value=mock_gateway)
result = await gateway_service.get_gateway(test_db, 1, include_inactive=True)
assert result.id == 1
assert result.enabled == False
test_db.get.reset_mock()
test_db.get = Mock(return_value=mock_gateway)
with pytest.raises(GatewayNotFoundError):
result = await gateway_service.get_gateway(test_db, 1, include_inactive=False)

# Create a mock for GatewayRead with a masked method
mock_gateway_read = Mock()
mock_gateway_read.id = 1
mock_gateway_read.enabled = False
mock_gateway_read.masked = Mock(return_value=mock_gateway_read)

with patch("mcpgateway.services.gateway_service.GatewayRead.model_validate", return_value=mock_gateway_read):
result = await gateway_service.get_gateway(test_db, 1, include_inactive=True)
assert result.id == 1
assert result.enabled == False

# Now test the inactive = False path
test_db.get = Mock(return_value=mock_gateway)
with pytest.raises(GatewayNotFoundError):
await gateway_service.get_gateway(test_db, 1, include_inactive=False)

# ────────────────────────────────────────────────────────────────────
# UPDATE
Expand All @@ -288,22 +319,36 @@ async def test_update_gateway(self, gateway_service, mock_gateway, test_db):
test_db.commit = Mock()
test_db.refresh = Mock()

# Simulate successful gateway initialization
gateway_service._initialize_gateway = AsyncMock(
return_value=(
{"prompts": {"subscribe": True}, "resources": {"subscribe": True}, "tools": {"subscribe": True}},
{
"prompts": {"subscribe": True},
"resources": {"subscribe": True},
"tools": {"subscribe": True},
},
[],
)
)
gateway_service._notify_gateway_updated = AsyncMock()

# Create the update payload
gateway_update = GatewayUpdate(
name="updated_gateway",
url="http://example.com/updated",
description="Updated description",
)

result = await gateway_service.update_gateway(test_db, 1, gateway_update)
# Create mock return for GatewayRead.model_validate().masked()
mock_gateway_read = MagicMock()
mock_gateway_read.name = "updated_gateway"
mock_gateway_read.masked.return_value = mock_gateway_read # Ensure .masked() returns the same object

# Patch the model_validate call in the service
with patch("mcpgateway.services.gateway_service.GatewayRead.model_validate", return_value=mock_gateway_read):
result = await gateway_service.update_gateway(test_db, 1, gateway_update)

# Assertions
test_db.commit.assert_called_once()
test_db.refresh.assert_called_once()
gateway_service._initialize_gateway.assert_called_once()
Expand Down Expand Up @@ -354,6 +399,7 @@ async def test_toggle_gateway_status(self, gateway_service, mock_gateway, test_d
query_proxy.filter.return_value = filter_proxy
test_db.query = Mock(return_value=query_proxy)

# Setup gateway service mocks
gateway_service._notify_gateway_activated = AsyncMock()
gateway_service._notify_gateway_deactivated = AsyncMock()
gateway_service._initialize_gateway = AsyncMock(return_value=({"prompts": {}}, []))
Expand All @@ -362,12 +408,17 @@ async def test_toggle_gateway_status(self, gateway_service, mock_gateway, test_d
tool_service_stub.toggle_tool_status = AsyncMock()
gateway_service.tool_service = tool_service_stub

result = await gateway_service.toggle_gateway_status(test_db, 1, activate=False)
# Patch model_validate to return a mock with .masked()
mock_gateway_read = MagicMock()
mock_gateway_read.masked.return_value = mock_gateway_read

with patch("mcpgateway.services.gateway_service.GatewayRead.model_validate", return_value=mock_gateway_read):
result = await gateway_service.toggle_gateway_status(test_db, 1, activate=False)

assert mock_gateway.enabled is False
gateway_service._notify_gateway_deactivated.assert_called_once()
assert tool_service_stub.toggle_tool_status.called
assert result.enabled is False
assert result == mock_gateway_read

# ────────────────────────────────────────────────────────────────────
# DELETE
Expand Down
5 changes: 4 additions & 1 deletion tests/unit/mcpgateway/test_admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,8 @@ async def test_admin_list_gateways_with_auth_info(self, mock_list_gateways, mock
"transport": "HTTP",
"enabled": True,
"auth_type": "bearer",
"auth_token": "hidden", # Should be masked
"auth_token": "Bearer hidden", # Should be masked
"auth_value": "Some value",
}

mock_list_gateways.return_value = [mock_gateway]
Expand All @@ -764,6 +765,8 @@ async def test_admin_get_gateway_all_transports(self, mock_get_gateway, mock_db)
mock_gateway.model_dump.return_value = {
"id": f"gateway-{transport}",
"transport": transport,
"name": f"Gateway {transport}", # Add this field
"url": f"https://gateway-{transport}.com", # Add this field
}
mock_get_gateway.return_value = mock_gateway

Expand Down
Loading