diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index 9e176980a..9b950db72 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -108,10 +108,6 @@ class OAuthContext: # State lock: anyio.Lock = field(default_factory=anyio.Lock) - # Discovery state for fallback support - discovery_base_url: str | None = None - discovery_pathname: str | None = None - def get_authorization_base_url(self, server_url: str) -> str: """Extract base URL by removing path component.""" parsed = urlparse(server_url) @@ -204,6 +200,43 @@ def __init__( ) self._initialized = False + def _build_protected_resource_discovery_urls(self, init_response: httpx.Response) -> list[str]: + """ + Build ordered list of URLs to try for protected resource metadata discovery. + + Per SEP-985, the client MUST: + 1. Try resource_metadata from WWW-Authenticate header (if present) + 2. Fall back to path-based well-known URI: /.well-known/oauth-protected-resource/{path} + 3. Fall back to root-based well-known URI: /.well-known/oauth-protected-resource + + Args: + init_response: The initial 401 response from the server + + Returns: + Ordered list of URLs to try for discovery + """ + urls: list[str] = [] + + # Priority 1: WWW-Authenticate header with resource_metadata parameter + www_auth_url = self._extract_resource_metadata_from_www_auth(init_response) + if www_auth_url: + urls.append(www_auth_url) + + # Priority 2-3: Well-known URIs (RFC 9728) + parsed = urlparse(self.context.server_url) + base_url = f"{parsed.scheme}://{parsed.netloc}" + + # Priority 2: Path-based well-known URI (if server has a path component) + if parsed.path and parsed.path != "/": + path_based_url = urljoin(base_url, f"/.well-known/oauth-protected-resource{parsed.path}") + urls.append(path_based_url) + + # Priority 3: Root-based well-known URI + root_based_url = urljoin(base_url, "/.well-known/oauth-protected-resource") + urls.append(root_based_url) + + return urls + def _extract_field_from_www_auth(self, init_response: httpx.Response, field_name: str) -> str | None: """ Extract field from WWW-Authenticate header. @@ -246,19 +279,15 @@ def _extract_scope_from_www_auth(self, init_response: httpx.Response) -> str | N """ return self._extract_field_from_www_auth(init_response, "scope") - async def _discover_protected_resource(self, init_response: httpx.Response) -> httpx.Request: - # RFC9728: Try to extract resource_metadata URL from WWW-Authenticate header of the initial response - url = self._extract_resource_metadata_from_www_auth(init_response) - - if not url: - # Fallback to well-known discovery - auth_base_url = self.context.get_authorization_base_url(self.context.server_url) - url = urljoin(auth_base_url, "/.well-known/oauth-protected-resource") + async def _handle_protected_resource_response(self, response: httpx.Response) -> bool: + """ + Handle protected resource metadata discovery response. - return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}) + Per SEP-985, supports fallback when discovery fails at one URL. - async def _handle_protected_resource_response(self, response: httpx.Response) -> None: - """Handle discovery response.""" + Returns: + True if metadata was successfully discovered, False if we should try next URL + """ if response.status_code == 200: try: content = await response.aread() @@ -266,10 +295,18 @@ async def _handle_protected_resource_response(self, response: httpx.Response) -> self.context.protected_resource_metadata = metadata if metadata.authorization_servers: self.context.auth_server_url = str(metadata.authorization_servers[0]) + return True except ValidationError: - pass + # Invalid metadata - try next URL + logger.warning(f"Invalid protected resource metadata at {response.request.url}") + return False + elif response.status_code == 404: + # Not found - try next URL in fallback chain + logger.debug(f"Protected resource metadata not found at {response.request.url}, trying next URL") + return False else: + # Other error - fail immediately raise OAuthFlowError(f"Protected Resource Metadata request failed: {response.status_code}") def _select_scopes(self, init_response: httpx.Response) -> None: @@ -573,10 +610,20 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. # Perform full OAuth flow try: # OAuth flow must be inline due to generator constraints - # Step 1: Discover protected resource metadata (RFC9728 with WWW-Authenticate support) - discovery_request = await self._discover_protected_resource(response) - discovery_response = yield discovery_request - await self._handle_protected_resource_response(discovery_response) + # Step 1: Discover protected resource metadata (SEP-985 with fallback support) + discovery_urls = self._build_protected_resource_discovery_urls(response) + discovery_success = False + for url in discovery_urls: + discovery_request = httpx.Request( + "GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION} + ) + discovery_response = yield discovery_request + discovery_success = await self._handle_protected_resource_response(discovery_response) + if discovery_success: + break + + if not discovery_success: + raise OAuthFlowError("Protected resource metadata discovery failed: no valid metadata found") # Step 2: Apply scope selection strategy self._select_scopes(response) diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index a1a0a3fde..8cea6cefd 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -241,10 +241,10 @@ class TestOAuthFlow: """Test OAuth flow methods.""" @pytest.mark.anyio - async def test_discover_protected_resource_request( + async def test_build_protected_resource_discovery_urls( self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage ): - """Test protected resource discovery request building maintains backward compatibility.""" + """Test protected resource metadata discovery URL building with fallback.""" async def redirect_handler(url: str) -> None: pass @@ -265,20 +265,19 @@ async def callback_handler() -> tuple[str, str | None]: status_code=401, headers={}, request=httpx.Request("GET", "https://request-api.example.com") ) - request = await provider._discover_protected_resource(init_response) - assert request.method == "GET" - assert str(request.url) == "https://api.example.com/.well-known/oauth-protected-resource" - assert "mcp-protocol-version" in request.headers + urls = provider._build_protected_resource_discovery_urls(init_response) + assert len(urls) == 1 + assert urls[0] == "https://api.example.com/.well-known/oauth-protected-resource" # Test with WWW-Authenticate header init_response.headers["WWW-Authenticate"] = ( 'Bearer resource_metadata="https://prm.example.com/.well-known/oauth-protected-resource/path"' ) - request = await provider._discover_protected_resource(init_response) - assert request.method == "GET" - assert str(request.url) == "https://prm.example.com/.well-known/oauth-protected-resource/path" - assert "mcp-protocol-version" in request.headers + urls = provider._build_protected_resource_discovery_urls(init_response) + assert len(urls) == 2 + assert urls[0] == "https://prm.example.com/.well-known/oauth-protected-resource/path" + assert urls[1] == "https://api.example.com/.well-known/oauth-protected-resource" @pytest.mark.anyio def test_create_oauth_metadata_request(self, oauth_provider: OAuthClientProvider): @@ -1034,6 +1033,182 @@ def test_build_metadata( ) +class TestSEP985Discovery: + """Test SEP-985 protected resource metadata discovery with fallback.""" + + @pytest.mark.anyio + async def test_path_based_fallback_when_no_www_authenticate( + self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage + ): + """Test that client falls back to path-based well-known URI when WWW-Authenticate is absent.""" + + async def redirect_handler(url: str) -> None: + pass + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", "test_state" + + provider = OAuthClientProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + ) + + # Test with 401 response without WWW-Authenticate header + init_response = httpx.Response( + status_code=401, headers={}, request=httpx.Request("GET", "https://api.example.com/v1/mcp") + ) + + # Build discovery URLs + discovery_urls = provider._build_protected_resource_discovery_urls(init_response) + + # Should have path-based URL first, then root-based URL + assert len(discovery_urls) == 2 + assert discovery_urls[0] == "https://api.example.com/.well-known/oauth-protected-resource/v1/mcp" + assert discovery_urls[1] == "https://api.example.com/.well-known/oauth-protected-resource" + + @pytest.mark.anyio + async def test_root_based_fallback_after_path_based_404( + self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage + ): + """Test that client falls back to root-based URI when path-based returns 404.""" + + async def redirect_handler(url: str) -> None: + pass + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", "test_state" + + provider = OAuthClientProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + ) + + # Ensure no tokens are stored + provider.context.current_tokens = None + provider.context.token_expiry_time = None + provider._initialized = True + + # Mock client info to skip DCR + provider.context.client_info = OAuthClientInformationFull( + client_id="existing_client", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ) + + # Create a test request + test_request = httpx.Request("GET", "https://api.example.com/v1/mcp") + + # Mock the auth flow + auth_flow = provider.async_auth_flow(test_request) + + # First request should be the original request without auth header + request = await auth_flow.__anext__() + assert "Authorization" not in request.headers + + # Send a 401 response without WWW-Authenticate header + response = httpx.Response(401, headers={}, request=test_request) + + # Next request should be to discover protected resource metadata (path-based) + discovery_request_1 = await auth_flow.asend(response) + assert str(discovery_request_1.url) == "https://api.example.com/.well-known/oauth-protected-resource/v1/mcp" + assert discovery_request_1.method == "GET" + + # Send 404 response for path-based discovery + discovery_response_1 = httpx.Response(404, request=discovery_request_1) + + # Next request should be to root-based well-known URI + discovery_request_2 = await auth_flow.asend(discovery_response_1) + assert str(discovery_request_2.url) == "https://api.example.com/.well-known/oauth-protected-resource" + assert discovery_request_2.method == "GET" + + # Send successful discovery response + discovery_response_2 = httpx.Response( + 200, + content=( + b'{"resource": "https://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com"]}' + ), + request=discovery_request_2, + ) + + # Mock the rest of the OAuth flow + provider._perform_authorization = mock.AsyncMock(return_value=("test_auth_code", "test_code_verifier")) + + # Next should be OAuth metadata discovery + oauth_metadata_request = await auth_flow.asend(discovery_response_2) + assert oauth_metadata_request.method == "GET" + + # Complete the flow + oauth_metadata_response = httpx.Response( + 200, + content=( + b'{"issuer": "https://auth.example.com", ' + b'"authorization_endpoint": "https://auth.example.com/authorize", ' + b'"token_endpoint": "https://auth.example.com/token"}' + ), + request=oauth_metadata_request, + ) + + token_request = await auth_flow.asend(oauth_metadata_response) + token_response = httpx.Response( + 200, + content=( + b'{"access_token": "new_access_token", "token_type": "Bearer", "expires_in": 3600, ' + b'"refresh_token": "new_refresh_token"}' + ), + request=token_request, + ) + + final_request = await auth_flow.asend(token_response) + final_response = httpx.Response(200, request=final_request) + try: + await auth_flow.asend(final_response) + except StopAsyncIteration: + pass + + @pytest.mark.anyio + async def test_www_authenticate_takes_priority_over_well_known( + self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage + ): + """Test that WWW-Authenticate header resource_metadata takes priority over well-known URIs.""" + + async def redirect_handler(url: str) -> None: + pass + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", "test_state" + + provider = OAuthClientProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + ) + + # Test with 401 response with WWW-Authenticate header + init_response = httpx.Response( + status_code=401, + headers={ + "WWW-Authenticate": 'Bearer resource_metadata="https://custom.example.com/.well-known/oauth-protected-resource"' + }, + request=httpx.Request("GET", "https://api.example.com/v1/mcp"), + ) + + # Build discovery URLs + discovery_urls = provider._build_protected_resource_discovery_urls(init_response) + + # Should have WWW-Authenticate URL first, then fallback URLs + assert len(discovery_urls) == 3 + assert discovery_urls[0] == "https://custom.example.com/.well-known/oauth-protected-resource" + assert discovery_urls[1] == "https://api.example.com/.well-known/oauth-protected-resource/v1/mcp" + assert discovery_urls[2] == "https://api.example.com/.well-known/oauth-protected-resource" + + class TestWWWAuthenticate: """Test WWW-Authenticate header parsing functionality."""