Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 67 additions & 20 deletions src/mcp/client/auth/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,6 @@ class OAuthContext:
# State
lock: anyio.Lock = field(default_factory=anyio.Lock)

# Discovery state for fallback support
discovery_base_url: str | None = None
discovery_pathname: str | None = None

def get_authorization_base_url(self, server_url: str) -> str:
"""Extract base URL by removing path component."""
parsed = urlparse(server_url)
Expand Down Expand Up @@ -204,6 +200,43 @@ def __init__(
)
self._initialized = False

def _build_protected_resource_discovery_urls(self, init_response: httpx.Response) -> list[str]:
"""
Build ordered list of URLs to try for protected resource metadata discovery.
Per SEP-985, the client MUST:
1. Try resource_metadata from WWW-Authenticate header (if present)
2. Fall back to path-based well-known URI: /.well-known/oauth-protected-resource/{path}
3. Fall back to root-based well-known URI: /.well-known/oauth-protected-resource
Args:
init_response: The initial 401 response from the server
Returns:
Ordered list of URLs to try for discovery
"""
urls: list[str] = []

# Priority 1: WWW-Authenticate header with resource_metadata parameter
www_auth_url = self._extract_resource_metadata_from_www_auth(init_response)
if www_auth_url:
urls.append(www_auth_url)

# Priority 2-3: Well-known URIs (RFC 9728)
parsed = urlparse(self.context.server_url)
base_url = f"{parsed.scheme}://{parsed.netloc}"

# Priority 2: Path-based well-known URI (if server has a path component)
if parsed.path and parsed.path != "/":
path_based_url = urljoin(base_url, f"/.well-known/oauth-protected-resource{parsed.path}")
urls.append(path_based_url)

# Priority 3: Root-based well-known URI
root_based_url = urljoin(base_url, "/.well-known/oauth-protected-resource")
urls.append(root_based_url)

return urls

def _extract_field_from_www_auth(self, init_response: httpx.Response, field_name: str) -> str | None:
"""
Extract field from WWW-Authenticate header.
Expand Down Expand Up @@ -246,30 +279,34 @@ def _extract_scope_from_www_auth(self, init_response: httpx.Response) -> str | N
"""
return self._extract_field_from_www_auth(init_response, "scope")

async def _discover_protected_resource(self, init_response: httpx.Response) -> httpx.Request:
# RFC9728: Try to extract resource_metadata URL from WWW-Authenticate header of the initial response
url = self._extract_resource_metadata_from_www_auth(init_response)

if not url:
# Fallback to well-known discovery
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
url = urljoin(auth_base_url, "/.well-known/oauth-protected-resource")
async def _handle_protected_resource_response(self, response: httpx.Response) -> bool:
"""
Handle protected resource metadata discovery response.
return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION})
Per SEP-985, supports fallback when discovery fails at one URL.
async def _handle_protected_resource_response(self, response: httpx.Response) -> None:
"""Handle discovery response."""
Returns:
True if metadata was successfully discovered, False if we should try next URL
"""
if response.status_code == 200:
try:
content = await response.aread()
metadata = ProtectedResourceMetadata.model_validate_json(content)
self.context.protected_resource_metadata = metadata
if metadata.authorization_servers:
self.context.auth_server_url = str(metadata.authorization_servers[0])
return True

except ValidationError:
pass
# Invalid metadata - try next URL
logger.warning(f"Invalid protected resource metadata at {response.request.url}")
return False
elif response.status_code == 404:
# Not found - try next URL in fallback chain
logger.debug(f"Protected resource metadata not found at {response.request.url}, trying next URL")
return False
else:
# Other error - fail immediately
raise OAuthFlowError(f"Protected Resource Metadata request failed: {response.status_code}")

def _select_scopes(self, init_response: httpx.Response) -> None:
Expand Down Expand Up @@ -573,10 +610,20 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
# Perform full OAuth flow
try:
# OAuth flow must be inline due to generator constraints
# Step 1: Discover protected resource metadata (RFC9728 with WWW-Authenticate support)
discovery_request = await self._discover_protected_resource(response)
discovery_response = yield discovery_request
await self._handle_protected_resource_response(discovery_response)
# Step 1: Discover protected resource metadata (SEP-985 with fallback support)
discovery_urls = self._build_protected_resource_discovery_urls(response)
discovery_success = False
for url in discovery_urls:
discovery_request = httpx.Request(
"GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}
)
discovery_response = yield discovery_request
discovery_success = await self._handle_protected_resource_response(discovery_response)
if discovery_success:
break

