diff --git a/mcpgateway/routers/oauth_router.py b/mcpgateway/routers/oauth_router.py index 988bef373..9ad4e4270 100644 --- a/mcpgateway/routers/oauth_router.py +++ b/mcpgateway/routers/oauth_router.py @@ -125,21 +125,17 @@ async def oauth_callback( # Get root path for URL construction root_path = request.scope.get("root_path", "") if request else "" - # Extract gateway_id from state parameter - # Try new base64-encoded JSON format first - # Standard - import base64 - import json + # Initialize OAuth manager early so we can decode the state payload consistently + token_storage = TokenStorageService(db) + oauth_manager = OAuthManager(token_storage=token_storage) try: - state_decoded = base64.urlsafe_b64decode(state.encode()).decode() - state_data = json.loads(state_decoded) + state_data = oauth_manager._decode_state_payload(state) # pylint: disable=protected-access gateway_id = state_data.get("gateway_id") if not gateway_id: - raise ValueError("No gateway_id in state") - except Exception as e: - # Fallback to legacy format (gateway_id_random) - logger.warning(f"Failed to decode state as JSON, trying legacy format: {e}") + raise ValueError("Gateway ID missing from state") + except (OAuthError, ValueError) as exc: + logger.warning(f"Failed to decode OAuth state payload: {exc}") if "_" not in state: return HTMLResponse(content="

❌ Invalid state parameter

", status_code=400) gateway_id = state.split("_")[0] @@ -179,9 +175,6 @@ async def oauth_callback( status_code=400, ) - # Complete OAuth flow - oauth_manager = OAuthManager(token_storage=TokenStorageService(db)) - result = await oauth_manager.complete_authorization_code_flow(gateway_id, code, state, gateway.oauth_config) logger.info(f"Completed OAuth flow for gateway {gateway_id}, user {result.get('user_id')}") diff --git a/mcpgateway/services/oauth_manager.py b/mcpgateway/services/oauth_manager.py index ec1597c68..18b108282 100644 --- a/mcpgateway/services/oauth_manager.py +++ b/mcpgateway/services/oauth_manager.py @@ -21,6 +21,7 @@ import logging import secrets from typing import Any, Dict, Optional +from urllib.parse import parse_qsl # Third-Party import aiohttp @@ -388,15 +389,16 @@ async def initiate_authorization_code_flow(self, gateway_id: str, credentials: D Dict containing authorization_url and state """ - # Generate state parameter with user context for CSRF protection - state = self._generate_state(gateway_id, app_user_email) + # Generate state parameter with user context for CSRF protection and PKCE verifier + state, code_verifier = self._generate_state(gateway_id, app_user_email) + code_challenge = self._compute_code_challenge(code_verifier) # Store state in session/cache for validation if self.token_storage: await self._store_authorization_state(gateway_id, state) # Generate authorization URL - auth_url, _ = self._create_authorization_url(credentials, state) + auth_url, _ = self._create_authorization_url(credentials, state, code_challenge) logger.info(f"Generated authorization URL for gateway {gateway_id}") @@ -421,38 +423,22 @@ async def complete_authorization_code_flow(self, gateway_id: str, code: str, sta if not await self._validate_authorization_state(gateway_id, state): raise OAuthError("Invalid or expired state parameter - possible replay attack") - # Decode state to extract user context and verify HMAC + code_verifier = None try: - # Decode base64 - state_with_sig = base64.urlsafe_b64decode(state.encode()) - - # Split state and signature (HMAC-SHA256 is 32 bytes) - state_bytes = state_with_sig[:-32] - received_signature = state_with_sig[-32:] - - # Verify HMAC signature - secret_key = self.settings.auth_encryption_secret.encode() if self.settings.auth_encryption_secret else b"default-secret-key" - expected_signature = hmac.new(secret_key, state_bytes, hashlib.sha256).digest() - - if not hmac.compare_digest(received_signature, expected_signature): - raise OAuthError("Invalid state signature - possible CSRF attack") - - # Parse state data - state_json = state_bytes.decode() - state_data = json.loads(state_json) + state_data = self._decode_state_payload(state) app_user_email = state_data.get("app_user_email") + code_verifier = state_data.get("code_verifier") state_gateway_id = state_data.get("gateway_id") - # Validate gateway ID matches - if state_gateway_id != gateway_id: + if state_gateway_id and state_gateway_id != gateway_id: raise OAuthError("State parameter gateway mismatch") - except Exception as e: - # Fallback for legacy state format (gateway_id_random) - logger.warning(f"Failed to decode state JSON, trying legacy format: {e}") + except OAuthError as exc: + logger.warning(f"Failed to decode signed state payload, trying legacy format: {exc}") app_user_email = None + state_gateway_id = None # Exchange code for tokens - token_response = await self._exchange_code_for_tokens(credentials, code) + token_response = await self._exchange_code_for_tokens(credentials, code, state=state, code_verifier=code_verifier) # Extract user information from token response user_id = self._extract_user_id(token_response, credentials) @@ -489,18 +475,28 @@ async def get_access_token_for_user(self, gateway_id: str, app_user_email: str) return await self.token_storage.get_user_token(gateway_id, app_user_email) return None - def _generate_state(self, gateway_id: str, app_user_email: str = None) -> str: + def _generate_state(self, gateway_id: str, app_user_email: str = None, code_verifier: Optional[str] = None) -> tuple[str, str]: """Generate a unique state parameter with user context for CSRF protection. Args: gateway_id: ID of the gateway app_user_email: MCP Gateway user email (optional but recommended) + code_verifier: Optional PKCE code verifier to embed in state Returns: - Unique state string with embedded user context and HMAC signature + Tuple containing the encoded state string and the PKCE code verifier used """ + if not code_verifier: + code_verifier = secrets.token_urlsafe(64) + # Include user email in state for secure user association - state_data = {"gateway_id": gateway_id, "app_user_email": app_user_email, "nonce": secrets.token_urlsafe(16), "timestamp": datetime.now(timezone.utc).isoformat()} + state_data = { + "gateway_id": gateway_id, + "app_user_email": app_user_email, + "nonce": secrets.token_urlsafe(16), + "timestamp": datetime.now(timezone.utc).isoformat(), + "code_verifier": code_verifier, + } # Encode state as JSON state_json = json.dumps(state_data, separators=(",", ":")) @@ -514,7 +510,39 @@ def _generate_state(self, gateway_id: str, app_user_email: str = None) -> str: state_with_sig = state_bytes + signature state_encoded = base64.urlsafe_b64encode(state_with_sig).decode() - return state_encoded + return state_encoded, code_verifier + + @staticmethod + def _compute_code_challenge(code_verifier: str) -> str: + """Compute a PKCE S256 code challenge from a verifier.""" + + digest = hashlib.sha256(code_verifier.encode()).digest() + return base64.urlsafe_b64encode(digest).rstrip(b"=").decode() + + def _decode_state_payload(self, state: str) -> Dict[str, Any]: + """Decode and verify a state payload generated by this manager.""" + + try: + state_with_sig = base64.urlsafe_b64decode(state.encode()) + except Exception as exc: # pylint: disable=broad-except + raise OAuthError("Invalid state parameter encoding") from exc + + if len(state_with_sig) <= 32: + raise OAuthError("State parameter is malformed") + + state_bytes = state_with_sig[:-32] + received_signature = state_with_sig[-32:] + + secret_key = self.settings.auth_encryption_secret.encode() if self.settings.auth_encryption_secret else b"default-secret-key" + expected_signature = hmac.new(secret_key, state_bytes, hashlib.sha256).digest() + + if not hmac.compare_digest(received_signature, expected_signature): + raise OAuthError("Invalid state signature - possible CSRF attack") + + try: + return json.loads(state_bytes.decode()) + except Exception as exc: # pylint: disable=broad-except + raise OAuthError("Invalid state payload") from exc async def _store_authorization_state(self, gateway_id: str, state: str) -> None: """Store authorization state for validation with TTL. @@ -683,12 +711,13 @@ async def _validate_authorization_state(self, gateway_id: str, state: str) -> bo logger.debug(f"Successfully validated OAuth state from memory for gateway {gateway_id}") return True - def _create_authorization_url(self, credentials: Dict[str, Any], state: str) -> tuple[str, str]: + def _create_authorization_url(self, credentials: Dict[str, Any], state: str, code_challenge: Optional[str] = None) -> tuple[str, str]: """Create authorization URL with state parameter. Args: credentials: OAuth configuration state: State parameter for CSRF protection + code_challenge: Optional PKCE code challenge Returns: Tuple of (authorization_url, state) @@ -701,17 +730,24 @@ def _create_authorization_url(self, credentials: Dict[str, Any], state: str) -> # Create OAuth2 session oauth = OAuth2Session(client_id, redirect_uri=redirect_uri, scope=scopes) + extra_params: Dict[str, Any] = {} + if code_challenge: + extra_params["code_challenge"] = code_challenge + extra_params["code_challenge_method"] = "S256" + # Generate authorization URL with state for CSRF protection - auth_url, state = oauth.authorization_url(authorization_url, state=state) + auth_url, state = oauth.authorization_url(authorization_url, state=state, **extra_params) return auth_url, state - async def _exchange_code_for_tokens(self, credentials: Dict[str, Any], code: str) -> Dict[str, Any]: + async def _exchange_code_for_tokens(self, credentials: Dict[str, Any], code: str, state: Optional[str] = None, code_verifier: Optional[str] = None) -> Dict[str, Any]: """Exchange authorization code for tokens. Args: credentials: OAuth configuration code: Authorization code from callback + state: Optional state parameter echoed back by provider + code_verifier: Optional PKCE code verifier for S256 exchange Returns: Token response dictionary @@ -747,11 +783,22 @@ async def _exchange_code_for_tokens(self, credentials: Dict[str, Any], code: str "client_secret": client_secret, } + if state: + token_data["state"] = state + + if code_verifier: + token_data["code_verifier"] = code_verifier + # Exchange code for token with retries for attempt in range(self.max_retries): try: async with aiohttp.ClientSession() as session: - async with session.post(token_url, data=token_data, timeout=aiohttp.ClientTimeout(total=self.request_timeout)) as response: + async with session.post( + token_url, + data=token_data, + timeout=aiohttp.ClientTimeout(total=self.request_timeout), + headers={"Accept": "application/json"}, + ) as response: response.raise_for_status() # GitHub returns form-encoded responses, not JSON @@ -759,11 +806,7 @@ async def _exchange_code_for_tokens(self, credentials: Dict[str, Any], code: str if "application/x-www-form-urlencoded" in content_type: # Parse form-encoded response text_response = await response.text() - token_response = {} - for pair in text_response.split("&"): - if "=" in pair: - key, value = pair.split("=", 1) - token_response[key] = value + token_response = {key: value for key, value in parse_qsl(text_response, keep_blank_values=True)} else: # Try JSON response try: diff --git a/tests/unit/mcpgateway/test_oauth_manager.py b/tests/unit/mcpgateway/test_oauth_manager.py index ca41a7e05..0e848dab9 100644 --- a/tests/unit/mcpgateway/test_oauth_manager.py +++ b/tests/unit/mcpgateway/test_oauth_manager.py @@ -560,23 +560,27 @@ async def test_initiate_authorization_code_flow_success(self): } with patch.object(manager, '_generate_state') as mock_generate_state: - mock_generate_state.return_value = "state123" + mock_generate_state.return_value = ("state123", "verifier123") - with patch.object(manager, '_store_authorization_state') as mock_store_state: - with patch.object(manager, '_create_authorization_url') as mock_create_url: - mock_create_url.return_value = ("https://oauth.example.com/authorize?state=state123", "state123") + with patch.object(manager, '_compute_code_challenge') as mock_code_challenge: + mock_code_challenge.return_value = "challenge123" - result = await manager.initiate_authorization_code_flow(gateway_id, credentials, app_user_email="test@example.com") + with patch.object(manager, '_store_authorization_state') as mock_store_state: + with patch.object(manager, '_create_authorization_url') as mock_create_url: + mock_create_url.return_value = ("https://oauth.example.com/authorize?state=state123", "state123") - expected = { - "authorization_url": "https://oauth.example.com/authorize?state=state123", - "state": "state123", - "gateway_id": "gateway123" - } - assert result == expected - mock_generate_state.assert_called_once_with(gateway_id, "test@example.com") - mock_store_state.assert_called_once_with(gateway_id, "state123") - mock_create_url.assert_called_once_with(credentials, "state123") + result = await manager.initiate_authorization_code_flow(gateway_id, credentials, app_user_email="test@example.com") + + expected = { + "authorization_url": "https://oauth.example.com/authorize?state=state123", + "state": "state123", + "gateway_id": "gateway123" + } + assert result == expected + mock_generate_state.assert_called_once_with(gateway_id, "test@example.com") + mock_code_challenge.assert_called_once_with("verifier123") + mock_store_state.assert_called_once_with(gateway_id, "state123") + mock_create_url.assert_called_once_with(credentials, "state123", "challenge123") @pytest.mark.asyncio async def test_complete_authorization_code_flow_success(self): @@ -726,7 +730,7 @@ def test_generate_state_format(self): manager = OAuthManager() - state = manager._generate_state("gateway123", "test@example.com") + state, code_verifier = manager._generate_state("gateway123", "test@example.com") # State is now base64 encoded JSON with HMAC signature state_with_sig = base64.urlsafe_b64decode(state.encode()) @@ -747,10 +751,12 @@ def test_generate_state_format(self): assert decoded["app_user_email"] == "test@example.com" assert "nonce" in decoded assert "timestamp" in decoded + assert decoded["code_verifier"] == code_verifier # Should generate different states each time (different nonce) - state2 = manager._generate_state("gateway123", "test@example.com") + state2, verifier2 = manager._generate_state("gateway123", "test@example.com") assert state != state2 + assert code_verifier != verifier2 @pytest.mark.asyncio async def test_store_authorization_state(self):