Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 46 additions & 49 deletions src/Server/Protocol.php
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
/**
* @final
*
* @phpstan-import-type McpFiber from \Mcp\Server\Transport\TransportInterface
* @phpstan-import-type FiberSuspend from \Mcp\Server\Transport\TransportInterface
* @phpstan-import-type McpFiber from TransportInterface
* @phpstan-import-type FiberSuspend from TransportInterface
*
* @author Christopher Hertel <[email protected]>
* @author Kyrian Obikwelu <[email protected]>
Expand All @@ -55,9 +55,6 @@ class Protocol
/** Session key for active request meta */
public const SESSION_ACTIVE_REQUEST_META = '_mcp.active_request_meta';

/** @var TransportInterface<mixed>|null */
private ?TransportInterface $transport = null;

/**
* @param array<int, RequestHandlerInterface<ResultInterface|array<string, mixed>>> $requestHandlers
* @param array<int, NotificationHandlerInterface> $notificationHandlers
Expand All @@ -73,39 +70,25 @@ public function __construct(
}

/**
* @return TransportInterface<mixed>
*/
public function getTransport(): TransportInterface
{
return $this->transport;
}

/**
* Connect this protocol to a transport.
* Connect this protocol to transport.
*
* The protocol takes ownership of the transport and sets up all callbacks.
*
* @param TransportInterface<mixed> $transport
*/
public function connect(TransportInterface $transport): void
{
if ($this->transport) {
throw new \RuntimeException('Protocol already connected to a transport');
}

$this->transport = $transport;
$transport->onMessage($this->processInput(...));

$this->transport->onMessage([$this, 'processInput']);
$transport->onSessionEnd($this->destroySession(...));

$this->transport->onSessionEnd([$this, 'destroySession']);
$transport->setOutgoingMessagesProvider($this->consumeOutgoingMessages(...));

$this->transport->setOutgoingMessagesProvider([$this, 'consumeOutgoingMessages']);
$transport->setPendingRequestsProvider($this->getPendingRequests(...));

$this->transport->setPendingRequestsProvider([$this, 'getPendingRequests']);
$transport->setResponseFinder($this->checkResponse(...));

$this->transport->setResponseFinder([$this, 'checkResponse']);

$this->transport->setFiberYieldHandler([$this, 'handleFiberYield']);
$transport->setFiberYieldHandler($this->handleFiberYield(...));

$this->logger->info('Protocol connected to transport', ['transport' => $transport::class]);
}
Expand All @@ -114,8 +97,10 @@ public function connect(TransportInterface $transport): void
* Handle an incoming message from the transport.
*
* This is called by the transport whenever ANY message arrives.
*
* @param TransportInterface<mixed> $transport
*/
public function processInput(string $input, ?Uuid $sessionId): void
public function processInput(TransportInterface $transport, string $input, ?Uuid $sessionId): void
{
$this->logger->info('Received message to process.', ['message' => $input]);

Expand All @@ -126,21 +111,21 @@ public function processInput(string $input, ?Uuid $sessionId): void
} catch (\JsonException $e) {
$this->logger->warning('Failed to decode json message.', ['exception' => $e]);
$error = Error::forParseError($e->getMessage());
$this->sendResponse($error, null);
$this->sendResponse($transport, $error, null);

return;
}

$session = $this->resolveSession($sessionId, $messages);
$session = $this->resolveSession($transport, $sessionId, $messages);
if (null === $session) {
return;
}