if not discovery_success:
raise OAuthFlowError("Protected resource metadata discovery failed: no valid metadata found")

# Step 2: Apply scope selection strategy
self._select_scopes(response)
Expand Down
195 changes: 185 additions & 10 deletions tests/client/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,10 +241,10 @@ class TestOAuthFlow:
"""Test OAuth flow methods."""

@pytest.mark.anyio
async def test_discover_protected_resource_request(
async def test_build_protected_resource_discovery_urls(
self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
):
"""Test protected resource discovery request building maintains backward compatibility."""
"""Test protected resource metadata discovery URL building with fallback."""

async def redirect_handler(url: str) -> None:
pass
Expand All @@ -265,20 +265,19 @@ async def callback_handler() -> tuple[str, str | None]:
status_code=401, headers={}, request=httpx.Request("GET", "https://request-api.example.com")
)

request = await provider._discover_protected_resource(init_response)
assert request.method == "GET"
assert str(request.url) == "https://api.example.com/.well-known/oauth-protected-resource"
assert "mcp-protocol-version" in request.headers
urls = provider._build_protected_resource_discovery_urls(init_response)
assert len(urls) == 1
assert urls[0] == "https://api.example.com/.well-known/oauth-protected-resource"

# Test with WWW-Authenticate header
init_response.headers["WWW-Authenticate"] = (
'Bearer resource_metadata="https://prm.example.com/.well-known/oauth-protected-resource/path"'
)

request = await provider._discover_protected_resource(init_response)
assert request.method == "GET"
assert str(request.url) == "https://prm.example.com/.well-known/oauth-protected-resource/path"
assert "mcp-protocol-version" in request.headers
urls = provider._build_protected_resource_discovery_urls(init_response)
assert len(urls) == 2
assert urls[0] == "https://prm.example.com/.well-known/oauth-protected-resource/path"
assert urls[1] == "https://api.example.com/.well-known/oauth-protected-resource"

@pytest.mark.anyio
def test_create_oauth_metadata_request(self, oauth_provider: OAuthClientProvider):
Expand Down Expand Up @@ -1034,6 +1033,182 @@ def test_build_metadata(
)


class TestSEP985Discovery:
"""Test SEP-985 protected resource metadata discovery with fallback."""

@pytest.mark.anyio
async def test_path_based_fallback_when_no_www_authenticate(
self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
):
"""Test that client falls back to path-based well-known URI when WWW-Authenticate is absent."""

async def redirect_handler(url: str) -> None:
pass

async def callback_handler() -> tuple[str, str | None]:
return "test_auth_code", "test_state"

provider = OAuthClientProvider(
server_url="https://api.example.com/v1/mcp",
client_metadata=client_metadata,
storage=mock_storage,
redirect_handler=redirect_handler,
callback_handler=callback_handler,
)

# Test with 401 response without WWW-Authenticate header
init_response = httpx.Response(
status_code=401, headers={}, request=httpx.Request("GET", "https://api.example.com/v1/mcp")
)

# Build discovery URLs
discovery_urls = provider._build_protected_resource_discovery_urls(init_response)

# Should have path-based URL first, then root-based URL
assert len(discovery_urls) == 2
assert discovery_urls[0] == "https://api.example.com/.well-known/oauth-protected-resource/v1/mcp"
assert discovery_urls[1] == "https://api.example.com/.well-known/oauth-protected-resource"

@pytest.mark.anyio
async def test_root_based_fallback_after_path_based_404(
self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
):
"""Test that client falls back to root-based URI when path-based returns 404."""

async def redirect_handler(url: str) -> None:
pass

async def callback_handler() -> tuple[str, str | None]:
return "test_auth_code", "test_state"

