Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mcpgateway/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ async def get_current_user(credentials: Optional[HTTPAuthorizationCredentials] =
logger = logging.getLogger(__name__)

if not credentials:
logger.debug("No credentials provided")
logger.warning("No credentials provided")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication required",
Expand Down
30 changes: 22 additions & 8 deletions mcpgateway/routers/oauth_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
from sqlalchemy.orm import Session

# First-Party
from mcpgateway.auth import get_current_user
from mcpgateway.db import Gateway, get_db
from mcpgateway.middleware.rbac import get_current_user_with_permissions
from mcpgateway.schemas import EmailUserResponse
from mcpgateway.services.oauth_manager import OAuthError, OAuthManager
from mcpgateway.services.token_storage_service import TokenStorageService
Expand All @@ -35,7 +35,7 @@


@oauth_router.get("/authorize/{gateway_id}")
async def initiate_oauth_flow(gateway_id: str, request: Request, current_user: EmailUserResponse = Depends(get_current_user), db: Session = Depends(get_db)) -> RedirectResponse:
async def initiate_oauth_flow(gateway_id: str, request: Request, current_user: EmailUserResponse = Depends(get_current_user_with_permissions), db: Session = Depends(get_db)) -> RedirectResponse:
"""Initiates the OAuth 2.0 Authorization Code flow for a specified gateway.

This endpoint retrieves the OAuth configuration for the given gateway, validates that
Expand Down Expand Up @@ -75,9 +75,9 @@ async def initiate_oauth_flow(gateway_id: str, request: Request, current_user: E

# Initiate OAuth flow with user context
oauth_manager = OAuthManager(token_storage=TokenStorageService(db))
auth_data = await oauth_manager.initiate_authorization_code_flow(gateway_id, gateway.oauth_config, app_user_email=current_user.email)
auth_data = await oauth_manager.initiate_authorization_code_flow(gateway_id, gateway.oauth_config, app_user_email=current_user.get("email"))

logger.info(f"Initiated OAuth flow for gateway {gateway_id} by user {current_user.email}")
logger.info(f"Initiated OAuth flow for gateway {gateway_id} by user {current_user.get('email')}")

# Redirect user to OAuth provider
return RedirectResponse(url=auth_data["authorization_url"])
Expand Down Expand Up @@ -132,8 +132,22 @@ async def oauth_callback(
import json

try:
state_decoded = base64.urlsafe_b64decode(state.encode()).decode()
state_data = json.loads(state_decoded)
# Expect state as base64url(payload || signature) where the last 32 bytes
# are the signature. Decode to bytes first so we can split payload vs sig.
state_raw = base64.urlsafe_b64decode(state.encode())
if len(state_raw) <= 32:
raise ValueError("State too short to contain payload and signature")

# Split payload and signature. Signature is the last 32 bytes.
payload_bytes = state_raw[:-32]
# signature_bytes = state_raw[-32:]

# Parse the JSON payload only (not including signature bytes)
try:
state_data = json.loads(payload_bytes.decode())
except Exception as decode_exc:
raise ValueError(f"Failed to parse state payload JSON: {decode_exc}")

gateway_id = state_data.get("gateway_id")
if not gateway_id:
raise ValueError("No gateway_id in state")
Expand Down Expand Up @@ -403,7 +417,7 @@ async def get_oauth_status(gateway_id: str, db: Session = Depends(get_db)) -> di


@oauth_router.post("/fetch-tools/{gateway_id}")
async def fetch_tools_after_oauth(gateway_id: str, current_user: EmailUserResponse = Depends(get_current_user), db: Session = Depends(get_db)) -> Dict[str, Any]:
async def fetch_tools_after_oauth(gateway_id: str, current_user: EmailUserResponse = Depends(get_current_user_with_permissions), db: Session = Depends(get_db)) -> Dict[str, Any]:
"""Fetch tools from MCP server after OAuth completion for Authorization Code flow.

Args:
Expand All @@ -422,7 +436,7 @@ async def fetch_tools_after_oauth(gateway_id: str, current_user: EmailUserRespon
from mcpgateway.services.gateway_service import GatewayService

gateway_service = GatewayService()
result = await gateway_service.fetch_tools_after_oauth(db, gateway_id, current_user.email)
result = await gateway_service.fetch_tools_after_oauth(db, gateway_id, current_user.get("email"))
tools_count = len(result.get("tools", []))

return {"success": True, "message": f"Successfully fetched and created {tools_count} tools"}
Expand Down
29 changes: 25 additions & 4 deletions mcpgateway/services/oauth_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,8 +604,20 @@ async def _validate_authorization_state(self, gateway_id: str, state: str) -> bo

state_data = json.loads(state_json)

# Parse expires_at as timezone-aware datetime. If the stored value
# is naive, assume UTC for compatibility.
try:
expires_at = datetime.fromisoformat(state_data["expires_at"])
except Exception:
# Fallback: try parsing without microseconds/offsets
expires_at = datetime.strptime(state_data["expires_at"], "%Y-%m-%dT%H:%M:%S")

if expires_at.tzinfo is None:
# Assume UTC for naive timestamps
expires_at = expires_at.replace(tzinfo=timezone.utc)

# Check if state has expired
if datetime.fromisoformat(state_data["expires_at"]) < datetime.now(timezone.utc):
if expires_at < datetime.now(timezone.utc):
logger.warning(f"State has expired for gateway {gateway_id}")
return False

Expand Down Expand Up @@ -636,7 +648,12 @@ async def _validate_authorization_state(self, gateway_id: str, state: str) -> bo
return False

# Check if state has expired
if oauth_state.expires_at < datetime.now(timezone.utc):
# Ensure oauth_state.expires_at is timezone-aware. If naive, assume UTC.
expires_at = oauth_state.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)

if expires_at < datetime.now(timezone.utc):
logger.warning(f"State has expired for gateway {gateway_id}")
db.delete(oauth_state)
db.commit()
Expand Down Expand Up @@ -667,8 +684,12 @@ async def _validate_authorization_state(self, gateway_id: str, state: str) -> bo
logger.warning(f"State not found in memory for gateway {gateway_id}")
return False

# Check if state has expired
if datetime.fromisoformat(state_data["expires_at"]) < datetime.now(timezone.utc):
# Parse and normalize expires_at to timezone-aware datetime
expires_at = datetime.fromisoformat(state_data["expires_at"])
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)

