diff --git a/mcpgateway/routers/oauth_router.py b/mcpgateway/routers/oauth_router.py
index 988bef373..9ad4e4270 100644
--- a/mcpgateway/routers/oauth_router.py
+++ b/mcpgateway/routers/oauth_router.py
@@ -125,21 +125,17 @@ async def oauth_callback(
# Get root path for URL construction
root_path = request.scope.get("root_path", "") if request else ""
- # Extract gateway_id from state parameter
- # Try new base64-encoded JSON format first
- # Standard
- import base64
- import json
+ # Initialize OAuth manager early so we can decode the state payload consistently
+ token_storage = TokenStorageService(db)
+ oauth_manager = OAuthManager(token_storage=token_storage)
try:
- state_decoded = base64.urlsafe_b64decode(state.encode()).decode()
- state_data = json.loads(state_decoded)
+ state_data = oauth_manager._decode_state_payload(state) # pylint: disable=protected-access
gateway_id = state_data.get("gateway_id")
if not gateway_id:
- raise ValueError("No gateway_id in state")
- except Exception as e:
- # Fallback to legacy format (gateway_id_random)
- logger.warning(f"Failed to decode state as JSON, trying legacy format: {e}")
+ raise ValueError("Gateway ID missing from state")
+ except (OAuthError, ValueError) as exc:
+ logger.warning(f"Failed to decode OAuth state payload: {exc}")
if "_" not in state:
return HTMLResponse(content="
❌ Invalid state parameter
", status_code=400)
gateway_id = state.split("_")[0]
@@ -179,9 +175,6 @@ async def oauth_callback(
status_code=400,
)
- # Complete OAuth flow
- oauth_manager = OAuthManager(token_storage=TokenStorageService(db))
-
result = await oauth_manager.complete_authorization_code_flow(gateway_id, code, state, gateway.oauth_config)
logger.info(f"Completed OAuth flow for gateway {gateway_id}, user {result.get('user_id')}")
diff --git a/mcpgateway/services/oauth_manager.py b/mcpgateway/services/oauth_manager.py
index ec1597c68..18b108282 100644
--- a/mcpgateway/services/oauth_manager.py
+++ b/mcpgateway/services/oauth_manager.py
@@ -21,6 +21,7 @@
import logging
import secrets
from typing import Any, Dict, Optional
+from urllib.parse import parse_qsl
# Third-Party
import aiohttp
@@ -388,15 +389,16 @@ async def initiate_authorization_code_flow(self, gateway_id: str, credentials: D
Dict containing authorization_url and state
"""
- # Generate state parameter with user context for CSRF protection
- state = self._generate_state(gateway_id, app_user_email)
+ # Generate state parameter with user context for CSRF protection and PKCE verifier
+ state, code_verifier = self._generate_state(gateway_id, app_user_email)
+ code_challenge = self._compute_code_challenge(code_verifier)
# Store state in session/cache for validation
if self.token_storage:
await self._store_authorization_state(gateway_id, state)
# Generate authorization URL
- auth_url, _ = self._create_authorization_url(credentials, state)
+ auth_url, _ = self._create_authorization_url(credentials, state, code_challenge)
logger.info(f"Generated authorization URL for gateway {gateway_id}")
@@ -421,38 +423,22 @@ async def complete_authorization_code_flow(self, gateway_id: str, code: str, sta
if not await self._validate_authorization_state(gateway_id, state):
raise OAuthError("Invalid or expired state parameter - possible replay attack")
- # Decode state to extract user context and verify HMAC
+ code_verifier = None
try:
- # Decode base64
- state_with_sig = base64.urlsafe_b64decode(state.encode())
-
- # Split state and signature (HMAC-SHA256 is 32 bytes)
- state_bytes = state_with_sig[:-32]
- received_signature = state_with_sig[-32:]
-
- # Verify HMAC signature
- secret_key = self.settings.auth_encryption_secret.encode() if self.settings.auth_encryption_secret else b"default-secret-key"
- expected_signature = hmac.new(secret_key, state_bytes, hashlib.sha256).digest()
-
- if not hmac.compare_digest(received_signature, expected_signature):
- raise OAuthError("Invalid state signature - possible CSRF attack")
-
- # Parse state data
- state_json = state_bytes.decode()
- state_data = json.loads(state_json)
+ state_data = self._decode_state_payload(state)
app_user_email = state_data.get("app_user_email")
+ code_verifier = state_data.get("code_verifier")
state_gateway_id = state_data.get("gateway_id")
- # Validate gateway ID matches
- if state_gateway_id != gateway_id:
+ if state_gateway_id and state_gateway_id != gateway_id:
raise OAuthError("State parameter gateway mismatch")
- except Exception as e:
- # Fallback for legacy state format (gateway_id_random)
- logger.warning(f"Failed to decode state JSON, trying legacy format: {e}")
+ except OAuthError as exc:
+ logger.warning(f"Failed to decode signed state payload, trying legacy format: {exc}")
app_user_email = None
+ state_gateway_id = None
# Exchange code for tokens
- token_response = await self._exchange_code_for_tokens(credentials, code)
+ token_response = await self._exchange_code_for_tokens(credentials, code, state=state, code_verifier=code_verifier)
# Extract user information from token response
user_id = self._extract_user_id(token_response, credentials)
@@ -489,18 +475,28 @@ async def get_access_token_for_user(self, gateway_id: str, app_user_email: str)
return await self.token_storage.get_user_token(gateway_id, app_user_email)
return None
- def _generate_state(self, gateway_id: str, app_user_email: str = None) -> str:
+ def _generate_state(self, gateway_id: str, app_user_email: str = None, code_verifier: Optional[str] = None) -> tuple[str, str]:
"""Generate a unique state parameter with user context for CSRF protection.
Args:
gateway_id: ID of the gateway
app_user_email: MCP Gateway user email (optional but recommended)
+ code_verifier: Optional PKCE code verifier to embed in state
Returns:
- Unique state string with embedded user context and HMAC signature
+ Tuple containing the encoded state string and the PKCE code verifier used
"""
+ if not code_verifier:
+ code_verifier = secrets.token_urlsafe(64)
+
# Include user email in state for secure user association
- state_data = {"gateway_id": gateway_id, "app_user_email": app_user_email, "nonce": secrets.token_urlsafe(16), "timestamp": datetime.now(timezone.utc).isoformat()}
+ state_data = {
+ "gateway_id": gateway_id,
+ "app_user_email": app_user_email,
+ "nonce": secrets.token_urlsafe(16),
+ "timestamp": datetime.now(timezone.utc).isoformat(),
+ "code_verifier": code_verifier,
+ }
# Encode state as JSON
state_json = json.dumps(state_data, separators=(",", ":"))
@@ -514,7 +510,39 @@ def _generate_state(self, gateway_id: str, app_user_email: str = None) -> str:
state_with_sig = state_bytes + signature
state_encoded = base64.urlsafe_b64encode(state_with_sig).decode()
- return state_encoded
+ return state_encoded, code_verifier
+
+ @staticmethod
+ def _compute_code_challenge(code_verifier: str) -> str:
+ """Compute a PKCE S256 code challenge from a verifier."""
+
+ digest = hashlib.sha256(code_verifier.encode()).digest()
+ return base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
+
+ def _decode_state_payload(self, state: str) -> Dict[str, Any]:
+ """Decode and verify a state payload generated by this manager."""
+
+ try:
+ state_with_sig = base64.urlsafe_b64decode(state.encode())
+ except Exception as exc: # pylint: disable=broad-except
+ raise OAuthError("Invalid state parameter encoding") from exc
+
+ if len(state_with_sig) <= 32:
+ raise OAuthError("State parameter is malformed")
+
+ state_bytes = state_with_sig[:-32]
+ received_signature = state_with_sig[-32:]
+
+ secret_key = self.settings.auth_encryption_secret.encode() if self.settings.auth_encryption_secret else b"default-secret-key"
+ expected_signature = hmac.new(secret_key, state_bytes, hashlib.sha256).digest()
+
+ if not hmac.compare_digest(received_signature, expected_signature):
+ raise OAuthError("Invalid state signature - possible CSRF attack")
+
+ try:
+ return json.loads(state_bytes.decode())
+ except Exception as exc: # pylint: disable=broad-except
+ raise OAuthError("Invalid state payload") from exc
async def _store_authorization_state(self, gateway_id: str, state: str) -> None:
"""Store authorization state for validation with TTL.
@@ -683,12 +711,13 @@ async def _validate_authorization_state(self, gateway_id: str, state: str) -> bo
logger.debug(f"Successfully validated OAuth state from memory for gateway {gateway_id}")
return True
- def _create_authorization_url(self, credentials: Dict[str, Any], state: str) -> tuple[str, str]:
+ def _create_authorization_url(self, credentials: Dict[str, Any], state: str, code_challenge: Optional[str] = None) -> tuple[str, str]:
"""Create authorization URL with state parameter.
Args:
credentials: OAuth configuration
state: State parameter for CSRF protection
+ code_challenge: Optional PKCE code challenge
Returns:
Tuple of (authorization_url, state)
@@ -701,17 +730,24 @@ def _create_authorization_url(self, credentials: Dict[str, Any], state: str) ->
# Create OAuth2 session
oauth = OAuth2Session(client_id, redirect_uri=redirect_uri, scope=scopes)
+ extra_params: Dict[str, Any] = {}
+ if code_challenge:
+ extra_params["code_challenge"] = code_challenge
+ extra_params["code_challenge_method"] = "S256"
+
# Generate authorization URL with state for CSRF protection
- auth_url, state = oauth.authorization_url(authorization_url, state=state)
+ auth_url, state = oauth.authorization_url(authorization_url, state=state, **extra_params)
return auth_url, state
- async def _exchange_code_for_tokens(self, credentials: Dict[str, Any], code: str) -> Dict[str, Any]:
+ async def _exchange_code_for_tokens(self, credentials: Dict[str, Any], code: str, state: Optional[str] = None, code_verifier: Optional[str] = None) -> Dict[str, Any]:
"""Exchange authorization code for tokens.
Args:
credentials: OAuth configuration
code: Authorization code from callback
+ state: Optional state parameter echoed back by provider
+ code_verifier: Optional PKCE code verifier for S256 exchange
Returns:
Token response dictionary
@@ -747,11 +783,22 @@ async def _exchange_code_for_tokens(self, credentials: Dict[str, Any], code: str
"client_secret": client_secret,
}
+ if state:
+ token_data["state"] = state
+
+ if code_verifier:
+ token_data["code_verifier"] = code_verifier
+
# Exchange code for token with retries
for attempt in range(self.max_retries):
try:
async with aiohttp.ClientSession() as session:
- async with session.post(token_url, data=token_data, timeout=aiohttp.ClientTimeout(total=self.request_timeout)) as response:
+ async with session.post(
+ token_url,
+ data=token_data,
+ timeout=aiohttp.ClientTimeout(total=self.request_timeout),
+ headers={"Accept": "application/json"},
+ ) as response:
response.raise_for_status()
# GitHub returns form-encoded responses, not JSON
@@ -759,11 +806,7 @@ async def _exchange_code_for_tokens(self, credentials: Dict[str, Any], code: str
if "application/x-www-form-urlencoded" in content_type:
# Parse form-encoded response
text_response = await response.text()
- token_response = {}
- for pair in text_response.split("&"):
- if "=" in pair:
- key, value = pair.split("=", 1)
- token_response[key] = value
+ token_response = {key: value for key, value in parse_qsl(text_response, keep_blank_values=True)}
else:
# Try JSON response
try:
diff --git a/tests/unit/mcpgateway/test_oauth_manager.py b/tests/unit/mcpgateway/test_oauth_manager.py
index ca41a7e05..0e848dab9 100644
--- a/tests/unit/mcpgateway/test_oauth_manager.py
+++ b/tests/unit/mcpgateway/test_oauth_manager.py
@@ -560,23 +560,27 @@ async def test_initiate_authorization_code_flow_success(self):
}
with patch.object(manager, '_generate_state') as mock_generate_state:
- mock_generate_state.return_value = "state123"
+ mock_generate_state.return_value = ("state123", "verifier123")
- with patch.object(manager, '_store_authorization_state') as mock_store_state:
- with patch.object(manager, '_create_authorization_url') as mock_create_url:
- mock_create_url.return_value = ("https://oauth.example.com/authorize?state=state123", "state123")
+ with patch.object(manager, '_compute_code_challenge') as mock_code_challenge:
+ mock_code_challenge.return_value = "challenge123"
- result = await manager.initiate_authorization_code_flow(gateway_id, credentials, app_user_email="test@example.com")
+ with patch.object(manager, '_store_authorization_state') as mock_store_state:
+ with patch.object(manager, '_create_authorization_url') as mock_create_url:
+ mock_create_url.return_value = ("https://oauth.example.com/authorize?state=state123", "state123")
- expected = {
- "authorization_url": "https://oauth.example.com/authorize?state=state123",
- "state": "state123",
- "gateway_id": "gateway123"
- }
- assert result == expected
- mock_generate_state.assert_called_once_with(gateway_id, "test@example.com")
- mock_store_state.assert_called_once_with(gateway_id, "state123")
- mock_create_url.assert_called_once_with(credentials, "state123")
+ result = await manager.initiate_authorization_code_flow(gateway_id, credentials, app_user_email="test@example.com")
+
+ expected = {
+ "authorization_url": "https://oauth.example.com/authorize?state=state123",
+ "state": "state123",
+ "gateway_id": "gateway123"
+ }
+ assert result == expected
+ mock_generate_state.assert_called_once_with(gateway_id, "test@example.com")
+ mock_code_challenge.assert_called_once_with("verifier123")
+ mock_store_state.assert_called_once_with(gateway_id, "state123")
+ mock_create_url.assert_called_once_with(credentials, "state123", "challenge123")
@pytest.mark.asyncio
async def test_complete_authorization_code_flow_success(self):
@@ -726,7 +730,7 @@ def test_generate_state_format(self):
manager = OAuthManager()
- state = manager._generate_state("gateway123", "test@example.com")
+ state, code_verifier = manager._generate_state("gateway123", "test@example.com")
# State is now base64 encoded JSON with HMAC signature
state_with_sig = base64.urlsafe_b64decode(state.encode())
@@ -747,10 +751,12 @@ def test_generate_state_format(self):
assert decoded["app_user_email"] == "test@example.com"
assert "nonce" in decoded
assert "timestamp" in decoded
+ assert decoded["code_verifier"] == code_verifier
# Should generate different states each time (different nonce)
- state2 = manager._generate_state("gateway123", "test@example.com")
+ state2, verifier2 = manager._generate_state("gateway123", "test@example.com")
assert state != state2
+ assert code_verifier != verifier2
@pytest.mark.asyncio
async def test_store_authorization_state(self):