From 56678d8bd00f387926e7102757e0b58fa7a1edc0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=90=BD=E6=9C=88?= <1569097443@qq.com> Date: Fri, 21 Nov 2025 21:23:14 +0800 Subject: [PATCH] Decouple Server from TransportInterface for reusability --- src/Server/Protocol.php | 95 +++++++++---------- src/Server/Transport/BaseTransport.php | 8 +- src/Server/Transport/InMemoryTransport.php | 4 +- .../Transport/ManagesTransportCallbacks.php | 2 +- src/Server/Transport/StdioTransport.php | 6 +- .../Transport/StreamableHttpTransport.php | 6 +- src/Server/Transport/TransportInterface.php | 2 +- tests/Unit/Server/ProtocolTest.php | 45 +++------ 8 files changed, 77 insertions(+), 91 deletions(-) diff --git a/src/Server/Protocol.php b/src/Server/Protocol.php index c3b42f58..0851e922 100644 --- a/src/Server/Protocol.php +++ b/src/Server/Protocol.php @@ -32,8 +32,8 @@ /** * @final * - * @phpstan-import-type McpFiber from \Mcp\Server\Transport\TransportInterface - * @phpstan-import-type FiberSuspend from \Mcp\Server\Transport\TransportInterface + * @phpstan-import-type McpFiber from TransportInterface + * @phpstan-import-type FiberSuspend from TransportInterface * * @author Christopher Hertel * @author Kyrian Obikwelu @@ -55,9 +55,6 @@ class Protocol /** Session key for active request meta */ public const SESSION_ACTIVE_REQUEST_META = '_mcp.active_request_meta'; - /** @var TransportInterface|null */ - private ?TransportInterface $transport = null; - /** * @param array>> $requestHandlers * @param array $notificationHandlers @@ -73,15 +70,7 @@ public function __construct( } /** - * @return TransportInterface - */ - public function getTransport(): TransportInterface - { - return $this->transport; - } - - /** - * Connect this protocol to a transport. + * Connect this protocol to transport. * * The protocol takes ownership of the transport and sets up all callbacks. * @@ -89,23 +78,17 @@ public function getTransport(): TransportInterface */ public function connect(TransportInterface $transport): void { - if ($this->transport) { - throw new \RuntimeException('Protocol already connected to a transport'); - } - - $this->transport = $transport; + $transport->onMessage($this->processInput(...)); - $this->transport->onMessage([$this, 'processInput']); + $transport->onSessionEnd($this->destroySession(...)); - $this->transport->onSessionEnd([$this, 'destroySession']); + $transport->setOutgoingMessagesProvider($this->consumeOutgoingMessages(...)); - $this->transport->setOutgoingMessagesProvider([$this, 'consumeOutgoingMessages']); + $transport->setPendingRequestsProvider($this->getPendingRequests(...)); - $this->transport->setPendingRequestsProvider([$this, 'getPendingRequests']); + $transport->setResponseFinder($this->checkResponse(...)); - $this->transport->setResponseFinder([$this, 'checkResponse']); - - $this->transport->setFiberYieldHandler([$this, 'handleFiberYield']); + $transport->setFiberYieldHandler($this->handleFiberYield(...)); $this->logger->info('Protocol connected to transport', ['transport' => $transport::class]); } @@ -114,8 +97,10 @@ public function connect(TransportInterface $transport): void * Handle an incoming message from the transport. * * This is called by the transport whenever ANY message arrives. + * + * @param TransportInterface $transport */ - public function processInput(string $input, ?Uuid $sessionId): void + public function processInput(TransportInterface $transport, string $input, ?Uuid $sessionId): void { $this->logger->info('Received message to process.', ['message' => $input]); @@ -126,21 +111,21 @@ public function processInput(string $input, ?Uuid $sessionId): void } catch (\JsonException $e) { $this->logger->warning('Failed to decode json message.', ['exception' => $e]); $error = Error::forParseError($e->getMessage()); - $this->sendResponse($error, null); + $this->sendResponse($transport, $error, null); return; } - $session = $this->resolveSession($sessionId, $messages); + $session = $this->resolveSession($transport, $sessionId, $messages); if (null === $session) { return; } foreach ($messages as $message) { if ($message instanceof InvalidInputMessageException) { - $this->handleInvalidMessage($message, $session); + $this->handleInvalidMessage($transport, $message, $session); } elseif ($message instanceof Request) { - $this->handleRequest($message, $session); + $this->handleRequest($transport, $message, $session); } elseif ($message instanceof Response || $message instanceof Error) { $this->handleResponse($message, $session); } elseif ($message instanceof Notification) { @@ -151,15 +136,25 @@ public function processInput(string $input, ?Uuid $sessionId): void $session->save(); } - private function handleInvalidMessage(InvalidInputMessageException $exception, SessionInterface $session): void + /** + * Handle an invalid message from the transport. + * + * @param TransportInterface $transport + */ + private function handleInvalidMessage(TransportInterface $transport, InvalidInputMessageException $exception, SessionInterface $session): void { $this->logger->warning('Failed to create message.', ['exception' => $exception]); $error = Error::forInvalidRequest($exception->getMessage()); - $this->sendResponse($error, $session); + $this->sendResponse($transport, $error, $session); } - private function handleRequest(Request $request, SessionInterface $session): void + /** + * Handle a request from the transport. + * + * @param TransportInterface $transport + */ + private function handleRequest(TransportInterface $transport, Request $request, SessionInterface $session): void { $this->logger->info('Handling request.', ['request' => $request]); @@ -192,24 +187,24 @@ private function handleRequest(Request $request, SessionInterface $session): voi } } - $this->transport->attachFiberToSession($fiber, $session->getId()); + $transport->attachFiberToSession($fiber, $session->getId()); return; } else { $finalResult = $fiber->getReturn(); - $this->sendResponse($finalResult, $session); + $this->sendResponse($transport, $finalResult, $session); } } catch (\InvalidArgumentException $e) { $this->logger->warning(\sprintf('Invalid argument: %s', $e->getMessage()), ['exception' => $e]); $error = Error::forInvalidParams($e->getMessage(), $request->getId()); - $this->sendResponse($error, $session); + $this->sendResponse($transport, $error, $session); } catch (\Throwable $e) { $this->logger->error(\sprintf('Uncaught exception: %s', $e->getMessage()), ['exception' => $e]); $error = Error::forInternalError($e->getMessage(), $request->getId()); - $this->sendResponse($error, $session); + $this->sendResponse($transport, $error, $session); } break; @@ -217,7 +212,7 @@ private function handleRequest(Request $request, SessionInterface $session): voi if (!$handlerFound) { $error = Error::forMethodNotFound(\sprintf('No handler found for method "%s".', $request::getMethod()), $request->getId()); - $this->sendResponse($error, $session); + $this->sendResponse($transport, $error, $session); } } @@ -299,10 +294,11 @@ public function sendNotification(Notification $notification, SessionInterface $s /** * Sends a response either immediately or queued for later delivery. * + * @param TransportInterface $transport * @param Response>|Error $response * @param array $context */ - private function sendResponse(Response|Error $response, ?SessionInterface $session, array $context = []): void + private function sendResponse(TransportInterface $transport, Response|Error $response, ?SessionInterface $session, array $context = []): void { if (null === $session) { $this->logger->info('Sending immediate response', [ @@ -327,7 +323,7 @@ private function sendResponse(Response|Error $response, ?SessionInterface $sessi } $context['type'] = 'response'; - $this->transport->send($encoded, $context); + $transport->send($encoded, $context); } else { $this->logger->info('Queueing server response', [ 'response_id' => $response->getId(), @@ -519,16 +515,17 @@ private function hasInitializeRequest(array $messages): bool /** * Resolves and validates the session based on the request context. * - * @param Uuid|null $sessionId The session ID from the transport - * @param array $messages The parsed messages + * @param TransportInterface $transport + * @param Uuid|null $sessionId The session ID from the transport + * @param array $messages The parsed messages */ - private function resolveSession(?Uuid $sessionId, array $messages): ?SessionInterface + private function resolveSession(TransportInterface $transport, ?Uuid $sessionId, array $messages): ?SessionInterface { if ($this->hasInitializeRequest($messages)) { // Spec: An initialize request must not be part of a batch. if (\count($messages) > 1) { $error = Error::forInvalidRequest('The "initialize" request MUST NOT be part of a batch.'); - $this->sendResponse($error, null); + $this->sendResponse($transport, $error, null); return null; } @@ -536,7 +533,7 @@ private function resolveSession(?Uuid $sessionId, array $messages): ?SessionInte // Spec: An initialize request must not have a session ID. if ($sessionId) { $error = Error::forInvalidRequest('A session ID MUST NOT be sent with an "initialize" request.'); - $this->sendResponse($error, null); + $this->sendResponse($transport, $error, null); return null; } @@ -546,21 +543,21 @@ private function resolveSession(?Uuid $sessionId, array $messages): ?SessionInte 'session_id' => $session->getId()->toRfc4122(), ]); - $this->transport->setSessionId($session->getId()); + $transport->setSessionId($session->getId()); return $session; } if (!$sessionId) { $error = Error::forInvalidRequest('A valid session id is REQUIRED for non-initialize requests.'); - $this->sendResponse($error, null, ['status_code' => 400]); + $this->sendResponse($transport, $error, null, ['status_code' => 400]); return null; } if (!$this->sessionStore->exists($sessionId)) { $error = Error::forInvalidRequest('Session not found or has expired.'); - $this->sendResponse($error, null, ['status_code' => 404]); + $this->sendResponse($transport, $error, null, ['status_code' => 404]); return null; } diff --git a/src/Server/Transport/BaseTransport.php b/src/Server/Transport/BaseTransport.php index e5e495e4..58172352 100644 --- a/src/Server/Transport/BaseTransport.php +++ b/src/Server/Transport/BaseTransport.php @@ -26,9 +26,13 @@ * @phpstan-import-type FiberSuspend from TransportInterface * @phpstan-import-type McpFiber from TransportInterface * + * @template TResult + * + * @implements TransportInterface + * * @author Kyrian Obikwelu */ -abstract class BaseTransport +abstract class BaseTransport implements TransportInterface { use ManagesTransportCallbacks; @@ -126,7 +130,7 @@ protected function handleFiberYield(mixed $yielded, ?Uuid $sessionId): void protected function handleMessage(string $payload, ?Uuid $sessionId): void { if (\is_callable($this->messageListener)) { - ($this->messageListener)($payload, $sessionId); + ($this->messageListener)($this, $payload, $sessionId); } } diff --git a/src/Server/Transport/InMemoryTransport.php b/src/Server/Transport/InMemoryTransport.php index d567d096..b7796b7f 100644 --- a/src/Server/Transport/InMemoryTransport.php +++ b/src/Server/Transport/InMemoryTransport.php @@ -15,11 +15,11 @@ use Symfony\Component\Uid\Uuid; /** - * @implements TransportInterface + * @extends BaseTransport * * @author Tobias Nyholm */ -class InMemoryTransport extends BaseTransport implements TransportInterface +class InMemoryTransport extends BaseTransport { use ManagesTransportCallbacks; diff --git a/src/Server/Transport/ManagesTransportCallbacks.php b/src/Server/Transport/ManagesTransportCallbacks.php index a0d1aa6b..072d3f0e 100644 --- a/src/Server/Transport/ManagesTransportCallbacks.php +++ b/src/Server/Transport/ManagesTransportCallbacks.php @@ -26,7 +26,7 @@ * */ trait ManagesTransportCallbacks { - /** @var callable(string, ?Uuid): void */ + /** @var callable(TransportInterface, string, ?Uuid): void */ protected $messageListener; /** @var callable(Uuid): void */ diff --git a/src/Server/Transport/StdioTransport.php b/src/Server/Transport/StdioTransport.php index cb994b94..e9b3f7ee 100644 --- a/src/Server/Transport/StdioTransport.php +++ b/src/Server/Transport/StdioTransport.php @@ -15,11 +15,11 @@ use Psr\Log\LoggerInterface; /** - * @implements TransportInterface + * @extends BaseTransport * * @author Kyrian Obikwelu - * */ -class StdioTransport extends BaseTransport implements TransportInterface + */ +class StdioTransport extends BaseTransport { /** * @param resource $input diff --git a/src/Server/Transport/StreamableHttpTransport.php b/src/Server/Transport/StreamableHttpTransport.php index dcac42dd..7be530a5 100644 --- a/src/Server/Transport/StreamableHttpTransport.php +++ b/src/Server/Transport/StreamableHttpTransport.php @@ -21,11 +21,11 @@ use Symfony\Component\Uid\Uuid; /** - * @implements TransportInterface + * @extends BaseTransport * * @author Kyrian Obikwelu - * */ -class StreamableHttpTransport extends BaseTransport implements TransportInterface + */ +class StreamableHttpTransport extends BaseTransport { private ResponseFactoryInterface $responseFactory; private StreamFactoryInterface $streamFactory; diff --git a/src/Server/Transport/TransportInterface.php b/src/Server/Transport/TransportInterface.php index 5a874a76..58d09789 100644 --- a/src/Server/Transport/TransportInterface.php +++ b/src/Server/Transport/TransportInterface.php @@ -70,7 +70,7 @@ public function close(): void; * * The transport calls this whenever ANY message arrives, regardless of source. * - * @param callable(string $message, ?Uuid $sessionId): void $listener + * @param callable(TransportInterface $transport, string $message, ?Uuid $sessionId): void $listener */ public function onMessage(callable $listener): void; diff --git a/tests/Unit/Server/ProtocolTest.php b/tests/Unit/Server/ProtocolTest.php index fa949c38..224564ed 100644 --- a/tests/Unit/Server/ProtocolTest.php +++ b/tests/Unit/Server/ProtocolTest.php @@ -68,10 +68,9 @@ public function testNotificationHandledByMultipleHandlers(): void sessionStore: $this->sessionStore, ); - $protocol->connect($this->transport); - $sessionId = Uuid::v4(); $protocol->processInput( + $this->transport, '{"jsonrpc": "2.0", "method": "notifications/initialized"}', $sessionId ); @@ -127,10 +126,9 @@ public function testRequestHandledByFirstMatchingHandler(): void sessionStore: $this->sessionStore, ); - $protocol->connect($this->transport); - $sessionId = Uuid::v4(); $protocol->processInput( + $this->transport, '{"jsonrpc": "2.0", "id": 1, "method": "tools/list"}', $sessionId ); @@ -166,10 +164,9 @@ public function testInitializeRequestWithSessionIdReturnsError(): void sessionStore: $this->sessionStore, ); - $protocol->connect($this->transport); - $sessionId = Uuid::v4(); $protocol->processInput( + $this->transport, '{"jsonrpc": "2.0", "id": 1, "method": "initialize", "params": {"protocolVersion": "2024-11-05", "capabilities": {}, "clientInfo": {"name": "test", "version": "1.0"}}}', $sessionId ); @@ -198,9 +195,8 @@ public function testInitializeRequestInBatchReturnsError(): void sessionStore: $this->sessionStore, ); - $protocol->connect($this->transport); - $protocol->processInput( + $this->transport, '[{"jsonrpc": "2.0", "id": 1, "method": "initialize", "params": {"protocolVersion": "2024-11-05", "capabilities": {}, "clientInfo": {"name": "test", "version": "1.0"}}}, {"jsonrpc": "2.0", "method": "ping", "id": 2}]', null ); @@ -231,9 +227,8 @@ public function testNonInitializeRequestWithoutSessionIdReturnsError(): void sessionStore: $this->sessionStore, ); - $protocol->connect($this->transport); - $protocol->processInput( + $this->transport, '{"jsonrpc": "2.0", "id": 1, "method": "tools/list"}', null ); @@ -266,10 +261,9 @@ public function testNonExistentSessionIdReturnsError(): void sessionStore: $this->sessionStore, ); - $protocol->connect($this->transport); - $sessionId = Uuid::v4(); $protocol->processInput( + $this->transport, '{"jsonrpc": "2.0", "id": 1, "method": "tools/list"}', $sessionId ); @@ -298,9 +292,8 @@ public function testInvalidJsonReturnsParseError(): void sessionStore: $this->sessionStore, ); - $protocol->connect($this->transport); - $protocol->processInput( + $this->transport, 'invalid json', null ); @@ -343,10 +336,9 @@ public function testInvalidMessageStructureReturnsError(): void sessionStore: $this->sessionStore, ); - $protocol->connect($this->transport); - $sessionId = Uuid::v4(); $protocol->processInput( + $this->transport, '{"jsonrpc": "2.0", "params": {}}', $sessionId ); @@ -397,10 +389,9 @@ public function testRequestWithoutHandlerReturnsMethodNotFoundError(): void sessionStore: $this->sessionStore, ); - $protocol->connect($this->transport); - $sessionId = Uuid::v4(); $protocol->processInput( + $this->transport, '{"jsonrpc": "2.0", "id": 1, "method": "ping"}', $sessionId ); @@ -456,10 +447,9 @@ public function testHandlerInvalidArgumentReturnsInvalidParamsError(): void sessionStore: $this->sessionStore, ); - $protocol->connect($this->transport); - $sessionId = Uuid::v4(); $protocol->processInput( + $this->transport, '{"jsonrpc": "2.0", "id": 1, "method": "tools/call", "params": {"name": "test"}}', $sessionId ); @@ -515,10 +505,9 @@ public function testHandlerUnexpectedExceptionReturnsInternalError(): void sessionStore: $this->sessionStore, ); - $protocol->connect($this->transport); - $sessionId = Uuid::v4(); $protocol->processInput( + $this->transport, '{"jsonrpc": "2.0", "id": 1, "method": "tools/call", "params": {"name": "test"}}', $sessionId ); @@ -553,10 +542,9 @@ public function testNotificationHandlerExceptionsAreCaught(): void sessionStore: $this->sessionStore, ); - $protocol->connect($this->transport); - $sessionId = Uuid::v4(); $protocol->processInput( + $this->transport, '{"jsonrpc": "2.0", "method": "notifications/initialized"}', $sessionId ); @@ -607,9 +595,8 @@ public function testSuccessfulRequestReturnsResponseWithSessionId(): void sessionStore: $this->sessionStore, ); - $protocol->connect($this->transport); - $protocol->processInput( + $this->transport, '{"jsonrpc": "2.0", "id": 1, "method": "tools/list"}', $sessionId ); @@ -670,10 +657,9 @@ public function testBatchRequestsAreProcessed(): void sessionStore: $this->sessionStore, ); - $protocol->connect($this->transport); - $sessionId = Uuid::v4(); $protocol->processInput( + $this->transport, '[{"jsonrpc": "2.0", "method": "tools/list", "id": 1}, {"jsonrpc": "2.0", "method": "prompts/list", "id": 2}]', $sessionId ); @@ -706,10 +692,9 @@ public function testSessionIsSavedAfterProcessing(): void sessionStore: $this->sessionStore, ); - $protocol->connect($this->transport); - $sessionId = Uuid::v4(); $protocol->processInput( + $this->transport, '{"jsonrpc": "2.0", "method": "notifications/initialized"}', $sessionId );