diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index e0141da19..05bc8f3bc 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -27,9 +27,11 @@ class MockTransport implements Transport { describe("protocol tests", () => { let protocol: Protocol; let transport: MockTransport; + let sendSpy: jest.SpyInstance; beforeEach(() => { transport = new MockTransport(); + sendSpy = jest.spyOn(transport, 'send'); protocol = new (class extends Protocol { protected assertCapabilityForMethod(): void {} protected assertNotificationCapability(): void {} @@ -63,6 +65,130 @@ describe("protocol tests", () => { expect(oncloseMock).toHaveBeenCalled(); }); + describe("_meta preservation with onprogress", () => { + test("should preserve existing _meta when adding progressToken", async () => { + await protocol.connect(transport); + const request = { + method: "example", + params: { + data: "test", + _meta: { + customField: "customValue", + anotherField: 123 + } + } + }; + const mockSchema: ZodType<{ result: string }> = z.object({ + result: z.string(), + }); + const onProgressMock = jest.fn(); + + protocol.request(request, mockSchema, { + onprogress: onProgressMock, + }); + + expect(sendSpy).toHaveBeenCalledWith(expect.objectContaining({ + method: "example", + params: { + data: "test", + _meta: { + customField: "customValue", + anotherField: 123, + progressToken: expect.any(Number) + } + }, + jsonrpc: "2.0", + id: expect.any(Number) + }), expect.any(Object)); + }); + + test("should create _meta with progressToken when no _meta exists", async () => { + await protocol.connect(transport); + const request = { + method: "example", + params: { + data: "test" + } + }; + const mockSchema: ZodType<{ result: string }> = z.object({ + result: z.string(), + }); + const onProgressMock = jest.fn(); + + protocol.request(request, mockSchema, { + onprogress: onProgressMock, + }); + + expect(sendSpy).toHaveBeenCalledWith(expect.objectContaining({ + method: "example", + params: { + data: "test", + _meta: { + progressToken: expect.any(Number) + } + }, + jsonrpc: "2.0", + id: expect.any(Number) + }), expect.any(Object)); + }); + + test("should not modify _meta when onprogress is not provided", async () => { + await protocol.connect(transport); + const request = { + method: "example", + params: { + data: "test", + _meta: { + customField: "customValue" + } + } + }; + const mockSchema: ZodType<{ result: string }> = z.object({ + result: z.string(), + }); + + protocol.request(request, mockSchema); + + expect(sendSpy).toHaveBeenCalledWith(expect.objectContaining({ + method: "example", + params: { + data: "test", + _meta: { + customField: "customValue" + } + }, + jsonrpc: "2.0", + id: expect.any(Number) + }), expect.any(Object)); + }); + + test("should handle params being undefined with onprogress", async () => { + await protocol.connect(transport); + const request = { + method: "example" + }; + const mockSchema: ZodType<{ result: string }> = z.object({ + result: z.string(), + }); + const onProgressMock = jest.fn(); + + protocol.request(request, mockSchema, { + onprogress: onProgressMock, + }); + + expect(sendSpy).toHaveBeenCalledWith(expect.objectContaining({ + method: "example", + params: { + _meta: { + progressToken: expect.any(Number) + } + }, + jsonrpc: "2.0", + id: expect.any(Number) + }), expect.any(Object)); + }); + }); + describe("progress notification timeout behavior", () => { beforeEach(() => { jest.useFakeTimers(); diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 4694929d7..a04f26eb2 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -541,7 +541,10 @@ export abstract class Protocol< this._progressHandlers.set(messageId, options.onprogress); jsonrpcRequest.params = { ...request.params, - _meta: { progressToken: messageId }, + _meta: { + ...(request.params?._meta || {}), + progressToken: messageId + }, }; }