Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 7 additions & 14 deletions mcpgateway/routers/oauth_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,21 +125,17 @@ async def oauth_callback(
# Get root path for URL construction
root_path = request.scope.get("root_path", "") if request else ""

# Extract gateway_id from state parameter
# Try new base64-encoded JSON format first
# Standard
import base64
import json
# Initialize OAuth manager early so we can decode the state payload consistently
token_storage = TokenStorageService(db)
oauth_manager = OAuthManager(token_storage=token_storage)

try:
state_decoded = base64.urlsafe_b64decode(state.encode()).decode()
state_data = json.loads(state_decoded)
state_data = oauth_manager._decode_state_payload(state) # pylint: disable=protected-access
gateway_id = state_data.get("gateway_id")
if not gateway_id:
raise ValueError("No gateway_id in state")
except Exception as e:
# Fallback to legacy format (gateway_id_random)
logger.warning(f"Failed to decode state as JSON, trying legacy format: {e}")
raise ValueError("Gateway ID missing from state")
except (OAuthError, ValueError) as exc:
logger.warning(f"Failed to decode OAuth state payload: {exc}")
if "_" not in state:
return HTMLResponse(content="<h1>❌ Invalid state parameter</h1>", status_code=400)
gateway_id = state.split("_")[0]
Expand Down Expand Up @@ -179,9 +175,6 @@ async def oauth_callback(
status_code=400,
)

# Complete OAuth flow
oauth_manager = OAuthManager(token_storage=TokenStorageService(db))

result = await oauth_manager.complete_authorization_code_flow(gateway_id, code, state, gateway.oauth_config)

logger.info(f"Completed OAuth flow for gateway {gateway_id}, user {result.get('user_id')}")
Expand Down
123 changes: 83 additions & 40 deletions mcpgateway/services/oauth_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import logging
import secrets
from typing import Any, Dict, Optional
from urllib.parse import parse_qsl

# Third-Party
import aiohttp
Expand Down Expand Up @@ -388,15 +389,16 @@ async def initiate_authorization_code_flow(self, gateway_id: str, credentials: D
Dict containing authorization_url and state
"""

# Generate state parameter with user context for CSRF protection
state = self._generate_state(gateway_id, app_user_email)
# Generate state parameter with user context for CSRF protection and PKCE verifier
state, code_verifier = self._generate_state(gateway_id, app_user_email)
code_challenge = self._compute_code_challenge(code_verifier)

# Store state in session/cache for validation
if self.token_storage:
await self._store_authorization_state(gateway_id, state)

# Generate authorization URL
auth_url, _ = self._create_authorization_url(credentials, state)
auth_url, _ = self._create_authorization_url(credentials, state, code_challenge)

logger.info(f"Generated authorization URL for gateway {gateway_id}")

Expand All @@ -421,38 +423,22 @@ async def complete_authorization_code_flow(self, gateway_id: str, code: str, sta
if not await self._validate_authorization_state(gateway_id, state):
raise OAuthError("Invalid or expired state parameter - possible replay attack")

# Decode state to extract user context and verify HMAC
code_verifier = None
try:
# Decode base64
state_with_sig = base64.urlsafe_b64decode(state.encode())

# Split state and signature (HMAC-SHA256 is 32 bytes)
state_bytes = state_with_sig[:-32]
received_signature = state_with_sig[-32:]

# Verify HMAC signature
secret_key = self.settings.auth_encryption_secret.encode() if self.settings.auth_encryption_secret else b"default-secret-key"
expected_signature = hmac.new(secret_key, state_bytes, hashlib.sha256).digest()

if not hmac.compare_digest(received_signature, expected_signature):
raise OAuthError("Invalid state signature - possible CSRF attack")

# Parse state data
state_json = state_bytes.decode()
state_data = json.loads(state_json)
state_data = self._decode_state_payload(state)
app_user_email = state_data.get("app_user_email")
code_verifier = state_data.get("code_verifier")
state_gateway_id = state_data.get("gateway_id")

# Validate gateway ID matches
if state_gateway_id != gateway_id:
if state_gateway_id and state_gateway_id != gateway_id:
raise OAuthError("State parameter gateway mismatch")
except Exception as e:
# Fallback for legacy state format (gateway_id_random)
logger.warning(f"Failed to decode state JSON, trying legacy format: {e}")
except OAuthError as exc:
logger.warning(f"Failed to decode signed state payload, trying legacy format: {exc}")
app_user_email = None
state_gateway_id = None

# Exchange code for tokens
token_response = await self._exchange_code_for_tokens(credentials, code)
token_response = await self._exchange_code_for_tokens(credentials, code, state=state, code_verifier=code_verifier)

# Extract user information from token response
user_id = self._extract_user_id(token_response, credentials)
Expand Down Expand Up @@ -489,18 +475,28 @@ async def get_access_token_for_user(self, gateway_id: str, app_user_email: str)
return await self.token_storage.get_user_token(gateway_id, app_user_email)
return None

def _generate_state(self, gateway_id: str, app_user_email: str = None) -> str:
def _generate_state(self, gateway_id: str, app_user_email: str = None, code_verifier: Optional[str] = None) -> tuple[str, str]:
"""Generate a unique state parameter with user context for CSRF protection.

Args:
gateway_id: ID of the gateway
app_user_email: MCP Gateway user email (optional but recommended)
code_verifier: Optional PKCE code verifier to embed in state

Returns:
Unique state string with embedded user context and HMAC signature
Tuple containing the encoded state string and the PKCE code verifier used
"""
if not code_verifier:
code_verifier = secrets.token_urlsafe(64)

# Include user email in state for secure user association
state_data = {"gateway_id": gateway_id, "app_user_email": app_user_email, "nonce": secrets.token_urlsafe(16), "timestamp": datetime.now(timezone.utc).isoformat()}
state_data = {
"gateway_id": gateway_id,
"app_user_email": app_user_email,
"nonce": secrets.token_urlsafe(16),
"timestamp": datetime.now(timezone.utc).isoformat(),
"code_verifier": code_verifier,
}

# Encode state as JSON
state_json = json.dumps(state_data, separators=(",", ":"))
Expand All @@ -514,7 +510,39 @@ def _generate_state(self, gateway_id: str, app_user_email: str = None) -> str:
state_with_sig = state_bytes + signature
state_encoded = base64.urlsafe_b64encode(state_with_sig).decode()

return state_encoded
return state_encoded, code_verifier

@staticmethod
def _compute_code_challenge(code_verifier: str) -> str:
"""Compute a PKCE S256 code challenge from a verifier."""

digest = hashlib.sha256(code_verifier.encode()).digest()
return base64.urlsafe_b64encode(digest).rstrip(b"=").decode()

def _decode_state_payload(self, state: str) -> Dict[str, Any]:
"""Decode and verify a state payload generated by this manager."""

try:
state_with_sig = base64.urlsafe_b64decode(state.encode())
except Exception as exc: # pylint: disable=broad-except
raise OAuthError("Invalid state parameter encoding") from exc

if len(state_with_sig) <= 32:
raise OAuthError("State parameter is malformed")

state_bytes = state_with_sig[:-32]
received_signature = state_with_sig[-32:]

secret_key = self.settings.auth_encryption_secret.encode() if self.settings.auth_encryption_secret else b"default-secret-key"
expected_signature = hmac.new(secret_key, state_bytes, hashlib.sha256).digest()

if not hmac.compare_digest(received_signature, expected_signature):
raise OAuthError("Invalid state signature - possible CSRF attack")

try:
return json.loads(state_bytes.decode())
except Exception as exc: # pylint: disable=broad-except
raise OAuthError("Invalid state payload") from exc

async def _store_authorization_state(self, gateway_id: str, state: str) -> None:
"""Store authorization state for validation with TTL.
Expand Down Expand Up @@ -683,12 +711,13 @@ async def _validate_authorization_state(self, gateway_id: str, state: str) -> bo
logger.debug(f"Successfully validated OAuth state from memory for gateway {gateway_id}")
return True

def _create_authorization_url(self, credentials: Dict[str, Any], state: str) -> tuple[str, str]:
def _create_authorization_url(self, credentials: Dict[str, Any], state: str, code_challenge: Optional[str] = None) -> tuple[str, str]:
"""Create authorization URL with state parameter.

Args:
credentials: OAuth configuration
state: State parameter for CSRF protection
code_challenge: Optional PKCE code challenge

Returns:
Tuple of (authorization_url, state)
Expand All @@ -701,17 +730,24 @@ def _create_authorization_url(self, credentials: Dict[str, Any], state: str) ->
# Create OAuth2 session
oauth = OAuth2Session(client_id, redirect_uri=redirect_uri, scope=scopes)

extra_params: Dict[str, Any] = {}
if code_challenge:
extra_params["code_challenge"] = code_challenge
extra_params["code_challenge_method"] = "S256"

# Generate authorization URL with state for CSRF protection
auth_url, state = oauth.authorization_url(authorization_url, state=state)
auth_url, state = oauth.authorization_url(authorization_url, state=state, **extra_params)

return auth_url, state

async def _exchange_code_for_tokens(self, credentials: Dict[str, Any], code: str) -> Dict[str, Any]:
async def _exchange_code_for_tokens(self, credentials: Dict[str, Any], code: str, state: Optional[str] = None, code_verifier: Optional[str] = None) -> Dict[str, Any]:
"""Exchange authorization code for tokens.

Args:
credentials: OAuth configuration
code: Authorization code from callback
state: Optional state parameter echoed back by provider
code_verifier: Optional PKCE code verifier for S256 exchange

Returns:
Token response dictionary
Expand Down Expand Up @@ -747,23 +783,30 @@ async def _exchange_code_for_tokens(self, credentials: Dict[str, Any], code: str
"client_secret": client_secret,
}

if state:
token_data["state"] = state

if code_verifier:
token_data["code_verifier"] = code_verifier

# Exchange code for token with retries
for attempt in range(self.max_retries):
try:
async with aiohttp.ClientSession() as session:
async with session.post(token_url, data=token_data, timeout=aiohttp.ClientTimeout(total=self.request_timeout)) as response:
async with session.post(
token_url,
data=token_data,
timeout=aiohttp.ClientTimeout(total=self.request_timeout),
headers={"Accept": "application/json"},
) as response:
response.raise_for_status()

# GitHub returns form-encoded responses, not JSON
content_type = response.headers.get("content-type", "")
if "application/x-www-form-urlencoded" in content_type:
# Parse form-encoded response
text_response = await response.text()
token_response = {}
for pair in text_response.split("&"):
if "=" in pair:
key, value = pair.split("=", 1)
token_response[key] = value
token_response = {key: value for key, value in parse_qsl(text_response, keep_blank_values=True)}
else:
# Try JSON response
try:
Expand Down
38 changes: 22 additions & 16 deletions tests/unit/mcpgateway/test_oauth_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,23 +560,27 @@ async def test_initiate_authorization_code_flow_success(self):
}

with patch.object(manager, '_generate_state') as mock_generate_state:
mock_generate_state.return_value = "state123"
mock_generate_state.return_value = ("state123", "verifier123")

with patch.object(manager, '_store_authorization_state') as mock_store_state:
with patch.object(manager, '_create_authorization_url') as mock_create_url:
mock_create_url.return_value = ("https://oauth.example.com/authorize?state=state123", "state123")
with patch.object(manager, '_compute_code_challenge') as mock_code_challenge:
mock_code_challenge.return_value = "challenge123"

result = await manager.initiate_authorization_code_flow(gateway_id, credentials, app_user_email="[email protected]")
with patch.object(manager, '_store_authorization_state') as mock_store_state:
with patch.object(manager, '_create_authorization_url') as mock_create_url:
mock_create_url.return_value = ("https://oauth.example.com/authorize?state=state123", "state123")

expected = {
"authorization_url": "https://oauth.example.com/authorize?state=state123",
"state": "state123",
"gateway_id": "gateway123"
}
assert result == expected
mock_generate_state.assert_called_once_with(gateway_id, "[email protected]")
mock_store_state.assert_called_once_with(gateway_id, "state123")
mock_create_url.assert_called_once_with(credentials, "state123")
result = await manager.initiate_authorization_code_flow(gateway_id, credentials, app_user_email="[email protected]")

expected = {
"authorization_url": "https://oauth.example.com/authorize?state=state123",
"state": "state123",
"gateway_id": "gateway123"
}
assert result == expected
mock_generate_state.assert_called_once_with(gateway_id, "[email protected]")
mock_code_challenge.assert_called_once_with("verifier123")
mock_store_state.assert_called_once_with(gateway_id, "state123")
mock_create_url.assert_called_once_with(credentials, "state123", "challenge123")

@pytest.mark.asyncio
async def test_complete_authorization_code_flow_success(self):
Expand Down Expand Up @@ -726,7 +730,7 @@ def test_generate_state_format(self):

manager = OAuthManager()

state = manager._generate_state("gateway123", "[email protected]")
state, code_verifier = manager._generate_state("gateway123", "[email protected]")

# State is now base64 encoded JSON with HMAC signature
state_with_sig = base64.urlsafe_b64decode(state.encode())
Expand All @@ -747,10 +751,12 @@ def test_generate_state_format(self):
assert decoded["app_user_email"] == "[email protected]"
assert "nonce" in decoded
assert "timestamp" in decoded
assert decoded["code_verifier"] == code_verifier

# Should generate different states each time (different nonce)
state2 = manager._generate_state("gateway123", "[email protected]")
state2, verifier2 = manager._generate_state("gateway123", "[email protected]")
assert state != state2
assert code_verifier != verifier2

@pytest.mark.asyncio
async def test_store_authorization_state(self):
Expand Down
Loading