diff --git a/mcpgateway/auth.py b/mcpgateway/auth.py index 070064786..41988a439 100644 --- a/mcpgateway/auth.py +++ b/mcpgateway/auth.py @@ -67,7 +67,7 @@ async def get_current_user(credentials: Optional[HTTPAuthorizationCredentials] = logger = logging.getLogger(__name__) if not credentials: - logger.debug("No credentials provided") + logger.warning("No credentials provided") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication required", diff --git a/mcpgateway/routers/oauth_router.py b/mcpgateway/routers/oauth_router.py index 988bef373..9d84089dc 100644 --- a/mcpgateway/routers/oauth_router.py +++ b/mcpgateway/routers/oauth_router.py @@ -23,8 +23,8 @@ from sqlalchemy.orm import Session # First-Party -from mcpgateway.auth import get_current_user from mcpgateway.db import Gateway, get_db +from mcpgateway.middleware.rbac import get_current_user_with_permissions from mcpgateway.schemas import EmailUserResponse from mcpgateway.services.oauth_manager import OAuthError, OAuthManager from mcpgateway.services.token_storage_service import TokenStorageService @@ -35,7 +35,7 @@ @oauth_router.get("/authorize/{gateway_id}") -async def initiate_oauth_flow(gateway_id: str, request: Request, current_user: EmailUserResponse = Depends(get_current_user), db: Session = Depends(get_db)) -> RedirectResponse: +async def initiate_oauth_flow(gateway_id: str, request: Request, current_user: EmailUserResponse = Depends(get_current_user_with_permissions), db: Session = Depends(get_db)) -> RedirectResponse: """Initiates the OAuth 2.0 Authorization Code flow for a specified gateway. This endpoint retrieves the OAuth configuration for the given gateway, validates that @@ -75,9 +75,9 @@ async def initiate_oauth_flow(gateway_id: str, request: Request, current_user: E # Initiate OAuth flow with user context oauth_manager = OAuthManager(token_storage=TokenStorageService(db)) - auth_data = await oauth_manager.initiate_authorization_code_flow(gateway_id, gateway.oauth_config, app_user_email=current_user.email) + auth_data = await oauth_manager.initiate_authorization_code_flow(gateway_id, gateway.oauth_config, app_user_email=current_user.get("email")) - logger.info(f"Initiated OAuth flow for gateway {gateway_id} by user {current_user.email}") + logger.info(f"Initiated OAuth flow for gateway {gateway_id} by user {current_user.get('email')}") # Redirect user to OAuth provider return RedirectResponse(url=auth_data["authorization_url"]) @@ -132,8 +132,22 @@ async def oauth_callback( import json try: - state_decoded = base64.urlsafe_b64decode(state.encode()).decode() - state_data = json.loads(state_decoded) + # Expect state as base64url(payload || signature) where the last 32 bytes + # are the signature. Decode to bytes first so we can split payload vs sig. + state_raw = base64.urlsafe_b64decode(state.encode()) + if len(state_raw) <= 32: + raise ValueError("State too short to contain payload and signature") + + # Split payload and signature. Signature is the last 32 bytes. + payload_bytes = state_raw[:-32] + # signature_bytes = state_raw[-32:] + + # Parse the JSON payload only (not including signature bytes) + try: + state_data = json.loads(payload_bytes.decode()) + except Exception as decode_exc: + raise ValueError(f"Failed to parse state payload JSON: {decode_exc}") + gateway_id = state_data.get("gateway_id") if not gateway_id: raise ValueError("No gateway_id in state") @@ -403,7 +417,7 @@ async def get_oauth_status(gateway_id: str, db: Session = Depends(get_db)) -> di @oauth_router.post("/fetch-tools/{gateway_id}") -async def fetch_tools_after_oauth(gateway_id: str, current_user: EmailUserResponse = Depends(get_current_user), db: Session = Depends(get_db)) -> Dict[str, Any]: +async def fetch_tools_after_oauth(gateway_id: str, current_user: EmailUserResponse = Depends(get_current_user_with_permissions), db: Session = Depends(get_db)) -> Dict[str, Any]: """Fetch tools from MCP server after OAuth completion for Authorization Code flow. Args: @@ -422,7 +436,7 @@ async def fetch_tools_after_oauth(gateway_id: str, current_user: EmailUserRespon from mcpgateway.services.gateway_service import GatewayService gateway_service = GatewayService() - result = await gateway_service.fetch_tools_after_oauth(db, gateway_id, current_user.email) + result = await gateway_service.fetch_tools_after_oauth(db, gateway_id, current_user.get("email")) tools_count = len(result.get("tools", [])) return {"success": True, "message": f"Successfully fetched and created {tools_count} tools"} diff --git a/mcpgateway/services/oauth_manager.py b/mcpgateway/services/oauth_manager.py index ec1597c68..d94ce0659 100644 --- a/mcpgateway/services/oauth_manager.py +++ b/mcpgateway/services/oauth_manager.py @@ -604,8 +604,20 @@ async def _validate_authorization_state(self, gateway_id: str, state: str) -> bo state_data = json.loads(state_json) + # Parse expires_at as timezone-aware datetime. If the stored value + # is naive, assume UTC for compatibility. + try: + expires_at = datetime.fromisoformat(state_data["expires_at"]) + except Exception: + # Fallback: try parsing without microseconds/offsets + expires_at = datetime.strptime(state_data["expires_at"], "%Y-%m-%dT%H:%M:%S") + + if expires_at.tzinfo is None: + # Assume UTC for naive timestamps + expires_at = expires_at.replace(tzinfo=timezone.utc) + # Check if state has expired - if datetime.fromisoformat(state_data["expires_at"]) < datetime.now(timezone.utc): + if expires_at < datetime.now(timezone.utc): logger.warning(f"State has expired for gateway {gateway_id}") return False @@ -636,7 +648,12 @@ async def _validate_authorization_state(self, gateway_id: str, state: str) -> bo return False # Check if state has expired - if oauth_state.expires_at < datetime.now(timezone.utc): + # Ensure oauth_state.expires_at is timezone-aware. If naive, assume UTC. + expires_at = oauth_state.expires_at + if expires_at.tzinfo is None: + expires_at = expires_at.replace(tzinfo=timezone.utc) + + if expires_at < datetime.now(timezone.utc): logger.warning(f"State has expired for gateway {gateway_id}") db.delete(oauth_state) db.commit() @@ -667,8 +684,12 @@ async def _validate_authorization_state(self, gateway_id: str, state: str) -> bo logger.warning(f"State not found in memory for gateway {gateway_id}") return False - # Check if state has expired - if datetime.fromisoformat(state_data["expires_at"]) < datetime.now(timezone.utc): + # Parse and normalize expires_at to timezone-aware datetime + expires_at = datetime.fromisoformat(state_data["expires_at"]) + if expires_at.tzinfo is None: + expires_at = expires_at.replace(tzinfo=timezone.utc) + + if expires_at < datetime.now(timezone.utc): logger.warning(f"State has expired for gateway {gateway_id}") del _oauth_states[state_key] # Clean up expired state return False diff --git a/tests/unit/mcpgateway/routers/test_oauth_router.py b/tests/unit/mcpgateway/routers/test_oauth_router.py index b3c4178f0..50ead388b 100644 --- a/tests/unit/mcpgateway/routers/test_oauth_router.py +++ b/tests/unit/mcpgateway/routers/test_oauth_router.py @@ -66,6 +66,7 @@ def mock_gateway(self): def mock_current_user(self): """Create mock current user.""" user = Mock(spec=EmailUserResponse) + user.get = Mock(return_value="test@example.com") user.email = "test@example.com" user.full_name = "Test User" user.is_active = True @@ -106,7 +107,7 @@ async def test_initiate_oauth_flow_success(self, mock_db, mock_request, mock_gat mock_oauth_manager_class.assert_called_once_with(token_storage=mock_token_storage) mock_oauth_manager.initiate_authorization_code_flow.assert_called_once_with( - "gateway123", mock_gateway.oauth_config, app_user_email="test@example.com" + "gateway123", mock_gateway.oauth_config, app_user_email=mock_current_user.get("email") ) @pytest.mark.asyncio @@ -194,9 +195,11 @@ async def test_oauth_callback_success(self, mock_db, mock_request, mock_gateway) import base64 import json - # Setup state with new format + # Setup state with new format (payload + 32-byte signature) state_data = {"gateway_id": "gateway123", "app_user_email": "test@example.com", "nonce": "abc123"} - state = base64.urlsafe_b64encode(json.dumps(state_data).encode()).decode() + payload = json.dumps(state_data).encode() + signature = b'x' * 32 # Mock 32-byte signature + state = base64.urlsafe_b64encode(payload + signature).decode() mock_db.execute.return_value.scalar_one_or_none.return_value = mock_gateway @@ -266,6 +269,27 @@ async def test_oauth_callback_invalid_state(self, mock_db, mock_request): assert result.status_code == 400 assert "Invalid state parameter" in result.body.decode() + @pytest.mark.asyncio + async def test_oauth_callback_state_too_short(self, mock_db, mock_request): + """Test OAuth callback with state that's too short to contain signature.""" + # Standard + import base64 + + # Setup - create state with less than 32 bytes total + short_payload = b"short" + state = base64.urlsafe_b64encode(short_payload).decode() + + # First-Party + from mcpgateway.routers.oauth_router import oauth_callback + + # Execute + result = await oauth_callback(code="auth_code_123", state=state, request=mock_request, db=mock_db) + + # Assert + assert isinstance(result, HTMLResponse) + assert result.status_code == 400 + assert "Invalid state parameter" in result.body.decode() + @pytest.mark.asyncio async def test_oauth_callback_gateway_not_found(self, mock_db, mock_request): """Test OAuth callback when gateway is not found.""" @@ -275,7 +299,9 @@ async def test_oauth_callback_gateway_not_found(self, mock_db, mock_request): # Setup state_data = {"gateway_id": "nonexistent", "app_user_email": "test@example.com"} - state = base64.urlsafe_b64encode(json.dumps(state_data).encode()).decode() + payload = json.dumps(state_data).encode() + signature = b'x' * 32 # Mock 32-byte signature + state = base64.urlsafe_b64encode(payload + signature).decode() mock_db.execute.return_value.scalar_one_or_none.return_value = None @@ -299,7 +325,9 @@ async def test_oauth_callback_no_oauth_config(self, mock_db, mock_request): # Setup state_data = {"gateway_id": "gateway123", "app_user_email": "test@example.com"} - state = base64.urlsafe_b64encode(json.dumps(state_data).encode()).decode() + payload = json.dumps(state_data).encode() + signature = b'x' * 32 # Mock 32-byte signature + state = base64.urlsafe_b64encode(payload + signature).decode() mock_gateway = Mock(spec=Gateway) mock_gateway.id = "gateway123" @@ -326,7 +354,9 @@ async def test_oauth_callback_oauth_error(self, mock_db, mock_request, mock_gate # Setup state_data = {"gateway_id": "gateway123", "app_user_email": "test@example.com"} - state = base64.urlsafe_b64encode(json.dumps(state_data).encode()).decode() + payload = json.dumps(state_data).encode() + signature = b'x' * 32 # Mock 32-byte signature + state = base64.urlsafe_b64encode(payload + signature).decode() mock_db.execute.return_value.scalar_one_or_none.return_value = mock_gateway @@ -412,7 +442,7 @@ async def test_fetch_tools_after_oauth_success(self, mock_db, mock_current_user) # Assert assert result["success"] is True assert "Successfully fetched and created 3 tools" in result["message"] - mock_gateway_service.fetch_tools_after_oauth.assert_called_once_with(mock_db, "gateway123", "test@example.com") + mock_gateway_service.fetch_tools_after_oauth.assert_called_once_with(mock_db, "gateway123", mock_current_user.get("email")) @pytest.mark.asyncio async def test_fetch_tools_after_oauth_no_tools(self, mock_db, mock_current_user):