foreach ($messages as $message) {
if ($message instanceof InvalidInputMessageException) {
$this->handleInvalidMessage($message, $session);
$this->handleInvalidMessage($transport, $message, $session);
} elseif ($message instanceof Request) {
$this->handleRequest($message, $session);
$this->handleRequest($transport, $message, $session);
} elseif ($message instanceof Response || $message instanceof Error) {
$this->handleResponse($message, $session);
} elseif ($message instanceof Notification) {
Expand All @@ -151,15 +136,25 @@ public function processInput(string $input, ?Uuid $sessionId): void
$session->save();
}

private function handleInvalidMessage(InvalidInputMessageException $exception, SessionInterface $session): void
/**
* Handle an invalid message from the transport.
*
* @param TransportInterface<mixed> $transport
*/
private function handleInvalidMessage(TransportInterface $transport, InvalidInputMessageException $exception, SessionInterface $session): void
{
$this->logger->warning('Failed to create message.', ['exception' => $exception]);

$error = Error::forInvalidRequest($exception->getMessage());
$this->sendResponse($error, $session);
$this->sendResponse($transport, $error, $session);
}

private function handleRequest(Request $request, SessionInterface $session): void
/**
* Handle a request from the transport.
*
* @param TransportInterface<mixed> $transport
*/
private function handleRequest(TransportInterface $transport, Request $request, SessionInterface $session): void
{
$this->logger->info('Handling request.', ['request' => $request]);

Expand Down Expand Up @@ -192,32 +187,32 @@ private function handleRequest(Request $request, SessionInterface $session): voi
}
}

$this->transport->attachFiberToSession($fiber, $session->getId());
$transport->attachFiberToSession($fiber, $session->getId());

return;
} else {
$finalResult = $fiber->getReturn();

$this->sendResponse($finalResult, $session);
$this->sendResponse($transport, $finalResult, $session);
}
} catch (\InvalidArgumentException $e) {
$this->logger->warning(\sprintf('Invalid argument: %s', $e->getMessage()), ['exception' => $e]);

$error = Error::forInvalidParams($e->getMessage(), $request->getId());
$this->sendResponse($error, $session);
$this->sendResponse($transport, $error, $session);
} catch (\Throwable $e) {
$this->logger->error(\sprintf('Uncaught exception: %s', $e->getMessage()), ['exception' => $e]);

$error = Error::forInternalError($e->getMessage(), $request->getId());
$this->sendResponse($error, $session);
$this->sendResponse($transport, $error, $session);
}

break;
}

if (!$handlerFound) {
$error = Error::forMethodNotFound(\sprintf('No handler found for method "%s".', $request::getMethod()), $request->getId());
$this->sendResponse($error, $session);
$this->sendResponse($transport, $error, $session);
}
}

