Skip to content

Commit 3390e49

Browse files
cbcoutinhoclaudepcarleton
authored
Implement SEP-985: OAuth Protected Resource Metadata discovery fallback (#1548)
Co-authored-by: Claude <[email protected]> Co-authored-by: Paul Carleton <[email protected]>
1 parent de2289d commit 3390e49

File tree

2 files changed

+252
-30
lines changed

2 files changed

+252
-30
lines changed

src/mcp/client/auth/oauth2.py

Lines changed: 67 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,6 @@ class OAuthContext:
108108
# State
109109
lock: anyio.Lock = field(default_factory=anyio.Lock)
110110

111-
# Discovery state for fallback support
112-
discovery_base_url: str | None = None
113-
discovery_pathname: str | None = None
114-
115111
def get_authorization_base_url(self, server_url: str) -> str:
116112
"""Extract base URL by removing path component."""
117113
parsed = urlparse(server_url)
@@ -204,6 +200,43 @@ def __init__(
204200
)
205201
self._initialized = False
206202

203+
def _build_protected_resource_discovery_urls(self, init_response: httpx.Response) -> list[str]:
204+
"""
205+
Build ordered list of URLs to try for protected resource metadata discovery.
206+
207+
Per SEP-985, the client MUST:
208+
1. Try resource_metadata from WWW-Authenticate header (if present)
209+
2. Fall back to path-based well-known URI: /.well-known/oauth-protected-resource/{path}
210+
3. Fall back to root-based well-known URI: /.well-known/oauth-protected-resource
211+
212+
Args:
213+
init_response: The initial 401 response from the server
214+
215+
Returns:
216+
Ordered list of URLs to try for discovery
217+
"""
218+
urls: list[str] = []
219+
220+
# Priority 1: WWW-Authenticate header with resource_metadata parameter
221+
www_auth_url = self._extract_resource_metadata_from_www_auth(init_response)
222+
if www_auth_url:
223+
urls.append(www_auth_url)
224+
225+
# Priority 2-3: Well-known URIs (RFC 9728)
226+
parsed = urlparse(self.context.server_url)
227+
base_url = f"{parsed.scheme}://{parsed.netloc}"
228+
229+
# Priority 2: Path-based well-known URI (if server has a path component)
230+
if parsed.path and parsed.path != "/":
231+
path_based_url = urljoin(base_url, f"/.well-known/oauth-protected-resource{parsed.path}")
232+
urls.append(path_based_url)
233+
234+
# Priority 3: Root-based well-known URI
235+
root_based_url = urljoin(base_url, "/.well-known/oauth-protected-resource")
236+
urls.append(root_based_url)
237+
238+
return urls
239+
207240
def _extract_field_from_www_auth(self, init_response: httpx.Response, field_name: str) -> str | None:
208241
"""
209242
Extract field from WWW-Authenticate header.
@@ -246,30 +279,34 @@ def _extract_scope_from_www_auth(self, init_response: httpx.Response) -> str | N
246279
"""
247280
return self._extract_field_from_www_auth(init_response, "scope")
248281

249-
async def _discover_protected_resource(self, init_response: httpx.Response) -> httpx.Request:
250-
# RFC9728: Try to extract resource_metadata URL from WWW-Authenticate header of the initial response
251-
url = self._extract_resource_metadata_from_www_auth(init_response)
252-
253-
if not url:
254-
# Fallback to well-known discovery
255-
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
256-
url = urljoin(auth_base_url, "/.well-known/oauth-protected-resource")
282+
async def _handle_protected_resource_response(self, response: httpx.Response) -> bool:
283+
"""
284+
Handle protected resource metadata discovery response.
257285
258-
return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION})
286+
Per SEP-985, supports fallback when discovery fails at one URL.
259287
260-
async def _handle_protected_resource_response(self, response: httpx.Response) -> None:
261-
"""Handle discovery response."""
288+
Returns:
289+
True if metadata was successfully discovered, False if we should try next URL
290+
"""
262291
if response.status_code == 200:
263292
try:
264293
content = await response.aread()
265294
metadata = ProtectedResourceMetadata.model_validate_json(content)
266295
self.context.protected_resource_metadata = metadata
267296
if metadata.authorization_servers:
268297
self.context.auth_server_url = str(metadata.authorization_servers[0])
298+
return True
269299

270300
except ValidationError:
271-
pass
301+
# Invalid metadata - try next URL
302+
logger.warning(f"Invalid protected resource metadata at {response.request.url}")
303+
return False
304+
elif response.status_code == 404:
305+
# Not found - try next URL in fallback chain
306+
logger.debug(f"Protected resource metadata not found at {response.request.url}, trying next URL")
307+
return False
272308
else:
309+
# Other error - fail immediately
273310
raise OAuthFlowError(f"Protected Resource Metadata request failed: {response.status_code}")
274311

275312
def _select_scopes(self, init_response: httpx.Response) -> None:
@@ -573,10 +610,20 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
573610
# Perform full OAuth flow
574611
try:
575612
# OAuth flow must be inline due to generator constraints
576-
# Step 1: Discover protected resource metadata (RFC9728 with WWW-Authenticate support)
577-
discovery_request = await self._discover_protected_resource(response)
578-
discovery_response = yield discovery_request
579-
await self._handle_protected_resource_response(discovery_response)
613+
# Step 1: Discover protected resource metadata (SEP-985 with fallback support)
614+
discovery_urls = self._build_protected_resource_discovery_urls(response)
615+
discovery_success = False
616+
for url in discovery_urls:
617+
discovery_request = httpx.Request(
618+
"GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}
619+
)
620+
discovery_response = yield discovery_request
621+
discovery_success = await self._handle_protected_resource_response(discovery_response)
622+
if discovery_success:
623+
break
624+
625+
if not discovery_success:
626+
raise OAuthFlowError("Protected resource metadata discovery failed: no valid metadata found")
580627

581628
# Step 2: Apply scope selection strategy
582629
self._select_scopes(response)

tests/client/test_auth.py

Lines changed: 185 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -241,10 +241,10 @@ class TestOAuthFlow:
241241
"""Test OAuth flow methods."""
242242

243243
@pytest.mark.anyio
244-
async def test_discover_protected_resource_request(
244+
async def test_build_protected_resource_discovery_urls(
245245
self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
246246
):
247-
"""Test protected resource discovery request building maintains backward compatibility."""
247+
"""Test protected resource metadata discovery URL building with fallback."""
248248

249249
async def redirect_handler(url: str) -> None:
250250
pass
@@ -265,20 +265,19 @@ async def callback_handler() -> tuple[str, str | None]:
265265
status_code=401, headers={}, request=httpx.Request("GET", "https://request-api.example.com")
266266
)
267267

268-
request = await provider._discover_protected_resource(init_response)
269-
assert request.method == "GET"
270-
assert str(request.url) == "https://api.example.com/.well-known/oauth-protected-resource"
271-
assert "mcp-protocol-version" in request.headers
268+
urls = provider._build_protected_resource_discovery_urls(init_response)
269+
assert len(urls) == 1
270+
assert urls[0] == "https://api.example.com/.well-known/oauth-protected-resource"
272271

273272
# Test with WWW-Authenticate header
274273
init_response.headers["WWW-Authenticate"] = (
275274
'Bearer resource_metadata="https://prm.example.com/.well-known/oauth-protected-resource/path"'
276275
)
277276

278-
request = await provider._discover_protected_resource(init_response)
279-
assert request.method == "GET"
280-
assert str(request.url) == "https://prm.example.com/.well-known/oauth-protected-resource/path"
281-
assert "mcp-protocol-version" in request.headers
277+
urls = provider._build_protected_resource_discovery_urls(init_response)
278+
assert len(urls) == 2
279+
assert urls[0] == "https://prm.example.com/.well-known/oauth-protected-resource/path"
280+
assert urls[1] == "https://api.example.com/.well-known/oauth-protected-resource"
282281

283282
@pytest.mark.anyio
284283
def test_create_oauth_metadata_request(self, oauth_provider: OAuthClientProvider):
@@ -1034,6 +1033,182 @@ def test_build_metadata(
10341033
)
10351034

10361035

1036+
class TestSEP985Discovery:
1037+
"""Test SEP-985 protected resource metadata discovery with fallback."""
1038+
1039+
@pytest.mark.anyio
1040+
async def test_path_based_fallback_when_no_www_authenticate(
1041+
self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
1042+
):
1043+
"""Test that client falls back to path-based well-known URI when WWW-Authenticate is absent."""
1044+
1045+
async def redirect_handler(url: str) -> None:
1046+
pass
1047+
1048+
async def callback_handler() -> tuple[str, str | None]:
1049+
return "test_auth_code", "test_state"
1050+
1051+
provider = OAuthClientProvider(
1052+
server_url="https://api.example.com/v1/mcp",
1053+
client_metadata=client_metadata,
1054+
storage=mock_storage,
1055+
redirect_handler=redirect_handler,
1056+
callback_handler=callback_handler,
1057+
)
1058+
1059+
# Test with 401 response without WWW-Authenticate header
1060+
init_response = httpx.Response(
1061+
status_code=401, headers={}, request=httpx.Request("GET", "https://api.example.com/v1/mcp")
1062+
)
1063+
1064+
# Build discovery URLs
1065+
discovery_urls = provider._build_protected_resource_discovery_urls(init_response)
1066+
1067+
# Should have path-based URL first, then root-based URL
1068+
assert len(discovery_urls) == 2
1069+
assert discovery_urls[0] == "https://api.example.com/.well-known/oauth-protected-resource/v1/mcp"
1070+
assert discovery_urls[1] == "https://api.example.com/.well-known/oauth-protected-resource"
1071+
1072+
@pytest.mark.anyio
1073+
async def test_root_based_fallback_after_path_based_404(
1074+
self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
1075+
):
1076+
"""Test that client falls back to root-based URI when path-based returns 404."""
1077+
1078+
async def redirect_handler(url: str) -> None:
1079+
pass
1080+
1081+
async def callback_handler() -> tuple[str, str | None]:
1082+
return "test_auth_code", "test_state"
1083+
1084+
provider = OAuthClientProvider(
1085+
server_url="https://api.example.com/v1/mcp",
1086+
client_metadata=client_metadata,
1087+
storage=mock_storage,
1088+
redirect_handler=redirect_handler,
1089+
callback_handler=callback_handler,
1090+
)
1091+
1092+
# Ensure no tokens are stored
1093+
provider.context.current_tokens = None
1094+
provider.context.token_expiry_time = None
1095+
provider._initialized = True
1096+
1097+
# Mock client info to skip DCR
1098+
provider.context.client_info = OAuthClientInformationFull(
1099+
client_id="existing_client",
1100+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
1101+
)
1102+
1103+
# Create a test request
1104+
test_request = httpx.Request("GET", "https://api.example.com/v1/mcp")
1105+
1106+
# Mock the auth flow
1107+
auth_flow = provider.async_auth_flow(test_request)
1108+
1109+
# First request should be the original request without auth header
1110+
request = await auth_flow.__anext__()
1111+
assert "Authorization" not in request.headers
1112+
1113+
# Send a 401 response without WWW-Authenticate header
1114+
response = httpx.Response(401, headers={}, request=test_request)
1115+
1116+
# Next request should be to discover protected resource metadata (path-based)
1117+
discovery_request_1 = await auth_flow.asend(response)
1118+
assert str(discovery_request_1.url) == "https://api.example.com/.well-known/oauth-protected-resource/v1/mcp"
1119+
assert discovery_request_1.method == "GET"
1120+
1121+
# Send 404 response for path-based discovery
1122+
discovery_response_1 = httpx.Response(404, request=discovery_request_1)
1123+
1124+
# Next request should be to root-based well-known URI
1125+
discovery_request_2 = await auth_flow.asend(discovery_response_1)
1126+
assert str(discovery_request_2.url) == "https://api.example.com/.well-known/oauth-protected-resource"
1127+
assert discovery_request_2.method == "GET"
1128+
1129+
# Send successful discovery response
1130+
discovery_response_2 = httpx.Response(
1131+
200,
1132+
content=(
1133+
b'{"resource": "https://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com"]}'
1134+
),
1135+
request=discovery_request_2,
1136+
)
1137+
1138+
# Mock the rest of the OAuth flow
1139+
provider._perform_authorization = mock.AsyncMock(return_value=("test_auth_code", "test_code_verifier"))
1140+
1141+
# Next should be OAuth metadata discovery
1142+
oauth_metadata_request = await auth_flow.asend(discovery_response_2)
1143+
assert oauth_metadata_request.method == "GET"
1144+
1145+
# Complete the flow
1146+
oauth_metadata_response = httpx.Response(
1147+
200,
1148+
content=(
1149+
b'{"issuer": "https://auth.example.com", '
1150+
b'"authorization_endpoint": "https://auth.example.com/authorize", '
1151+
b'"token_endpoint": "https://auth.example.com/token"}'
1152+
),
1153+
request=oauth_metadata_request,
1154+
)
1155+
1156+
token_request = await auth_flow.asend(oauth_metadata_response)
1157+
token_response = httpx.Response(
1158+
200,
1159+
content=(
1160+
b'{"access_token": "new_access_token", "token_type": "Bearer", "expires_in": 3600, '
1161+
b'"refresh_token": "new_refresh_token"}'
1162+
),
1163+
request=token_request,
1164+
)
1165+
1166+
final_request = await auth_flow.asend(token_response)
1167+
final_response = httpx.Response(200, request=final_request)
1168+
try:
1169+
await auth_flow.asend(final_response)
1170+
except StopAsyncIteration:
1171+
pass
1172+
1173+
@pytest.mark.anyio
1174+
async def test_www_authenticate_takes_priority_over_well_known(
1175+
self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
1176+
):
1177+
"""Test that WWW-Authenticate header resource_metadata takes priority over well-known URIs."""
1178+
1179+
async def redirect_handler(url: str) -> None:
1180+
pass
1181+
1182+
async def callback_handler() -> tuple[str, str | None]:
1183+
return "test_auth_code", "test_state"
1184+
1185+
provider = OAuthClientProvider(
1186+
server_url="https://api.example.com/v1/mcp",
1187+
client_metadata=client_metadata,
1188+
storage=mock_storage,
1189+
redirect_handler=redirect_handler,
1190+
callback_handler=callback_handler,
1191+
)
1192+
1193+
# Test with 401 response with WWW-Authenticate header
1194+
init_response = httpx.Response(
1195+
status_code=401,
1196+
headers={
1197+
"WWW-Authenticate": 'Bearer resource_metadata="https://custom.example.com/.well-known/oauth-protected-resource"'
1198+
},
1199+
request=httpx.Request("GET", "https://api.example.com/v1/mcp"),
1200+
)
1201+
1202+
# Build discovery URLs
1203+
discovery_urls = provider._build_protected_resource_discovery_urls(init_response)
1204+
1205+
# Should have WWW-Authenticate URL first, then fallback URLs
1206+
assert len(discovery_urls) == 3
1207+
assert discovery_urls[0] == "https://custom.example.com/.well-known/oauth-protected-resource"
1208+
assert discovery_urls[1] == "https://api.example.com/.well-known/oauth-protected-resource/v1/mcp"
1209+
assert discovery_urls[2] == "https://api.example.com/.well-known/oauth-protected-resource"
1210+
1211+
10371212
class TestWWWAuthenticate:
10381213
"""Test WWW-Authenticate header parsing functionality."""
10391214

0 commit comments

Comments
 (0)