diff --git a/package.json b/package.json index 9b6347b3..b18fa2cb 100644 --- a/package.json +++ b/package.json @@ -29,7 +29,7 @@ "check:types": "tsc --noEmit --project tsconfig.json", "reformat": "prettier --write .", "generate": "./scripts/generate.sh", - "test": "vitest --coverage" + "test": "vitest --run --coverage" }, "license": "Apache-2.0", "devDependencies": { diff --git a/src/index.ts b/src/index.ts index c5f4ddee..73457dd6 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,50 +1,23 @@ #!/usr/bin/env node import logger, { LogId } from "./common/logger.js"; -import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; import { config } from "./common/config.js"; -import { Session } from "./common/session.js"; -import { Server } from "./server.js"; -import { packageInfo } from "./common/packageInfo.js"; -import { Telemetry } from "./telemetry/telemetry.js"; -import { createStdioTransport } from "./transports/stdio.js"; -import { createHttpTransport } from "./transports/streamableHttp.js"; +import { StdioRunner } from "./transports/stdio.js"; +import { StreamableHttpRunner } from "./transports/streamableHttp.js"; -try { - const session = new Session({ - apiBaseUrl: config.apiBaseUrl, - apiClientId: config.apiClientId, - apiClientSecret: config.apiClientSecret, - }); - - const transport = config.transport === "stdio" ? createStdioTransport() : await createHttpTransport(); - - const telemetry = Telemetry.create(session, config); - - const mcpServer = new McpServer({ - name: packageInfo.mcpServerName, - version: packageInfo.version, - }); - - const server = new Server({ - mcpServer, - session, - telemetry, - userConfig: config, - }); +async function main() { + const transportRunner = config.transport === "stdio" ? new StdioRunner() : new StreamableHttpRunner(); const shutdown = () => { logger.info(LogId.serverCloseRequested, "server", `Server close requested`); - server + transportRunner .close() .then(() => { - logger.info(LogId.serverClosed, "server", `Server closed successfully`); process.exit(0); }) - .catch((err: unknown) => { - const error = err instanceof Error ? err : new Error(String(err)); - logger.error(LogId.serverCloseFailure, "server", `Error closing server: ${error.message}`); + .catch((error: unknown) => { + logger.error(LogId.serverCloseFailure, "server", `Error closing server: ${error as string}`); process.exit(1); }); }; @@ -54,8 +27,22 @@ try { process.once("SIGTERM", shutdown); process.once("SIGQUIT", shutdown); - await server.connect(transport); -} catch (error: unknown) { + try { + await transportRunner.start(); + } catch (error: unknown) { + logger.emergency(LogId.serverStartFailure, "server", `Fatal error running server: ${error as string}`); + try { + await transportRunner.close(); + logger.error(LogId.serverClosed, "server", "Server closed"); + } catch (error: unknown) { + logger.error(LogId.serverCloseFailure, "server", `Error closing server: ${error as string}`); + } finally { + process.exit(1); + } + } +} + +main().catch((error: unknown) => { logger.emergency(LogId.serverStartFailure, "server", `Fatal error running server: ${error as string}`); process.exit(1); -} +}); diff --git a/src/transports/base.ts b/src/transports/base.ts new file mode 100644 index 00000000..442db18a --- /dev/null +++ b/src/transports/base.ts @@ -0,0 +1,34 @@ +import { config } from "../common/config.js"; +import { packageInfo } from "../common/packageInfo.js"; +import { Server } from "../server.js"; +import { Session } from "../common/session.js"; +import { Telemetry } from "../telemetry/telemetry.js"; +import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; + +export abstract class TransportRunnerBase { + protected setupServer(): Server { + const session = new Session({ + apiBaseUrl: config.apiBaseUrl, + apiClientId: config.apiClientId, + apiClientSecret: config.apiClientSecret, + }); + + const telemetry = Telemetry.create(session, config); + + const mcpServer = new McpServer({ + name: packageInfo.mcpServerName, + version: packageInfo.version, + }); + + return new Server({ + mcpServer, + session, + telemetry, + userConfig: config, + }); + } + + abstract start(): Promise; + + abstract close(): Promise; +} diff --git a/src/transports/stdio.ts b/src/transports/stdio.ts index 0f9f4c0c..9f18627c 100644 --- a/src/transports/stdio.ts +++ b/src/transports/stdio.ts @@ -1,3 +1,6 @@ +import logger, { LogId } from "../common/logger.js"; +import { Server } from "../server.js"; +import { TransportRunnerBase } from "./base.js"; import { JSONRPCMessage, JSONRPCMessageSchema } from "@modelcontextprotocol/sdk/types.js"; import { EJSON } from "bson"; import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js"; @@ -45,3 +48,24 @@ export function createStdioTransport(): StdioServerTransport { return server; } + +export class StdioRunner extends TransportRunnerBase { + private server: Server | undefined; + + async start() { + try { + this.server = this.setupServer(); + + const transport = createStdioTransport(); + + await this.server.connect(transport); + } catch (error: unknown) { + logger.emergency(LogId.serverStartFailure, "server", `Fatal error running server: ${error as string}`); + process.exit(1); + } + } + + async close(): Promise { + await this.server?.close(); + } +} diff --git a/src/transports/streamableHttp.ts b/src/transports/streamableHttp.ts index bb4d0f06..e15af8d5 100644 --- a/src/transports/streamableHttp.ts +++ b/src/transports/streamableHttp.ts @@ -1,84 +1,92 @@ import express from "express"; import http from "http"; import { StreamableHTTPServerTransport } from "@modelcontextprotocol/sdk/server/streamableHttp.js"; - +import { TransportRunnerBase } from "./base.js"; import { config } from "../common/config.js"; import logger, { LogId } from "../common/logger.js"; const JSON_RPC_ERROR_CODE_PROCESSING_REQUEST_FAILED = -32000; +const JSON_RPC_ERROR_CODE_METHOD_NOT_ALLOWED = -32601; -export async function createHttpTransport(): Promise { - const app = express(); - app.enable("trust proxy"); // needed for reverse proxy support - app.use(express.urlencoded({ extended: true })); - app.use(express.json()); +function promiseHandler( + fn: (req: express.Request, res: express.Response, next: express.NextFunction) => Promise +) { + return (req: express.Request, res: express.Response, next: express.NextFunction) => { + fn(req, res, next).catch(next); + }; +} - const transport = new StreamableHTTPServerTransport({ - sessionIdGenerator: undefined, - }); +export class StreamableHttpRunner extends TransportRunnerBase { + private httpServer: http.Server | undefined; - app.post("/mcp", async (req: express.Request, res: express.Response) => { - try { - await transport.handleRequest(req, res, req.body); - } catch (error) { - logger.error( - LogId.streamableHttpTransportRequestFailure, - "streamableHttpTransport", - `Error handling request: ${error instanceof Error ? error.message : String(error)}` - ); - res.status(400).json({ - jsonrpc: "2.0", - error: { - code: JSON_RPC_ERROR_CODE_PROCESSING_REQUEST_FAILED, - message: `failed to handle request`, - data: error instanceof Error ? error.message : String(error), - }, - }); - } - }); + async start() { + const app = express(); + app.enable("trust proxy"); // needed for reverse proxy support + app.use(express.urlencoded({ extended: true })); + app.use(express.json()); + + app.post( + "/mcp", + promiseHandler(async (req: express.Request, res: express.Response) => { + const transport = new StreamableHTTPServerTransport({ + sessionIdGenerator: undefined, + }); + + const server = this.setupServer(); - app.get("/mcp", async (req: express.Request, res: express.Response) => { - try { - await transport.handleRequest(req, res, req.body); - } catch (error) { - logger.error( - LogId.streamableHttpTransportRequestFailure, - "streamableHttpTransport", - `Error handling request: ${error instanceof Error ? error.message : String(error)}` - ); - res.status(400).json({ + await server.connect(transport); + + res.on("close", () => { + Promise.all([transport.close(), server.close()]).catch((error: unknown) => { + logger.error( + LogId.streamableHttpTransportCloseFailure, + "streamableHttpTransport", + `Error closing server: ${error instanceof Error ? error.message : String(error)}` + ); + }); + }); + + try { + await transport.handleRequest(req, res, req.body); + } catch (error) { + logger.error( + LogId.streamableHttpTransportRequestFailure, + "streamableHttpTransport", + `Error handling request: ${error instanceof Error ? error.message : String(error)}` + ); + res.status(400).json({ + jsonrpc: "2.0", + error: { + code: JSON_RPC_ERROR_CODE_PROCESSING_REQUEST_FAILED, + message: `failed to handle request`, + data: error instanceof Error ? error.message : String(error), + }, + }); + } + }) + ); + + app.get("/mcp", (req: express.Request, res: express.Response) => { + res.status(405).json({ jsonrpc: "2.0", error: { - code: JSON_RPC_ERROR_CODE_PROCESSING_REQUEST_FAILED, - message: `failed to handle request`, - data: error instanceof Error ? error.message : String(error), + code: JSON_RPC_ERROR_CODE_METHOD_NOT_ALLOWED, + message: `method not allowed`, }, }); - } - }); + }); - app.delete("/mcp", async (req: express.Request, res: express.Response) => { - try { - await transport.handleRequest(req, res, req.body); - } catch (error) { - logger.error( - LogId.streamableHttpTransportRequestFailure, - "streamableHttpTransport", - `Error handling request: ${error instanceof Error ? error.message : String(error)}` - ); - res.status(400).json({ + app.delete("/mcp", (req: express.Request, res: express.Response) => { + res.status(405).json({ jsonrpc: "2.0", error: { - code: JSON_RPC_ERROR_CODE_PROCESSING_REQUEST_FAILED, - message: `failed to handle request`, - data: error instanceof Error ? error.message : String(error), + code: JSON_RPC_ERROR_CODE_METHOD_NOT_ALLOWED, + message: `method not allowed`, }, }); - } - }); + }); - try { - const server = await new Promise((resolve, reject) => { + this.httpServer = await new Promise((resolve, reject) => { const result = app.listen(config.httpPort, config.httpHost, (err?: Error) => { if (err) { reject(err); @@ -93,31 +101,17 @@ export async function createHttpTransport(): Promise { - logger.info(LogId.streamableHttpTransportCloseRequested, "streamableHttpTransport", `Closing server`); - server.close((err?: Error) => { + async close(): Promise { + await new Promise((resolve, reject) => { + this.httpServer?.close((err) => { if (err) { - logger.error( - LogId.streamableHttpTransportCloseFailure, - "streamableHttpTransport", - `Error closing server: ${err.message}` - ); + reject(err); return; } - logger.info(LogId.streamableHttpTransportCloseSuccess, "streamableHttpTransport", `Server closed`); + resolve(); }); - }; - - return transport; - } catch (error: unknown) { - const err = error instanceof Error ? error : new Error(String(error)); - logger.info( - LogId.streamableHttpTransportStartFailure, - "streamableHttpTransport", - `Error starting server: ${err.message}` - ); - - throw err; + }); } } diff --git a/tests/integration/transports/stdio.test.ts b/tests/integration/transports/stdio.test.ts new file mode 100644 index 00000000..afbcce00 --- /dev/null +++ b/tests/integration/transports/stdio.test.ts @@ -0,0 +1,70 @@ +import { JSONRPCMessage } from "@modelcontextprotocol/sdk/types.js"; +import { describe, expect, it, beforeAll, afterAll } from "vitest"; +import { StdioClientTransport } from "@modelcontextprotocol/sdk/client/stdio.js"; + +describe("StdioRunner", () => { + describe("client connects successfully", () => { + let client: StdioClientTransport; + beforeAll(async () => { + client = new StdioClientTransport({ + command: "node", + args: ["dist/index.js"], + env: { + MDB_MCP_TRANSPORT: "stdio", + }, + }); + await client.start(); + }); + + afterAll(async () => { + await client.close(); + }); + + it("handles requests and sends responses", async () => { + let fixedResolve: ((value: JSONRPCMessage) => void) | undefined = undefined; + const messagePromise = new Promise((resolve) => { + fixedResolve = resolve; + }); + + client.onmessage = (message: JSONRPCMessage) => { + fixedResolve?.(message); + }; + + await client.send({ + jsonrpc: "2.0", + id: 1, + method: "tools/list", + params: { + _meta: { + progressToken: 1, + }, + }, + }); + + const message = (await messagePromise) as { + jsonrpc: string; + id: number; + result: { + tools: { + name: string; + description: string; + }[]; + }; + error?: { + code: number; + message: string; + }; + }; + + expect(message.jsonrpc).toBe("2.0"); + expect(message.id).toBe(1); + expect(message.result).toBeDefined(); + expect(message.result?.tools).toBeDefined(); + expect(message.result?.tools.length).toBeGreaterThan(0); + const tools = message.result?.tools; + tools.sort((a, b) => a.name.localeCompare(b.name)); + expect(tools[0]?.name).toBe("aggregate"); + expect(tools[0]?.description).toBe("Run an aggregation against a MongoDB collection"); + }); + }); +}); diff --git a/tests/integration/transports/streamableHttp.test.ts b/tests/integration/transports/streamableHttp.test.ts new file mode 100644 index 00000000..031e7798 --- /dev/null +++ b/tests/integration/transports/streamableHttp.test.ts @@ -0,0 +1,76 @@ +import { StreamableHttpRunner } from "../../../src/transports/streamableHttp.js"; +import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js"; +import { JSONRPCMessage } from "@modelcontextprotocol/sdk/types.js"; +import { describe, expect, it, beforeAll, afterAll } from "vitest"; + +describe("StreamableHttpRunner", () => { + let runner: StreamableHttpRunner; + + beforeAll(async () => { + runner = new StreamableHttpRunner(); + await runner.start(); + }); + + afterAll(async () => { + await runner.close(); + }); + + describe("client connects successfully", () => { + let client: StreamableHTTPClientTransport; + beforeAll(async () => { + client = new StreamableHTTPClientTransport(new URL("http://127.0.0.1:3000/mcp")); + await client.start(); + }); + + afterAll(async () => { + await client.close(); + }); + + it("handles requests and sends responses", async () => { + let fixedResolve: ((value: JSONRPCMessage) => void) | undefined = undefined; + const messagePromise = new Promise((resolve) => { + fixedResolve = resolve; + }); + + client.onmessage = (message: JSONRPCMessage) => { + fixedResolve?.(message); + }; + + await client.send({ + jsonrpc: "2.0", + id: 1, + method: "tools/list", + params: { + _meta: { + progressToken: 1, + }, + }, + }); + + const message = (await messagePromise) as { + jsonrpc: string; + id: number; + result: { + tools: { + name: string; + description: string; + }[]; + }; + error?: { + code: number; + message: string; + }; + }; + + expect(message.jsonrpc).toBe("2.0"); + expect(message.id).toBe(1); + expect(message.result).toBeDefined(); + expect(message.result?.tools).toBeDefined(); + expect(message.result?.tools.length).toBeGreaterThan(0); + const tools = message.result?.tools; + tools.sort((a, b) => a.name.localeCompare(b.name)); + expect(tools[0]?.name).toBe("aggregate"); + expect(tools[0]?.description).toBe("Run an aggregation against a MongoDB collection"); + }); + }); +}); diff --git a/tests/unit/transports/stdio.test.ts b/tests/unit/transports/stdio.test.ts index 2a1c62de..6a53f67b 100644 --- a/tests/unit/transports/stdio.test.ts +++ b/tests/unit/transports/stdio.test.ts @@ -6,7 +6,6 @@ import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js" import { Readable } from "stream"; import { ReadBuffer } from "@modelcontextprotocol/sdk/shared/stdio.js"; import { describe, expect, it, beforeEach, afterEach } from "vitest"; - describe("stdioTransport", () => { let transport: StdioServerTransport; beforeEach(async () => { diff --git a/tests/unit/transports/streamableHttp.test.ts b/tests/unit/transports/streamableHttp.test.ts deleted file mode 100644 index 01eeb136..00000000 --- a/tests/unit/transports/streamableHttp.test.ts +++ /dev/null @@ -1,78 +0,0 @@ -import { createHttpTransport } from "../../../src/transports/streamableHttp.js"; -import { StreamableHTTPServerTransport } from "@modelcontextprotocol/sdk/server/streamableHttp.js"; -import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js"; -import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; -import { JSONRPCMessage } from "@modelcontextprotocol/sdk/types.js"; -import { z } from "zod"; -import { describe, expect, it, beforeAll, afterAll } from "vitest"; -describe("streamableHttpTransport", () => { - let transport: StreamableHTTPServerTransport; - const mcpServer = new McpServer({ - name: "test", - version: "1.0.0", - }); - beforeAll(async () => { - transport = await createHttpTransport(); - mcpServer.registerTool( - "hello", - { - title: "Hello Tool", - description: "Say hello", - inputSchema: { name: z.string() }, - }, - ({ name }) => ({ - content: [{ type: "text", text: `Hello, ${name}!` }], - }) - ); - await mcpServer.connect(transport); - }); - - afterAll(async () => { - await mcpServer.close(); - }); - - describe("client connects successfully", () => { - let client: StreamableHTTPClientTransport; - beforeAll(async () => { - client = new StreamableHTTPClientTransport(new URL("http://127.0.0.1:3000/mcp")); - await client.start(); - }); - - afterAll(async () => { - await client.close(); - }); - - it("handles requests and sends responses", async () => { - client.onmessage = (message: JSONRPCMessage) => { - const messageResult = message as - | { - result?: { - tools: { - name: string; - description: string; - }[]; - }; - } - | undefined; - - expect(message.jsonrpc).toBe("2.0"); - expect(messageResult).toBeDefined(); - expect(messageResult?.result?.tools).toBeDefined(); - expect(messageResult?.result?.tools.length).toBe(1); - expect(messageResult?.result?.tools[0]?.name).toBe("hello"); - expect(messageResult?.result?.tools[0]?.description).toBe("Say hello"); - }; - - await client.send({ - jsonrpc: "2.0", - id: 1, - method: "tools/list", - params: { - _meta: { - progressToken: 1, - }, - }, - }); - }); - }); -});