if expires_at < datetime.now(timezone.utc):
logger.warning(f"State has expired for gateway {gateway_id}")
del _oauth_states[state_key] # Clean up expired state
return False
Expand Down
44 changes: 37 additions & 7 deletions tests/unit/mcpgateway/routers/test_oauth_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def mock_gateway(self):
def mock_current_user(self):
"""Create mock current user."""
user = Mock(spec=EmailUserResponse)
user.get = Mock(return_value="[email protected]")
user.email = "[email protected]"
user.full_name = "Test User"
user.is_active = True
Expand Down Expand Up @@ -106,7 +107,7 @@ async def test_initiate_oauth_flow_success(self, mock_db, mock_request, mock_gat

mock_oauth_manager_class.assert_called_once_with(token_storage=mock_token_storage)
mock_oauth_manager.initiate_authorization_code_flow.assert_called_once_with(
"gateway123", mock_gateway.oauth_config, app_user_email="[email protected]"
"gateway123", mock_gateway.oauth_config, app_user_email=mock_current_user.get("email")
)

@pytest.mark.asyncio
Expand Down Expand Up @@ -194,9 +195,11 @@ async def test_oauth_callback_success(self, mock_db, mock_request, mock_gateway)
import base64
import json

# Setup state with new format
# Setup state with new format (payload + 32-byte signature)
state_data = {"gateway_id": "gateway123", "app_user_email": "[email protected]", "nonce": "abc123"}
state = base64.urlsafe_b64encode(json.dumps(state_data).encode()).decode()
payload = json.dumps(state_data).encode()
signature = b'x' * 32 # Mock 32-byte signature
state = base64.urlsafe_b64encode(payload + signature).decode()

mock_db.execute.return_value.scalar_one_or_none.return_value = mock_gateway

Expand Down Expand Up @@ -266,6 +269,27 @@ async def test_oauth_callback_invalid_state(self, mock_db, mock_request):
assert result.status_code == 400
assert "Invalid state parameter" in result.body.decode()

@pytest.mark.asyncio
async def test_oauth_callback_state_too_short(self, mock_db, mock_request):
"""Test OAuth callback with state that's too short to contain signature."""
# Standard
import base64

# Setup - create state with less than 32 bytes total
short_payload = b"short"
state = base64.urlsafe_b64encode(short_payload).decode()

# First-Party
from mcpgateway.routers.oauth_router import oauth_callback

# Execute
result = await oauth_callback(code="auth_code_123", state=state, request=mock_request, db=mock_db)

# Assert
assert isinstance(result, HTMLResponse)
assert result.status_code == 400
assert "Invalid state parameter" in result.body.decode()

@pytest.mark.asyncio
async def test_oauth_callback_gateway_not_found(self, mock_db, mock_request):
"""Test OAuth callback when gateway is not found."""
Expand All @@ -275,7 +299,9 @@ async def test_oauth_callback_gateway_not_found(self, mock_db, mock_request):

# Setup
state_data = {"gateway_id": "nonexistent", "app_user_email": "[email protected]"}
state = base64.urlsafe_b64encode(json.dumps(state_data).encode()).decode()
payload = json.dumps(state_data).encode()
signature = b'x' * 32 # Mock 32-byte signature
state = base64.urlsafe_b64encode(payload + signature).decode()

mock_db.execute.return_value.scalar_one_or_none.return_value = None

Expand All @@ -299,7 +325,9 @@ async def test_oauth_callback_no_oauth_config(self, mock_db, mock_request):

# Setup
state_data = {"gateway_id": "gateway123", "app_user_email": "[email protected]"}
state = base64.urlsafe_b64encode(json.dumps(state_data).encode()).decode()
payload = json.dumps(state_data).encode()
signature = b'x' * 32 # Mock 32-byte signature
state = base64.urlsafe_b64encode(payload + signature).decode()

mock_gateway = Mock(spec=Gateway)
mock_gateway.id = "gateway123"
Expand All @@ -326,7 +354,9 @@ async def test_oauth_callback_oauth_error(self, mock_db, mock_request, mock_gate

# Setup
state_data = {"gateway_id": "gateway123", "app_user_email": "[email protected]"}
state = base64.urlsafe_b64encode(json.dumps(state_data).encode()).decode()
payload = json.dumps(state_data).encode()
signature = b'x' * 32 # Mock 32-byte signature
state = base64.urlsafe_b64encode(payload + signature).decode()

mock_db.execute.return_value.scalar_one_or_none.return_value = mock_gateway

Expand Down Expand Up @@ -412,7 +442,7 @@ async def test_fetch_tools_after_oauth_success(self, mock_db, mock_current_user)
# Assert
assert result["success"] is True
assert "Successfully fetched and created 3 tools" in result["message"]
mock_gateway_service.fetch_tools_after_oauth.assert_called_once_with(mock_db, "gateway123", "[email protected]")
mock_gateway_service.fetch_tools_after_oauth.assert_called_once_with(mock_db, "gateway123", mock_current_user.get("email"))

@pytest.mark.asyncio
async def test_fetch_tools_after_oauth_no_tools(self, mock_db, mock_current_user):
Expand Down
Loading