-
Notifications
You must be signed in to change notification settings - Fork 82
feat: add session management for streamableHttp [MCP-52] #379
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
2be0d17
3a1a320
51da4ea
9f80487
e4f4d18
325a466
c103225
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import { StreamableHTTPServerTransport } from "@modelcontextprotocol/sdk/server/streamableHttp.js"; | ||
import logger, { LogId } from "./logger.js"; | ||
|
||
export class SessionStore { | ||
private sessions: { [sessionId: string]: StreamableHTTPServerTransport | undefined } = {}; | ||
|
||
getSession(sessionId: string): StreamableHTTPServerTransport | undefined { | ||
return this.sessions[sessionId]; | ||
} | ||
|
||
setSession(sessionId: string, transport: StreamableHTTPServerTransport): void { | ||
if (this.sessions[sessionId]) { | ||
throw new Error(`Session ${sessionId} already exists`); | ||
} | ||
this.sessions[sessionId] = transport; | ||
} | ||
|
||
async closeSession(sessionId: string, closeTransport: boolean = true): Promise<void> { | ||
if (!this.sessions[sessionId]) { | ||
throw new Error(`Session ${sessionId} not found`); | ||
} | ||
if (closeTransport) { | ||
const transport = this.sessions[sessionId]; | ||
if (!transport) { | ||
throw new Error(`Session ${sessionId} not found`); | ||
} | ||
try { | ||
await transport.close(); | ||
} catch (error) { | ||
logger.error( | ||
LogId.streamableHttpTransportSessionCloseFailure, | ||
"streamableHttpTransport", | ||
`Error closing transport ${sessionId}: ${error instanceof Error ? error.message : String(error)}` | ||
); | ||
} | ||
} | ||
delete this.sessions[sessionId]; | ||
} | ||
|
||
async closeAllSessions(): Promise<void> { | ||
await Promise.all( | ||
Object.values(this.sessions) | ||
.filter((transport) => transport !== undefined) | ||
.map((transport) => transport.close()) | ||
); | ||
this.sessions = {}; | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,90 +1,132 @@ | ||
import express from "express"; | ||
import http from "http"; | ||
import { StreamableHTTPServerTransport } from "@modelcontextprotocol/sdk/server/streamableHttp.js"; | ||
import { isInitializeRequest } from "@modelcontextprotocol/sdk/types.js"; | ||
import { TransportRunnerBase } from "./base.js"; | ||
import { config } from "../common/config.js"; | ||
import logger, { LogId } from "../common/logger.js"; | ||
import { randomUUID } from "crypto"; | ||
import { SessionStore } from "../common/sessionStore.js"; | ||
|
||
const JSON_RPC_ERROR_CODE_PROCESSING_REQUEST_FAILED = -32000; | ||
const JSON_RPC_ERROR_CODE_METHOD_NOT_ALLOWED = -32601; | ||
const JSON_RPC_ERROR_CODE_SESSION_ID_REQUIRED = -32001; | ||
const JSON_RPC_ERROR_CODE_SESSION_NOT_FOUND = -32002; | ||
const JSON_RPC_ERROR_CODE_INVALID_REQUEST = -32003; | ||
|
||
function promiseHandler( | ||
fn: (req: express.Request, res: express.Response, next: express.NextFunction) => Promise<void> | ||
) { | ||
return (req: express.Request, res: express.Response, next: express.NextFunction) => { | ||
fn(req, res, next).catch(next); | ||
fn(req, res, next).catch((error) => { | ||
logger.error( | ||
LogId.streamableHttpTransportRequestFailure, | ||
"streamableHttpTransport", | ||
`Error handling request: ${error instanceof Error ? error.message : String(error)}` | ||
); | ||
res.status(400).json({ | ||
jsonrpc: "2.0", | ||
error: { | ||
code: JSON_RPC_ERROR_CODE_PROCESSING_REQUEST_FAILED, | ||
message: `failed to handle request`, | ||
data: error instanceof Error ? error.message : String(error), | ||
}, | ||
}); | ||
}); | ||
}; | ||
} | ||
|
||
export class StreamableHttpRunner extends TransportRunnerBase { | ||
private httpServer: http.Server | undefined; | ||
private sessionStore: SessionStore = new SessionStore(); | ||
|
||
async start() { | ||
const app = express(); | ||
app.enable("trust proxy"); // needed for reverse proxy support | ||
app.use(express.urlencoded({ extended: true })); | ||
app.use(express.json()); | ||
|
||
const handleRequest = async (req: express.Request, res: express.Response) => { | ||
const sessionId = req.headers["mcp-session-id"] as string; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wouldn't the better way to write this be: const sessionId = req.headers["mcp-session-id"];
if (typeof sessionId !== 'string' || !sessionId) { Or at least use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good catch, I had an issue on my IDE before I think |
||
if (!sessionId) { | ||
res.status(400).json({ | ||
jsonrpc: "2.0", | ||
error: { | ||
code: JSON_RPC_ERROR_CODE_SESSION_ID_REQUIRED, | ||
message: `session id is required`, | ||
}, | ||
}); | ||
return; | ||
} | ||
const transport = this.sessionStore.getSession(sessionId); | ||
if (!transport) { | ||
res.status(404).json({ | ||
jsonrpc: "2.0", | ||
error: { | ||
code: JSON_RPC_ERROR_CODE_SESSION_NOT_FOUND, | ||
message: `session not found`, | ||
}, | ||
}); | ||
return; | ||
} | ||
await transport.handleRequest(req, res, req.body); | ||
}; | ||
|
||
app.post( | ||
"/mcp", | ||
promiseHandler(async (req: express.Request, res: express.Response) => { | ||
const transport = new StreamableHTTPServerTransport({ | ||
sessionIdGenerator: undefined, | ||
}); | ||
const sessionId = req.headers["mcp-session-id"] as string; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as previous remark |
||
if (sessionId) { | ||
await handleRequest(req, res); | ||
return; | ||
} | ||
|
||
const server = this.setupServer(); | ||
if (!isInitializeRequest(req.body)) { | ||
res.status(400).json({ | ||
jsonrpc: "2.0", | ||
error: { | ||
code: JSON_RPC_ERROR_CODE_INVALID_REQUEST, | ||
message: `invalid request`, | ||
}, | ||
}); | ||
return; | ||
} | ||
|
||
await server.connect(transport); | ||
const server = this.setupServer(); | ||
const transport = new StreamableHTTPServerTransport({ | ||
sessionIdGenerator: () => randomUUID().toString(), | ||
onsessioninitialized: (sessionId) => { | ||
this.sessionStore.setSession(sessionId, transport); | ||
}, | ||
onsessionclosed: async (sessionId) => { | ||
try { | ||
await this.sessionStore.closeSession(sessionId, false); | ||
} catch (error) { | ||
logger.error( | ||
LogId.streamableHttpTransportSessionCloseFailure, | ||
"streamableHttpTransport", | ||
`Error closing session: ${error instanceof Error ? error.message : String(error)}` | ||
); | ||
} | ||
}, | ||
}); | ||
|
||
res.on("close", () => { | ||
Promise.all([transport.close(), server.close()]).catch((error: unknown) => { | ||
transport.onclose = () => { | ||
server.close().catch((error) => { | ||
logger.error( | ||
LogId.streamableHttpTransportCloseFailure, | ||
"streamableHttpTransport", | ||
`Error closing server: ${error instanceof Error ? error.message : String(error)}` | ||
); | ||
}); | ||
}); | ||
}; | ||
|
||
try { | ||
await transport.handleRequest(req, res, req.body); | ||
} catch (error) { | ||
logger.error( | ||
LogId.streamableHttpTransportRequestFailure, | ||
"streamableHttpTransport", | ||
`Error handling request: ${error instanceof Error ? error.message : String(error)}` | ||
); | ||
res.status(400).json({ | ||
jsonrpc: "2.0", | ||
error: { | ||
code: JSON_RPC_ERROR_CODE_PROCESSING_REQUEST_FAILED, | ||
message: `failed to handle request`, | ||
data: error instanceof Error ? error.message : String(error), | ||
}, | ||
}); | ||
} | ||
await server.connect(transport); | ||
|
||
await transport.handleRequest(req, res, req.body); | ||
}) | ||
); | ||
|
||
app.get("/mcp", (req: express.Request, res: express.Response) => { | ||
res.status(405).json({ | ||
jsonrpc: "2.0", | ||
error: { | ||
code: JSON_RPC_ERROR_CODE_METHOD_NOT_ALLOWED, | ||
message: `method not allowed`, | ||
}, | ||
}); | ||
}); | ||
|
||
app.delete("/mcp", (req: express.Request, res: express.Response) => { | ||
res.status(405).json({ | ||
jsonrpc: "2.0", | ||
error: { | ||
code: JSON_RPC_ERROR_CODE_METHOD_NOT_ALLOWED, | ||
message: `method not allowed`, | ||
}, | ||
}); | ||
}); | ||
app.get("/mcp", promiseHandler(handleRequest)); | ||
app.delete("/mcp", promiseHandler(handleRequest)); | ||
|
||
this.httpServer = await new Promise<http.Server>((resolve, reject) => { | ||
const result = app.listen(config.httpPort, config.httpHost, (err?: Error) => { | ||
|
@@ -104,14 +146,17 @@ export class StreamableHttpRunner extends TransportRunnerBase { | |
} | ||
|
||
async close(): Promise<void> { | ||
await new Promise<void>((resolve, reject) => { | ||
this.httpServer?.close((err) => { | ||
if (err) { | ||
reject(err); | ||
return; | ||
} | ||
resolve(); | ||
}); | ||
}); | ||
await Promise.all([ | ||
this.sessionStore.closeAllSessions(), | ||
new Promise<void>((resolve, reject) => { | ||
this.httpServer?.close((err) => { | ||
if (err) { | ||
reject(err); | ||
return; | ||
} | ||
resolve(); | ||
}); | ||
}), | ||
]); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[q] Is there a reason why they value is nullable?
I don't see any location where the value (StreamableHTTPServerTransport) is set to
undefined
/removedUh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for the case where a key is not found