22Tests for refactored OAuth client authentication implementation.
33"""
44
5+ import base64
6+ import json
57import time
68from unittest import mock
9+ from urllib .parse import unquote
710
811import httpx
912import pytest
1013from inline_snapshot import Is , snapshot
1114from pydantic import AnyHttpUrl , AnyUrl
1215
13- from mcp .client .auth import OAuthClientProvider , PKCEParameters
14- from mcp .shared .auth import OAuthClientInformationFull , OAuthClientMetadata , OAuthToken , ProtectedResourceMetadata
16+ from mcp .client .auth import OAuthClientProvider , OAuthRegistrationError , PKCEParameters
17+ from mcp .shared .auth import (
18+ OAuthClientInformationFull ,
19+ OAuthClientMetadata ,
20+ OAuthMetadata ,
21+ OAuthToken ,
22+ ProtectedResourceMetadata ,
23+ )
1524
1625
1726class MockTokenStorage :
@@ -415,6 +424,43 @@ async def test_register_client_skip_if_registered(self, oauth_provider: OAuthCli
415424 request = await oauth_provider ._register_client ()
416425 assert request is None
417426
427+ @pytest .mark .anyio
428+ async def test_register_client_none_auth_method_with_server_metadata (self , oauth_provider : OAuthClientProvider ):
429+ """Test that token_endpoint_auth_method=None selects from server's supported methods."""
430+ # Set server metadata with specific supported methods
431+ oauth_provider .context .oauth_metadata = OAuthMetadata (
432+ issuer = AnyHttpUrl ("https://auth.example.com" ),
433+ authorization_endpoint = AnyHttpUrl ("https://auth.example.com/authorize" ),
434+ token_endpoint = AnyHttpUrl ("https://auth.example.com/token" ),
435+ token_endpoint_auth_methods_supported = ["client_secret_post" ],
436+ )
437+ # Ensure client_metadata has None for token_endpoint_auth_method
438+ assert oauth_provider .context .client_metadata .token_endpoint_auth_method is None
439+
440+ request = await oauth_provider ._register_client ()
441+ assert request is not None
442+
443+ body = json .loads (request .content )
444+ assert body ["token_endpoint_auth_method" ] == "client_secret_post"
445+
446+ @pytest .mark .anyio
447+ async def test_register_client_none_auth_method_no_compatible (self , oauth_provider : OAuthClientProvider ):
448+ """Test that registration raises error when no compatible auth methods."""
449+ # Set server metadata with unsupported methods only
450+ oauth_provider .context .oauth_metadata = OAuthMetadata (
451+ issuer = AnyHttpUrl ("https://auth.example.com" ),
452+ authorization_endpoint = AnyHttpUrl ("https://auth.example.com/authorize" ),
453+ token_endpoint = AnyHttpUrl ("https://auth.example.com/token" ),
454+ token_endpoint_auth_methods_supported = ["private_key_jwt" , "client_secret_jwt" ],
455+ )
456+ assert oauth_provider .context .client_metadata .token_endpoint_auth_method is None
457+
458+ with pytest .raises (OAuthRegistrationError ) as exc_info :
459+ await oauth_provider ._register_client ()
460+
461+ assert "No compatible authentication methods" in str (exc_info .value )
462+ assert "private_key_jwt" in str (exc_info .value )
463+
418464 @pytest .mark .anyio
419465 async def test_token_exchange_request (self , oauth_provider : OAuthClientProvider ):
420466 """Test token exchange request building."""
@@ -423,6 +469,7 @@ async def test_token_exchange_request(self, oauth_provider: OAuthClientProvider)
423469 client_id = "test_client" ,
424470 client_secret = "test_secret" ,
425471 redirect_uris = [AnyUrl ("http://localhost:3030/callback" )],
472+ token_endpoint_auth_method = "client_secret_post" ,
426473 )
427474
428475 request = await oauth_provider ._exchange_token ("test_auth_code" , "test_verifier" )
@@ -448,6 +495,7 @@ async def test_refresh_token_request(self, oauth_provider: OAuthClientProvider,
448495 client_id = "test_client" ,
449496 client_secret = "test_secret" ,
450497 redirect_uris = [AnyUrl ("http://localhost:3030/callback" )],
498+ token_endpoint_auth_method = "client_secret_post" ,
451499 )
452500
453501 request = await oauth_provider ._refresh_token ()
@@ -463,6 +511,114 @@ async def test_refresh_token_request(self, oauth_provider: OAuthClientProvider,
463511 assert "client_id=test_client" in content
464512 assert "client_secret=test_secret" in content
465513
514+ @pytest .mark .anyio
515+ async def test_basic_auth_token_exchange (self , oauth_provider : OAuthClientProvider ):
516+ """Test token exchange with client_secret_basic authentication."""
517+ # Set up OAuth metadata to support basic auth
518+ oauth_provider .context .oauth_metadata = OAuthMetadata (
519+ issuer = AnyHttpUrl ("https://auth.example.com" ),
520+ authorization_endpoint = AnyHttpUrl ("https://auth.example.com/authorize" ),
521+ token_endpoint = AnyHttpUrl ("https://auth.example.com/token" ),
522+ token_endpoint_auth_methods_supported = ["client_secret_basic" , "client_secret_post" ],
523+ )
524+
525+ client_id_raw = "test@client" # Include special character to test URL encoding
526+ client_secret_raw = "test:secret" # Include colon to test URL encoding
527+
528+ oauth_provider .context .client_info = OAuthClientInformationFull (
529+ client_id = client_id_raw ,
530+ client_secret = client_secret_raw ,
531+ redirect_uris = [AnyUrl ("http://localhost:3030/callback" )],
532+ token_endpoint_auth_method = "client_secret_basic" ,
533+ )
534+
535+ request = await oauth_provider ._exchange_token ("test_auth_code" , "test_verifier" )
536+
537+ # Should use basic auth (registered method)
538+ assert "Authorization" in request .headers
539+ assert request .headers ["Authorization" ].startswith ("Basic " )
540+
541+ # Decode and verify credentials are properly URL-encoded
542+ encoded_creds = request .headers ["Authorization" ][6 :] # Remove "Basic " prefix
543+ decoded = base64 .b64decode (encoded_creds ).decode ()
544+ client_id , client_secret = decoded .split (":" , 1 )
545+
546+ # Check URL encoding was applied
547+ assert client_id == "test%40client" # @ should be encoded as %40
548+ assert client_secret == "test%3Asecret" # : should be encoded as %3A
549+
550+ # Verify decoded values match original
551+ assert unquote (client_id ) == client_id_raw
552+ assert unquote (client_secret ) == client_secret_raw
553+
554+ # client_secret should NOT be in body for basic auth
555+ content = request .content .decode ()
556+ assert "client_secret=" not in content
557+ assert "client_id=test%40client" in content # client_id still in body
558+
559+ @pytest .mark .anyio
560+ async def test_basic_auth_refresh_token (self , oauth_provider : OAuthClientProvider , valid_tokens : OAuthToken ):
561+ """Test token refresh with client_secret_basic authentication."""
562+ oauth_provider .context .current_tokens = valid_tokens
563+
564+ # Set up OAuth metadata to only support basic auth
565+ oauth_provider .context .oauth_metadata = OAuthMetadata (
566+ issuer = AnyHttpUrl ("https://auth.example.com" ),
567+ authorization_endpoint = AnyHttpUrl ("https://auth.example.com/authorize" ),
568+ token_endpoint = AnyHttpUrl ("https://auth.example.com/token" ),
569+ token_endpoint_auth_methods_supported = ["client_secret_basic" ],
570+ )
571+
572+ client_id = "test_client"
573+ client_secret = "test_secret"
574+ oauth_provider .context .client_info = OAuthClientInformationFull (
575+ client_id = client_id ,
576+ client_secret = client_secret ,
577+ redirect_uris = [AnyUrl ("http://localhost:3030/callback" )],
578+ token_endpoint_auth_method = "client_secret_basic" ,
579+ )
580+
581+ request = await oauth_provider ._refresh_token ()
582+
583+ assert "Authorization" in request .headers
584+ assert request .headers ["Authorization" ].startswith ("Basic " )
585+
586+ encoded_creds = request .headers ["Authorization" ][6 :]
587+ decoded = base64 .b64decode (encoded_creds ).decode ()
588+ assert decoded == f"{ client_id } :{ client_secret } "
589+
590+ # client_secret should NOT be in body
591+ content = request .content .decode ()
592+ assert "client_secret=" not in content
593+
594+ @pytest .mark .anyio
595+ async def test_none_auth_method (self , oauth_provider : OAuthClientProvider ):
596+ """Test 'none' authentication method (public client)."""
597+ oauth_provider .context .oauth_metadata = OAuthMetadata (
598+ issuer = AnyHttpUrl ("https://auth.example.com" ),
599+ authorization_endpoint = AnyHttpUrl ("https://auth.example.com/authorize" ),
600+ token_endpoint = AnyHttpUrl ("https://auth.example.com/token" ),
601+ token_endpoint_auth_methods_supported = ["none" ],
602+ )
603+
604+ client_id = "public_client"
605+ oauth_provider .context .client_info = OAuthClientInformationFull (
606+ client_id = client_id ,
607+ client_secret = None , # No secret for public client
608+ redirect_uris = [AnyUrl ("http://localhost:3030/callback" )],
609+ token_endpoint_auth_method = "none" ,
610+ )
611+
612+ request = await oauth_provider ._exchange_token ("test_auth_code" , "test_verifier" )
613+
614+ # Should NOT have Authorization header
615+ assert "Authorization" not in request .headers
616+
617+ # Should NOT have client_secret in body
618+ content = request .content .decode ()
619+ assert "client_secret=" not in content
620+ assert "client_id=public_client" in content
621+
466622
467623class TestProtectedResourceMetadata :
468624 """Test protected resource handling."""
0 commit comments