From ff982e0c54d91f7297725a79c3a18511ee926ef3 Mon Sep 17 00:00:00 2001 From: Filipe Constantinov Menezes Date: Mon, 14 Jul 2025 16:57:20 +0100 Subject: [PATCH 1/7] feat: add streamable http [MCP-55] (#359) --- .github/workflows/check.yml | 2 +- Dockerfile | 1 + README.md | 79 +++++++++++--- package-lock.json | 99 +++++++++++++++++ package.json | 2 + src/common/config.ts | 10 +- src/common/logger.ts | 44 +++----- src/index.ts | 13 ++- src/server.ts | 54 +++++++-- .../EJsonTransport.ts => transports/stdio.ts} | 2 +- src/transports/streamableHttp.ts | 103 ++++++++++++++++++ tests/integration/helpers.ts | 1 + tests/unit/{ => common}/apiClient.test.ts | 4 +- tests/unit/{ => common}/session.test.ts | 4 +- tests/unit/{ => helpers}/indexCheck.test.ts | 2 +- .../stdio.test.ts} | 6 +- 16 files changed, 357 insertions(+), 69 deletions(-) rename src/{helpers/EJsonTransport.ts => transports/stdio.ts} (96%) create mode 100644 src/transports/streamableHttp.ts rename tests/unit/{ => common}/apiClient.test.ts (98%) rename tests/unit/{ => common}/session.test.ts (95%) rename tests/unit/{ => helpers}/indexCheck.test.ts (98%) rename tests/unit/{EJsonTransport.test.ts => transports/stdio.test.ts} (93%) diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index 71a5b657..7304823f 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -55,4 +55,4 @@ jobs: rm -rf node_modules npm pkg set scripts.prepare="exit 0" npm install --omit=dev - - run: npx -y @modelcontextprotocol/inspector --cli --method tools/list -- node dist/index.js --connectionString "mongodb://localhost" + - run: npx -y @modelcontextprotocol/inspector --cli --method tools/list -- node dist/index.js diff --git a/Dockerfile b/Dockerfile index 05da379f..d842f633 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,6 +4,7 @@ RUN addgroup -S mcp && adduser -S mcp -G mcp RUN npm install -g mongodb-mcp-server@${VERSION} USER mcp WORKDIR /home/mcp +ENV MDB_MCP_LOGGERS=stderr,mcp ENTRYPOINT ["mongodb-mcp-server"] LABEL maintainer="MongoDB Inc " LABEL description="MongoDB MCP Server" diff --git a/README.md b/README.md index d7537387..006fbef4 100644 --- a/README.md +++ b/README.md @@ -225,6 +225,27 @@ With Atlas API credentials: } ``` +#### Option 6: Running as an HTTP Server + +You can run the MongoDB MCP Server as an HTTP server instead of the default stdio transport. This is useful if you want to interact with the server over HTTP, for example from a web client or to expose the server on a specific port. + +To start the server with HTTP transport, use the `--transport http` option: + +```shell +npx -y mongodb-mcp-server --transport http +``` + +By default, the server will listen on `http://127.0.0.1:3000`. You can customize the host and port using the `--httpHost` and `--httpPort` options: + +```shell +npx -y mongodb-mcp-server --transport http --httpHost=0.0.0.0 --httpPort=8080 +``` + +- `--httpHost` (default: 127.0.0.1): The host to bind the HTTP server. +- `--httpPort` (default: 3000): The port number for the HTTP server. + +> **Note:** The default transport is `stdio`, which is suitable for integration with most MCP clients. Use `http` transport if you need to interact with the server over HTTP. + ## 🛠️ Supported Tools ### Tool List @@ -278,23 +299,53 @@ The MongoDB MCP Server can be configured using multiple methods, with the follow ### Configuration Options -| Option | Description | -| ------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `apiClientId` | Atlas API client ID for authentication. Required for running Atlas tools. | -| `apiClientSecret` | Atlas API client secret for authentication. Required for running Atlas tools. | -| `connectionString` | MongoDB connection string for direct database connections. Optional, if not set, you'll need to call the `connect` tool before interacting with MongoDB data. | -| `logPath` | Folder to store logs. | -| `disabledTools` | An array of tool names, operation types, and/or categories of tools that will be disabled. | -| `readOnly` | When set to true, only allows read, connect, and metadata operation types, disabling create/update/delete operations. | -| `indexCheck` | When set to true, enforces that query operations must use an index, rejecting queries that perform a collection scan. | -| `telemetry` | When set to disabled, disables telemetry collection. | +| Option | Default | Description | +| ------------------ | ---------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `apiClientId` | | Atlas API client ID for authentication. Required for running Atlas tools. | +| `apiClientSecret` | | Atlas API client secret for authentication. Required for running Atlas tools. | +| `connectionString` | | MongoDB connection string for direct database connections. Optional, if not set, you'll need to call the `connect` tool before interacting with MongoDB data. | +| `loggers` | disk,mcp | Comma separated values, possible values are `mcp`, `disk` and `stderr`. See [Logger Options](#logger-options) for details. | +| `logPath` | see note\* | Folder to store logs. | +| `disabledTools` | | An array of tool names, operation types, and/or categories of tools that will be disabled. | +| `readOnly` | false | When set to true, only allows read, connect, and metadata operation types, disabling create/update/delete operations. | +| `indexCheck` | false | When set to true, enforces that query operations must use an index, rejecting queries that perform a collection scan. | +| `telemetry` | enabled | When set to disabled, disables telemetry collection. | +| `transport` | stdio | Either 'stdio' or 'http'. | +| `httpPort` | 3000 | Port number. | +| `httpHost` | 127.0.0.1 | Host to bind the http server. | + +#### Logger Options + +The `loggers` configuration option controls where logs are sent. You can specify one or more logger types as a comma-separated list. The available options are: + +- `mcp`: Sends logs to the MCP client (if supported by the client/transport). +- `disk`: Writes logs to disk files. Log files are stored in the log path (see `logPath` above). +- `stderr`: Outputs logs to standard error (stderr), useful for debugging or when running in containers. + +**Default:** `disk,mcp` (logs are written to disk and sent to the MCP client). + +You can combine multiple loggers, e.g. `--loggers disk,stderr` or `export MDB_MCP_LOGGERS="mcp,stderr"`. + +##### Example: Set logger via environment variable + +```shell +export MDB_MCP_LOGGERS="disk,stderr" +``` + +##### Example: Set logger via command-line argument + +```shell +npx -y mongodb-mcp-server --loggers mcp,stderr +``` + +##### Log File Location -#### Log Path +When using the `disk` logger, log files are stored in: -Default log location is as follows: +- **Windows:** `%LOCALAPPDATA%\mongodb\mongodb-mcp\.app-logs` +- **macOS/Linux:** `~/.mongodb/mongodb-mcp/.app-logs` -- Windows: `%LOCALAPPDATA%\mongodb\mongodb-mcp\.app-logs` -- macOS/Linux: `~/.mongodb/mongodb-mcp/.app-logs` +You can override the log directory with the `logPath` option. #### Disabled Tools diff --git a/package-lock.json b/package-lock.json index 29132ba3..3919b575 100644 --- a/package-lock.json +++ b/package-lock.json @@ -14,6 +14,7 @@ "@mongodb-js/devtools-connect": "^3.7.2", "@mongosh/service-provider-node-driver": "^3.6.0", "bson": "^6.10.4", + "express": "^5.1.0", "lru-cache": "^11.1.0", "mongodb": "^6.17.0", "mongodb-connection-string-url": "^3.0.2", @@ -34,6 +35,7 @@ "@jest/globals": "^30.0.4", "@modelcontextprotocol/inspector": "^0.16.0", "@redocly/cli": "^1.34.4", + "@types/express": "^5.0.1", "@types/jest": "^30.0.0", "@types/node": "^24.0.12", "@types/simple-oauth2": "^5.0.7", @@ -5906,6 +5908,27 @@ "@babel/types": "^7.20.7" } }, + "node_modules/@types/body-parser": { + "version": "1.19.5", + "resolved": "https://registry.npmjs.org/@types/body-parser/-/body-parser-1.19.5.tgz", + "integrity": "sha512-fB3Zu92ucau0iQ0JMCFQE7b/dv8Ot07NI3KaZIkIUNXq82k4eBAqUaneXfleGY9JWskeS9y+u0nXMyspcuQrCg==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/connect": "*", + "@types/node": "*" + } + }, + "node_modules/@types/connect": { + "version": "3.4.38", + "resolved": "https://registry.npmjs.org/@types/connect/-/connect-3.4.38.tgz", + "integrity": "sha512-K6uROf1LD88uDQqJCktA4yzL1YYAK6NgfsI0v/mTgyPKWsX1CnJ0XPSDhViejru1GcRkLWb8RlzFYJRqGUbaug==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/node": "*" + } + }, "node_modules/@types/estree": { "version": "1.0.7", "resolved": "https://registry.npmjs.org/@types/estree/-/estree-1.0.7.tgz", @@ -5913,6 +5936,38 @@ "dev": true, "license": "MIT" }, + "node_modules/@types/express": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/@types/express/-/express-5.0.1.tgz", + "integrity": "sha512-UZUw8vjpWFXuDnjFTh7/5c2TWDlQqeXHi6hcN7F2XSVT5P+WmUnnbFS3KA6Jnc6IsEqI2qCVu2bK0R0J4A8ZQQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/body-parser": "*", + "@types/express-serve-static-core": "^5.0.0", + "@types/serve-static": "*" + } + }, + "node_modules/@types/express-serve-static-core": { + "version": "5.0.6", + "resolved": "https://registry.npmjs.org/@types/express-serve-static-core/-/express-serve-static-core-5.0.6.tgz", + "integrity": "sha512-3xhRnjJPkULekpSzgtoNYYcTWgEZkp4myc+Saevii5JPnHNvHMRlBSHDbs7Bh1iPPoVTERHEZXyhyLbMEsExsA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/node": "*", + "@types/qs": "*", + "@types/range-parser": "*", + "@types/send": "*" + } + }, + "node_modules/@types/http-errors": { + "version": "2.0.4", + "resolved": "https://registry.npmjs.org/@types/http-errors/-/http-errors-2.0.4.tgz", + "integrity": "sha512-D0CFMMtydbJAegzOyHjtiKPLlvnm3iTZyZRSZoLq2mRhDdmLfIWOCYPfQJ4cu2erKghU++QvjcUjp/5h7hESpA==", + "dev": true, + "license": "MIT" + }, "node_modules/@types/istanbul-lib-coverage": { "version": "2.0.6", "resolved": "https://registry.npmjs.org/@types/istanbul-lib-coverage/-/istanbul-lib-coverage-2.0.6.tgz", @@ -6006,6 +6061,13 @@ "dev": true, "license": "MIT" }, + "node_modules/@types/mime": { + "version": "1.3.5", + "resolved": "https://registry.npmjs.org/@types/mime/-/mime-1.3.5.tgz", + "integrity": "sha512-/pyBZWSLD2n0dcHE3hq8s8ZvcETHtEuF+3E7XVt0Ig2nvsVQXdghHVcEkIWjy9A0wKfTn97a/PSDYohKIlnP/w==", + "dev": true, + "license": "MIT" + }, "node_modules/@types/node": { "version": "24.0.12", "resolved": "https://registry.npmjs.org/@types/node/-/node-24.0.12.tgz", @@ -6016,6 +6078,43 @@ "undici-types": "~7.8.0" } }, + "node_modules/@types/qs": { + "version": "6.9.18", + "resolved": "https://registry.npmjs.org/@types/qs/-/qs-6.9.18.tgz", + "integrity": "sha512-kK7dgTYDyGqS+e2Q4aK9X3D7q234CIZ1Bv0q/7Z5IwRDoADNU81xXJK/YVyLbLTZCoIwUoDoffFeF+p/eIklAA==", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/range-parser": { + "version": "1.2.7", + "resolved": "https://registry.npmjs.org/@types/range-parser/-/range-parser-1.2.7.tgz", + "integrity": "sha512-hKormJbkJqzQGhziax5PItDUTMAM9uE2XXQmM37dyd4hVM+5aVl7oVxMVUiVQn2oCQFN/LKCZdvSM0pFRqbSmQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/send": { + "version": "0.17.4", + "resolved": "https://registry.npmjs.org/@types/send/-/send-0.17.4.tgz", + "integrity": "sha512-x2EM6TJOybec7c52BX0ZspPodMsQUd5L6PRwOunVyVUhXiBSKf3AezDL8Dgvgt5o0UfKNfuA0eMLr2wLT4AiBA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/mime": "^1", + "@types/node": "*" + } + }, + "node_modules/@types/serve-static": { + "version": "1.15.7", + "resolved": "https://registry.npmjs.org/@types/serve-static/-/serve-static-1.15.7.tgz", + "integrity": "sha512-W8Ym+h8nhuRwaKPaDw34QUkwsGi6Rc4yYqvKFo5rm2FUEhCFbzVWrxXUxuKK8TASjWsysJY0nsmNCGhCOIsrOw==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/http-errors": "*", + "@types/node": "*", + "@types/send": "*" + } + }, "node_modules/@types/simple-oauth2": { "version": "5.0.7", "resolved": "https://registry.npmjs.org/@types/simple-oauth2/-/simple-oauth2-5.0.7.tgz", diff --git a/package.json b/package.json index 53d6d2c6..4cee9b92 100644 --- a/package.json +++ b/package.json @@ -37,6 +37,7 @@ "@jest/globals": "^30.0.4", "@modelcontextprotocol/inspector": "^0.16.0", "@redocly/cli": "^1.34.4", + "@types/express": "^5.0.1", "@types/jest": "^30.0.0", "@types/node": "^24.0.12", "@types/simple-oauth2": "^5.0.7", @@ -65,6 +66,7 @@ "@mongodb-js/devtools-connect": "^3.7.2", "@mongosh/service-provider-node-driver": "^3.6.0", "bson": "^6.10.4", + "express": "^5.1.0", "lru-cache": "^11.1.0", "mongodb": "^6.17.0", "mongodb-connection-string-url": "^3.0.2", diff --git a/src/common/config.ts b/src/common/config.ts index d9aa0bbc..8eda2fba 100644 --- a/src/common/config.ts +++ b/src/common/config.ts @@ -17,13 +17,17 @@ export interface UserConfig { apiBaseUrl: string; apiClientId?: string; apiClientSecret?: string; - telemetry?: "enabled" | "disabled"; + telemetry: "enabled" | "disabled"; logPath: string; connectionString?: string; connectOptions: ConnectOptions; disabledTools: Array; readOnly?: boolean; indexCheck?: boolean; + transport: "stdio" | "http"; + httpPort: number; + httpHost: string; + loggers: Array<"stderr" | "disk" | "mcp">; } const defaults: UserConfig = { @@ -39,6 +43,10 @@ const defaults: UserConfig = { telemetry: "enabled", readOnly: false, indexCheck: false, + transport: "stdio", + httpPort: 3000, + httpHost: "127.0.0.1", + loggers: ["disk", "mcp"], }; export const config = { diff --git a/src/common/logger.ts b/src/common/logger.ts index b1fb78a9..0e9186d8 100644 --- a/src/common/logger.ts +++ b/src/common/logger.ts @@ -12,6 +12,7 @@ export const LogId = { serverCloseRequested: mongoLogId(1_000_003), serverClosed: mongoLogId(1_000_004), serverCloseFailure: mongoLogId(1_000_005), + serverDuplicateLoggers: mongoLogId(1_000_006), atlasCheckCredentials: mongoLogId(1_001_001), atlasDeleteDatabaseUserFailure: mongoLogId(1_001_002), @@ -37,9 +38,16 @@ export const LogId = { mongodbDisconnectFailure: mongoLogId(1_004_002), toolUpdateFailure: mongoLogId(1_005_001), + + streamableHttpTransportStarted: mongoLogId(1_006_001), + streamableHttpTransportSessionInitialized: mongoLogId(1_006_002), + streamableHttpTransportRequestFailure: mongoLogId(1_006_003), + streamableHttpTransportCloseRequested: mongoLogId(1_006_004), + streamableHttpTransportCloseSuccess: mongoLogId(1_006_005), + streamableHttpTransportCloseFailure: mongoLogId(1_006_006), } as const; -abstract class LoggerBase { +export abstract class LoggerBase { abstract log(level: LogLevel, id: MongoLogId, context: string, message: string): void; info(id: MongoLogId, context: string, message: string): void { @@ -74,14 +82,14 @@ abstract class LoggerBase { } } -class ConsoleLogger extends LoggerBase { +export class ConsoleLogger extends LoggerBase { log(level: LogLevel, id: MongoLogId, context: string, message: string): void { message = redact(message); - console.error(`[${level.toUpperCase()}] ${id.__value} - ${context}: ${message}`); + console.error(`[${level.toUpperCase()}] ${id.__value} - ${context}: ${message} (${process.pid})`); } } -class DiskLogger extends LoggerBase { +export class DiskLogger extends LoggerBase { private constructor(private logWriter: MongoLogWriter) { super(); } @@ -133,7 +141,7 @@ class DiskLogger extends LoggerBase { } } -class McpLogger extends LoggerBase { +export class McpLogger extends LoggerBase { constructor(private server: McpServer) { super(); } @@ -152,18 +160,12 @@ class McpLogger extends LoggerBase { } class CompositeLogger extends LoggerBase { - private loggers: LoggerBase[]; + private loggers: LoggerBase[] = []; constructor(...loggers: LoggerBase[]) { super(); - if (loggers.length === 0) { - // default to ConsoleLogger - this.loggers = [new ConsoleLogger()]; - return; - } - - this.loggers = [...loggers]; + this.setLoggers(...loggers); } setLoggers(...loggers: LoggerBase[]): void { @@ -180,19 +182,5 @@ class CompositeLogger extends LoggerBase { } } -const logger = new CompositeLogger(); +const logger = new CompositeLogger(new ConsoleLogger()); export default logger; - -export async function setStdioPreset(server: McpServer, logPath: string): Promise { - const diskLogger = await DiskLogger.fromPath(logPath); - const mcpLogger = new McpLogger(server); - - logger.setLoggers(mcpLogger, diskLogger); -} - -export function setContainerPreset(server: McpServer): void { - const mcpLogger = new McpLogger(server); - const consoleLogger = new ConsoleLogger(); - - logger.setLoggers(mcpLogger, consoleLogger); -} diff --git a/src/index.ts b/src/index.ts index f94c4371..f09ed604 100644 --- a/src/index.ts +++ b/src/index.ts @@ -7,7 +7,8 @@ import { Session } from "./common/session.js"; import { Server } from "./server.js"; import { packageInfo } from "./common/packageInfo.js"; import { Telemetry } from "./telemetry/telemetry.js"; -import { createEJsonTransport } from "./helpers/EJsonTransport.js"; +import { createStdioTransport } from "./transports/stdio.js"; +import { createHttpTransport } from "./transports/streamableHttp.js"; try { const session = new Session({ @@ -15,13 +16,16 @@ try { apiClientId: config.apiClientId, apiClientSecret: config.apiClientSecret, }); + + const transport = config.transport === "stdio" ? createStdioTransport() : createHttpTransport(); + + const telemetry = Telemetry.create(session, config); + const mcpServer = new McpServer({ name: packageInfo.mcpServerName, version: packageInfo.version, }); - const telemetry = Telemetry.create(session, config); - const server = new Server({ mcpServer, session, @@ -29,8 +33,6 @@ try { userConfig: config, }); - const transport = createEJsonTransport(); - const shutdown = () => { logger.info(LogId.serverCloseRequested, "server", `Server close requested`); @@ -48,6 +50,7 @@ try { }; process.once("SIGINT", shutdown); + process.once("SIGABRT", shutdown); process.once("SIGTERM", shutdown); process.once("SIGQUIT", shutdown); diff --git a/src/server.ts b/src/server.ts index 3c65d2e3..d58cca52 100644 --- a/src/server.ts +++ b/src/server.ts @@ -3,7 +3,7 @@ import { Session } from "./common/session.js"; import { Transport } from "@modelcontextprotocol/sdk/shared/transport.js"; import { AtlasTools } from "./tools/atlas/tools.js"; import { MongoDbTools } from "./tools/mongodb/tools.js"; -import logger, { setStdioPreset, setContainerPreset, LogId } from "./common/logger.js"; +import logger, { LogId, LoggerBase, McpLogger, DiskLogger, ConsoleLogger } from "./common/logger.js"; import { ObjectId } from "mongodb"; import { Telemetry } from "./telemetry/telemetry.js"; import { UserConfig } from "./common/config.js"; @@ -11,7 +11,6 @@ import { type ServerEvent } from "./telemetry/types.js"; import { type ServerCommand } from "./telemetry/types.js"; import { CallToolRequestSchema, CallToolResult } from "@modelcontextprotocol/sdk/types.js"; import assert from "assert"; -import { detectContainerEnv } from "./helpers/container.js"; import { ToolBase } from "./tools/tool.js"; export interface ServerOptions { @@ -38,6 +37,8 @@ export class Server { } async connect(transport: Transport): Promise { + await this.validateConfig(); + this.mcpServer.server.registerCapabilities({ logging: {} }); this.registerTools(); @@ -66,15 +67,17 @@ export class Server { return existingHandler(request, extra); }); - const containerEnv = await detectContainerEnv(); - - if (containerEnv) { - setContainerPreset(this.mcpServer); - } else { - await setStdioPreset(this.mcpServer, this.userConfig.logPath); + const loggers: LoggerBase[] = []; + if (this.userConfig.loggers.includes("mcp")) { + loggers.push(new McpLogger(this.mcpServer)); } - - await this.mcpServer.connect(transport); + if (this.userConfig.loggers.includes("disk")) { + loggers.push(await DiskLogger.fromPath(this.userConfig.logPath)); + } + if (this.userConfig.loggers.includes("stderr")) { + loggers.push(new ConsoleLogger()); + } + logger.setLoggers(...loggers); this.mcpServer.server.oninitialized = () => { this.session.setAgentRunner(this.mcpServer.server.getClientVersion()); @@ -99,7 +102,7 @@ export class Server { this.emitServerEvent("stop", Date.now() - closeTime, error); }; - await this.validateConfig(); + await this.mcpServer.connect(transport); } async close(): Promise { @@ -186,6 +189,35 @@ export class Server { } private async validateConfig(): Promise { + const transport = this.userConfig.transport as string; + if (transport !== "http" && transport !== "stdio") { + throw new Error(`Invalid transport: ${transport}`); + } + + const telemetry = this.userConfig.telemetry as string; + if (telemetry !== "enabled" && telemetry !== "disabled") { + throw new Error(`Invalid telemetry: ${telemetry}`); + } + + if (this.userConfig.httpPort < 1 || this.userConfig.httpPort > 65535) { + throw new Error(`Invalid httpPort: ${this.userConfig.httpPort}`); + } + + if (this.userConfig.loggers.length === 0) { + throw new Error("No loggers found in config"); + } + + const loggerTypes = new Set(this.userConfig.loggers); + if (loggerTypes.size !== this.userConfig.loggers.length) { + throw new Error("Duplicate loggers found in config"); + } + + for (const loggerType of this.userConfig.loggers as string[]) { + if (loggerType !== "mcp" && loggerType !== "disk" && loggerType !== "stderr") { + throw new Error(`Invalid logger: ${loggerType}`); + } + } + if (this.userConfig.connectionString) { try { await this.session.connectToMongoDB(this.userConfig.connectionString, this.userConfig.connectOptions); diff --git a/src/helpers/EJsonTransport.ts b/src/transports/stdio.ts similarity index 96% rename from src/helpers/EJsonTransport.ts rename to src/transports/stdio.ts index 307e90bd..0f9f4c0c 100644 --- a/src/helpers/EJsonTransport.ts +++ b/src/transports/stdio.ts @@ -39,7 +39,7 @@ export class EJsonReadBuffer { // // This function creates a StdioServerTransport and replaces the internal readBuffer with EJsonReadBuffer // that uses EJson.parse instead. -export function createEJsonTransport(): StdioServerTransport { +export function createStdioTransport(): StdioServerTransport { const server = new StdioServerTransport(); server["_readBuffer"] = new EJsonReadBuffer(); diff --git a/src/transports/streamableHttp.ts b/src/transports/streamableHttp.ts new file mode 100644 index 00000000..f613422f --- /dev/null +++ b/src/transports/streamableHttp.ts @@ -0,0 +1,103 @@ +import express from "express"; +import { StreamableHTTPServerTransport } from "@modelcontextprotocol/sdk/server/streamableHttp.js"; + +import { config } from "../common/config.js"; +import logger, { LogId } from "../common/logger.js"; + +const JSON_RPC_ERROR_CODE_PROCESSING_REQUEST_FAILED = -32000; + +export function createHttpTransport(): StreamableHTTPServerTransport { + const app = express(); + app.enable("trust proxy"); // needed for reverse proxy support + app.use(express.urlencoded({ extended: true })); + app.use(express.json()); + + const transport = new StreamableHTTPServerTransport({ + sessionIdGenerator: undefined, + }); + + app.post("/mcp", async (req: express.Request, res: express.Response) => { + try { + await transport.handleRequest(req, res, req.body); + } catch (error) { + logger.error( + LogId.streamableHttpTransportRequestFailure, + "streamableHttpTransport", + `Error handling request: ${error instanceof Error ? error.message : String(error)}` + ); + res.status(400).json({ + jsonrpc: "2.0", + error: { + code: JSON_RPC_ERROR_CODE_PROCESSING_REQUEST_FAILED, + message: `failed to handle request`, + data: error instanceof Error ? error.message : String(error), + }, + }); + } + }); + + app.get("/mcp", async (req: express.Request, res: express.Response) => { + try { + await transport.handleRequest(req, res, req.body); + } catch (error) { + logger.error( + LogId.streamableHttpTransportRequestFailure, + "streamableHttpTransport", + `Error handling request: ${error instanceof Error ? error.message : String(error)}` + ); + res.status(400).json({ + jsonrpc: "2.0", + error: { + code: JSON_RPC_ERROR_CODE_PROCESSING_REQUEST_FAILED, + message: `failed to handle request`, + data: error instanceof Error ? error.message : String(error), + }, + }); + } + }); + + app.delete("/mcp", async (req: express.Request, res: express.Response) => { + try { + await transport.handleRequest(req, res, req.body); + } catch (error) { + logger.error( + LogId.streamableHttpTransportRequestFailure, + "streamableHttpTransport", + `Error handling request: ${error instanceof Error ? error.message : String(error)}` + ); + res.status(400).json({ + jsonrpc: "2.0", + error: { + code: JSON_RPC_ERROR_CODE_PROCESSING_REQUEST_FAILED, + message: `failed to handle request`, + data: error instanceof Error ? error.message : String(error), + }, + }); + } + }); + + const server = app.listen(config.httpPort, config.httpHost, () => { + logger.info( + LogId.streamableHttpTransportStarted, + "streamableHttpTransport", + `Server started on http://${config.httpHost}:${config.httpPort}` + ); + }); + + transport.onclose = () => { + logger.info(LogId.streamableHttpTransportCloseRequested, "streamableHttpTransport", `Closing server`); + server.close((err?: Error) => { + if (err) { + logger.error( + LogId.streamableHttpTransportCloseFailure, + "streamableHttpTransport", + `Error closing server: ${err.message}` + ); + return; + } + logger.info(LogId.streamableHttpTransportCloseSuccess, "streamableHttpTransport", `Server closed`); + }); + }; + + return transport; +} diff --git a/tests/integration/helpers.ts b/tests/integration/helpers.ts index 8f4e0539..84eecf14 100644 --- a/tests/integration/helpers.ts +++ b/tests/integration/helpers.ts @@ -25,6 +25,7 @@ export interface IntegrationTest { export const defaultTestConfig: UserConfig = { ...config, telemetry: "disabled", + loggers: ["stderr"], }; export function setupIntegrationTest(getUserConfig: () => UserConfig): IntegrationTest { diff --git a/tests/unit/apiClient.test.ts b/tests/unit/common/apiClient.test.ts similarity index 98% rename from tests/unit/apiClient.test.ts rename to tests/unit/common/apiClient.test.ts index 6b9fd427..00d26e9f 100644 --- a/tests/unit/apiClient.test.ts +++ b/tests/unit/common/apiClient.test.ts @@ -1,6 +1,6 @@ import { jest } from "@jest/globals"; -import { ApiClient } from "../../src/common/atlas/apiClient.js"; -import { CommonProperties, TelemetryEvent, TelemetryResult } from "../../src/telemetry/types.js"; +import { ApiClient } from "../../../src/common/atlas/apiClient.js"; +import { CommonProperties, TelemetryEvent, TelemetryResult } from "../../../src/telemetry/types.js"; describe("ApiClient", () => { let apiClient: ApiClient; diff --git a/tests/unit/session.test.ts b/tests/unit/common/session.test.ts similarity index 95% rename from tests/unit/session.test.ts rename to tests/unit/common/session.test.ts index fdd4296b..bb43a4a0 100644 --- a/tests/unit/session.test.ts +++ b/tests/unit/common/session.test.ts @@ -1,7 +1,7 @@ import { jest } from "@jest/globals"; import { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver"; -import { Session } from "../../src/common/session.js"; -import { config } from "../../src/common/config.js"; +import { Session } from "../../../src/common/session.js"; +import { config } from "../../../src/common/config.js"; jest.mock("@mongosh/service-provider-node-driver"); const MockNodeDriverServiceProvider = NodeDriverServiceProvider as jest.MockedClass; diff --git a/tests/unit/indexCheck.test.ts b/tests/unit/helpers/indexCheck.test.ts similarity index 98% rename from tests/unit/indexCheck.test.ts rename to tests/unit/helpers/indexCheck.test.ts index 82b67e68..aedac1cf 100644 --- a/tests/unit/indexCheck.test.ts +++ b/tests/unit/helpers/indexCheck.test.ts @@ -1,4 +1,4 @@ -import { usesIndex, getIndexCheckErrorMessage } from "../../src/helpers/indexCheck.js"; +import { usesIndex, getIndexCheckErrorMessage } from "../../../src/helpers/indexCheck.js"; import { Document } from "mongodb"; describe("indexCheck", () => { diff --git a/tests/unit/EJsonTransport.test.ts b/tests/unit/transports/stdio.test.ts similarity index 93% rename from tests/unit/EJsonTransport.test.ts rename to tests/unit/transports/stdio.test.ts index 6bbb7999..0e00968b 100644 --- a/tests/unit/EJsonTransport.test.ts +++ b/tests/unit/transports/stdio.test.ts @@ -1,15 +1,15 @@ import { Decimal128, MaxKey, MinKey, ObjectId, Timestamp, UUID } from "bson"; -import { createEJsonTransport, EJsonReadBuffer } from "../../src/helpers/EJsonTransport.js"; +import { createStdioTransport, EJsonReadBuffer } from "../../../src/transports/stdio.js"; import { JSONRPCMessage } from "@modelcontextprotocol/sdk/types.js"; import { AuthInfo } from "@modelcontextprotocol/sdk/server/auth/types.js"; import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js"; import { Readable } from "stream"; import { ReadBuffer } from "@modelcontextprotocol/sdk/shared/stdio.js"; -describe("EJsonTransport", () => { +describe("stdioTransport", () => { let transport: StdioServerTransport; beforeEach(async () => { - transport = createEJsonTransport(); + transport = createStdioTransport(); await transport.start(); }); From 2c21af33f47776b3a282f5ca7260270fcb5eafe4 Mon Sep 17 00:00:00 2001 From: Filipe Constantinov Menezes Date: Tue, 15 Jul 2025 10:21:20 +0100 Subject: [PATCH 2/7] test: add streamable http tests [MCP-60] (#362) --- README.md | 4 +- src/common/config.ts | 2 +- src/common/logger.ts | 11 +-- src/index.ts | 2 +- src/transports/streamableHttp.ts | 56 +++++++++----- tests/unit/transports/streamableHttp.test.ts | 77 ++++++++++++++++++++ 6 files changed, 125 insertions(+), 27 deletions(-) create mode 100644 tests/unit/transports/streamableHttp.test.ts diff --git a/README.md b/README.md index 006fbef4..3d729bc3 100644 --- a/README.md +++ b/README.md @@ -324,7 +324,7 @@ The `loggers` configuration option controls where logs are sent. You can specify **Default:** `disk,mcp` (logs are written to disk and sent to the MCP client). -You can combine multiple loggers, e.g. `--loggers disk,stderr` or `export MDB_MCP_LOGGERS="mcp,stderr"`. +You can combine multiple loggers, e.g. `--loggers disk stderr` or `export MDB_MCP_LOGGERS="mcp,stderr"`. ##### Example: Set logger via environment variable @@ -335,7 +335,7 @@ export MDB_MCP_LOGGERS="disk,stderr" ##### Example: Set logger via command-line argument ```shell -npx -y mongodb-mcp-server --loggers mcp,stderr +npx -y mongodb-mcp-server --loggers mcp stderr ``` ##### Log File Location diff --git a/src/common/config.ts b/src/common/config.ts index 8eda2fba..98c13cfc 100644 --- a/src/common/config.ts +++ b/src/common/config.ts @@ -128,6 +128,6 @@ function SNAKE_CASE_toCamelCase(str: string): string { // Reads the cli args and parses them into a UserConfig object. function getCliConfig() { return argv(process.argv.slice(2), { - array: ["disabledTools"], + array: ["disabledTools", "loggers"], }) as unknown as Partial; } diff --git a/src/common/logger.ts b/src/common/logger.ts index 0e9186d8..8f6069a0 100644 --- a/src/common/logger.ts +++ b/src/common/logger.ts @@ -40,11 +40,12 @@ export const LogId = { toolUpdateFailure: mongoLogId(1_005_001), streamableHttpTransportStarted: mongoLogId(1_006_001), - streamableHttpTransportSessionInitialized: mongoLogId(1_006_002), - streamableHttpTransportRequestFailure: mongoLogId(1_006_003), - streamableHttpTransportCloseRequested: mongoLogId(1_006_004), - streamableHttpTransportCloseSuccess: mongoLogId(1_006_005), - streamableHttpTransportCloseFailure: mongoLogId(1_006_006), + streamableHttpTransportStartFailure: mongoLogId(1_006_002), + streamableHttpTransportSessionInitialized: mongoLogId(1_006_003), + streamableHttpTransportRequestFailure: mongoLogId(1_006_004), + streamableHttpTransportCloseRequested: mongoLogId(1_006_005), + streamableHttpTransportCloseSuccess: mongoLogId(1_006_006), + streamableHttpTransportCloseFailure: mongoLogId(1_006_007), } as const; export abstract class LoggerBase { diff --git a/src/index.ts b/src/index.ts index f09ed604..c5f4ddee 100644 --- a/src/index.ts +++ b/src/index.ts @@ -17,7 +17,7 @@ try { apiClientSecret: config.apiClientSecret, }); - const transport = config.transport === "stdio" ? createStdioTransport() : createHttpTransport(); + const transport = config.transport === "stdio" ? createStdioTransport() : await createHttpTransport(); const telemetry = Telemetry.create(session, config); diff --git a/src/transports/streamableHttp.ts b/src/transports/streamableHttp.ts index f613422f..bb4d0f06 100644 --- a/src/transports/streamableHttp.ts +++ b/src/transports/streamableHttp.ts @@ -1,4 +1,5 @@ import express from "express"; +import http from "http"; import { StreamableHTTPServerTransport } from "@modelcontextprotocol/sdk/server/streamableHttp.js"; import { config } from "../common/config.js"; @@ -6,7 +7,7 @@ import logger, { LogId } from "../common/logger.js"; const JSON_RPC_ERROR_CODE_PROCESSING_REQUEST_FAILED = -32000; -export function createHttpTransport(): StreamableHTTPServerTransport { +export async function createHttpTransport(): Promise { const app = express(); app.enable("trust proxy"); // needed for reverse proxy support app.use(express.urlencoded({ extended: true })); @@ -76,28 +77,47 @@ export function createHttpTransport(): StreamableHTTPServerTransport { } }); - const server = app.listen(config.httpPort, config.httpHost, () => { + try { + const server = await new Promise((resolve, reject) => { + const result = app.listen(config.httpPort, config.httpHost, (err?: Error) => { + if (err) { + reject(err); + return; + } + resolve(result); + }); + }); + logger.info( LogId.streamableHttpTransportStarted, "streamableHttpTransport", `Server started on http://${config.httpHost}:${config.httpPort}` ); - }); - transport.onclose = () => { - logger.info(LogId.streamableHttpTransportCloseRequested, "streamableHttpTransport", `Closing server`); - server.close((err?: Error) => { - if (err) { - logger.error( - LogId.streamableHttpTransportCloseFailure, - "streamableHttpTransport", - `Error closing server: ${err.message}` - ); - return; - } - logger.info(LogId.streamableHttpTransportCloseSuccess, "streamableHttpTransport", `Server closed`); - }); - }; + transport.onclose = () => { + logger.info(LogId.streamableHttpTransportCloseRequested, "streamableHttpTransport", `Closing server`); + server.close((err?: Error) => { + if (err) { + logger.error( + LogId.streamableHttpTransportCloseFailure, + "streamableHttpTransport", + `Error closing server: ${err.message}` + ); + return; + } + logger.info(LogId.streamableHttpTransportCloseSuccess, "streamableHttpTransport", `Server closed`); + }); + }; + + return transport; + } catch (error: unknown) { + const err = error instanceof Error ? error : new Error(String(error)); + logger.info( + LogId.streamableHttpTransportStartFailure, + "streamableHttpTransport", + `Error starting server: ${err.message}` + ); - return transport; + throw err; + } } diff --git a/tests/unit/transports/streamableHttp.test.ts b/tests/unit/transports/streamableHttp.test.ts new file mode 100644 index 00000000..1150052b --- /dev/null +++ b/tests/unit/transports/streamableHttp.test.ts @@ -0,0 +1,77 @@ +import { createHttpTransport } from "../../../src/transports/streamableHttp.js"; +import { StreamableHTTPServerTransport } from "@modelcontextprotocol/sdk/server/streamableHttp.js"; +import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js"; +import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; +import { JSONRPCMessage } from "@modelcontextprotocol/sdk/types.js"; +import { z } from "zod"; +describe("streamableHttpTransport", () => { + let transport: StreamableHTTPServerTransport; + const mcpServer = new McpServer({ + name: "test", + version: "1.0.0", + }); + beforeAll(async () => { + transport = await createHttpTransport(); + mcpServer.registerTool( + "hello", + { + title: "Hello Tool", + description: "Say hello", + inputSchema: { name: z.string() }, + }, + ({ name }) => ({ + content: [{ type: "text", text: `Hello, ${name}!` }], + }) + ); + await mcpServer.connect(transport); + }); + + afterAll(async () => { + await mcpServer.close(); + }); + + describe("client connects successfully", () => { + let client: StreamableHTTPClientTransport; + beforeAll(async () => { + client = new StreamableHTTPClientTransport(new URL("http://127.0.0.1:3000/mcp")); + await client.start(); + }); + + afterAll(async () => { + await client.close(); + }); + + it("handles requests and sends responses", async () => { + client.onmessage = (message: JSONRPCMessage) => { + const messageResult = message as + | { + result?: { + tools: { + name: string; + description: string; + }[]; + }; + } + | undefined; + + expect(message.jsonrpc).toBe("2.0"); + expect(messageResult).toBeDefined(); + expect(messageResult?.result?.tools).toBeDefined(); + expect(messageResult?.result?.tools.length).toBe(1); + expect(messageResult?.result?.tools[0]?.name).toBe("hello"); + expect(messageResult?.result?.tools[0]?.description).toBe("Say hello"); + }; + + await client.send({ + jsonrpc: "2.0", + id: 1, + method: "tools/list", + params: { + _meta: { + progressToken: 1, + }, + }, + }); + }); + }); +}); From db23253149e8a856d2ae835a9d44f4378a5875b3 Mon Sep 17 00:00:00 2001 From: Filipe Constantinov Menezes Date: Tue, 15 Jul 2025 18:24:36 +0100 Subject: [PATCH 3/7] refactor: move http server (#377) --- package.json | 2 +- src/index.ts | 61 +++---- src/transports/base.ts | 34 ++++ src/transports/stdio.ts | 24 +++ src/transports/streamableHttp.ts | 156 +++++++++--------- tests/integration/transports/stdio.test.ts | 70 ++++++++ .../transports/streamableHttp.test.ts | 76 +++++++++ tests/unit/transports/stdio.test.ts | 1 - tests/unit/transports/streamableHttp.test.ts | 78 --------- 9 files changed, 304 insertions(+), 198 deletions(-) create mode 100644 src/transports/base.ts create mode 100644 tests/integration/transports/stdio.test.ts create mode 100644 tests/integration/transports/streamableHttp.test.ts delete mode 100644 tests/unit/transports/streamableHttp.test.ts diff --git a/package.json b/package.json index 9b6347b3..b18fa2cb 100644 --- a/package.json +++ b/package.json @@ -29,7 +29,7 @@ "check:types": "tsc --noEmit --project tsconfig.json", "reformat": "prettier --write .", "generate": "./scripts/generate.sh", - "test": "vitest --coverage" + "test": "vitest --run --coverage" }, "license": "Apache-2.0", "devDependencies": { diff --git a/src/index.ts b/src/index.ts index c5f4ddee..73457dd6 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,50 +1,23 @@ #!/usr/bin/env node import logger, { LogId } from "./common/logger.js"; -import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; import { config } from "./common/config.js"; -import { Session } from "./common/session.js"; -import { Server } from "./server.js"; -import { packageInfo } from "./common/packageInfo.js"; -import { Telemetry } from "./telemetry/telemetry.js"; -import { createStdioTransport } from "./transports/stdio.js"; -import { createHttpTransport } from "./transports/streamableHttp.js"; +import { StdioRunner } from "./transports/stdio.js"; +import { StreamableHttpRunner } from "./transports/streamableHttp.js"; -try { - const session = new Session({ - apiBaseUrl: config.apiBaseUrl, - apiClientId: config.apiClientId, - apiClientSecret: config.apiClientSecret, - }); - - const transport = config.transport === "stdio" ? createStdioTransport() : await createHttpTransport(); - - const telemetry = Telemetry.create(session, config); - - const mcpServer = new McpServer({ - name: packageInfo.mcpServerName, - version: packageInfo.version, - }); - - const server = new Server({ - mcpServer, - session, - telemetry, - userConfig: config, - }); +async function main() { + const transportRunner = config.transport === "stdio" ? new StdioRunner() : new StreamableHttpRunner(); const shutdown = () => { logger.info(LogId.serverCloseRequested, "server", `Server close requested`); - server + transportRunner .close() .then(() => { - logger.info(LogId.serverClosed, "server", `Server closed successfully`); process.exit(0); }) - .catch((err: unknown) => { - const error = err instanceof Error ? err : new Error(String(err)); - logger.error(LogId.serverCloseFailure, "server", `Error closing server: ${error.message}`); + .catch((error: unknown) => { + logger.error(LogId.serverCloseFailure, "server", `Error closing server: ${error as string}`); process.exit(1); }); }; @@ -54,8 +27,22 @@ try { process.once("SIGTERM", shutdown); process.once("SIGQUIT", shutdown); - await server.connect(transport); -} catch (error: unknown) { + try { + await transportRunner.start(); + } catch (error: unknown) { + logger.emergency(LogId.serverStartFailure, "server", `Fatal error running server: ${error as string}`); + try { + await transportRunner.close(); + logger.error(LogId.serverClosed, "server", "Server closed"); + } catch (error: unknown) { + logger.error(LogId.serverCloseFailure, "server", `Error closing server: ${error as string}`); + } finally { + process.exit(1); + } + } +} + +main().catch((error: unknown) => { logger.emergency(LogId.serverStartFailure, "server", `Fatal error running server: ${error as string}`); process.exit(1); -} +}); diff --git a/src/transports/base.ts b/src/transports/base.ts new file mode 100644 index 00000000..442db18a --- /dev/null +++ b/src/transports/base.ts @@ -0,0 +1,34 @@ +import { config } from "../common/config.js"; +import { packageInfo } from "../common/packageInfo.js"; +import { Server } from "../server.js"; +import { Session } from "../common/session.js"; +import { Telemetry } from "../telemetry/telemetry.js"; +import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; + +export abstract class TransportRunnerBase { + protected setupServer(): Server { + const session = new Session({ + apiBaseUrl: config.apiBaseUrl, + apiClientId: config.apiClientId, + apiClientSecret: config.apiClientSecret, + }); + + const telemetry = Telemetry.create(session, config); + + const mcpServer = new McpServer({ + name: packageInfo.mcpServerName, + version: packageInfo.version, + }); + + return new Server({ + mcpServer, + session, + telemetry, + userConfig: config, + }); + } + + abstract start(): Promise; + + abstract close(): Promise; +} diff --git a/src/transports/stdio.ts b/src/transports/stdio.ts index 0f9f4c0c..9f18627c 100644 --- a/src/transports/stdio.ts +++ b/src/transports/stdio.ts @@ -1,3 +1,6 @@ +import logger, { LogId } from "../common/logger.js"; +import { Server } from "../server.js"; +import { TransportRunnerBase } from "./base.js"; import { JSONRPCMessage, JSONRPCMessageSchema } from "@modelcontextprotocol/sdk/types.js"; import { EJSON } from "bson"; import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js"; @@ -45,3 +48,24 @@ export function createStdioTransport(): StdioServerTransport { return server; } + +export class StdioRunner extends TransportRunnerBase { + private server: Server | undefined; + + async start() { + try { + this.server = this.setupServer(); + + const transport = createStdioTransport(); + + await this.server.connect(transport); + } catch (error: unknown) { + logger.emergency(LogId.serverStartFailure, "server", `Fatal error running server: ${error as string}`); + process.exit(1); + } + } + + async close(): Promise { + await this.server?.close(); + } +} diff --git a/src/transports/streamableHttp.ts b/src/transports/streamableHttp.ts index bb4d0f06..e15af8d5 100644 --- a/src/transports/streamableHttp.ts +++ b/src/transports/streamableHttp.ts @@ -1,84 +1,92 @@ import express from "express"; import http from "http"; import { StreamableHTTPServerTransport } from "@modelcontextprotocol/sdk/server/streamableHttp.js"; - +import { TransportRunnerBase } from "./base.js"; import { config } from "../common/config.js"; import logger, { LogId } from "../common/logger.js"; const JSON_RPC_ERROR_CODE_PROCESSING_REQUEST_FAILED = -32000; +const JSON_RPC_ERROR_CODE_METHOD_NOT_ALLOWED = -32601; -export async function createHttpTransport(): Promise { - const app = express(); - app.enable("trust proxy"); // needed for reverse proxy support - app.use(express.urlencoded({ extended: true })); - app.use(express.json()); +function promiseHandler( + fn: (req: express.Request, res: express.Response, next: express.NextFunction) => Promise +) { + return (req: express.Request, res: express.Response, next: express.NextFunction) => { + fn(req, res, next).catch(next); + }; +} - const transport = new StreamableHTTPServerTransport({ - sessionIdGenerator: undefined, - }); +export class StreamableHttpRunner extends TransportRunnerBase { + private httpServer: http.Server | undefined; - app.post("/mcp", async (req: express.Request, res: express.Response) => { - try { - await transport.handleRequest(req, res, req.body); - } catch (error) { - logger.error( - LogId.streamableHttpTransportRequestFailure, - "streamableHttpTransport", - `Error handling request: ${error instanceof Error ? error.message : String(error)}` - ); - res.status(400).json({ - jsonrpc: "2.0", - error: { - code: JSON_RPC_ERROR_CODE_PROCESSING_REQUEST_FAILED, - message: `failed to handle request`, - data: error instanceof Error ? error.message : String(error), - }, - }); - } - }); + async start() { + const app = express(); + app.enable("trust proxy"); // needed for reverse proxy support + app.use(express.urlencoded({ extended: true })); + app.use(express.json()); + + app.post( + "/mcp", + promiseHandler(async (req: express.Request, res: express.Response) => { + const transport = new StreamableHTTPServerTransport({ + sessionIdGenerator: undefined, + }); + + const server = this.setupServer(); - app.get("/mcp", async (req: express.Request, res: express.Response) => { - try { - await transport.handleRequest(req, res, req.body); - } catch (error) { - logger.error( - LogId.streamableHttpTransportRequestFailure, - "streamableHttpTransport", - `Error handling request: ${error instanceof Error ? error.message : String(error)}` - ); - res.status(400).json({ + await server.connect(transport); + + res.on("close", () => { + Promise.all([transport.close(), server.close()]).catch((error: unknown) => { + logger.error( + LogId.streamableHttpTransportCloseFailure, + "streamableHttpTransport", + `Error closing server: ${error instanceof Error ? error.message : String(error)}` + ); + }); + }); + + try { + await transport.handleRequest(req, res, req.body); + } catch (error) { + logger.error( + LogId.streamableHttpTransportRequestFailure, + "streamableHttpTransport", + `Error handling request: ${error instanceof Error ? error.message : String(error)}` + ); + res.status(400).json({ + jsonrpc: "2.0", + error: { + code: JSON_RPC_ERROR_CODE_PROCESSING_REQUEST_FAILED, + message: `failed to handle request`, + data: error instanceof Error ? error.message : String(error), + }, + }); + } + }) + ); + + app.get("/mcp", (req: express.Request, res: express.Response) => { + res.status(405).json({ jsonrpc: "2.0", error: { - code: JSON_RPC_ERROR_CODE_PROCESSING_REQUEST_FAILED, - message: `failed to handle request`, - data: error instanceof Error ? error.message : String(error), + code: JSON_RPC_ERROR_CODE_METHOD_NOT_ALLOWED, + message: `method not allowed`, }, }); - } - }); + }); - app.delete("/mcp", async (req: express.Request, res: express.Response) => { - try { - await transport.handleRequest(req, res, req.body); - } catch (error) { - logger.error( - LogId.streamableHttpTransportRequestFailure, - "streamableHttpTransport", - `Error handling request: ${error instanceof Error ? error.message : String(error)}` - ); - res.status(400).json({ + app.delete("/mcp", (req: express.Request, res: express.Response) => { + res.status(405).json({ jsonrpc: "2.0", error: { - code: JSON_RPC_ERROR_CODE_PROCESSING_REQUEST_FAILED, - message: `failed to handle request`, - data: error instanceof Error ? error.message : String(error), + code: JSON_RPC_ERROR_CODE_METHOD_NOT_ALLOWED, + message: `method not allowed`, }, }); - } - }); + }); - try { - const server = await new Promise((resolve, reject) => { + this.httpServer = await new Promise((resolve, reject) => { const result = app.listen(config.httpPort, config.httpHost, (err?: Error) => { if (err) { reject(err); @@ -93,31 +101,17 @@ export async function createHttpTransport(): Promise { - logger.info(LogId.streamableHttpTransportCloseRequested, "streamableHttpTransport", `Closing server`); - server.close((err?: Error) => { + async close(): Promise { + await new Promise((resolve, reject) => { + this.httpServer?.close((err) => { if (err) { - logger.error( - LogId.streamableHttpTransportCloseFailure, - "streamableHttpTransport", - `Error closing server: ${err.message}` - ); + reject(err); return; } - logger.info(LogId.streamableHttpTransportCloseSuccess, "streamableHttpTransport", `Server closed`); + resolve(); }); - }; - - return transport; - } catch (error: unknown) { - const err = error instanceof Error ? error : new Error(String(error)); - logger.info( - LogId.streamableHttpTransportStartFailure, - "streamableHttpTransport", - `Error starting server: ${err.message}` - ); - - throw err; + }); } } diff --git a/tests/integration/transports/stdio.test.ts b/tests/integration/transports/stdio.test.ts new file mode 100644 index 00000000..afbcce00 --- /dev/null +++ b/tests/integration/transports/stdio.test.ts @@ -0,0 +1,70 @@ +import { JSONRPCMessage } from "@modelcontextprotocol/sdk/types.js"; +import { describe, expect, it, beforeAll, afterAll } from "vitest"; +import { StdioClientTransport } from "@modelcontextprotocol/sdk/client/stdio.js"; + +describe("StdioRunner", () => { + describe("client connects successfully", () => { + let client: StdioClientTransport; + beforeAll(async () => { + client = new StdioClientTransport({ + command: "node", + args: ["dist/index.js"], + env: { + MDB_MCP_TRANSPORT: "stdio", + }, + }); + await client.start(); + }); + + afterAll(async () => { + await client.close(); + }); + + it("handles requests and sends responses", async () => { + let fixedResolve: ((value: JSONRPCMessage) => void) | undefined = undefined; + const messagePromise = new Promise((resolve) => { + fixedResolve = resolve; + }); + + client.onmessage = (message: JSONRPCMessage) => { + fixedResolve?.(message); + }; + + await client.send({ + jsonrpc: "2.0", + id: 1, + method: "tools/list", + params: { + _meta: { + progressToken: 1, + }, + }, + }); + + const message = (await messagePromise) as { + jsonrpc: string; + id: number; + result: { + tools: { + name: string; + description: string; + }[]; + }; + error?: { + code: number; + message: string; + }; + }; + + expect(message.jsonrpc).toBe("2.0"); + expect(message.id).toBe(1); + expect(message.result).toBeDefined(); + expect(message.result?.tools).toBeDefined(); + expect(message.result?.tools.length).toBeGreaterThan(0); + const tools = message.result?.tools; + tools.sort((a, b) => a.name.localeCompare(b.name)); + expect(tools[0]?.name).toBe("aggregate"); + expect(tools[0]?.description).toBe("Run an aggregation against a MongoDB collection"); + }); + }); +}); diff --git a/tests/integration/transports/streamableHttp.test.ts b/tests/integration/transports/streamableHttp.test.ts new file mode 100644 index 00000000..031e7798 --- /dev/null +++ b/tests/integration/transports/streamableHttp.test.ts @@ -0,0 +1,76 @@ +import { StreamableHttpRunner } from "../../../src/transports/streamableHttp.js"; +import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js"; +import { JSONRPCMessage } from "@modelcontextprotocol/sdk/types.js"; +import { describe, expect, it, beforeAll, afterAll } from "vitest"; + +describe("StreamableHttpRunner", () => { + let runner: StreamableHttpRunner; + + beforeAll(async () => { + runner = new StreamableHttpRunner(); + await runner.start(); + }); + + afterAll(async () => { + await runner.close(); + }); + + describe("client connects successfully", () => { + let client: StreamableHTTPClientTransport; + beforeAll(async () => { + client = new StreamableHTTPClientTransport(new URL("http://127.0.0.1:3000/mcp")); + await client.start(); + }); + + afterAll(async () => { + await client.close(); + }); + + it("handles requests and sends responses", async () => { + let fixedResolve: ((value: JSONRPCMessage) => void) | undefined = undefined; + const messagePromise = new Promise((resolve) => { + fixedResolve = resolve; + }); + + client.onmessage = (message: JSONRPCMessage) => { + fixedResolve?.(message); + }; + + await client.send({ + jsonrpc: "2.0", + id: 1, + method: "tools/list", + params: { + _meta: { + progressToken: 1, + }, + }, + }); + + const message = (await messagePromise) as { + jsonrpc: string; + id: number; + result: { + tools: { + name: string; + description: string; + }[]; + }; + error?: { + code: number; + message: string; + }; + }; + + expect(message.jsonrpc).toBe("2.0"); + expect(message.id).toBe(1); + expect(message.result).toBeDefined(); + expect(message.result?.tools).toBeDefined(); + expect(message.result?.tools.length).toBeGreaterThan(0); + const tools = message.result?.tools; + tools.sort((a, b) => a.name.localeCompare(b.name)); + expect(tools[0]?.name).toBe("aggregate"); + expect(tools[0]?.description).toBe("Run an aggregation against a MongoDB collection"); + }); + }); +}); diff --git a/tests/unit/transports/stdio.test.ts b/tests/unit/transports/stdio.test.ts index 2a1c62de..6a53f67b 100644 --- a/tests/unit/transports/stdio.test.ts +++ b/tests/unit/transports/stdio.test.ts @@ -6,7 +6,6 @@ import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js" import { Readable } from "stream"; import { ReadBuffer } from "@modelcontextprotocol/sdk/shared/stdio.js"; import { describe, expect, it, beforeEach, afterEach } from "vitest"; - describe("stdioTransport", () => { let transport: StdioServerTransport; beforeEach(async () => { diff --git a/tests/unit/transports/streamableHttp.test.ts b/tests/unit/transports/streamableHttp.test.ts deleted file mode 100644 index 01eeb136..00000000 --- a/tests/unit/transports/streamableHttp.test.ts +++ /dev/null @@ -1,78 +0,0 @@ -import { createHttpTransport } from "../../../src/transports/streamableHttp.js"; -import { StreamableHTTPServerTransport } from "@modelcontextprotocol/sdk/server/streamableHttp.js"; -import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js"; -import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; -import { JSONRPCMessage } from "@modelcontextprotocol/sdk/types.js"; -import { z } from "zod"; -import { describe, expect, it, beforeAll, afterAll } from "vitest"; -describe("streamableHttpTransport", () => { - let transport: StreamableHTTPServerTransport; - const mcpServer = new McpServer({ - name: "test", - version: "1.0.0", - }); - beforeAll(async () => { - transport = await createHttpTransport(); - mcpServer.registerTool( - "hello", - { - title: "Hello Tool", - description: "Say hello", - inputSchema: { name: z.string() }, - }, - ({ name }) => ({ - content: [{ type: "text", text: `Hello, ${name}!` }], - }) - ); - await mcpServer.connect(transport); - }); - - afterAll(async () => { - await mcpServer.close(); - }); - - describe("client connects successfully", () => { - let client: StreamableHTTPClientTransport; - beforeAll(async () => { - client = new StreamableHTTPClientTransport(new URL("http://127.0.0.1:3000/mcp")); - await client.start(); - }); - - afterAll(async () => { - await client.close(); - }); - - it("handles requests and sends responses", async () => { - client.onmessage = (message: JSONRPCMessage) => { - const messageResult = message as - | { - result?: { - tools: { - name: string; - description: string; - }[]; - }; - } - | undefined; - - expect(message.jsonrpc).toBe("2.0"); - expect(messageResult).toBeDefined(); - expect(messageResult?.result?.tools).toBeDefined(); - expect(messageResult?.result?.tools.length).toBe(1); - expect(messageResult?.result?.tools[0]?.name).toBe("hello"); - expect(messageResult?.result?.tools[0]?.description).toBe("Say hello"); - }; - - await client.send({ - jsonrpc: "2.0", - id: 1, - method: "tools/list", - params: { - _meta: { - progressToken: 1, - }, - }, - }); - }); - }); -}); From ac40e8a9faedf2d2fabe3a4af703cd19688ac28b Mon Sep 17 00:00:00 2001 From: Filipe Constantinov Menezes Date: Fri, 18 Jul 2025 16:55:25 +0100 Subject: [PATCH 4/7] feat: add session management for streamableHttp [MCP-52] (#379) --- .vscode/launch.json | 1 + README.md | 28 +-- package.json | 1 + src/common/logger.ts | 9 +- src/common/sessionStore.ts | 48 +++++ src/index.ts | 9 +- src/transports/streamableHttp.ts | 168 ++++++++++++------ tests/integration/transports/stdio.test.ts | 64 ++----- .../transports/streamableHttp.test.ts | 76 +++----- 9 files changed, 229 insertions(+), 175 deletions(-) create mode 100644 src/common/sessionStore.ts diff --git a/.vscode/launch.json b/.vscode/launch.json index 0756e2d0..f8eaa53f 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -19,6 +19,7 @@ "name": "Launch Program", "skipFiles": ["/**"], "program": "${workspaceFolder}/dist/index.js", + "args": ["--transport", "http", "--loggers", "stderr", "mcp"], "preLaunchTask": "tsc: build - tsconfig.build.json", "outFiles": ["${workspaceFolder}/dist/**/*.js"] } diff --git a/README.md b/README.md index fa36afc1..d845bf94 100644 --- a/README.md +++ b/README.md @@ -299,20 +299,20 @@ The MongoDB MCP Server can be configured using multiple methods, with the follow ### Configuration Options -| Option | Default | Description | -| ------------------ | ---------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `apiClientId` | | Atlas API client ID for authentication. Required for running Atlas tools. | -| `apiClientSecret` | | Atlas API client secret for authentication. Required for running Atlas tools. | -| `connectionString` | | MongoDB connection string for direct database connections. Optional, if not set, you'll need to call the `connect` tool before interacting with MongoDB data. | -| `loggers` | disk,mcp | Comma separated values, possible values are `mcp`, `disk` and `stderr`. See [Logger Options](#logger-options) for details. | -| `logPath` | see note\* | Folder to store logs. | -| `disabledTools` | | An array of tool names, operation types, and/or categories of tools that will be disabled. | -| `readOnly` | false | When set to true, only allows read, connect, and metadata operation types, disabling create/update/delete operations. | -| `indexCheck` | false | When set to true, enforces that query operations must use an index, rejecting queries that perform a collection scan. | -| `telemetry` | enabled | When set to disabled, disables telemetry collection. | -| `transport` | stdio | Either 'stdio' or 'http'. | -| `httpPort` | 3000 | Port number. | -| `httpHost` | 127.0.0.1 | Host to bind the http server. | +| CLI Option | Environment Variable | Default | Description | +| ------------------ | --------------------------- | ---------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `apiClientId` | `MDB_MCP_API_CLIENT_ID` | | Atlas API client ID for authentication. Required for running Atlas tools. | +| `apiClientSecret` | `MDB_MCP_API_CLIENT_SECRET` | | Atlas API client secret for authentication. Required for running Atlas tools. | +| `connectionString` | `MDB_MCP_CONNECTION_STRING` | | MongoDB connection string for direct database connections. Optional, if not set, you'll need to call the `connect` tool before interacting with MongoDB data. | +| `loggers` | `MDB_MCP_LOGGERS` | disk,mcp | Comma separated values, possible values are `mcp`, `disk` and `stderr`. See [Logger Options](#logger-options) for details. | +| `logPath` | `MDB_MCP_LOG_PATH` | see note\* | Folder to store logs. | +| `disabledTools` | `MDB_MCP_DISABLED_TOOLS` | | An array of tool names, operation types, and/or categories of tools that will be disabled. | +| `readOnly` | `MDB_MCP_READ_ONLY` | false | When set to true, only allows read, connect, and metadata operation types, disabling create/update/delete operations. | +| `indexCheck` | `MDB_MCP_INDEX_CHECK` | false | When set to true, enforces that query operations must use an index, rejecting queries that perform a collection scan. | +| `telemetry` | `MDB_MCP_TELEMETRY` | enabled | When set to disabled, disables telemetry collection. | +| `transport` | `MDB_MCP_TRANSPORT` | stdio | Either 'stdio' or 'http'. | +| `httpPort` | `MDB_MCP_HTTP_PORT` | 3000 | Port number. | +| `httpHost` | `MDB_MCP_HTTP_HOST` | 127.0.0.1 | Host to bind the http server. | #### Logger Options diff --git a/package.json b/package.json index 3606f7c9..6668cc87 100644 --- a/package.json +++ b/package.json @@ -16,6 +16,7 @@ }, "type": "module", "scripts": { + "start": "node dist/index.js --transport http", "prepare": "npm run build", "build:clean": "rm -rf dist", "build:compile": "tsc --project tsconfig.build.json", diff --git a/src/common/logger.ts b/src/common/logger.ts index 8f6069a0..faa5507a 100644 --- a/src/common/logger.ts +++ b/src/common/logger.ts @@ -40,12 +40,9 @@ export const LogId = { toolUpdateFailure: mongoLogId(1_005_001), streamableHttpTransportStarted: mongoLogId(1_006_001), - streamableHttpTransportStartFailure: mongoLogId(1_006_002), - streamableHttpTransportSessionInitialized: mongoLogId(1_006_003), - streamableHttpTransportRequestFailure: mongoLogId(1_006_004), - streamableHttpTransportCloseRequested: mongoLogId(1_006_005), - streamableHttpTransportCloseSuccess: mongoLogId(1_006_006), - streamableHttpTransportCloseFailure: mongoLogId(1_006_007), + streamableHttpTransportSessionCloseFailure: mongoLogId(1_006_002), + streamableHttpTransportRequestFailure: mongoLogId(1_006_003), + streamableHttpTransportCloseFailure: mongoLogId(1_006_004), } as const; export abstract class LoggerBase { diff --git a/src/common/sessionStore.ts b/src/common/sessionStore.ts new file mode 100644 index 00000000..9159f633 --- /dev/null +++ b/src/common/sessionStore.ts @@ -0,0 +1,48 @@ +import { StreamableHTTPServerTransport } from "@modelcontextprotocol/sdk/server/streamableHttp.js"; +import logger, { LogId } from "./logger.js"; + +export class SessionStore { + private sessions: { [sessionId: string]: StreamableHTTPServerTransport } = {}; + + getSession(sessionId: string): StreamableHTTPServerTransport | undefined { + return this.sessions[sessionId]; + } + + setSession(sessionId: string, transport: StreamableHTTPServerTransport): void { + if (this.sessions[sessionId]) { + throw new Error(`Session ${sessionId} already exists`); + } + this.sessions[sessionId] = transport; + } + + async closeSession(sessionId: string, closeTransport: boolean = true): Promise { + if (!this.sessions[sessionId]) { + throw new Error(`Session ${sessionId} not found`); + } + if (closeTransport) { + const transport = this.sessions[sessionId]; + if (!transport) { + throw new Error(`Session ${sessionId} not found`); + } + try { + await transport.close(); + } catch (error) { + logger.error( + LogId.streamableHttpTransportSessionCloseFailure, + "streamableHttpTransport", + `Error closing transport ${sessionId}: ${error instanceof Error ? error.message : String(error)}` + ); + } + } + delete this.sessions[sessionId]; + } + + async closeAllSessions(): Promise { + await Promise.all( + Object.values(this.sessions) + .filter((transport) => transport !== undefined) + .map((transport) => transport.close()) + ); + this.sessions = {}; + } +} diff --git a/src/index.ts b/src/index.ts index 73457dd6..4f81c0cd 100644 --- a/src/index.ts +++ b/src/index.ts @@ -14,6 +14,7 @@ async function main() { transportRunner .close() .then(() => { + logger.info(LogId.serverClosed, "server", `Server closed`); process.exit(0); }) .catch((error: unknown) => { @@ -22,10 +23,10 @@ async function main() { }); }; - process.once("SIGINT", shutdown); - process.once("SIGABRT", shutdown); - process.once("SIGTERM", shutdown); - process.once("SIGQUIT", shutdown); + process.on("SIGINT", shutdown); + process.on("SIGABRT", shutdown); + process.on("SIGTERM", shutdown); + process.on("SIGQUIT", shutdown); try { await transportRunner.start(); diff --git a/src/transports/streamableHttp.ts b/src/transports/streamableHttp.ts index e15af8d5..fbe01a55 100644 --- a/src/transports/streamableHttp.ts +++ b/src/transports/streamableHttp.ts @@ -1,90 +1,143 @@ import express from "express"; import http from "http"; import { StreamableHTTPServerTransport } from "@modelcontextprotocol/sdk/server/streamableHttp.js"; +import { isInitializeRequest } from "@modelcontextprotocol/sdk/types.js"; import { TransportRunnerBase } from "./base.js"; import { config } from "../common/config.js"; import logger, { LogId } from "../common/logger.js"; +import { randomUUID } from "crypto"; +import { SessionStore } from "../common/sessionStore.js"; const JSON_RPC_ERROR_CODE_PROCESSING_REQUEST_FAILED = -32000; -const JSON_RPC_ERROR_CODE_METHOD_NOT_ALLOWED = -32601; +const JSON_RPC_ERROR_CODE_SESSION_ID_REQUIRED = -32001; +const JSON_RPC_ERROR_CODE_SESSION_ID_INVALID = -32002; +const JSON_RPC_ERROR_CODE_SESSION_NOT_FOUND = -32003; +const JSON_RPC_ERROR_CODE_INVALID_REQUEST = -32004; function promiseHandler( fn: (req: express.Request, res: express.Response, next: express.NextFunction) => Promise ) { return (req: express.Request, res: express.Response, next: express.NextFunction) => { - fn(req, res, next).catch(next); + fn(req, res, next).catch((error) => { + logger.error( + LogId.streamableHttpTransportRequestFailure, + "streamableHttpTransport", + `Error handling request: ${error instanceof Error ? error.message : String(error)}` + ); + res.status(400).json({ + jsonrpc: "2.0", + error: { + code: JSON_RPC_ERROR_CODE_PROCESSING_REQUEST_FAILED, + message: `failed to handle request`, + data: error instanceof Error ? error.message : String(error), + }, + }); + }); }; } export class StreamableHttpRunner extends TransportRunnerBase { private httpServer: http.Server | undefined; + private sessionStore: SessionStore = new SessionStore(); async start() { const app = express(); app.enable("trust proxy"); // needed for reverse proxy support - app.use(express.urlencoded({ extended: true })); app.use(express.json()); + const handleRequest = async (req: express.Request, res: express.Response) => { + const sessionId = req.headers["mcp-session-id"]; + if (!sessionId) { + res.status(400).json({ + jsonrpc: "2.0", + error: { + code: JSON_RPC_ERROR_CODE_SESSION_ID_REQUIRED, + message: `session id is required`, + }, + }); + return; + } + if (typeof sessionId !== "string") { + res.status(400).json({ + jsonrpc: "2.0", + error: { + code: JSON_RPC_ERROR_CODE_SESSION_ID_INVALID, + message: `session id is invalid`, + }, + }); + return; + } + const transport = this.sessionStore.getSession(sessionId); + if (!transport) { + res.status(404).json({ + jsonrpc: "2.0", + error: { + code: JSON_RPC_ERROR_CODE_SESSION_NOT_FOUND, + message: `session not found`, + }, + }); + return; + } + await transport.handleRequest(req, res, req.body); + }; + app.post( "/mcp", promiseHandler(async (req: express.Request, res: express.Response) => { - const transport = new StreamableHTTPServerTransport({ - sessionIdGenerator: undefined, - }); + const sessionId = req.headers["mcp-session-id"]; + if (sessionId) { + await handleRequest(req, res); + return; + } - const server = this.setupServer(); + if (!isInitializeRequest(req.body)) { + res.status(400).json({ + jsonrpc: "2.0", + error: { + code: JSON_RPC_ERROR_CODE_INVALID_REQUEST, + message: `invalid request`, + }, + }); + return; + } - await server.connect(transport); + const server = this.setupServer(); + const transport = new StreamableHTTPServerTransport({ + sessionIdGenerator: () => randomUUID().toString(), + onsessioninitialized: (sessionId) => { + this.sessionStore.setSession(sessionId, transport); + }, + onsessionclosed: async (sessionId) => { + try { + await this.sessionStore.closeSession(sessionId, false); + } catch (error) { + logger.error( + LogId.streamableHttpTransportSessionCloseFailure, + "streamableHttpTransport", + `Error closing session: ${error instanceof Error ? error.message : String(error)}` + ); + } + }, + }); - res.on("close", () => { - Promise.all([transport.close(), server.close()]).catch((error: unknown) => { + transport.onclose = () => { + server.close().catch((error) => { logger.error( LogId.streamableHttpTransportCloseFailure, "streamableHttpTransport", `Error closing server: ${error instanceof Error ? error.message : String(error)}` ); }); - }); + }; - try { - await transport.handleRequest(req, res, req.body); - } catch (error) { - logger.error( - LogId.streamableHttpTransportRequestFailure, - "streamableHttpTransport", - `Error handling request: ${error instanceof Error ? error.message : String(error)}` - ); - res.status(400).json({ - jsonrpc: "2.0", - error: { - code: JSON_RPC_ERROR_CODE_PROCESSING_REQUEST_FAILED, - message: `failed to handle request`, - data: error instanceof Error ? error.message : String(error), - }, - }); - } + await server.connect(transport); + + await transport.handleRequest(req, res, req.body); }) ); - app.get("/mcp", (req: express.Request, res: express.Response) => { - res.status(405).json({ - jsonrpc: "2.0", - error: { - code: JSON_RPC_ERROR_CODE_METHOD_NOT_ALLOWED, - message: `method not allowed`, - }, - }); - }); - - app.delete("/mcp", (req: express.Request, res: express.Response) => { - res.status(405).json({ - jsonrpc: "2.0", - error: { - code: JSON_RPC_ERROR_CODE_METHOD_NOT_ALLOWED, - message: `method not allowed`, - }, - }); - }); + app.get("/mcp", promiseHandler(handleRequest)); + app.delete("/mcp", promiseHandler(handleRequest)); this.httpServer = await new Promise((resolve, reject) => { const result = app.listen(config.httpPort, config.httpHost, (err?: Error) => { @@ -104,14 +157,17 @@ export class StreamableHttpRunner extends TransportRunnerBase { } async close(): Promise { - await new Promise((resolve, reject) => { - this.httpServer?.close((err) => { - if (err) { - reject(err); - return; - } - resolve(); - }); - }); + await Promise.all([ + this.sessionStore.closeAllSessions(), + new Promise((resolve, reject) => { + this.httpServer?.close((err) => { + if (err) { + reject(err); + return; + } + resolve(); + }); + }), + ]); } } diff --git a/tests/integration/transports/stdio.test.ts b/tests/integration/transports/stdio.test.ts index afbcce00..2bc03b5b 100644 --- a/tests/integration/transports/stdio.test.ts +++ b/tests/integration/transports/stdio.test.ts @@ -1,70 +1,40 @@ -import { JSONRPCMessage } from "@modelcontextprotocol/sdk/types.js"; import { describe, expect, it, beforeAll, afterAll } from "vitest"; import { StdioClientTransport } from "@modelcontextprotocol/sdk/client/stdio.js"; +import { Client } from "@modelcontextprotocol/sdk/client/index.js"; describe("StdioRunner", () => { describe("client connects successfully", () => { - let client: StdioClientTransport; + let client: Client; + let transport: StdioClientTransport; beforeAll(async () => { - client = new StdioClientTransport({ + transport = new StdioClientTransport({ command: "node", args: ["dist/index.js"], env: { MDB_MCP_TRANSPORT: "stdio", }, }); - await client.start(); + client = new Client({ + name: "test", + version: "0.0.0", + }); + await client.connect(transport); }); afterAll(async () => { await client.close(); + await transport.close(); }); it("handles requests and sends responses", async () => { - let fixedResolve: ((value: JSONRPCMessage) => void) | undefined = undefined; - const messagePromise = new Promise((resolve) => { - fixedResolve = resolve; - }); - - client.onmessage = (message: JSONRPCMessage) => { - fixedResolve?.(message); - }; - - await client.send({ - jsonrpc: "2.0", - id: 1, - method: "tools/list", - params: { - _meta: { - progressToken: 1, - }, - }, - }); - - const message = (await messagePromise) as { - jsonrpc: string; - id: number; - result: { - tools: { - name: string; - description: string; - }[]; - }; - error?: { - code: number; - message: string; - }; - }; + const response = await client.listTools(); + expect(response).toBeDefined(); + expect(response.tools).toBeDefined(); + expect(response.tools).toHaveLength(20); - expect(message.jsonrpc).toBe("2.0"); - expect(message.id).toBe(1); - expect(message.result).toBeDefined(); - expect(message.result?.tools).toBeDefined(); - expect(message.result?.tools.length).toBeGreaterThan(0); - const tools = message.result?.tools; - tools.sort((a, b) => a.name.localeCompare(b.name)); - expect(tools[0]?.name).toBe("aggregate"); - expect(tools[0]?.description).toBe("Run an aggregation against a MongoDB collection"); + const sortedTools = response.tools.sort((a, b) => a.name.localeCompare(b.name)); + expect(sortedTools[0]?.name).toBe("aggregate"); + expect(sortedTools[0]?.description).toBe("Run an aggregation against a MongoDB collection"); }); }); }); diff --git a/tests/integration/transports/streamableHttp.test.ts b/tests/integration/transports/streamableHttp.test.ts index 031e7798..c295705e 100644 --- a/tests/integration/transports/streamableHttp.test.ts +++ b/tests/integration/transports/streamableHttp.test.ts @@ -1,76 +1,56 @@ import { StreamableHttpRunner } from "../../../src/transports/streamableHttp.js"; +import { Client } from "@modelcontextprotocol/sdk/client/index.js"; import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js"; -import { JSONRPCMessage } from "@modelcontextprotocol/sdk/types.js"; import { describe, expect, it, beforeAll, afterAll } from "vitest"; +import { config } from "../../../src/common/config.js"; describe("StreamableHttpRunner", () => { let runner: StreamableHttpRunner; + let oldTelemetry: "enabled" | "disabled"; + let oldLoggers: ("stderr" | "disk" | "mcp")[]; beforeAll(async () => { + oldTelemetry = config.telemetry; + oldLoggers = config.loggers; + config.telemetry = "disabled"; + config.loggers = ["stderr"]; runner = new StreamableHttpRunner(); await runner.start(); }); afterAll(async () => { await runner.close(); + config.telemetry = oldTelemetry; + config.loggers = oldLoggers; }); describe("client connects successfully", () => { - let client: StreamableHTTPClientTransport; + let client: Client; + let transport: StreamableHTTPClientTransport; beforeAll(async () => { - client = new StreamableHTTPClientTransport(new URL("http://127.0.0.1:3000/mcp")); - await client.start(); + transport = new StreamableHTTPClientTransport(new URL("http://127.0.0.1:3000/mcp")); + + client = new Client({ + name: "test", + version: "0.0.0", + }); + await client.connect(transport); }); afterAll(async () => { await client.close(); + await transport.close(); }); it("handles requests and sends responses", async () => { - let fixedResolve: ((value: JSONRPCMessage) => void) | undefined = undefined; - const messagePromise = new Promise((resolve) => { - fixedResolve = resolve; - }); - - client.onmessage = (message: JSONRPCMessage) => { - fixedResolve?.(message); - }; - - await client.send({ - jsonrpc: "2.0", - id: 1, - method: "tools/list", - params: { - _meta: { - progressToken: 1, - }, - }, - }); - - const message = (await messagePromise) as { - jsonrpc: string; - id: number; - result: { - tools: { - name: string; - description: string; - }[]; - }; - error?: { - code: number; - message: string; - }; - }; - - expect(message.jsonrpc).toBe("2.0"); - expect(message.id).toBe(1); - expect(message.result).toBeDefined(); - expect(message.result?.tools).toBeDefined(); - expect(message.result?.tools.length).toBeGreaterThan(0); - const tools = message.result?.tools; - tools.sort((a, b) => a.name.localeCompare(b.name)); - expect(tools[0]?.name).toBe("aggregate"); - expect(tools[0]?.description).toBe("Run an aggregation against a MongoDB collection"); + const response = await client.listTools(); + expect(response).toBeDefined(); + expect(response.tools).toBeDefined(); + expect(response.tools.length).toBeGreaterThan(0); + + const sortedTools = response.tools.sort((a, b) => a.name.localeCompare(b.name)); + expect(sortedTools[0]?.name).toBe("aggregate"); + expect(sortedTools[0]?.description).toBe("Run an aggregation against a MongoDB collection"); }); }); }); From 427663f8fd5dde437e57cbab68c6f6deae56c825 Mon Sep 17 00:00:00 2001 From: Filipe Constantinov Menezes Date: Fri, 18 Jul 2025 17:11:12 +0100 Subject: [PATCH 5/7] chore: add transport to telemetry [MCP-62] (#382) --- src/telemetry/telemetry.ts | 1 + src/telemetry/types.ts | 1 + 2 files changed, 2 insertions(+) diff --git a/src/telemetry/telemetry.ts b/src/telemetry/telemetry.ts index 80385843..eb759edc 100644 --- a/src/telemetry/telemetry.ts +++ b/src/telemetry/telemetry.ts @@ -116,6 +116,7 @@ export class Telemetry { public getCommonProperties(): CommonProperties { return { ...this.commonProperties, + transport: this.userConfig.transport, mcp_client_version: this.session.agentRunner?.version, mcp_client_name: this.session.agentRunner?.name, session_id: this.session.sessionId, diff --git a/src/telemetry/types.ts b/src/telemetry/types.ts index 862441fd..f919ab88 100644 --- a/src/telemetry/types.ts +++ b/src/telemetry/types.ts @@ -69,6 +69,7 @@ export type CommonProperties = { is_container_env?: boolean; mcp_client_version?: string; mcp_client_name?: string; + transport?: "stdio" | "http"; config_atlas_auth?: TelemetryBoolSet; config_connection_string?: TelemetryBoolSet; session_id?: string; From 1297ed957ea0ee950b330979547b0f6a7b195e4d Mon Sep 17 00:00:00 2001 From: Filipe Constantinov Menezes Date: Mon, 21 Jul 2025 11:58:21 +0100 Subject: [PATCH 6/7] feat: add client idle timeout [MCP-57] (#383) --- .vscode/launch.json | 4 +- README.md | 30 ++++--- package.json | 2 +- src/common/config.ts | 4 + src/common/logger.ts | 5 +- src/common/sessionStore.ts | 88 +++++++++++++++---- src/common/timeoutManager.ts | 63 +++++++++++++ src/index.ts | 2 +- src/transports/base.ts | 14 +-- src/transports/stdio.ts | 7 +- src/transports/streamableHttp.ts | 17 ++-- .../transports/streamableHttp.test.ts | 2 +- tests/unit/common/timeoutManager.test.ts | 79 +++++++++++++++++ 13 files changed, 266 insertions(+), 51 deletions(-) create mode 100644 src/common/timeoutManager.ts create mode 100644 tests/unit/common/timeoutManager.test.ts diff --git a/.vscode/launch.json b/.vscode/launch.json index f8eaa53f..8eec7d6e 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -18,8 +18,8 @@ "request": "launch", "name": "Launch Program", "skipFiles": ["/**"], - "program": "${workspaceFolder}/dist/index.js", - "args": ["--transport", "http", "--loggers", "stderr", "mcp"], + "runtimeExecutable": "npm", + "runtimeArgs": ["start"], "preLaunchTask": "tsc: build - tsconfig.build.json", "outFiles": ["${workspaceFolder}/dist/**/*.js"] } diff --git a/README.md b/README.md index 78169a00..6a91e158 100644 --- a/README.md +++ b/README.md @@ -302,20 +302,22 @@ The MongoDB MCP Server can be configured using multiple methods, with the follow ### Configuration Options -| CLI Option | Environment Variable | Default | Description | -| ------------------ | --------------------------- | ---------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `apiClientId` | `MDB_MCP_API_CLIENT_ID` | | Atlas API client ID for authentication. Required for running Atlas tools. | -| `apiClientSecret` | `MDB_MCP_API_CLIENT_SECRET` | | Atlas API client secret for authentication. Required for running Atlas tools. | -| `connectionString` | `MDB_MCP_CONNECTION_STRING` | | MongoDB connection string for direct database connections. Optional, if not set, you'll need to call the `connect` tool before interacting with MongoDB data. | -| `loggers` | `MDB_MCP_LOGGERS` | disk,mcp | Comma separated values, possible values are `mcp`, `disk` and `stderr`. See [Logger Options](#logger-options) for details. | -| `logPath` | `MDB_MCP_LOG_PATH` | see note\* | Folder to store logs. | -| `disabledTools` | `MDB_MCP_DISABLED_TOOLS` | | An array of tool names, operation types, and/or categories of tools that will be disabled. | -| `readOnly` | `MDB_MCP_READ_ONLY` | false | When set to true, only allows read, connect, and metadata operation types, disabling create/update/delete operations. | -| `indexCheck` | `MDB_MCP_INDEX_CHECK` | false | When set to true, enforces that query operations must use an index, rejecting queries that perform a collection scan. | -| `telemetry` | `MDB_MCP_TELEMETRY` | enabled | When set to disabled, disables telemetry collection. | -| `transport` | `MDB_MCP_TRANSPORT` | stdio | Either 'stdio' or 'http'. | -| `httpPort` | `MDB_MCP_HTTP_PORT` | 3000 | Port number. | -| `httpHost` | `MDB_MCP_HTTP_HOST` | 127.0.0.1 | Host to bind the http server. | +| CLI Option | Environment Variable | Default | Description | +| ----------------------- | --------------------------------- | ---------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `apiClientId` | `MDB_MCP_API_CLIENT_ID` | | Atlas API client ID for authentication. Required for running Atlas tools. | +| `apiClientSecret` | `MDB_MCP_API_CLIENT_SECRET` | | Atlas API client secret for authentication. Required for running Atlas tools. | +| `connectionString` | `MDB_MCP_CONNECTION_STRING` | | MongoDB connection string for direct database connections. Optional, if not set, you'll need to call the `connect` tool before interacting with MongoDB data. | +| `loggers` | `MDB_MCP_LOGGERS` | disk,mcp | Comma separated values, possible values are `mcp`, `disk` and `stderr`. See [Logger Options](#logger-options) for details. | +| `logPath` | `MDB_MCP_LOG_PATH` | see note\* | Folder to store logs. | +| `disabledTools` | `MDB_MCP_DISABLED_TOOLS` | | An array of tool names, operation types, and/or categories of tools that will be disabled. | +| `readOnly` | `MDB_MCP_READ_ONLY` | false | When set to true, only allows read, connect, and metadata operation types, disabling create/update/delete operations. | +| `indexCheck` | `MDB_MCP_INDEX_CHECK` | false | When set to true, enforces that query operations must use an index, rejecting queries that perform a collection scan. | +| `telemetry` | `MDB_MCP_TELEMETRY` | enabled | When set to disabled, disables telemetry collection. | +| `transport` | `MDB_MCP_TRANSPORT` | stdio | Either 'stdio' or 'http'. | +| `httpPort` | `MDB_MCP_HTTP_PORT` | 3000 | Port number. | +| `httpHost` | `MDB_MCP_HTTP_HOST` | 127.0.0.1 | Host to bind the http server. | +| `idleTimeoutMs` | `MDB_MCP_IDLE_TIMEOUT_MS` | 600000 | Idle timeout for a client to disconnect (only applies to http transport). | +| `notificationTimeoutMs` | `MDB_MCP_NOTIFICATION_TIMEOUT_MS` | 540000 | Notification timeout for a client to be aware of diconnect (only applies to http transport). | #### Logger Options diff --git a/package.json b/package.json index cafa5e9b..b58c1191 100644 --- a/package.json +++ b/package.json @@ -16,7 +16,7 @@ }, "type": "module", "scripts": { - "start": "node dist/index.js --transport http", + "start": "node dist/index.js --transport http --loggers stderr mcp", "prepare": "npm run build", "build:clean": "rm -rf dist", "build:compile": "tsc --project tsconfig.build.json", diff --git a/src/common/config.ts b/src/common/config.ts index 98c13cfc..3406a440 100644 --- a/src/common/config.ts +++ b/src/common/config.ts @@ -28,6 +28,8 @@ export interface UserConfig { httpPort: number; httpHost: string; loggers: Array<"stderr" | "disk" | "mcp">; + idleTimeoutMs: number; + notificationTimeoutMs: number; } const defaults: UserConfig = { @@ -47,6 +49,8 @@ const defaults: UserConfig = { httpPort: 3000, httpHost: "127.0.0.1", loggers: ["disk", "mcp"], + idleTimeoutMs: 600000, // 10 minutes + notificationTimeoutMs: 540000, // 9 minutes }; export const config = { diff --git a/src/common/logger.ts b/src/common/logger.ts index 259d173e..7ed1a9ac 100644 --- a/src/common/logger.ts +++ b/src/common/logger.ts @@ -43,8 +43,9 @@ export const LogId = { streamableHttpTransportStarted: mongoLogId(1_006_001), streamableHttpTransportSessionCloseFailure: mongoLogId(1_006_002), - streamableHttpTransportRequestFailure: mongoLogId(1_006_003), - streamableHttpTransportCloseFailure: mongoLogId(1_006_004), + streamableHttpTransportSessionCloseNotification: mongoLogId(1_006_003), + streamableHttpTransportRequestFailure: mongoLogId(1_006_004), + streamableHttpTransportCloseFailure: mongoLogId(1_006_005), } as const; export abstract class LoggerBase { diff --git a/src/common/sessionStore.ts b/src/common/sessionStore.ts index 9159f633..9ad9d9bb 100644 --- a/src/common/sessionStore.ts +++ b/src/common/sessionStore.ts @@ -1,31 +1,92 @@ import { StreamableHTTPServerTransport } from "@modelcontextprotocol/sdk/server/streamableHttp.js"; -import logger, { LogId } from "./logger.js"; +import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; +import logger, { LogId, McpLogger } from "./logger.js"; +import { TimeoutManager } from "./timeoutManager.js"; export class SessionStore { - private sessions: { [sessionId: string]: StreamableHTTPServerTransport } = {}; + private sessions: { + [sessionId: string]: { + mcpServer: McpServer; + transport: StreamableHTTPServerTransport; + abortTimeout: TimeoutManager; + notificationTimeout: TimeoutManager; + }; + } = {}; + + constructor( + private readonly idleTimeoutMS: number, + private readonly notificationTimeoutMS: number + ) { + if (idleTimeoutMS <= 0) { + throw new Error("idleTimeoutMS must be greater than 0"); + } + if (notificationTimeoutMS <= 0) { + throw new Error("notificationTimeoutMS must be greater than 0"); + } + if (idleTimeoutMS <= notificationTimeoutMS) { + throw new Error("idleTimeoutMS must be greater than notificationTimeoutMS"); + } + } getSession(sessionId: string): StreamableHTTPServerTransport | undefined { - return this.sessions[sessionId]; + this.resetTimeout(sessionId); + return this.sessions[sessionId]?.transport; + } + + private resetTimeout(sessionId: string): void { + const session = this.sessions[sessionId]; + if (!session) { + return; + } + + session.abortTimeout.reset(); + + session.notificationTimeout.reset(); + } + + private sendNotification(sessionId: string): void { + const session = this.sessions[sessionId]; + if (!session) { + return; + } + const logger = new McpLogger(session.mcpServer); + logger.info( + LogId.streamableHttpTransportSessionCloseNotification, + "sessionStore", + "Session is about to be closed due to inactivity" + ); } - setSession(sessionId: string, transport: StreamableHTTPServerTransport): void { + setSession(sessionId: string, transport: StreamableHTTPServerTransport, mcpServer: McpServer): void { if (this.sessions[sessionId]) { throw new Error(`Session ${sessionId} already exists`); } - this.sessions[sessionId] = transport; + const abortTimeout = new TimeoutManager(async () => { + const logger = new McpLogger(mcpServer); + logger.info( + LogId.streamableHttpTransportSessionCloseNotification, + "sessionStore", + "Session closed due to inactivity" + ); + + await this.closeSession(sessionId); + }, this.idleTimeoutMS); + const notificationTimeout = new TimeoutManager( + () => this.sendNotification(sessionId), + this.notificationTimeoutMS + ); + this.sessions[sessionId] = { mcpServer, transport, abortTimeout, notificationTimeout }; } async closeSession(sessionId: string, closeTransport: boolean = true): Promise { if (!this.sessions[sessionId]) { throw new Error(`Session ${sessionId} not found`); } + this.sessions[sessionId].abortTimeout.clear(); + this.sessions[sessionId].notificationTimeout.clear(); if (closeTransport) { - const transport = this.sessions[sessionId]; - if (!transport) { - throw new Error(`Session ${sessionId} not found`); - } try { - await transport.close(); + await this.sessions[sessionId].transport.close(); } catch (error) { logger.error( LogId.streamableHttpTransportSessionCloseFailure, @@ -38,11 +99,6 @@ export class SessionStore { } async closeAllSessions(): Promise { - await Promise.all( - Object.values(this.sessions) - .filter((transport) => transport !== undefined) - .map((transport) => transport.close()) - ); - this.sessions = {}; + await Promise.all(Object.keys(this.sessions).map((sessionId) => this.closeSession(sessionId))); } } diff --git a/src/common/timeoutManager.ts b/src/common/timeoutManager.ts new file mode 100644 index 00000000..03161dfc --- /dev/null +++ b/src/common/timeoutManager.ts @@ -0,0 +1,63 @@ +/** + * A class that manages timeouts for a callback function. + * It is used to ensure that a callback function is called after a certain amount of time. + * If the callback function is not called after the timeout, it will be called with an error. + */ +export class TimeoutManager { + private timeoutId?: NodeJS.Timeout; + + /** + * A callback function that is called when the timeout is reached. + */ + public onerror?: (error: unknown) => void; + + /** + * Creates a new TimeoutManager. + * @param callback - A callback function that is called when the timeout is reached. + * @param timeoutMS - The timeout in milliseconds. + */ + constructor( + private readonly callback: () => Promise | void, + private readonly timeoutMS: number + ) { + if (timeoutMS <= 0) { + throw new Error("timeoutMS must be greater than 0"); + } + this.reset(); + } + + /** + * Clears the timeout. + */ + clear() { + if (this.timeoutId) { + clearTimeout(this.timeoutId); + this.timeoutId = undefined; + } + } + + /** + * Runs the callback function. + */ + private async runCallback() { + if (this.callback) { + try { + await this.callback(); + } catch (error: unknown) { + this.onerror?.(error); + } + } + } + + /** + * Resets the timeout. + */ + reset() { + this.clear(); + this.timeoutId = setTimeout(() => { + void this.runCallback().finally(() => { + this.timeoutId = undefined; + }); + }, this.timeoutMS); + } +} diff --git a/src/index.ts b/src/index.ts index 4f81c0cd..fca9b83f 100644 --- a/src/index.ts +++ b/src/index.ts @@ -6,7 +6,7 @@ import { StdioRunner } from "./transports/stdio.js"; import { StreamableHttpRunner } from "./transports/streamableHttp.js"; async function main() { - const transportRunner = config.transport === "stdio" ? new StdioRunner() : new StreamableHttpRunner(); + const transportRunner = config.transport === "stdio" ? new StdioRunner(config) : new StreamableHttpRunner(config); const shutdown = () => { logger.info(LogId.serverCloseRequested, "server", `Server close requested`); diff --git a/src/transports/base.ts b/src/transports/base.ts index 442db18a..cc58f750 100644 --- a/src/transports/base.ts +++ b/src/transports/base.ts @@ -1,4 +1,4 @@ -import { config } from "../common/config.js"; +import { UserConfig } from "../common/config.js"; import { packageInfo } from "../common/packageInfo.js"; import { Server } from "../server.js"; import { Session } from "../common/session.js"; @@ -6,14 +6,14 @@ import { Telemetry } from "../telemetry/telemetry.js"; import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; export abstract class TransportRunnerBase { - protected setupServer(): Server { + protected setupServer(userConfig: UserConfig): Server { const session = new Session({ - apiBaseUrl: config.apiBaseUrl, - apiClientId: config.apiClientId, - apiClientSecret: config.apiClientSecret, + apiBaseUrl: userConfig.apiBaseUrl, + apiClientId: userConfig.apiClientId, + apiClientSecret: userConfig.apiClientSecret, }); - const telemetry = Telemetry.create(session, config); + const telemetry = Telemetry.create(session, userConfig); const mcpServer = new McpServer({ name: packageInfo.mcpServerName, @@ -24,7 +24,7 @@ export abstract class TransportRunnerBase { mcpServer, session, telemetry, - userConfig: config, + userConfig, }); } diff --git a/src/transports/stdio.ts b/src/transports/stdio.ts index 9f18627c..870ec73c 100644 --- a/src/transports/stdio.ts +++ b/src/transports/stdio.ts @@ -4,6 +4,7 @@ import { TransportRunnerBase } from "./base.js"; import { JSONRPCMessage, JSONRPCMessageSchema } from "@modelcontextprotocol/sdk/types.js"; import { EJSON } from "bson"; import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js"; +import { UserConfig } from "../common/config.js"; // This is almost a copy of ReadBuffer from @modelcontextprotocol/sdk // but it uses EJSON.parse instead of JSON.parse to handle BSON types @@ -52,9 +53,13 @@ export function createStdioTransport(): StdioServerTransport { export class StdioRunner extends TransportRunnerBase { private server: Server | undefined; + constructor(private userConfig: UserConfig) { + super(); + } + async start() { try { - this.server = this.setupServer(); + this.server = this.setupServer(this.userConfig); const transport = createStdioTransport(); diff --git a/src/transports/streamableHttp.ts b/src/transports/streamableHttp.ts index fbe01a55..282cd7bc 100644 --- a/src/transports/streamableHttp.ts +++ b/src/transports/streamableHttp.ts @@ -3,7 +3,7 @@ import http from "http"; import { StreamableHTTPServerTransport } from "@modelcontextprotocol/sdk/server/streamableHttp.js"; import { isInitializeRequest } from "@modelcontextprotocol/sdk/types.js"; import { TransportRunnerBase } from "./base.js"; -import { config } from "../common/config.js"; +import { UserConfig } from "../common/config.js"; import logger, { LogId } from "../common/logger.js"; import { randomUUID } from "crypto"; import { SessionStore } from "../common/sessionStore.js"; @@ -38,7 +38,12 @@ function promiseHandler( export class StreamableHttpRunner extends TransportRunnerBase { private httpServer: http.Server | undefined; - private sessionStore: SessionStore = new SessionStore(); + private sessionStore: SessionStore; + + constructor(private userConfig: UserConfig) { + super(); + this.sessionStore = new SessionStore(this.userConfig.idleTimeoutMs, this.userConfig.notificationTimeoutMs); + } async start() { const app = express(); @@ -101,11 +106,11 @@ export class StreamableHttpRunner extends TransportRunnerBase { return; } - const server = this.setupServer(); + const server = this.setupServer(this.userConfig); const transport = new StreamableHTTPServerTransport({ sessionIdGenerator: () => randomUUID().toString(), onsessioninitialized: (sessionId) => { - this.sessionStore.setSession(sessionId, transport); + this.sessionStore.setSession(sessionId, transport, server.mcpServer); }, onsessionclosed: async (sessionId) => { try { @@ -140,7 +145,7 @@ export class StreamableHttpRunner extends TransportRunnerBase { app.delete("/mcp", promiseHandler(handleRequest)); this.httpServer = await new Promise((resolve, reject) => { - const result = app.listen(config.httpPort, config.httpHost, (err?: Error) => { + const result = app.listen(this.userConfig.httpPort, this.userConfig.httpHost, (err?: Error) => { if (err) { reject(err); return; @@ -152,7 +157,7 @@ export class StreamableHttpRunner extends TransportRunnerBase { logger.info( LogId.streamableHttpTransportStarted, "streamableHttpTransport", - `Server started on http://${config.httpHost}:${config.httpPort}` + `Server started on http://${this.userConfig.httpHost}:${this.userConfig.httpPort}` ); } diff --git a/tests/integration/transports/streamableHttp.test.ts b/tests/integration/transports/streamableHttp.test.ts index c295705e..d5b6e0be 100644 --- a/tests/integration/transports/streamableHttp.test.ts +++ b/tests/integration/transports/streamableHttp.test.ts @@ -14,7 +14,7 @@ describe("StreamableHttpRunner", () => { oldLoggers = config.loggers; config.telemetry = "disabled"; config.loggers = ["stderr"]; - runner = new StreamableHttpRunner(); + runner = new StreamableHttpRunner(config); await runner.start(); }); diff --git a/tests/unit/common/timeoutManager.test.ts b/tests/unit/common/timeoutManager.test.ts new file mode 100644 index 00000000..a0cc5b30 --- /dev/null +++ b/tests/unit/common/timeoutManager.test.ts @@ -0,0 +1,79 @@ +import { afterAll, beforeAll, describe, expect, it, vi } from "vitest"; +import { TimeoutManager } from "../../../src/common/timeoutManager.js"; + +describe("TimeoutManager", () => { + beforeAll(() => { + vi.useFakeTimers(); + }); + + afterAll(() => { + vi.useRealTimers(); + }); + + it("calls the timeout callback", () => { + const callback = vi.fn(); + + new TimeoutManager(callback, 1000); + + vi.advanceTimersByTime(1000); + expect(callback).toHaveBeenCalled(); + }); + + it("does not call the timeout callback if the timeout is cleared", () => { + const callback = vi.fn(); + + const timeoutManager = new TimeoutManager(callback, 1000); + + vi.advanceTimersByTime(500); + timeoutManager.clear(); + vi.advanceTimersByTime(500); + + expect(callback).not.toHaveBeenCalled(); + }); + + it("does not call the timeout callback if the timeout is reset", () => { + const callback = vi.fn(); + + const timeoutManager = new TimeoutManager(callback, 1000); + + vi.advanceTimersByTime(500); + timeoutManager.reset(); + vi.advanceTimersByTime(500); + expect(callback).not.toHaveBeenCalled(); + }); + + it("calls the onerror callback", () => { + const onerrorCallback = vi.fn(); + + const tm = new TimeoutManager(() => { + throw new Error("test"); + }, 1000); + tm.onerror = onerrorCallback; + + vi.advanceTimersByTime(1000); + expect(onerrorCallback).toHaveBeenCalled(); + }); + + describe("if timeout is reset", () => { + it("does not call the timeout callback within the timeout period", () => { + const callback = vi.fn(); + + const timeoutManager = new TimeoutManager(callback, 1000); + + vi.advanceTimersByTime(500); + timeoutManager.reset(); + vi.advanceTimersByTime(500); + expect(callback).not.toHaveBeenCalled(); + }); + it("calls the timeout callback after the timeout period", () => { + const callback = vi.fn(); + + const timeoutManager = new TimeoutManager(callback, 1000); + + vi.advanceTimersByTime(500); + timeoutManager.reset(); + vi.advanceTimersByTime(1000); + expect(callback).toHaveBeenCalled(); + }); + }); +}); From 3ba8a4a71da6a1137b33404920d711aed2982214 Mon Sep 17 00:00:00 2001 From: Filipe Constantinov Menezes Date: Mon, 21 Jul 2025 17:40:54 +0100 Subject: [PATCH 7/7] chore: address comments from #361 (#386) --- src/common/logger.ts | 5 +- src/common/managedTimeout.ts | 27 ++++++++ src/common/sessionStore.ts | 55 +++++++++------- src/common/timeoutManager.ts | 63 ------------------- ...Manager.test.ts => managedTimeout.test.ts} | 34 ++++------ 5 files changed, 72 insertions(+), 112 deletions(-) create mode 100644 src/common/managedTimeout.ts delete mode 100644 src/common/timeoutManager.ts rename tests/unit/common/{timeoutManager.test.ts => managedTimeout.test.ts} (61%) diff --git a/src/common/logger.ts b/src/common/logger.ts index 7ed1a9ac..0c2fd726 100644 --- a/src/common/logger.ts +++ b/src/common/logger.ts @@ -44,8 +44,9 @@ export const LogId = { streamableHttpTransportStarted: mongoLogId(1_006_001), streamableHttpTransportSessionCloseFailure: mongoLogId(1_006_002), streamableHttpTransportSessionCloseNotification: mongoLogId(1_006_003), - streamableHttpTransportRequestFailure: mongoLogId(1_006_004), - streamableHttpTransportCloseFailure: mongoLogId(1_006_005), + streamableHttpTransportSessionCloseNotificationFailure: mongoLogId(1_006_004), + streamableHttpTransportRequestFailure: mongoLogId(1_006_005), + streamableHttpTransportCloseFailure: mongoLogId(1_006_006), } as const; export abstract class LoggerBase { diff --git a/src/common/managedTimeout.ts b/src/common/managedTimeout.ts new file mode 100644 index 00000000..9309947e --- /dev/null +++ b/src/common/managedTimeout.ts @@ -0,0 +1,27 @@ +export interface ManagedTimeout { + cancel: () => void; + restart: () => void; +} + +export function setManagedTimeout(callback: () => Promise | void, timeoutMS: number): ManagedTimeout { + let timeoutId: NodeJS.Timeout | undefined = setTimeout(() => { + void callback(); + }, timeoutMS); + + function cancel() { + clearTimeout(timeoutId); + timeoutId = undefined; + } + + function restart() { + cancel(); + timeoutId = setTimeout(() => { + void callback(); + }, timeoutMS); + } + + return { + cancel, + restart, + }; +} diff --git a/src/common/sessionStore.ts b/src/common/sessionStore.ts index 9ad9d9bb..e37358fc 100644 --- a/src/common/sessionStore.ts +++ b/src/common/sessionStore.ts @@ -1,15 +1,15 @@ import { StreamableHTTPServerTransport } from "@modelcontextprotocol/sdk/server/streamableHttp.js"; import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; -import logger, { LogId, McpLogger } from "./logger.js"; -import { TimeoutManager } from "./timeoutManager.js"; +import logger, { LogId, LoggerBase, McpLogger } from "./logger.js"; +import { ManagedTimeout, setManagedTimeout } from "./managedTimeout.js"; export class SessionStore { private sessions: { [sessionId: string]: { - mcpServer: McpServer; + logger: LoggerBase; transport: StreamableHTTPServerTransport; - abortTimeout: TimeoutManager; - notificationTimeout: TimeoutManager; + abortTimeout: ManagedTimeout; + notificationTimeout: ManagedTimeout; }; } = {}; @@ -39,18 +39,22 @@ export class SessionStore { return; } - session.abortTimeout.reset(); + session.abortTimeout.restart(); - session.notificationTimeout.reset(); + session.notificationTimeout.restart(); } private sendNotification(sessionId: string): void { const session = this.sessions[sessionId]; if (!session) { + logger.warning( + LogId.streamableHttpTransportSessionCloseNotificationFailure, + "sessionStore", + `session ${sessionId} not found, no notification delivered` + ); return; } - const logger = new McpLogger(session.mcpServer); - logger.info( + session.logger.info( LogId.streamableHttpTransportSessionCloseNotification, "sessionStore", "Session is about to be closed due to inactivity" @@ -58,35 +62,38 @@ export class SessionStore { } setSession(sessionId: string, transport: StreamableHTTPServerTransport, mcpServer: McpServer): void { - if (this.sessions[sessionId]) { + const session = this.sessions[sessionId]; + if (session) { throw new Error(`Session ${sessionId} already exists`); } - const abortTimeout = new TimeoutManager(async () => { - const logger = new McpLogger(mcpServer); - logger.info( - LogId.streamableHttpTransportSessionCloseNotification, - "sessionStore", - "Session closed due to inactivity" - ); + const abortTimeout = setManagedTimeout(async () => { + if (this.sessions[sessionId]) { + this.sessions[sessionId].logger.info( + LogId.streamableHttpTransportSessionCloseNotification, + "sessionStore", + "Session closed due to inactivity" + ); - await this.closeSession(sessionId); + await this.closeSession(sessionId); + } }, this.idleTimeoutMS); - const notificationTimeout = new TimeoutManager( + const notificationTimeout = setManagedTimeout( () => this.sendNotification(sessionId), this.notificationTimeoutMS ); - this.sessions[sessionId] = { mcpServer, transport, abortTimeout, notificationTimeout }; + this.sessions[sessionId] = { logger: new McpLogger(mcpServer), transport, abortTimeout, notificationTimeout }; } async closeSession(sessionId: string, closeTransport: boolean = true): Promise { - if (!this.sessions[sessionId]) { + const session = this.sessions[sessionId]; + if (!session) { throw new Error(`Session ${sessionId} not found`); } - this.sessions[sessionId].abortTimeout.clear(); - this.sessions[sessionId].notificationTimeout.clear(); + session.abortTimeout.cancel(); + session.notificationTimeout.cancel(); if (closeTransport) { try { - await this.sessions[sessionId].transport.close(); + await session.transport.close(); } catch (error) { logger.error( LogId.streamableHttpTransportSessionCloseFailure, diff --git a/src/common/timeoutManager.ts b/src/common/timeoutManager.ts deleted file mode 100644 index 03161dfc..00000000 --- a/src/common/timeoutManager.ts +++ /dev/null @@ -1,63 +0,0 @@ -/** - * A class that manages timeouts for a callback function. - * It is used to ensure that a callback function is called after a certain amount of time. - * If the callback function is not called after the timeout, it will be called with an error. - */ -export class TimeoutManager { - private timeoutId?: NodeJS.Timeout; - - /** - * A callback function that is called when the timeout is reached. - */ - public onerror?: (error: unknown) => void; - - /** - * Creates a new TimeoutManager. - * @param callback - A callback function that is called when the timeout is reached. - * @param timeoutMS - The timeout in milliseconds. - */ - constructor( - private readonly callback: () => Promise | void, - private readonly timeoutMS: number - ) { - if (timeoutMS <= 0) { - throw new Error("timeoutMS must be greater than 0"); - } - this.reset(); - } - - /** - * Clears the timeout. - */ - clear() { - if (this.timeoutId) { - clearTimeout(this.timeoutId); - this.timeoutId = undefined; - } - } - - /** - * Runs the callback function. - */ - private async runCallback() { - if (this.callback) { - try { - await this.callback(); - } catch (error: unknown) { - this.onerror?.(error); - } - } - } - - /** - * Resets the timeout. - */ - reset() { - this.clear(); - this.timeoutId = setTimeout(() => { - void this.runCallback().finally(() => { - this.timeoutId = undefined; - }); - }, this.timeoutMS); - } -} diff --git a/tests/unit/common/timeoutManager.test.ts b/tests/unit/common/managedTimeout.test.ts similarity index 61% rename from tests/unit/common/timeoutManager.test.ts rename to tests/unit/common/managedTimeout.test.ts index a0cc5b30..d51c4b13 100644 --- a/tests/unit/common/timeoutManager.test.ts +++ b/tests/unit/common/managedTimeout.test.ts @@ -1,7 +1,7 @@ import { afterAll, beforeAll, describe, expect, it, vi } from "vitest"; -import { TimeoutManager } from "../../../src/common/timeoutManager.js"; +import { setManagedTimeout } from "../../../src/common/managedTimeout.js"; -describe("TimeoutManager", () => { +describe("setManagedTimeout", () => { beforeAll(() => { vi.useFakeTimers(); }); @@ -13,7 +13,7 @@ describe("TimeoutManager", () => { it("calls the timeout callback", () => { const callback = vi.fn(); - new TimeoutManager(callback, 1000); + setManagedTimeout(callback, 1000); vi.advanceTimersByTime(1000); expect(callback).toHaveBeenCalled(); @@ -22,10 +22,10 @@ describe("TimeoutManager", () => { it("does not call the timeout callback if the timeout is cleared", () => { const callback = vi.fn(); - const timeoutManager = new TimeoutManager(callback, 1000); + const timeout = setManagedTimeout(callback, 1000); vi.advanceTimersByTime(500); - timeoutManager.clear(); + timeout.cancel(); vi.advanceTimersByTime(500); expect(callback).not.toHaveBeenCalled(); @@ -34,44 +34,32 @@ describe("TimeoutManager", () => { it("does not call the timeout callback if the timeout is reset", () => { const callback = vi.fn(); - const timeoutManager = new TimeoutManager(callback, 1000); + const timeout = setManagedTimeout(callback, 1000); vi.advanceTimersByTime(500); - timeoutManager.reset(); + timeout.restart(); vi.advanceTimersByTime(500); expect(callback).not.toHaveBeenCalled(); }); - it("calls the onerror callback", () => { - const onerrorCallback = vi.fn(); - - const tm = new TimeoutManager(() => { - throw new Error("test"); - }, 1000); - tm.onerror = onerrorCallback; - - vi.advanceTimersByTime(1000); - expect(onerrorCallback).toHaveBeenCalled(); - }); - describe("if timeout is reset", () => { it("does not call the timeout callback within the timeout period", () => { const callback = vi.fn(); - const timeoutManager = new TimeoutManager(callback, 1000); + const timeout = setManagedTimeout(callback, 1000); vi.advanceTimersByTime(500); - timeoutManager.reset(); + timeout.restart(); vi.advanceTimersByTime(500); expect(callback).not.toHaveBeenCalled(); }); it("calls the timeout callback after the timeout period", () => { const callback = vi.fn(); - const timeoutManager = new TimeoutManager(callback, 1000); + const timeout = setManagedTimeout(callback, 1000); vi.advanceTimersByTime(500); - timeoutManager.reset(); + timeout.restart(); vi.advanceTimersByTime(1000); expect(callback).toHaveBeenCalled(); });