diff --git a/.vscode/launch.json b/.vscode/launch.json index f8eaa53f..8eec7d6e 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -18,8 +18,8 @@ "request": "launch", "name": "Launch Program", "skipFiles": ["/**"], - "program": "${workspaceFolder}/dist/index.js", - "args": ["--transport", "http", "--loggers", "stderr", "mcp"], + "runtimeExecutable": "npm", + "runtimeArgs": ["start"], "preLaunchTask": "tsc: build - tsconfig.build.json", "outFiles": ["${workspaceFolder}/dist/**/*.js"] } diff --git a/README.md b/README.md index 78169a00..6a91e158 100644 --- a/README.md +++ b/README.md @@ -302,20 +302,22 @@ The MongoDB MCP Server can be configured using multiple methods, with the follow ### Configuration Options -| CLI Option | Environment Variable | Default | Description | -| ------------------ | --------------------------- | ---------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `apiClientId` | `MDB_MCP_API_CLIENT_ID` | | Atlas API client ID for authentication. Required for running Atlas tools. | -| `apiClientSecret` | `MDB_MCP_API_CLIENT_SECRET` | | Atlas API client secret for authentication. Required for running Atlas tools. | -| `connectionString` | `MDB_MCP_CONNECTION_STRING` | | MongoDB connection string for direct database connections. Optional, if not set, you'll need to call the `connect` tool before interacting with MongoDB data. | -| `loggers` | `MDB_MCP_LOGGERS` | disk,mcp | Comma separated values, possible values are `mcp`, `disk` and `stderr`. See [Logger Options](#logger-options) for details. | -| `logPath` | `MDB_MCP_LOG_PATH` | see note\* | Folder to store logs. | -| `disabledTools` | `MDB_MCP_DISABLED_TOOLS` | | An array of tool names, operation types, and/or categories of tools that will be disabled. | -| `readOnly` | `MDB_MCP_READ_ONLY` | false | When set to true, only allows read, connect, and metadata operation types, disabling create/update/delete operations. | -| `indexCheck` | `MDB_MCP_INDEX_CHECK` | false | When set to true, enforces that query operations must use an index, rejecting queries that perform a collection scan. | -| `telemetry` | `MDB_MCP_TELEMETRY` | enabled | When set to disabled, disables telemetry collection. | -| `transport` | `MDB_MCP_TRANSPORT` | stdio | Either 'stdio' or 'http'. | -| `httpPort` | `MDB_MCP_HTTP_PORT` | 3000 | Port number. | -| `httpHost` | `MDB_MCP_HTTP_HOST` | 127.0.0.1 | Host to bind the http server. | +| CLI Option | Environment Variable | Default | Description | +| ----------------------- | --------------------------------- | ---------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `apiClientId` | `MDB_MCP_API_CLIENT_ID` | | Atlas API client ID for authentication. Required for running Atlas tools. | +| `apiClientSecret` | `MDB_MCP_API_CLIENT_SECRET` | | Atlas API client secret for authentication. Required for running Atlas tools. | +| `connectionString` | `MDB_MCP_CONNECTION_STRING` | | MongoDB connection string for direct database connections. Optional, if not set, you'll need to call the `connect` tool before interacting with MongoDB data. | +| `loggers` | `MDB_MCP_LOGGERS` | disk,mcp | Comma separated values, possible values are `mcp`, `disk` and `stderr`. See [Logger Options](#logger-options) for details. | +| `logPath` | `MDB_MCP_LOG_PATH` | see note\* | Folder to store logs. | +| `disabledTools` | `MDB_MCP_DISABLED_TOOLS` | | An array of tool names, operation types, and/or categories of tools that will be disabled. | +| `readOnly` | `MDB_MCP_READ_ONLY` | false | When set to true, only allows read, connect, and metadata operation types, disabling create/update/delete operations. | +| `indexCheck` | `MDB_MCP_INDEX_CHECK` | false | When set to true, enforces that query operations must use an index, rejecting queries that perform a collection scan. | +| `telemetry` | `MDB_MCP_TELEMETRY` | enabled | When set to disabled, disables telemetry collection. | +| `transport` | `MDB_MCP_TRANSPORT` | stdio | Either 'stdio' or 'http'. | +| `httpPort` | `MDB_MCP_HTTP_PORT` | 3000 | Port number. | +| `httpHost` | `MDB_MCP_HTTP_HOST` | 127.0.0.1 | Host to bind the http server. | +| `idleTimeoutMs` | `MDB_MCP_IDLE_TIMEOUT_MS` | 600000 | Idle timeout for a client to disconnect (only applies to http transport). | +| `notificationTimeoutMs` | `MDB_MCP_NOTIFICATION_TIMEOUT_MS` | 540000 | Notification timeout for a client to be aware of diconnect (only applies to http transport). | #### Logger Options diff --git a/package.json b/package.json index cafa5e9b..b58c1191 100644 --- a/package.json +++ b/package.json @@ -16,7 +16,7 @@ }, "type": "module", "scripts": { - "start": "node dist/index.js --transport http", + "start": "node dist/index.js --transport http --loggers stderr mcp", "prepare": "npm run build", "build:clean": "rm -rf dist", "build:compile": "tsc --project tsconfig.build.json", diff --git a/src/common/config.ts b/src/common/config.ts index 98c13cfc..3406a440 100644 --- a/src/common/config.ts +++ b/src/common/config.ts @@ -28,6 +28,8 @@ export interface UserConfig { httpPort: number; httpHost: string; loggers: Array<"stderr" | "disk" | "mcp">; + idleTimeoutMs: number; + notificationTimeoutMs: number; } const defaults: UserConfig = { @@ -47,6 +49,8 @@ const defaults: UserConfig = { httpPort: 3000, httpHost: "127.0.0.1", loggers: ["disk", "mcp"], + idleTimeoutMs: 600000, // 10 minutes + notificationTimeoutMs: 540000, // 9 minutes }; export const config = { diff --git a/src/common/logger.ts b/src/common/logger.ts index 259d173e..7ed1a9ac 100644 --- a/src/common/logger.ts +++ b/src/common/logger.ts @@ -43,8 +43,9 @@ export const LogId = { streamableHttpTransportStarted: mongoLogId(1_006_001), streamableHttpTransportSessionCloseFailure: mongoLogId(1_006_002), - streamableHttpTransportRequestFailure: mongoLogId(1_006_003), - streamableHttpTransportCloseFailure: mongoLogId(1_006_004), + streamableHttpTransportSessionCloseNotification: mongoLogId(1_006_003), + streamableHttpTransportRequestFailure: mongoLogId(1_006_004), + streamableHttpTransportCloseFailure: mongoLogId(1_006_005), } as const; export abstract class LoggerBase { diff --git a/src/common/sessionStore.ts b/src/common/sessionStore.ts index 9159f633..643543a8 100644 --- a/src/common/sessionStore.ts +++ b/src/common/sessionStore.ts @@ -1,31 +1,132 @@ import { StreamableHTTPServerTransport } from "@modelcontextprotocol/sdk/server/streamableHttp.js"; -import logger, { LogId } from "./logger.js"; +import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; +import logger, { LogId, McpLogger } from "./logger.js"; + +class TimeoutManager { + private timeoutId?: NodeJS.Timeout; + public onerror?: (error: unknown) => void; + + constructor( + private readonly callback: () => Promise | void, + private readonly timeoutMS: number + ) { + if (timeoutMS <= 0) { + throw new Error("timeoutMS must be greater than 0"); + } + this.reset(); + } + + clear() { + if (this.timeoutId) { + clearTimeout(this.timeoutId); + this.timeoutId = undefined; + } + } + + private async runCallback() { + if (this.callback) { + try { + await this.callback(); + } catch (error: unknown) { + this.onerror?.(error); + } + } + } + + reset() { + this.clear(); + this.timeoutId = setTimeout(() => { + void this.runCallback().finally(() => { + this.timeoutId = undefined; + }); + }, this.timeoutMS); + } +} export class SessionStore { - private sessions: { [sessionId: string]: StreamableHTTPServerTransport } = {}; + private sessions: { + [sessionId: string]: { + mcpServer: McpServer; + transport: StreamableHTTPServerTransport; + abortTimeout: TimeoutManager; + notificationTimeout: TimeoutManager; + }; + } = {}; + + constructor( + private readonly idleTimeoutMS: number, + private readonly notificationTimeoutMS: number + ) { + if (idleTimeoutMS <= 0) { + throw new Error("idleTimeoutMS must be greater than 0"); + } + if (notificationTimeoutMS <= 0) { + throw new Error("notificationTimeoutMS must be greater than 0"); + } + if (idleTimeoutMS <= notificationTimeoutMS) { + throw new Error("idleTimeoutMS must be greater than notificationTimeoutMS"); + } + } getSession(sessionId: string): StreamableHTTPServerTransport | undefined { - return this.sessions[sessionId]; + this.resetTimeout(sessionId); + return this.sessions[sessionId]?.transport; + } + + private resetTimeout(sessionId: string): void { + const session = this.sessions[sessionId]; + if (!session) { + return; + } + + session.abortTimeout.reset(); + + session.notificationTimeout.reset(); } - setSession(sessionId: string, transport: StreamableHTTPServerTransport): void { + private sendNotification(sessionId: string): void { + const session = this.sessions[sessionId]; + if (!session) { + return; + } + const logger = new McpLogger(session.mcpServer); + logger.info( + LogId.streamableHttpTransportSessionCloseNotification, + "sessionStore", + "Session is about to be closed due to inactivity" + ); + } + + setSession(sessionId: string, transport: StreamableHTTPServerTransport, mcpServer: McpServer): void { if (this.sessions[sessionId]) { throw new Error(`Session ${sessionId} already exists`); } - this.sessions[sessionId] = transport; + const abortTimeout = new TimeoutManager(async () => { + const logger = new McpLogger(mcpServer); + logger.info( + LogId.streamableHttpTransportSessionCloseNotification, + "sessionStore", + "Session closed due to inactivity" + ); + + await this.closeSession(sessionId); + }, this.idleTimeoutMS); + const notificationTimeout = new TimeoutManager( + () => this.sendNotification(sessionId), + this.notificationTimeoutMS + ); + this.sessions[sessionId] = { mcpServer, transport, abortTimeout, notificationTimeout }; } async closeSession(sessionId: string, closeTransport: boolean = true): Promise { if (!this.sessions[sessionId]) { throw new Error(`Session ${sessionId} not found`); } + this.sessions[sessionId].abortTimeout.clear(); + this.sessions[sessionId].notificationTimeout.clear(); if (closeTransport) { - const transport = this.sessions[sessionId]; - if (!transport) { - throw new Error(`Session ${sessionId} not found`); - } try { - await transport.close(); + await this.sessions[sessionId].transport.close(); } catch (error) { logger.error( LogId.streamableHttpTransportSessionCloseFailure, @@ -38,11 +139,6 @@ export class SessionStore { } async closeAllSessions(): Promise { - await Promise.all( - Object.values(this.sessions) - .filter((transport) => transport !== undefined) - .map((transport) => transport.close()) - ); - this.sessions = {}; + await Promise.all(Object.keys(this.sessions).map((sessionId) => this.closeSession(sessionId))); } } diff --git a/src/index.ts b/src/index.ts index 4f81c0cd..fca9b83f 100644 --- a/src/index.ts +++ b/src/index.ts @@ -6,7 +6,7 @@ import { StdioRunner } from "./transports/stdio.js"; import { StreamableHttpRunner } from "./transports/streamableHttp.js"; async function main() { - const transportRunner = config.transport === "stdio" ? new StdioRunner() : new StreamableHttpRunner(); + const transportRunner = config.transport === "stdio" ? new StdioRunner(config) : new StreamableHttpRunner(config); const shutdown = () => { logger.info(LogId.serverCloseRequested, "server", `Server close requested`); diff --git a/src/transports/base.ts b/src/transports/base.ts index 442db18a..cc58f750 100644 --- a/src/transports/base.ts +++ b/src/transports/base.ts @@ -1,4 +1,4 @@ -import { config } from "../common/config.js"; +import { UserConfig } from "../common/config.js"; import { packageInfo } from "../common/packageInfo.js"; import { Server } from "../server.js"; import { Session } from "../common/session.js"; @@ -6,14 +6,14 @@ import { Telemetry } from "../telemetry/telemetry.js"; import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; export abstract class TransportRunnerBase { - protected setupServer(): Server { + protected setupServer(userConfig: UserConfig): Server { const session = new Session({ - apiBaseUrl: config.apiBaseUrl, - apiClientId: config.apiClientId, - apiClientSecret: config.apiClientSecret, + apiBaseUrl: userConfig.apiBaseUrl, + apiClientId: userConfig.apiClientId, + apiClientSecret: userConfig.apiClientSecret, }); - const telemetry = Telemetry.create(session, config); + const telemetry = Telemetry.create(session, userConfig); const mcpServer = new McpServer({ name: packageInfo.mcpServerName, @@ -24,7 +24,7 @@ export abstract class TransportRunnerBase { mcpServer, session, telemetry, - userConfig: config, + userConfig, }); } diff --git a/src/transports/stdio.ts b/src/transports/stdio.ts index 9f18627c..870ec73c 100644 --- a/src/transports/stdio.ts +++ b/src/transports/stdio.ts @@ -4,6 +4,7 @@ import { TransportRunnerBase } from "./base.js"; import { JSONRPCMessage, JSONRPCMessageSchema } from "@modelcontextprotocol/sdk/types.js"; import { EJSON } from "bson"; import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js"; +import { UserConfig } from "../common/config.js"; // This is almost a copy of ReadBuffer from @modelcontextprotocol/sdk // but it uses EJSON.parse instead of JSON.parse to handle BSON types @@ -52,9 +53,13 @@ export function createStdioTransport(): StdioServerTransport { export class StdioRunner extends TransportRunnerBase { private server: Server | undefined; + constructor(private userConfig: UserConfig) { + super(); + } + async start() { try { - this.server = this.setupServer(); + this.server = this.setupServer(this.userConfig); const transport = createStdioTransport(); diff --git a/src/transports/streamableHttp.ts b/src/transports/streamableHttp.ts index fbe01a55..282cd7bc 100644 --- a/src/transports/streamableHttp.ts +++ b/src/transports/streamableHttp.ts @@ -3,7 +3,7 @@ import http from "http"; import { StreamableHTTPServerTransport } from "@modelcontextprotocol/sdk/server/streamableHttp.js"; import { isInitializeRequest } from "@modelcontextprotocol/sdk/types.js"; import { TransportRunnerBase } from "./base.js"; -import { config } from "../common/config.js"; +import { UserConfig } from "../common/config.js"; import logger, { LogId } from "../common/logger.js"; import { randomUUID } from "crypto"; import { SessionStore } from "../common/sessionStore.js"; @@ -38,7 +38,12 @@ function promiseHandler( export class StreamableHttpRunner extends TransportRunnerBase { private httpServer: http.Server | undefined; - private sessionStore: SessionStore = new SessionStore(); + private sessionStore: SessionStore; + + constructor(private userConfig: UserConfig) { + super(); + this.sessionStore = new SessionStore(this.userConfig.idleTimeoutMs, this.userConfig.notificationTimeoutMs); + } async start() { const app = express(); @@ -101,11 +106,11 @@ export class StreamableHttpRunner extends TransportRunnerBase { return; } - const server = this.setupServer(); + const server = this.setupServer(this.userConfig); const transport = new StreamableHTTPServerTransport({ sessionIdGenerator: () => randomUUID().toString(), onsessioninitialized: (sessionId) => { - this.sessionStore.setSession(sessionId, transport); + this.sessionStore.setSession(sessionId, transport, server.mcpServer); }, onsessionclosed: async (sessionId) => { try { @@ -140,7 +145,7 @@ export class StreamableHttpRunner extends TransportRunnerBase { app.delete("/mcp", promiseHandler(handleRequest)); this.httpServer = await new Promise((resolve, reject) => { - const result = app.listen(config.httpPort, config.httpHost, (err?: Error) => { + const result = app.listen(this.userConfig.httpPort, this.userConfig.httpHost, (err?: Error) => { if (err) { reject(err); return; @@ -152,7 +157,7 @@ export class StreamableHttpRunner extends TransportRunnerBase { logger.info( LogId.streamableHttpTransportStarted, "streamableHttpTransport", - `Server started on http://${config.httpHost}:${config.httpPort}` + `Server started on http://${this.userConfig.httpHost}:${this.userConfig.httpPort}` ); } diff --git a/tests/integration/transports/streamableHttp.test.ts b/tests/integration/transports/streamableHttp.test.ts index c295705e..d5b6e0be 100644 --- a/tests/integration/transports/streamableHttp.test.ts +++ b/tests/integration/transports/streamableHttp.test.ts @@ -14,7 +14,7 @@ describe("StreamableHttpRunner", () => { oldLoggers = config.loggers; config.telemetry = "disabled"; config.loggers = ["stderr"]; - runner = new StreamableHttpRunner(); + runner = new StreamableHttpRunner(config); await runner.start(); });