diff --git a/src/mcp/client/auth.py b/src/mcp/client/auth.py index 91f8576d7..11e6d265d 100644 --- a/src/mcp/client/auth.py +++ b/src/mcp/client/auth.py @@ -424,7 +424,10 @@ async def _exchange_token(self, auth_code: str, code_verifier: str) -> httpx.Req token_data["client_secret"] = self.context.client_info.client_secret return httpx.Request( - "POST", token_url, data=token_data, headers={"Content-Type": "application/x-www-form-urlencoded"} + "POST", + token_url, + data=token_data, + headers={"Content-Type": "application/x-www-form-urlencoded", "Accept": "application/json"}, ) async def _handle_token_response(self, response: httpx.Response) -> None: @@ -478,7 +481,10 @@ async def _refresh_token(self) -> httpx.Request: refresh_data["client_secret"] = self.context.client_info.client_secret return httpx.Request( - "POST", token_url, data=refresh_data, headers={"Content-Type": "application/x-www-form-urlencoded"} + "POST", + token_url, + data=refresh_data, + headers={"Content-Type": "application/x-www-form-urlencoded", "Accept": "application/json"}, ) async def _handle_refresh_response(self, response: httpx.Response) -> bool: diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index fb1a93e39..1890ed238 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -527,6 +527,7 @@ async def test_token_exchange_request(self, oauth_provider: OAuthClientProvider) assert request.method == "POST" assert str(request.url) == "https://api.example.com/token" assert request.headers["Content-Type"] == "application/x-www-form-urlencoded" + assert request.headers["Accept"] == "application/json" # Check form data content = request.content.decode() @@ -552,6 +553,7 @@ async def test_refresh_token_request(self, oauth_provider: OAuthClientProvider, assert request.method == "POST" assert str(request.url) == "https://api.example.com/token" assert request.headers["Content-Type"] == "application/x-www-form-urlencoded" + assert request.headers["Accept"] == "application/json" # Check form data content = request.content.decode()