diff --git a/README.md b/README.md index aa8f9304c..8dc0ef535 100644 --- a/README.md +++ b/README.md @@ -268,6 +268,7 @@ server.registerPrompt( { title: "Code Review", description: "Review code for best practices and potential issues", + // can use string or enum argsSchema: { code: z.string() } }, ({ code }) => ({ diff --git a/src/server/mcp.test.ts b/src/server/mcp.test.ts index 50df25b53..be2b3272c 100644 --- a/src/server/mcp.test.ts +++ b/src/server/mcp.test.ts @@ -4010,6 +4010,69 @@ describe("Tool title precedence", () => { }); }); + test("registerPrompt schema should support strings and enums", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + + const client = new Client({ + name: "test client", + version: "1.0", + }); + + mcpServer.registerPrompt( + "test-prompt", + { + title: "Team Greeting", + description: "Generate a greeting for team members", + argsSchema: { + name: z.string(), + visibility: z.enum(["public", "private"]).optional().default("private"), + } + }, + async ({ name, visibility }) => ({ + messages: [ + { + role: "assistant", + content: { + type: "text", + text: `Creating a new project named ${name} with visibility ${visibility}.`, + }, + }, + ], + }), + ); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + const result1 = await client.getPrompt({ + name: "test-prompt", + arguments: { + name: "Test Project", + visibility: "public" + } + }); + + expect(result1.messages[0].content.text).toBe("Creating a new project named Test Project with visibility public."); + + await expect( + client.getPrompt({ + name: "test-prompt", + arguments: { + name: "Test Project", + visibility: "foo" + } + }), + ).rejects.toThrow(); +}); + describe("elicitInput()", () => { const checkAvailability = jest.fn().mockResolvedValue(false); diff --git a/src/server/mcp.ts b/src/server/mcp.ts index 3d9673da7..c4e29aba1 100644 --- a/src/server/mcp.ts +++ b/src/server/mcp.ts @@ -10,6 +10,8 @@ import { ZodType, ZodTypeDef, ZodOptional, + ZodDefault, + ZodEnum, } from "zod"; import { Implementation, @@ -1246,7 +1248,11 @@ export type RegisteredResourceTemplate = { type PromptArgsRawShape = { [k: string]: | ZodType - | ZodOptional>; + | ZodEnum<[string, ...string[]]> + | ZodDefault> + | ZodDefault>> + | ZodOptional> + | ZodDefault>>; }; export type PromptCallback< diff --git a/src/types.ts b/src/types.ts index 3606a6be7..b47739dfe 100644 --- a/src/types.ts +++ b/src/types.ts @@ -685,7 +685,10 @@ export const GetPromptRequestSchema = RequestSchema.extend({ /** * Arguments to use for templating the prompt. */ - arguments: z.optional(z.record(z.string())), + arguments: z.optional(z.record(z.union([ + z.string(), + z.object({ type: z.literal("enum"), enum: z.array(z.string()) }) + ]))), }), });