diff --git a/src/client/streamableHttp.test.ts b/src/client/streamableHttp.test.ts index 0dc582d4..40f22139 100644 --- a/src/client/streamableHttp.test.ts +++ b/src/client/streamableHttp.test.ts @@ -80,7 +80,7 @@ describe("StreamableHTTPClientTransport", () => { (global.fetch as jest.Mock).mockResolvedValueOnce({ ok: true, status: 200, - headers: new Headers({ "mcp-session-id": "test-session-id" }), + headers: new Headers({ "content-type": "text/event-stream", "mcp-session-id": "test-session-id" }), }); await transport.send(message); @@ -164,7 +164,7 @@ describe("StreamableHTTPClientTransport", () => { // We expect the 405 error to be caught and handled gracefully // This should not throw an error that breaks the transport await transport.start(); - await expect(transport.openSseStream()).rejects.toThrow('Failed to open SSE stream: Method Not Allowed'); + await expect(transport.openSseStream()).rejects.toThrow("Failed to open SSE stream: Method Not Allowed"); // Check that GET was attempted expect(global.fetch).toHaveBeenCalledWith( @@ -192,7 +192,7 @@ describe("StreamableHTTPClientTransport", () => { const stream = new ReadableStream({ start(controller) { // Send a server notification via SSE - const event = 'event: message\ndata: {"jsonrpc": "2.0", "method": "serverNotification", "params": {}}\n\n'; + const event = "event: message\ndata: {\"jsonrpc\": \"2.0\", \"method\": \"serverNotification\", \"params\": {}}\n\n"; controller.enqueue(encoder.encode(event)); } }); @@ -237,7 +237,7 @@ describe("StreamableHTTPClientTransport", () => { (global.fetch as jest.Mock) .mockResolvedValueOnce({ - ok: true, + ok: true, status: 200, headers: new Headers({ "content-type": "text/event-stream" }), body: makeStream("request1") @@ -263,13 +263,13 @@ describe("StreamableHTTPClientTransport", () => { // Both streams should have delivered their messages expect(messageSpy).toHaveBeenCalledTimes(2); - + // Verify received messages without assuming specific order expect(messageSpy.mock.calls.some(call => { const msg = call[0]; return msg.id === "request1" && msg.result?.id === "request1"; })).toBe(true); - + expect(messageSpy.mock.calls.some(call => { const msg = call[0]; return msg.id === "request2" && msg.result?.id === "request2"; @@ -281,7 +281,7 @@ describe("StreamableHTTPClientTransport", () => { const encoder = new TextEncoder(); const stream = new ReadableStream({ start(controller) { - const event = 'id: event-123\nevent: message\ndata: {"jsonrpc": "2.0", "method": "serverNotification", "params": {}}\n\n'; + const event = "id: event-123\nevent: message\ndata: {\"jsonrpc\": \"2.0\", \"method\": \"serverNotification\", \"params\": {}}\n\n"; controller.enqueue(encoder.encode(event)); controller.close(); } @@ -313,4 +313,67 @@ describe("StreamableHTTPClientTransport", () => { const lastCall = calls[calls.length - 1]; expect(lastCall[1].headers.get("last-event-id")).toBe("event-123"); }); -}); \ No newline at end of file + + it("should throw error when invalid content-type is received", async () => { + const message: JSONRPCMessage = { + jsonrpc: "2.0", + method: "test", + params: {}, + id: "test-id" + }; + + const stream = new ReadableStream({ + start(controller) { + controller.enqueue("invalid text response"); + controller.close(); + } + }); + + const errorSpy = jest.fn(); + transport.onerror = errorSpy; + + (global.fetch as jest.Mock).mockResolvedValueOnce({ + ok: true, + status: 200, + headers: new Headers({ "content-type": "text/plain" }), + body: stream + }); + + await transport.start(); + await expect(transport.send(message)).rejects.toThrow("Unexpected content type: text/plain"); + expect(errorSpy).toHaveBeenCalled(); + }); + + + it("should always send specified custom headers", async () => { + const requestInit = { + headers: { + "X-Custom-Header": "CustomValue" + } + }; + transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), { + requestInit: requestInit + }); + + let actualReqInit: RequestInit = {}; + + ((global.fetch as jest.Mock)).mockImplementation( + async (_url, reqInit) => { + actualReqInit = reqInit; + return new Response(null, { status: 200, headers: { "content-type": "text/event-stream" } }); + } + ); + + await transport.start(); + + await transport.openSseStream(); + expect((actualReqInit.headers as Headers).get("x-custom-header")).toBe("CustomValue"); + + requestInit.headers["X-Custom-Header"] = "SecondCustomValue"; + + await transport.send({ jsonrpc: "2.0", method: "test", params: {} } as JSONRPCMessage); + expect((actualReqInit.headers as Headers).get("x-custom-header")).toBe("SecondCustomValue"); + + expect(global.fetch).toHaveBeenCalledTimes(2); + }); +}); diff --git a/src/client/streamableHttp.ts b/src/client/streamableHttp.ts index 0c667e35..5ea537c7 100644 --- a/src/client/streamableHttp.ts +++ b/src/client/streamableHttp.ts @@ -1,7 +1,8 @@ import { Transport } from "../shared/transport.js"; import { JSONRPCMessage, JSONRPCMessageSchema } from "../types.js"; import { auth, AuthResult, OAuthClientProvider, UnauthorizedError } from "./auth.js"; -import { EventSourceParserStream } from 'eventsource-parser/stream'; +import { EventSourceParserStream } from "eventsource-parser/stream"; + export class StreamableHTTPError extends Error { constructor( public readonly code: number | undefined, @@ -17,16 +18,16 @@ export class StreamableHTTPError extends Error { export type StreamableHTTPClientTransportOptions = { /** * An OAuth client provider to use for authentication. - * + * * When an `authProvider` is specified and the connection is started: * 1. The connection is attempted with any existing access token from the `authProvider`. * 2. If the access token has expired, the `authProvider` is used to refresh the token. * 3. If token refresh fails or no access token exists, and auth is required, `OAuthClientProvider.redirectToAuthorization` is called, and an `UnauthorizedError` will be thrown from `connect`/`start`. - * + * * After the user has finished authorizing via their user agent, and is redirected back to the MCP client application, call `StreamableHTTPClientTransport.finishAuth` with the authorization code before retrying the connection. - * + * * If an `authProvider` is not provided, and auth is required, an `UnauthorizedError` will be thrown. - * + * * `UnauthorizedError` might also be thrown when sending any message over the transport, indicating that the session has expired, and needs to be re-authed and reconnected. */ authProvider?: OAuthClientProvider; @@ -83,7 +84,7 @@ export class StreamableHTTPClientTransport implements Transport { return await this._startOrAuthStandaloneSSE(); } - private async _commonHeaders(): Promise { + private async _commonHeaders(): Promise { const headers: HeadersInit = {}; if (this._authProvider) { const tokens = await this._authProvider.tokens(); @@ -96,24 +97,25 @@ export class StreamableHTTPClientTransport implements Transport { headers["mcp-session-id"] = this._sessionId; } - return headers; + return new Headers( + { ...headers, ...this._requestInit?.headers } + ); } private async _startOrAuthStandaloneSSE(): Promise { try { // Try to open an initial SSE stream with GET to listen for server messages // This is optional according to the spec - server may not support it - const commonHeaders = await this._commonHeaders(); - const headers = new Headers(commonHeaders); - headers.set('Accept', 'text/event-stream'); + const headers = await this._commonHeaders(); + headers.set("Accept", "text/event-stream"); // Include Last-Event-ID header for resumable streams if (this._lastEventId) { - headers.set('last-event-id', this._lastEventId); + headers.set("last-event-id", this._lastEventId); } const response = await fetch(this._url, { - method: 'GET', + method: "GET", headers, signal: this._abortController?.signal, }); @@ -124,12 +126,10 @@ export class StreamableHTTPClientTransport implements Transport { return await this._authThenStart(); } - const error = new StreamableHTTPError( + throw new StreamableHTTPError( response.status, `Failed to open SSE stream: ${response.statusText}`, ); - this.onerror?.(error); - throw error; } // Successful connection, handle the SSE stream as a standalone listener @@ -144,42 +144,32 @@ export class StreamableHTTPClientTransport implements Transport { if (!stream) { return; } - // Create a pipeline: binary stream -> text decoder -> SSE parser - const eventStream = stream - .pipeThrough(new TextDecoderStream()) - .pipeThrough(new EventSourceParserStream()); - const reader = eventStream.getReader(); const processStream = async () => { - try { - while (true) { - const { done, value: event } = await reader.read(); - if (done) { - break; - } - - // Update last event ID if provided - if (event.id) { - this._lastEventId = event.id; - } - - // Handle message events (default event type is undefined per docs) - // or explicit 'message' event type - if (!event.event || event.event === 'message') { - try { - const message = JSONRPCMessageSchema.parse(JSON.parse(event.data)); - this.onmessage?.(message); - } catch (error) { - this.onerror?.(error as Error); - } + // Create a pipeline: binary stream -> text decoder -> SSE parser + const eventStream = stream + .pipeThrough(new TextDecoderStream()) + .pipeThrough(new EventSourceParserStream()); + + for await (const event of eventStream) { + // Update last event ID if provided + if (event.id) { + this._lastEventId = event.id; + } + // Handle message events (default event type is undefined per docs) + // or explicit 'message' event type + if (!event.event || event.event === "message") { + try { + const message = JSONRPCMessageSchema.parse(JSON.parse(event.data)); + this.onmessage?.(message); + } catch (error) { + this.onerror?.(error as Error); } } - } catch (error) { - this.onerror?.(error as Error); } }; - processStream(); + processStream().catch(err => this.onerror?.(err)); } async start() { @@ -215,8 +205,7 @@ export class StreamableHTTPClientTransport implements Transport { async send(message: JSONRPCMessage | JSONRPCMessage[]): Promise { try { - const commonHeaders = await this._commonHeaders(); - const headers = new Headers({ ...commonHeaders, ...this._requestInit?.headers }); + const headers = await this._commonHeaders(); headers.set("content-type", "application/json"); headers.set("accept", "application/json, text/event-stream"); @@ -261,20 +250,13 @@ export class StreamableHTTPClientTransport implements Transport { // Get original message(s) for detecting request IDs const messages = Array.isArray(message) ? message : [message]; - // Extract IDs from request messages for tracking responses - const requestIds = messages.filter(msg => 'method' in msg && 'id' in msg) - .map(msg => 'id' in msg ? msg.id : undefined) - .filter(id => id !== undefined); - - // If we have request IDs and an SSE response, create a unique stream ID - const hasRequests = requestIds.length > 0; + const hasRequests = messages.filter(msg => "method" in msg && "id" in msg && msg.id !== undefined).length > 0; // Check the response type const contentType = response.headers.get("content-type"); if (hasRequests) { if (contentType?.includes("text/event-stream")) { - // For streaming responses, create a unique stream ID based on request IDs this._handleSseStream(response.body); } else if (contentType?.includes("application/json")) { // For non-streaming servers, we might get direct JSON responses @@ -286,6 +268,11 @@ export class StreamableHTTPClientTransport implements Transport { for (const msg of responseMessages) { this.onmessage?.(msg); } + } else { + throw new StreamableHTTPError( + -1, + `Unexpected content type: ${contentType}`, + ); } } } catch (error) { @@ -296,7 +283,7 @@ export class StreamableHTTPClientTransport implements Transport { /** * Opens SSE stream to receive messages from the server. - * + * * This allows the server to push messages to the client without requiring the client * to first send a request via HTTP POST. Some servers may not support this feature. * If authentication is required but fails, this method will throw an UnauthorizedError. @@ -309,4 +296,4 @@ export class StreamableHTTPClientTransport implements Transport { } await this._startOrAuthStandaloneSSE(); } -} \ No newline at end of file +}