diff --git a/.vscode/launch.json b/.vscode/launch.json index 0756e2d0..f8eaa53f 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -19,6 +19,7 @@ "name": "Launch Program", "skipFiles": ["/**"], "program": "${workspaceFolder}/dist/index.js", + "args": ["--transport", "http", "--loggers", "stderr", "mcp"], "preLaunchTask": "tsc: build - tsconfig.build.json", "outFiles": ["${workspaceFolder}/dist/**/*.js"] } diff --git a/README.md b/README.md index fa36afc1..d845bf94 100644 --- a/README.md +++ b/README.md @@ -299,20 +299,20 @@ The MongoDB MCP Server can be configured using multiple methods, with the follow ### Configuration Options -| Option | Default | Description | -| ------------------ | ---------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `apiClientId` | | Atlas API client ID for authentication. Required for running Atlas tools. | -| `apiClientSecret` | | Atlas API client secret for authentication. Required for running Atlas tools. | -| `connectionString` | | MongoDB connection string for direct database connections. Optional, if not set, you'll need to call the `connect` tool before interacting with MongoDB data. | -| `loggers` | disk,mcp | Comma separated values, possible values are `mcp`, `disk` and `stderr`. See [Logger Options](#logger-options) for details. | -| `logPath` | see note\* | Folder to store logs. | -| `disabledTools` | | An array of tool names, operation types, and/or categories of tools that will be disabled. | -| `readOnly` | false | When set to true, only allows read, connect, and metadata operation types, disabling create/update/delete operations. | -| `indexCheck` | false | When set to true, enforces that query operations must use an index, rejecting queries that perform a collection scan. | -| `telemetry` | enabled | When set to disabled, disables telemetry collection. | -| `transport` | stdio | Either 'stdio' or 'http'. | -| `httpPort` | 3000 | Port number. | -| `httpHost` | 127.0.0.1 | Host to bind the http server. | +| CLI Option | Environment Variable | Default | Description | +| ------------------ | --------------------------- | ---------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `apiClientId` | `MDB_MCP_API_CLIENT_ID` | | Atlas API client ID for authentication. Required for running Atlas tools. | +| `apiClientSecret` | `MDB_MCP_API_CLIENT_SECRET` | | Atlas API client secret for authentication. Required for running Atlas tools. | +| `connectionString` | `MDB_MCP_CONNECTION_STRING` | | MongoDB connection string for direct database connections. Optional, if not set, you'll need to call the `connect` tool before interacting with MongoDB data. | +| `loggers` | `MDB_MCP_LOGGERS` | disk,mcp | Comma separated values, possible values are `mcp`, `disk` and `stderr`. See [Logger Options](#logger-options) for details. | +| `logPath` | `MDB_MCP_LOG_PATH` | see note\* | Folder to store logs. | +| `disabledTools` | `MDB_MCP_DISABLED_TOOLS` | | An array of tool names, operation types, and/or categories of tools that will be disabled. | +| `readOnly` | `MDB_MCP_READ_ONLY` | false | When set to true, only allows read, connect, and metadata operation types, disabling create/update/delete operations. | +| `indexCheck` | `MDB_MCP_INDEX_CHECK` | false | When set to true, enforces that query operations must use an index, rejecting queries that perform a collection scan. | +| `telemetry` | `MDB_MCP_TELEMETRY` | enabled | When set to disabled, disables telemetry collection. | +| `transport` | `MDB_MCP_TRANSPORT` | stdio | Either 'stdio' or 'http'. | +| `httpPort` | `MDB_MCP_HTTP_PORT` | 3000 | Port number. | +| `httpHost` | `MDB_MCP_HTTP_HOST` | 127.0.0.1 | Host to bind the http server. | #### Logger Options diff --git a/package.json b/package.json index 3606f7c9..6668cc87 100644 --- a/package.json +++ b/package.json @@ -16,6 +16,7 @@ }, "type": "module", "scripts": { + "start": "node dist/index.js --transport http", "prepare": "npm run build", "build:clean": "rm -rf dist", "build:compile": "tsc --project tsconfig.build.json", diff --git a/src/common/logger.ts b/src/common/logger.ts index 8f6069a0..faa5507a 100644 --- a/src/common/logger.ts +++ b/src/common/logger.ts @@ -40,12 +40,9 @@ export const LogId = { toolUpdateFailure: mongoLogId(1_005_001), streamableHttpTransportStarted: mongoLogId(1_006_001), - streamableHttpTransportStartFailure: mongoLogId(1_006_002), - streamableHttpTransportSessionInitialized: mongoLogId(1_006_003), - streamableHttpTransportRequestFailure: mongoLogId(1_006_004), - streamableHttpTransportCloseRequested: mongoLogId(1_006_005), - streamableHttpTransportCloseSuccess: mongoLogId(1_006_006), - streamableHttpTransportCloseFailure: mongoLogId(1_006_007), + streamableHttpTransportSessionCloseFailure: mongoLogId(1_006_002), + streamableHttpTransportRequestFailure: mongoLogId(1_006_003), + streamableHttpTransportCloseFailure: mongoLogId(1_006_004), } as const; export abstract class LoggerBase { diff --git a/src/common/sessionStore.ts b/src/common/sessionStore.ts new file mode 100644 index 00000000..9159f633 --- /dev/null +++ b/src/common/sessionStore.ts @@ -0,0 +1,48 @@ +import { StreamableHTTPServerTransport } from "@modelcontextprotocol/sdk/server/streamableHttp.js"; +import logger, { LogId } from "./logger.js"; + +export class SessionStore { + private sessions: { [sessionId: string]: StreamableHTTPServerTransport } = {}; + + getSession(sessionId: string): StreamableHTTPServerTransport | undefined { + return this.sessions[sessionId]; + } + + setSession(sessionId: string, transport: StreamableHTTPServerTransport): void { + if (this.sessions[sessionId]) { + throw new Error(`Session ${sessionId} already exists`); + } + this.sessions[sessionId] = transport; + } + + async closeSession(sessionId: string, closeTransport: boolean = true): Promise { + if (!this.sessions[sessionId]) { + throw new Error(`Session ${sessionId} not found`); + } + if (closeTransport) { + const transport = this.sessions[sessionId]; + if (!transport) { + throw new Error(`Session ${sessionId} not found`); + } + try { + await transport.close(); + } catch (error) { + logger.error( + LogId.streamableHttpTransportSessionCloseFailure, + "streamableHttpTransport", + `Error closing transport ${sessionId}: ${error instanceof Error ? error.message : String(error)}` + ); + } + } + delete this.sessions[sessionId]; + } + + async closeAllSessions(): Promise { + await Promise.all( + Object.values(this.sessions) + .filter((transport) => transport !== undefined) + .map((transport) => transport.close()) + ); + this.sessions = {}; + } +} diff --git a/src/index.ts b/src/index.ts index 73457dd6..4f81c0cd 100644 --- a/src/index.ts +++ b/src/index.ts @@ -14,6 +14,7 @@ async function main() { transportRunner .close() .then(() => { + logger.info(LogId.serverClosed, "server", `Server closed`); process.exit(0); }) .catch((error: unknown) => { @@ -22,10 +23,10 @@ async function main() { }); }; - process.once("SIGINT", shutdown); - process.once("SIGABRT", shutdown); - process.once("SIGTERM", shutdown); - process.once("SIGQUIT", shutdown); + process.on("SIGINT", shutdown); + process.on("SIGABRT", shutdown); + process.on("SIGTERM", shutdown); + process.on("SIGQUIT", shutdown); try { await transportRunner.start(); diff --git a/src/transports/streamableHttp.ts b/src/transports/streamableHttp.ts index e15af8d5..fbe01a55 100644 --- a/src/transports/streamableHttp.ts +++ b/src/transports/streamableHttp.ts @@ -1,90 +1,143 @@ import express from "express"; import http from "http"; import { StreamableHTTPServerTransport } from "@modelcontextprotocol/sdk/server/streamableHttp.js"; +import { isInitializeRequest } from "@modelcontextprotocol/sdk/types.js"; import { TransportRunnerBase } from "./base.js"; import { config } from "../common/config.js"; import logger, { LogId } from "../common/logger.js"; +import { randomUUID } from "crypto"; +import { SessionStore } from "../common/sessionStore.js"; const JSON_RPC_ERROR_CODE_PROCESSING_REQUEST_FAILED = -32000; -const JSON_RPC_ERROR_CODE_METHOD_NOT_ALLOWED = -32601; +const JSON_RPC_ERROR_CODE_SESSION_ID_REQUIRED = -32001; +const JSON_RPC_ERROR_CODE_SESSION_ID_INVALID = -32002; +const JSON_RPC_ERROR_CODE_SESSION_NOT_FOUND = -32003; +const JSON_RPC_ERROR_CODE_INVALID_REQUEST = -32004; function promiseHandler( fn: (req: express.Request, res: express.Response, next: express.NextFunction) => Promise ) { return (req: express.Request, res: express.Response, next: express.NextFunction) => { - fn(req, res, next).catch(next); + fn(req, res, next).catch((error) => { + logger.error( + LogId.streamableHttpTransportRequestFailure, + "streamableHttpTransport", + `Error handling request: ${error instanceof Error ? error.message : String(error)}` + ); + res.status(400).json({ + jsonrpc: "2.0", + error: { + code: JSON_RPC_ERROR_CODE_PROCESSING_REQUEST_FAILED, + message: `failed to handle request`, + data: error instanceof Error ? error.message : String(error), + }, + }); + }); }; } export class StreamableHttpRunner extends TransportRunnerBase { private httpServer: http.Server | undefined; + private sessionStore: SessionStore = new SessionStore(); async start() { const app = express(); app.enable("trust proxy"); // needed for reverse proxy support - app.use(express.urlencoded({ extended: true })); app.use(express.json()); + const handleRequest = async (req: express.Request, res: express.Response) => { + const sessionId = req.headers["mcp-session-id"]; + if (!sessionId) { + res.status(400).json({ + jsonrpc: "2.0", + error: { + code: JSON_RPC_ERROR_CODE_SESSION_ID_REQUIRED, + message: `session id is required`, + }, + }); + return; + } + if (typeof sessionId !== "string") { + res.status(400).json({ + jsonrpc: "2.0", + error: { + code: JSON_RPC_ERROR_CODE_SESSION_ID_INVALID, + message: `session id is invalid`, + }, + }); + return; + } + const transport = this.sessionStore.getSession(sessionId); + if (!transport) { + res.status(404).json({ + jsonrpc: "2.0", + error: { + code: JSON_RPC_ERROR_CODE_SESSION_NOT_FOUND, + message: `session not found`, + }, + }); + return; + } + await transport.handleRequest(req, res, req.body); + }; + app.post( "/mcp", promiseHandler(async (req: express.Request, res: express.Response) => { - const transport = new StreamableHTTPServerTransport({ - sessionIdGenerator: undefined, - }); + const sessionId = req.headers["mcp-session-id"]; + if (sessionId) { + await handleRequest(req, res); + return; + } - const server = this.setupServer(); + if (!isInitializeRequest(req.body)) { + res.status(400).json({ + jsonrpc: "2.0", + error: { + code: JSON_RPC_ERROR_CODE_INVALID_REQUEST, + message: `invalid request`, + }, + }); + return; + } - await server.connect(transport); + const server = this.setupServer(); + const transport = new StreamableHTTPServerTransport({ + sessionIdGenerator: () => randomUUID().toString(), + onsessioninitialized: (sessionId) => { + this.sessionStore.setSession(sessionId, transport); + }, + onsessionclosed: async (sessionId) => { + try { + await this.sessionStore.closeSession(sessionId, false); + } catch (error) { + logger.error( + LogId.streamableHttpTransportSessionCloseFailure, + "streamableHttpTransport", + `Error closing session: ${error instanceof Error ? error.message : String(error)}` + ); + } + }, + }); - res.on("close", () => { - Promise.all([transport.close(), server.close()]).catch((error: unknown) => { + transport.onclose = () => { + server.close().catch((error) => { logger.error( LogId.streamableHttpTransportCloseFailure, "streamableHttpTransport", `Error closing server: ${error instanceof Error ? error.message : String(error)}` ); }); - }); + }; - try { - await transport.handleRequest(req, res, req.body); - } catch (error) { - logger.error( - LogId.streamableHttpTransportRequestFailure, - "streamableHttpTransport", - `Error handling request: ${error instanceof Error ? error.message : String(error)}` - ); - res.status(400).json({ - jsonrpc: "2.0", - error: { - code: JSON_RPC_ERROR_CODE_PROCESSING_REQUEST_FAILED, - message: `failed to handle request`, - data: error instanceof Error ? error.message : String(error), - }, - }); - } + await server.connect(transport); + + await transport.handleRequest(req, res, req.body); }) ); - app.get("/mcp", (req: express.Request, res: express.Response) => { - res.status(405).json({ - jsonrpc: "2.0", - error: { - code: JSON_RPC_ERROR_CODE_METHOD_NOT_ALLOWED, - message: `method not allowed`, - }, - }); - }); - - app.delete("/mcp", (req: express.Request, res: express.Response) => { - res.status(405).json({ - jsonrpc: "2.0", - error: { - code: JSON_RPC_ERROR_CODE_METHOD_NOT_ALLOWED, - message: `method not allowed`, - }, - }); - }); + app.get("/mcp", promiseHandler(handleRequest)); + app.delete("/mcp", promiseHandler(handleRequest)); this.httpServer = await new Promise((resolve, reject) => { const result = app.listen(config.httpPort, config.httpHost, (err?: Error) => { @@ -104,14 +157,17 @@ export class StreamableHttpRunner extends TransportRunnerBase { } async close(): Promise { - await new Promise((resolve, reject) => { - this.httpServer?.close((err) => { - if (err) { - reject(err); - return; - } - resolve(); - }); - }); + await Promise.all([ + this.sessionStore.closeAllSessions(), + new Promise((resolve, reject) => { + this.httpServer?.close((err) => { + if (err) { + reject(err); + return; + } + resolve(); + }); + }), + ]); } } diff --git a/tests/integration/transports/stdio.test.ts b/tests/integration/transports/stdio.test.ts index afbcce00..2bc03b5b 100644 --- a/tests/integration/transports/stdio.test.ts +++ b/tests/integration/transports/stdio.test.ts @@ -1,70 +1,40 @@ -import { JSONRPCMessage } from "@modelcontextprotocol/sdk/types.js"; import { describe, expect, it, beforeAll, afterAll } from "vitest"; import { StdioClientTransport } from "@modelcontextprotocol/sdk/client/stdio.js"; +import { Client } from "@modelcontextprotocol/sdk/client/index.js"; describe("StdioRunner", () => { describe("client connects successfully", () => { - let client: StdioClientTransport; + let client: Client; + let transport: StdioClientTransport; beforeAll(async () => { - client = new StdioClientTransport({ + transport = new StdioClientTransport({ command: "node", args: ["dist/index.js"], env: { MDB_MCP_TRANSPORT: "stdio", }, }); - await client.start(); + client = new Client({ + name: "test", + version: "0.0.0", + }); + await client.connect(transport); }); afterAll(async () => { await client.close(); + await transport.close(); }); it("handles requests and sends responses", async () => { - let fixedResolve: ((value: JSONRPCMessage) => void) | undefined = undefined; - const messagePromise = new Promise((resolve) => { - fixedResolve = resolve; - }); - - client.onmessage = (message: JSONRPCMessage) => { - fixedResolve?.(message); - }; - - await client.send({ - jsonrpc: "2.0", - id: 1, - method: "tools/list", - params: { - _meta: { - progressToken: 1, - }, - }, - }); - - const message = (await messagePromise) as { - jsonrpc: string; - id: number; - result: { - tools: { - name: string; - description: string; - }[]; - }; - error?: { - code: number; - message: string; - }; - }; + const response = await client.listTools(); + expect(response).toBeDefined(); + expect(response.tools).toBeDefined(); + expect(response.tools).toHaveLength(20); - expect(message.jsonrpc).toBe("2.0"); - expect(message.id).toBe(1); - expect(message.result).toBeDefined(); - expect(message.result?.tools).toBeDefined(); - expect(message.result?.tools.length).toBeGreaterThan(0); - const tools = message.result?.tools; - tools.sort((a, b) => a.name.localeCompare(b.name)); - expect(tools[0]?.name).toBe("aggregate"); - expect(tools[0]?.description).toBe("Run an aggregation against a MongoDB collection"); + const sortedTools = response.tools.sort((a, b) => a.name.localeCompare(b.name)); + expect(sortedTools[0]?.name).toBe("aggregate"); + expect(sortedTools[0]?.description).toBe("Run an aggregation against a MongoDB collection"); }); }); }); diff --git a/tests/integration/transports/streamableHttp.test.ts b/tests/integration/transports/streamableHttp.test.ts index 031e7798..c295705e 100644 --- a/tests/integration/transports/streamableHttp.test.ts +++ b/tests/integration/transports/streamableHttp.test.ts @@ -1,76 +1,56 @@ import { StreamableHttpRunner } from "../../../src/transports/streamableHttp.js"; +import { Client } from "@modelcontextprotocol/sdk/client/index.js"; import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js"; -import { JSONRPCMessage } from "@modelcontextprotocol/sdk/types.js"; import { describe, expect, it, beforeAll, afterAll } from "vitest"; +import { config } from "../../../src/common/config.js"; describe("StreamableHttpRunner", () => { let runner: StreamableHttpRunner; + let oldTelemetry: "enabled" | "disabled"; + let oldLoggers: ("stderr" | "disk" | "mcp")[]; beforeAll(async () => { + oldTelemetry = config.telemetry; + oldLoggers = config.loggers; + config.telemetry = "disabled"; + config.loggers = ["stderr"]; runner = new StreamableHttpRunner(); await runner.start(); }); afterAll(async () => { await runner.close(); + config.telemetry = oldTelemetry; + config.loggers = oldLoggers; }); describe("client connects successfully", () => { - let client: StreamableHTTPClientTransport; + let client: Client; + let transport: StreamableHTTPClientTransport; beforeAll(async () => { - client = new StreamableHTTPClientTransport(new URL("http://127.0.0.1:3000/mcp")); - await client.start(); + transport = new StreamableHTTPClientTransport(new URL("http://127.0.0.1:3000/mcp")); + + client = new Client({ + name: "test", + version: "0.0.0", + }); + await client.connect(transport); }); afterAll(async () => { await client.close(); + await transport.close(); }); it("handles requests and sends responses", async () => { - let fixedResolve: ((value: JSONRPCMessage) => void) | undefined = undefined; - const messagePromise = new Promise((resolve) => { - fixedResolve = resolve; - }); - - client.onmessage = (message: JSONRPCMessage) => { - fixedResolve?.(message); - }; - - await client.send({ - jsonrpc: "2.0", - id: 1, - method: "tools/list", - params: { - _meta: { - progressToken: 1, - }, - }, - }); - - const message = (await messagePromise) as { - jsonrpc: string; - id: number; - result: { - tools: { - name: string; - description: string; - }[]; - }; - error?: { - code: number; - message: string; - }; - }; - - expect(message.jsonrpc).toBe("2.0"); - expect(message.id).toBe(1); - expect(message.result).toBeDefined(); - expect(message.result?.tools).toBeDefined(); - expect(message.result?.tools.length).toBeGreaterThan(0); - const tools = message.result?.tools; - tools.sort((a, b) => a.name.localeCompare(b.name)); - expect(tools[0]?.name).toBe("aggregate"); - expect(tools[0]?.description).toBe("Run an aggregation against a MongoDB collection"); + const response = await client.listTools(); + expect(response).toBeDefined(); + expect(response.tools).toBeDefined(); + expect(response.tools.length).toBeGreaterThan(0); + + const sortedTools = response.tools.sort((a, b) => a.name.localeCompare(b.name)); + expect(sortedTools[0]?.name).toBe("aggregate"); + expect(sortedTools[0]?.description).toBe("Run an aggregation against a MongoDB collection"); }); }); });