provider = OAuthClientProvider(
server_url="https://api.example.com/v1/mcp",
client_metadata=client_metadata,
storage=mock_storage,
redirect_handler=redirect_handler,
callback_handler=callback_handler,
)

# Ensure no tokens are stored
provider.context.current_tokens = None
provider.context.token_expiry_time = None
provider._initialized = True

# Mock client info to skip DCR
provider.context.client_info = OAuthClientInformationFull(
client_id="existing_client",
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
)

# Create a test request
test_request = httpx.Request("GET", "https://api.example.com/v1/mcp")

# Mock the auth flow
auth_flow = provider.async_auth_flow(test_request)

# First request should be the original request without auth header
request = await auth_flow.__anext__()
assert "Authorization" not in request.headers

# Send a 401 response without WWW-Authenticate header
response = httpx.Response(401, headers={}, request=test_request)

# Next request should be to discover protected resource metadata (path-based)
discovery_request_1 = await auth_flow.asend(response)
assert str(discovery_request_1.url) == "https://api.example.com/.well-known/oauth-protected-resource/v1/mcp"
assert discovery_request_1.method == "GET"

# Send 404 response for path-based discovery
discovery_response_1 = httpx.Response(404, request=discovery_request_1)

# Next request should be to root-based well-known URI
discovery_request_2 = await auth_flow.asend(discovery_response_1)
assert str(discovery_request_2.url) == "https://api.example.com/.well-known/oauth-protected-resource"
assert discovery_request_2.method == "GET"

# Send successful discovery response
discovery_response_2 = httpx.Response(
200,
content=(
b'{"resource": "https://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com"]}'
),
request=discovery_request_2,
)

# Mock the rest of the OAuth flow
provider._perform_authorization = mock.AsyncMock(return_value=("test_auth_code", "test_code_verifier"))

# Next should be OAuth metadata discovery
oauth_metadata_request = await auth_flow.asend(discovery_response_2)
assert oauth_metadata_request.method == "GET"

# Complete the flow
oauth_metadata_response = httpx.Response(
200,
content=(
b'{"issuer": "https://auth.example.com", '
b'"authorization_endpoint": "https://auth.example.com/authorize", '
b'"token_endpoint": "https://auth.example.com/token"}'
),
request=oauth_metadata_request,
)

token_request = await auth_flow.asend(oauth_metadata_response)
token_response = httpx.Response(
200,
content=(
b'{"access_token": "new_access_token", "token_type": "Bearer", "expires_in": 3600, '
b'"refresh_token": "new_refresh_token"}'
),
request=token_request,
)

final_request = await auth_flow.asend(token_response)
final_response = httpx.Response(200, request=final_request)
try:
await auth_flow.asend(final_response)
except StopAsyncIteration:
pass

@pytest.mark.anyio
async def test_www_authenticate_takes_priority_over_well_known(
self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
):
"""Test that WWW-Authenticate header resource_metadata takes priority over well-known URIs."""

async def redirect_handler(url: str) -> None:
pass

async def callback_handler() -> tuple[str, str | None]:
return "test_auth_code", "test_state"

provider = OAuthClientProvider(
server_url="https://api.example.com/v1/mcp",
client_metadata=client_metadata,
storage=mock_storage,
redirect_handler=redirect_handler,
callback_handler=callback_handler,
)

# Test with 401 response with WWW-Authenticate header
init_response = httpx.Response(
status_code=401,
headers={
"WWW-Authenticate": 'Bearer resource_metadata="https://custom.example.com/.well-known/oauth-protected-resource"'
},
request=httpx.Request("GET", "https://api.example.com/v1/mcp"),
)

# Build discovery URLs
discovery_urls = provider._build_protected_resource_discovery_urls(init_response)

# Should have WWW-Authenticate URL first, then fallback URLs
assert len(discovery_urls) == 3
assert discovery_urls[0] == "https://custom.example.com/.well-known/oauth-protected-resource"
assert discovery_urls[1] == "https://api.example.com/.well-known/oauth-protected-resource/v1/mcp"
assert discovery_urls[2] == "https://api.example.com/.well-known/oauth-protected-resource"


class TestWWWAuthenticate:
"""Test WWW-Authenticate header parsing functionality."""

Expand Down