From d0c83e946c5cc37fa4be7eaa5595a50e3a2d4882 Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Thu, 14 Nov 2024 13:26:24 +0000 Subject: [PATCH 1/9] Fix notification params schemas not permitting `_meta` --- src/types.ts | 66 ++++++++++++++++++++++++---------------------------- 1 file changed, 31 insertions(+), 35 deletions(-) diff --git a/src/types.ts b/src/types.ts index 0d55a75bb..3cd1cb06a 100644 --- a/src/types.ts +++ b/src/types.ts @@ -39,18 +39,18 @@ export const RequestSchema = z.object({ params: z.optional(BaseRequestParamsSchema), }); +const BaseNotificationParamsSchema = z + .object({ + /** + * This parameter name is reserved by MCP to allow clients and servers to attach additional metadata to their notifications. + */ + _meta: z.optional(z.object({}).passthrough()), + }) + .passthrough(); + export const NotificationSchema = z.object({ method: z.string(), - params: z.optional( - z - .object({ - /** - * This parameter name is reserved by MCP to allow clients and servers to attach additional metadata to their notifications. - */ - _meta: z.optional(z.object({}).passthrough()), - }) - .passthrough(), - ), + params: z.optional(BaseNotificationParamsSchema), }); export const ResultSchema = z @@ -312,7 +312,7 @@ export const ProgressSchema = z */ export const ProgressNotificationSchema = NotificationSchema.extend({ method: z.literal("notifications/progress"), - params: ProgressSchema.extend({ + params: BaseNotificationParamsSchema.merge(ProgressSchema).extend({ /** * The progress token which was given in the initial request, used to associate this notification with the request that is proceeding. */ @@ -522,14 +522,12 @@ export const UnsubscribeRequestSchema = RequestSchema.extend({ */ export const ResourceUpdatedNotificationSchema = NotificationSchema.extend({ method: z.literal("notifications/resources/updated"), - params: z - .object({ - /** - * The URI of the resource that has been updated. This might be a sub-resource of the one that the client actually subscribed to. - */ - uri: z.string(), - }) - .passthrough(), + params: BaseNotificationParamsSchema.extend({ + /** + * The URI of the resource that has been updated. This might be a sub-resource of the one that the client actually subscribed to. + */ + uri: z.string(), + }), }); /* Prompts */ @@ -786,22 +784,20 @@ export const SetLevelRequestSchema = RequestSchema.extend({ */ export const LoggingMessageNotificationSchema = NotificationSchema.extend({ method: z.literal("notifications/message"), - params: z - .object({ - /** - * The severity of this log message. - */ - level: LoggingLevelSchema, - /** - * An optional name of the logger issuing this message. - */ - logger: z.optional(z.string()), - /** - * The data to be logged, such as a string message or an object. Any JSON serializable type is allowed here. - */ - data: z.unknown(), - }) - .passthrough(), + params: BaseNotificationParamsSchema.extend({ + /** + * The severity of this log message. + */ + level: LoggingLevelSchema, + /** + * An optional name of the logger issuing this message. + */ + logger: z.optional(z.string()), + /** + * The data to be logged, such as a string message or an object. Any JSON serializable type is allowed here. + */ + data: z.unknown(), + }), }); /* Sampling */ From 347e8905a9447ee666fcd52eb961a412e80afaec Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Thu, 14 Nov 2024 13:27:51 +0000 Subject: [PATCH 2/9] Add types for cancellation notifications --- src/types.ts | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/src/types.ts b/src/types.ts index 3cd1cb06a..5b9896066 100644 --- a/src/types.ts +++ b/src/types.ts @@ -151,6 +151,33 @@ export const JSONRPCMessageSchema = z.union([ */ export const EmptyResultSchema = ResultSchema.strict(); +/* Cancellation */ +/** + * This notification can be sent by either side to indicate that it is cancelling a previously-issued request. + * + * The request SHOULD still be in-flight, but due to communication latency, it is always possible that this notification MAY arrive after the request has already finished. + * + * This notification indicates that the result will be unused, so any associated processing SHOULD cease. + * + * A client MUST NOT attempt to cancel its `initialize` request. + */ +export const CancelledNotificationSchema = NotificationSchema.extend({ + method: z.literal("cancelled"), + params: BaseNotificationParamsSchema.extend({ + /** + * The ID of the request to cancel. + * + * This MUST correspond to the ID of a request previously issued in the same direction. + */ + requestId: RequestIdSchema, + + /** + * An optional string describing the reason for the cancellation. This MAY be logged or presented to the user. + */ + reason: z.string().optional(), + }), +}); + /* Initialization */ /** * Describes the name and version of an MCP implementation. @@ -1030,6 +1057,7 @@ export const ClientRequestSchema = z.union([ ]); export const ClientNotificationSchema = z.union([ + CancelledNotificationSchema, ProgressNotificationSchema, InitializedNotificationSchema, RootsListChangedNotificationSchema, @@ -1049,6 +1077,7 @@ export const ServerRequestSchema = z.union([ ]); export const ServerNotificationSchema = z.union([ + CancelledNotificationSchema, ProgressNotificationSchema, LoggingMessageNotificationSchema, ResourceUpdatedNotificationSchema, @@ -1096,6 +1125,9 @@ export type JSONRPCMessage = z.infer; /* Empty result */ export type EmptyResult = z.infer; +/* Cancellation */ +export type CancelledNotification = z.infer; + /* Initialization */ export type Implementation = z.infer; export type ClientCapabilities = z.infer; From 54d146bbdc2321b1b2e6fcfe8502851fc25bc997 Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Thu, 14 Nov 2024 13:42:18 +0000 Subject: [PATCH 3/9] Support passing an AbortSignal to request() --- src/client/index.ts | 37 ++++++++++++++++----------------- src/server/index.ts | 10 ++++----- src/shared/protocol.ts | 46 ++++++++++++++++++++++++++++++++++++++---- 3 files changed, 64 insertions(+), 29 deletions(-) diff --git a/src/client/index.ts b/src/client/index.ts index e0df322b4..97e6ee072 100644 --- a/src/client/index.ts +++ b/src/client/index.ts @@ -1,7 +1,7 @@ import { - ProgressCallback, Protocol, ProtocolOptions, + RequestOptions, } from "../shared/protocol.js"; import { Transport } from "../shared/transport.js"; import { @@ -278,14 +278,11 @@ export class Client< return this.request({ method: "ping" }, EmptyResultSchema); } - async complete( - params: CompleteRequest["params"], - onprogress?: ProgressCallback, - ) { + async complete(params: CompleteRequest["params"], options?: RequestOptions) { return this.request( { method: "completion/complete", params }, CompleteResultSchema, - onprogress, + options, ); } @@ -298,56 +295,56 @@ export class Client< async getPrompt( params: GetPromptRequest["params"], - onprogress?: ProgressCallback, + options?: RequestOptions, ) { return this.request( { method: "prompts/get", params }, GetPromptResultSchema, - onprogress, + options, ); } async listPrompts( params?: ListPromptsRequest["params"], - onprogress?: ProgressCallback, + options?: RequestOptions, ) { return this.request( { method: "prompts/list", params }, ListPromptsResultSchema, - onprogress, + options, ); } async listResources( params?: ListResourcesRequest["params"], - onprogress?: ProgressCallback, + options?: RequestOptions, ) { return this.request( { method: "resources/list", params }, ListResourcesResultSchema, - onprogress, + options, ); } async listResourceTemplates( params?: ListResourceTemplatesRequest["params"], - onprogress?: ProgressCallback, + options?: RequestOptions, ) { return this.request( { method: "resources/templates/list", params }, ListResourceTemplatesResultSchema, - onprogress, + options, ); } async readResource( params: ReadResourceRequest["params"], - onprogress?: ProgressCallback, + options?: RequestOptions, ) { return this.request( { method: "resources/read", params }, ReadResourceResultSchema, - onprogress, + options, ); } @@ -370,23 +367,23 @@ export class Client< resultSchema: | typeof CallToolResultSchema | typeof CompatibilityCallToolResultSchema = CallToolResultSchema, - onprogress?: ProgressCallback, + options?: RequestOptions, ) { return this.request( { method: "tools/call", params }, resultSchema, - onprogress, + options, ); } async listTools( params?: ListToolsRequest["params"], - onprogress?: ProgressCallback, + options?: RequestOptions, ) { return this.request( { method: "tools/list", params }, ListToolsResultSchema, - onprogress, + options, ); } diff --git a/src/server/index.ts b/src/server/index.ts index ecb525b5c..29a3207af 100644 --- a/src/server/index.ts +++ b/src/server/index.ts @@ -1,7 +1,7 @@ import { - ProgressCallback, Protocol, ProtocolOptions, + RequestOptions, } from "../shared/protocol.js"; import { ClientCapabilities, @@ -257,23 +257,23 @@ export class Server< async createMessage( params: CreateMessageRequest["params"], - onprogress?: ProgressCallback, + options?: RequestOptions, ) { return this.request( { method: "sampling/createMessage", params }, CreateMessageResultSchema, - onprogress, + options, ); } async listRoots( params?: ListRootsRequest["params"], - onprogress?: ProgressCallback, + options?: RequestOptions, ) { return this.request( { method: "roots/list", params }, ListRootsResultSchema, - onprogress, + options, ); } diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 85610a9d6..e83fe44e5 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -35,6 +35,21 @@ export type ProtocolOptions = { enforceStrictCapabilities?: boolean; }; +/** + * Options that can be given per request. + */ +export type RequestOptions = { + /** + * If set, requests progress notifications from the remote end (if supported). When progress notifications are received, this callback will be invoked. + */ + onprogress?: ProgressCallback; + + /** + * Can be used to cancel an in-flight request. This will cause an AbortError to be raised from request(). + */ + signal?: AbortSignal; +}; + /** * Implements MCP protocol framing on top of a pluggable transport, including * features like request/response linking, notifications, and progress. @@ -285,14 +300,14 @@ export abstract class Protocol< protected abstract assertRequestHandlerCapability(method: string): void; /** - * Sends a request and wait for a response, with optional progress notifications in the meantime (if supported by the server). + * Sends a request and wait for a response. * * Do not use this method to emit notifications! Use notification() instead. */ request>( request: SendRequestT, resultSchema: T, - onprogress?: ProgressCallback, + options?: RequestOptions, ): Promise> { return new Promise((resolve, reject) => { if (!this._transport) { @@ -304,6 +319,8 @@ export abstract class Protocol< this.assertCapabilityForMethod(request.method); } + options?.signal?.throwIfAborted(); + const messageId = this._requestMessageId++; const jsonrpcRequest: JSONRPCRequest = { ...request, @@ -311,8 +328,8 @@ export abstract class Protocol< id: messageId, }; - if (onprogress) { - this._progressHandlers.set(messageId, onprogress); + if (options?.onprogress) { + this._progressHandlers.set(messageId, options.onprogress); jsonrpcRequest.params = { ...request.params, _meta: { progressToken: messageId }, @@ -320,6 +337,10 @@ export abstract class Protocol< } this._responseHandlers.set(messageId, (response) => { + if (options?.signal?.aborted) { + return; + } + if (response instanceof Error) { return reject(response); } @@ -332,6 +353,23 @@ export abstract class Protocol< } }); + options?.signal?.addEventListener("abort", () => { + const reason = options?.signal?.reason; + this._responseHandlers.delete(messageId); + this._progressHandlers.delete(messageId); + + this._transport?.send({ + jsonrpc: "2.0", + method: "cancelled", + params: { + requestId: messageId, + reason: String(reason), + }, + }); + + reject(reason); + }); + this._transport.send(jsonrpcRequest).catch(reject); }); } From 90ba8950d22e5df64a64e5ac64166384eeb1b8ad Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Thu, 14 Nov 2024 13:57:04 +0000 Subject: [PATCH 4/9] Pass an AbortSignal to request handlers --- src/shared/protocol.ts | 57 ++++++++++++++++++++++++++++++++++++------ 1 file changed, 49 insertions(+), 8 deletions(-) diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index e83fe44e5..752d2cc09 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -1,5 +1,6 @@ import { ZodLiteral, ZodObject, ZodType, z } from "zod"; import { + CancelledNotificationSchema, ErrorCode, JSONRPCError, JSONRPCNotification, @@ -12,6 +13,7 @@ import { ProgressNotification, ProgressNotificationSchema, Request, + RequestId, Result, } from "../types.js"; import { Transport } from "./transport.js"; @@ -50,6 +52,16 @@ export type RequestOptions = { signal?: AbortSignal; }; +/** + * Extra data given to request handlers. + */ +export type RequestHandlerExtra = { + /** + * An abort signal used to communicate if the request was cancelled from the sender's side. + */ + signal: AbortSignal; +}; + /** * Implements MCP protocol framing on top of a pluggable transport, including * features like request/response linking, notifications, and progress. @@ -61,10 +73,15 @@ export abstract class Protocol< > { private _transport?: Transport; private _requestMessageId = 0; - protected _requestHandlers: Map< + private _requestHandlers: Map< string, - (request: JSONRPCRequest) => Promise + ( + request: JSONRPCRequest, + extra: RequestHandlerExtra, + ) => Promise > = new Map(); + private _requestHandlerAbortControllers: Map = + new Map(); private _notificationHandlers: Map< string, (notification: JSONRPCNotification) => Promise @@ -100,6 +117,13 @@ export abstract class Protocol< fallbackNotificationHandler?: (notification: Notification) => Promise; constructor(private _options?: ProtocolOptions) { + this.setNotificationHandler(CancelledNotificationSchema, (notification) => { + const controller = this._requestHandlerAbortControllers.get( + notification.params.requestId, + ); + controller?.abort(notification.params.reason); + }); + this.setNotificationHandler(ProgressNotificationSchema, (notification) => { this._onprogress(notification as unknown as ProgressNotification); }); @@ -195,16 +219,27 @@ export abstract class Protocol< return; } - handler(request) + const abortController = new AbortController(); + this._requestHandlerAbortControllers.set(request.id, abortController); + + handler(request, { signal: abortController.signal }) .then( (result) => { - this._transport?.send({ + if (abortController.signal.aborted) { + return; + } + + return this._transport?.send({ result, jsonrpc: "2.0", id: request.id, }); }, (error) => { + if (abortController.signal.aborted) { + return; + } + return this._transport?.send({ jsonrpc: "2.0", id: request.id, @@ -219,7 +254,10 @@ export abstract class Protocol< ) .catch((error) => this._onerror(new Error(`Failed to send response: ${error}`)), - ); + ) + .finally(() => { + this._requestHandlerAbortControllers.delete(request.id); + }); } private _onprogress(notification: ProgressNotification): void { @@ -403,12 +441,15 @@ export abstract class Protocol< }>, >( requestSchema: T, - handler: (request: z.infer) => SendResultT | Promise, + handler: ( + request: z.infer, + extra: RequestHandlerExtra, + ) => SendResultT | Promise, ): void { const method = requestSchema.shape.method.value; this.assertRequestHandlerCapability(method); - this._requestHandlers.set(method, (request) => - Promise.resolve(handler(requestSchema.parse(request))), + this._requestHandlers.set(method, (request, extra) => + Promise.resolve(handler(requestSchema.parse(request), extra)), ); } From 1b08f203ea446391e3836ac213e3a715762acd23 Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Thu, 14 Nov 2024 14:02:51 +0000 Subject: [PATCH 5/9] Catch errors that occur when a handler is invoked, but before it returns a Promise --- src/shared/protocol.ts | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 752d2cc09..ab8421bfb 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -190,11 +190,14 @@ export abstract class Protocol< return; } - handler(notification).catch((error) => - this._onerror( - new Error(`Uncaught error in notification handler: ${error}`), - ), - ); + // Starting with Promise.resolve() puts any synchronous errors into the monad as well. + Promise.resolve() + .then(() => handler(notification)) + .catch((error) => + this._onerror( + new Error(`Uncaught error in notification handler: ${error}`), + ), + ); } private _onrequest(request: JSONRPCRequest): void { @@ -222,7 +225,9 @@ export abstract class Protocol< const abortController = new AbortController(); this._requestHandlerAbortControllers.set(request.id, abortController); - handler(request, { signal: abortController.signal }) + // Starting with Promise.resolve() puts any synchronous errors into the monad as well. + Promise.resolve() + .then(() => handler(request, { signal: abortController.signal })) .then( (result) => { if (abortController.signal.aborted) { From 15bd2dc7aa6a703d9e26eaae6700ecba0c2678c4 Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Thu, 14 Nov 2024 14:07:16 +0000 Subject: [PATCH 6/9] Fix cancelled notification method name --- src/types.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/types.ts b/src/types.ts index 5b9896066..5b9fb664e 100644 --- a/src/types.ts +++ b/src/types.ts @@ -162,7 +162,7 @@ export const EmptyResultSchema = ResultSchema.strict(); * A client MUST NOT attempt to cancel its `initialize` request. */ export const CancelledNotificationSchema = NotificationSchema.extend({ - method: z.literal("cancelled"), + method: z.literal("notifications/cancelled"), params: BaseNotificationParamsSchema.extend({ /** * The ID of the request to cancel. From 1451d0ea246b2445b53357f180020012b8d1fdc7 Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Thu, 14 Nov 2024 14:07:28 +0000 Subject: [PATCH 7/9] Explicitly allow cancelled notifications without any capability --- src/client/index.ts | 4 ++++ src/server/index.ts | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/src/client/index.ts b/src/client/index.ts index 97e6ee072..2c08bcc09 100644 --- a/src/client/index.ts +++ b/src/client/index.ts @@ -244,6 +244,10 @@ export class Client< // No specific capability required for initialized break; + case "notifications/cancelled": + // Cancellation notifications are always allowed + break; + case "notifications/progress": // Progress notifications are always allowed break; diff --git a/src/server/index.ts b/src/server/index.ts index 29a3207af..d15ad3c0d 100644 --- a/src/server/index.ts +++ b/src/server/index.ts @@ -157,6 +157,10 @@ export class Server< } break; + case "notifications/cancelled": + // Cancellation notifications are always allowed + break; + case "notifications/progress": // Progress notifications are always allowed break; From 6ec4c09325051d5296246ee6843dd8b5be69908e Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Thu, 14 Nov 2024 14:11:11 +0000 Subject: [PATCH 8/9] Add tests for issuing cancellation notifications --- src/client/index.test.ts | 55 ++++++++++++++++++++++++++++++++ src/server/index.test.ts | 68 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 123 insertions(+) diff --git a/src/client/index.test.ts b/src/client/index.test.ts index 5610a6293..0d7eb9418 100644 --- a/src/client/index.test.ts +++ b/src/client/index.test.ts @@ -436,3 +436,58 @@ test("should typecheck", () => { }, }); }); + +test("should handle client cancelling a request", async () => { + const server = new Server( + { + name: "test server", + version: "1.0", + }, + { + capabilities: { + resources: {}, + }, + }, + ); + + // Set up server to delay responding to listResources + server.setRequestHandler( + ListResourcesRequestSchema, + async (request, extra) => { + await new Promise((resolve) => setTimeout(resolve, 1000)); + return { + resources: [], + }; + }, + ); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + capabilities: {}, + }, + ); + + await Promise.all([ + client.connect(clientTransport), + server.connect(serverTransport), + ]); + + // Set up abort controller + const controller = new AbortController(); + + // Issue request but cancel it immediately + const listResourcesPromise = client.listResources(undefined, { + signal: controller.signal, + }); + controller.abort("Cancelled by test"); + + // Request should be rejected + await expect(listResourcesPromise).rejects.toBe("Cancelled by test"); +}); diff --git a/src/server/index.test.ts b/src/server/index.test.ts index d30c670bc..0697cc5cb 100644 --- a/src/server/index.test.ts +++ b/src/server/index.test.ts @@ -407,3 +407,71 @@ test("should typecheck", () => { }, ); }); + +test("should handle server cancelling a request", async () => { + const server = new Server( + { + name: "test server", + version: "1.0", + }, + { + capabilities: { + sampling: {}, + }, + }, + ); + + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + capabilities: { + sampling: {}, + }, + }, + ); + + // Set up client to delay responding to createMessage + client.setRequestHandler( + CreateMessageRequestSchema, + async (_request, extra) => { + await new Promise((resolve) => setTimeout(resolve, 1000)); + return { + model: "test", + role: "assistant", + content: { + type: "text", + text: "Test response", + }, + }; + }, + ); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + server.connect(serverTransport), + ]); + + // Set up abort controller + const controller = new AbortController(); + + // Issue request but cancel it immediately + const createMessagePromise = server.createMessage( + { + messages: [], + maxTokens: 10, + }, + { + signal: controller.signal, + }, + ); + controller.abort("Cancelled by test"); + + // Request should be rejected + await expect(createMessagePromise).rejects.toBe("Cancelled by test"); +}); From 9882f1c919db73229965e0b897861d1e400ea960 Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Thu, 14 Nov 2024 14:25:16 +0000 Subject: [PATCH 9/9] Add `abortAfterTimeout()` utility to make it easy to add timeouts --- src/shared/protocol.ts | 2 ++ src/utils.test.ts | 15 +++++++++++++++ src/utils.ts | 11 +++++++++++ 3 files changed, 28 insertions(+) create mode 100644 src/utils.test.ts create mode 100644 src/utils.ts diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index ab8421bfb..8103695d1 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -48,6 +48,8 @@ export type RequestOptions = { /** * Can be used to cancel an in-flight request. This will cause an AbortError to be raised from request(). + * + * Use abortAfterTimeout() to easily implement timeouts using this signal. */ signal?: AbortSignal; }; diff --git a/src/utils.test.ts b/src/utils.test.ts new file mode 100644 index 000000000..e4aa4e5fc --- /dev/null +++ b/src/utils.test.ts @@ -0,0 +1,15 @@ +import { abortAfterTimeout } from "./utils.js"; + +describe("abortAfterTimeout", () => { + it("should abort after timeout", () => { + const signal = abortAfterTimeout(0); + expect(signal.aborted).toBe(false); + + return new Promise((resolve) => { + setTimeout(() => { + expect(signal.aborted).toBe(true); + resolve(true); + }, 0); + }); + }); +}); diff --git a/src/utils.ts b/src/utils.ts new file mode 100644 index 000000000..11672ecd2 --- /dev/null +++ b/src/utils.ts @@ -0,0 +1,11 @@ +/** + * Returns an AbortSignal that will enter aborted state after `timeoutMs` milliseconds. + */ +export function abortAfterTimeout(timeoutMs: number): AbortSignal { + const controller = new AbortController(); + setTimeout(() => { + controller.abort(); + }, timeoutMs); + + return controller.signal; +}