diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index 3073d0af..1d037b98 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -62,6 +62,162 @@ describe("protocol tests", () => { await transport.close(); expect(oncloseMock).toHaveBeenCalled(); }); + + describe("progress notification timeout behavior", () => { + beforeEach(() => { + jest.useFakeTimers(); + }); + afterEach(() => { + jest.useRealTimers(); + }); + + test("should reset timeout when progress notification is received", async () => { + await protocol.connect(transport); + const request = { method: "example", params: {} }; + const mockSchema: ZodType<{ result: string }> = z.object({ + result: z.string(), + }); + const onProgressMock = jest.fn(); + const requestPromise = protocol.request(request, mockSchema, { + timeout: 1000, + resetTimeoutOnProgress: true, + onprogress: onProgressMock, + }); + jest.advanceTimersByTime(800); + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: "2.0", + method: "notifications/progress", + params: { + progressToken: 0, + progress: 50, + total: 100, + }, + }); + } + await Promise.resolve(); + expect(onProgressMock).toHaveBeenCalledWith({ + progress: 50, + total: 100, + }); + jest.advanceTimersByTime(800); + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: "2.0", + id: 0, + result: { result: "success" }, + }); + } + await Promise.resolve(); + await expect(requestPromise).resolves.toEqual({ result: "success" }); + }); + + test("should respect maxTotalTimeout", async () => { + await protocol.connect(transport); + const request = { method: "example", params: {} }; + const mockSchema: ZodType<{ result: string }> = z.object({ + result: z.string(), + }); + const onProgressMock = jest.fn(); + const requestPromise = protocol.request(request, mockSchema, { + timeout: 1000, + maxTotalTimeout: 150, + resetTimeoutOnProgress: true, + onprogress: onProgressMock, + }); + + // First progress notification should work + jest.advanceTimersByTime(80); + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: "2.0", + method: "notifications/progress", + params: { + progressToken: 0, + progress: 50, + total: 100, + }, + }); + } + await Promise.resolve(); + expect(onProgressMock).toHaveBeenCalledWith({ + progress: 50, + total: 100, + }); + jest.advanceTimersByTime(80); + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: "2.0", + method: "notifications/progress", + params: { + progressToken: 0, + progress: 75, + total: 100, + }, + }); + } + await expect(requestPromise).rejects.toThrow("Maximum total timeout exceeded"); + expect(onProgressMock).toHaveBeenCalledTimes(1); + }); + + test("should timeout if no progress received within timeout period", async () => { + await protocol.connect(transport); + const request = { method: "example", params: {} }; + const mockSchema: ZodType<{ result: string }> = z.object({ + result: z.string(), + }); + const requestPromise = protocol.request(request, mockSchema, { + timeout: 100, + resetTimeoutOnProgress: true, + }); + jest.advanceTimersByTime(101); + await expect(requestPromise).rejects.toThrow("Request timed out"); + }); + + test("should handle multiple progress notifications correctly", async () => { + await protocol.connect(transport); + const request = { method: "example", params: {} }; + const mockSchema: ZodType<{ result: string }> = z.object({ + result: z.string(), + }); + const onProgressMock = jest.fn(); + const requestPromise = protocol.request(request, mockSchema, { + timeout: 1000, + resetTimeoutOnProgress: true, + onprogress: onProgressMock, + }); + + // Simulate multiple progress updates + for (let i = 1; i <= 3; i++) { + jest.advanceTimersByTime(800); + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: "2.0", + method: "notifications/progress", + params: { + progressToken: 0, + progress: i * 25, + total: 100, + }, + }); + } + await Promise.resolve(); + expect(onProgressMock).toHaveBeenNthCalledWith(i, { + progress: i * 25, + total: 100, + }); + } + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: "2.0", + id: 0, + result: { result: "success" }, + }); + } + await Promise.resolve(); + await expect(requestPromise).resolves.toEqual({ result: "success" }); + }); + }); }); describe("mergeCapabilities", () => { diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index a4f211c6..97213bf0 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -64,6 +64,20 @@ export type RequestOptions = { * If not specified, `DEFAULT_REQUEST_TIMEOUT_MSEC` will be used as the timeout. */ timeout?: number; + + /** + * If true, receiving a progress notification will reset the request timeout. + * This is useful for long-running operations that send periodic progress updates. + * Default: false + */ + resetTimeoutOnProgress?: boolean; + + /** + * Maximum total time (in milliseconds) to wait for a response. + * If exceeded, an McpError with code `RequestTimeout` will be raised, regardless of progress notifications. + * If not specified, there is no maximum total timeout. + */ + maxTotalTimeout?: number; }; /** @@ -76,6 +90,17 @@ export type RequestHandlerExtra = { signal: AbortSignal; }; +/** + * Information about a request's timeout state + */ +type TimeoutInfo = { + timeoutId: ReturnType; + startTime: number; + timeout: number; + maxTotalTimeout?: number; + onTimeout: () => void; +}; + /** * Implements MCP protocol framing on top of a pluggable transport, including * features like request/response linking, notifications, and progress. @@ -105,6 +130,7 @@ export abstract class Protocol< (response: JSONRPCResponse | Error) => void > = new Map(); private _progressHandlers: Map = new Map(); + private _timeoutInfo: Map = new Map(); /** * Callback for when the connection is closed for any reason. @@ -149,6 +175,48 @@ export abstract class Protocol< ); } + private _setupTimeout( + messageId: number, + timeout: number, + maxTotalTimeout: number | undefined, + onTimeout: () => void + ) { + this._timeoutInfo.set(messageId, { + timeoutId: setTimeout(onTimeout, timeout), + startTime: Date.now(), + timeout, + maxTotalTimeout, + onTimeout + }); + } + + private _resetTimeout(messageId: number): boolean { + const info = this._timeoutInfo.get(messageId); + if (!info) return false; + + const totalElapsed = Date.now() - info.startTime; + if (info.maxTotalTimeout && totalElapsed >= info.maxTotalTimeout) { + this._timeoutInfo.delete(messageId); + throw new McpError( + ErrorCode.RequestTimeout, + "Maximum total timeout exceeded", + { maxTotalTimeout: info.maxTotalTimeout, totalElapsed } + ); + } + + clearTimeout(info.timeoutId); + info.timeoutId = setTimeout(info.onTimeout, info.timeout); + return true; + } + + private _cleanupTimeout(messageId: number) { + const info = this._timeoutInfo.get(messageId); + if (info) { + clearTimeout(info.timeoutId); + this._timeoutInfo.delete(messageId); + } + } + /** * Attaches to the given transport, starts it, and starts listening for messages. * @@ -281,22 +349,30 @@ export abstract class Protocol< private _onprogress(notification: ProgressNotification): void { const { progressToken, ...params } = notification.params; - const handler = this._progressHandlers.get(Number(progressToken)); - if (handler === undefined) { - this._onerror( - new Error( - `Received a progress notification for an unknown token: ${JSON.stringify(notification)}`, - ), - ); + const messageId = Number(progressToken); + + const handler = this._progressHandlers.get(messageId); + if (!handler) { + this._onerror(new Error(`Received a progress notification for an unknown token: ${JSON.stringify(notification)}`)); return; } + const responseHandler = this._responseHandlers.get(messageId); + if (this._timeoutInfo.has(messageId) && responseHandler) { + try { + this._resetTimeout(messageId); + } catch (error) { + responseHandler(error as Error); + return; + } + } + handler(params); } private _onresponse(response: JSONRPCResponse | JSONRPCError): void { - const messageId = response.id; - const handler = this._responseHandlers.get(Number(messageId)); + const messageId = Number(response.id); + const handler = this._responseHandlers.get(messageId); if (handler === undefined) { this._onerror( new Error( @@ -306,8 +382,10 @@ export abstract class Protocol< return; } - this._responseHandlers.delete(Number(messageId)); - this._progressHandlers.delete(Number(messageId)); + this._responseHandlers.delete(messageId); + this._progressHandlers.delete(messageId); + this._cleanupTimeout(messageId); + if ("result" in response) { handler(response); } else { @@ -393,32 +471,10 @@ export abstract class Protocol< }; } - let timeoutId: ReturnType | undefined = undefined; - - this._responseHandlers.set(messageId, (response) => { - if (timeoutId !== undefined) { - clearTimeout(timeoutId); - } - - if (options?.signal?.aborted) { - return; - } - - if (response instanceof Error) { - return reject(response); - } - - try { - const result = resultSchema.parse(response.result); - resolve(result); - } catch (error) { - reject(error); - } - }); - const cancel = (reason: unknown) => { this._responseHandlers.delete(messageId); this._progressHandlers.delete(messageId); + this._cleanupTimeout(messageId); this._transport ?.send({ @@ -436,30 +492,38 @@ export abstract class Protocol< reject(reason); }; - options?.signal?.addEventListener("abort", () => { - if (timeoutId !== undefined) { - clearTimeout(timeoutId); + this._responseHandlers.set(messageId, (response) => { + if (options?.signal?.aborted) { + return; } + if (response instanceof Error) { + return reject(response); + } + + try { + const result = resultSchema.parse(response.result); + resolve(result); + } catch (error) { + reject(error); + } + }); + + options?.signal?.addEventListener("abort", () => { cancel(options?.signal?.reason); }); const timeout = options?.timeout ?? DEFAULT_REQUEST_TIMEOUT_MSEC; - timeoutId = setTimeout( - () => - cancel( - new McpError(ErrorCode.RequestTimeout, "Request timed out", { - timeout, - }), - ), - timeout, - ); + const timeoutHandler = () => cancel(new McpError( + ErrorCode.RequestTimeout, + "Request timed out", + { timeout } + )); - this._transport.send(jsonrpcRequest).catch((error) => { - if (timeoutId !== undefined) { - clearTimeout(timeoutId); - } + this._setupTimeout(messageId, timeout, options?.maxTotalTimeout, timeoutHandler); + this._transport.send(jsonrpcRequest).catch((error) => { + this._cleanupTimeout(messageId); reject(error); }); });