Expand Down Expand Up @@ -299,10 +294,11 @@ public function sendNotification(Notification $notification, SessionInterface $s
/**
* Sends a response either immediately or queued for later delivery.
*
* @param TransportInterface<mixed> $transport
* @param Response<ResultInterface|array<string, mixed>>|Error $response
* @param array<string, mixed> $context
*/
private function sendResponse(Response|Error $response, ?SessionInterface $session, array $context = []): void
private function sendResponse(TransportInterface $transport, Response|Error $response, ?SessionInterface $session, array $context = []): void
{
if (null === $session) {
$this->logger->info('Sending immediate response', [
Expand All @@ -327,7 +323,7 @@ private function sendResponse(Response|Error $response, ?SessionInterface $sessi
}

$context['type'] = 'response';
$this->transport->send($encoded, $context);
$transport->send($encoded, $context);
} else {
$this->logger->info('Queueing server response', [
'response_id' => $response->getId(),
Expand Down Expand Up @@ -519,24 +515,25 @@ private function hasInitializeRequest(array $messages): bool
/**
* Resolves and validates the session based on the request context.
*
* @param Uuid|null $sessionId The session ID from the transport
* @param array<int,mixed> $messages The parsed messages
* @param TransportInterface<mixed> $transport
* @param Uuid|null $sessionId The session ID from the transport
* @param array<int,mixed> $messages The parsed messages
*/
private function resolveSession(?Uuid $sessionId, array $messages): ?SessionInterface
private function resolveSession(TransportInterface $transport, ?Uuid $sessionId, array $messages): ?SessionInterface
{
if ($this->hasInitializeRequest($messages)) {
// Spec: An initialize request must not be part of a batch.
if (\count($messages) > 1) {
$error = Error::forInvalidRequest('The "initialize" request MUST NOT be part of a batch.');
$this->sendResponse($error, null);
$this->sendResponse($transport, $error, null);

return null;
}

// Spec: An initialize request must not have a session ID.
if ($sessionId) {
$error = Error::forInvalidRequest('A session ID MUST NOT be sent with an "initialize" request.');
$this->sendResponse($error, null);
$this->sendResponse($transport, $error, null);

return null;
}
Expand All @@ -546,21 +543,21 @@ private function resolveSession(?Uuid $sessionId, array $messages): ?SessionInte
'session_id' => $session->getId()->toRfc4122(),
]);

$this->transport->setSessionId($session->getId());
$transport->setSessionId($session->getId());

return $session;
}

if (!$sessionId) {
$error = Error::forInvalidRequest('A valid session id is REQUIRED for non-initialize requests.');
$this->sendResponse($error, null, ['status_code' => 400]);
$this->sendResponse($transport, $error, null, ['status_code' => 400]);

return null;
}

if (!$this->sessionStore->exists($sessionId)) {
$error = Error::forInvalidRequest('Session not found or has expired.');
$this->sendResponse($error, null, ['status_code' => 404]);
$this->sendResponse($transport, $error, null, ['status_code' => 404]);

return null;
}
Expand Down
8 changes: 6 additions & 2 deletions src/Server/Transport/BaseTransport.php
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,13 @@
* @phpstan-import-type FiberSuspend from TransportInterface
* @phpstan-import-type McpFiber from TransportInterface
*
* @template TResult
*
* @implements TransportInterface<TResult>
*
* @author Kyrian Obikwelu <[email protected]>
*/
abstract class BaseTransport
abstract class BaseTransport implements TransportInterface
{
use ManagesTransportCallbacks;

Expand Down Expand Up @@ -126,7 +130,7 @@ protected function handleFiberYield(mixed $yielded, ?Uuid $sessionId): void
protected function handleMessage(string $payload, ?Uuid $sessionId): void
{
if (\is_callable($this->messageListener)) {
($this->messageListener)($payload, $sessionId);
($this->messageListener)($this, $payload, $sessionId);
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/Server/Transport/InMemoryTransport.php
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
use Symfony\Component\Uid\Uuid;

/**
* @implements TransportInterface<null>
* @extends BaseTransport<null>
*
* @author Tobias Nyholm <[email protected]>
*/
class InMemoryTransport extends BaseTransport implements TransportInterface
class InMemoryTransport extends BaseTransport
{
use ManagesTransportCallbacks;

Expand Down
2 changes: 1 addition & 1 deletion src/Server/Transport/ManagesTransportCallbacks.php
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
* */
trait ManagesTransportCallbacks
{
/** @var callable(string, ?Uuid): void */
/** @var callable(TransportInterface<mixed>, string, ?Uuid): void */
protected $messageListener;

/** @var callable(Uuid): void */
Expand Down
6 changes: 3 additions & 3 deletions src/Server/Transport/StdioTransport.php
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
use Psr\Log\LoggerInterface;

/**
* @implements TransportInterface<int>
* @extends BaseTransport<int>
*
* @author Kyrian Obikwelu <[email protected]>
* */
class StdioTransport extends BaseTransport implements TransportInterface
*/
class StdioTransport extends BaseTransport
{
/**
* @param resource $input
Expand Down
6 changes: 3 additions & 3 deletions src/Server/Transport/StreamableHttpTransport.php
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
use Symfony\Component\Uid\Uuid;

/**
* @implements TransportInterface<ResponseInterface>
* @extends BaseTransport<ResponseInterface>
*
* @author Kyrian Obikwelu <[email protected]>
* */
class StreamableHttpTransport extends BaseTransport implements TransportInterface
*/
class StreamableHttpTransport extends BaseTransport
{
private ResponseFactoryInterface $responseFactory;
private StreamFactoryInterface $streamFactory;
Expand Down
2 changes: 1 addition & 1 deletion src/Server/Transport/TransportInterface.php
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public function close(): void;
*
* The transport calls this whenever ANY message arrives, regardless of source.
*
* @param callable(string $message, ?Uuid $sessionId): void $listener
* @param callable(TransportInterface<TResult> $transport, string $message, ?Uuid $sessionId): void $listener
*/
public function onMessage(callable $listener): void;

Expand Down
Loading