diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 9e9389ac1..b4220e7e0 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -170,6 +170,9 @@ async def initialize(self) -> types.InitializeResult: if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS: raise RuntimeError(f"Unsupported protocol version from the server: {result.protocolVersion}") + if self._client_info is DEFAULT_CLIENT_INFO: + self._client_info = result.serverInfo + await self.send_notification(types.ClientNotification(types.InitializedNotification())) return result diff --git a/tests/client/test_session.py b/tests/client/test_session.py index f2135e455..f479853fd 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -175,6 +175,8 @@ async def mock_server(): # Assert that the custom client info was sent assert received_client_info == custom_client_info + # Assert that the client info was not replaced with server info after initialization + assert session._client_info == custom_client_info @pytest.mark.anyio @@ -183,6 +185,7 @@ async def test_client_session_default_client_info(): server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) received_client_info = None + received_server_info = None async def mock_server(): nonlocal received_client_info @@ -231,10 +234,13 @@ async def mock_server(): server_to_client_receive, ): tg.start_soon(mock_server) - await session.initialize() + result = await session.initialize() + received_server_info = result.serverInfo # Assert that the default client info was sent assert received_client_info == DEFAULT_CLIENT_INFO + # Assert that the default client info was replaced with server info after initialization + assert session._client_info == received_server_info @pytest.mark.anyio