diff --git a/src/ModelContextProtocol.Core/Protocol/ElicitRequestParams.cs b/src/ModelContextProtocol.Core/Protocol/ElicitRequestParams.cs index 05d8a49a..3a9926e2 100644 --- a/src/ModelContextProtocol.Core/Protocol/ElicitRequestParams.cs +++ b/src/ModelContextProtocol.Core/Protocol/ElicitRequestParams.cs @@ -1,4 +1,7 @@ +using System.ComponentModel; +using System.Diagnostics; using System.Diagnostics.CodeAnalysis; +using System.Text.Json; using System.Text.Json.Serialization; namespace ModelContextProtocol.Protocol; @@ -54,39 +57,273 @@ public IDictionary Properties public IList? Required { get; set; } } - /// /// Represents restricted subset of JSON Schema: /// , , , or . /// - [JsonDerivedType(typeof(BooleanSchema))] - [JsonDerivedType(typeof(EnumSchema))] - [JsonDerivedType(typeof(NumberSchema))] - [JsonDerivedType(typeof(StringSchema))] + [JsonConverter(typeof(Converter))] // TODO: This converter exists due to the lack of downlevel support for AllowOutOfOrderMetadataProperties. public abstract class PrimitiveSchemaDefinition { /// Prevent external derivations. protected private PrimitiveSchemaDefinition() { } - } - /// Represents a schema for a string type. - public sealed class StringSchema : PrimitiveSchemaDefinition - { /// Gets the type of the schema. - /// This is always "string". [JsonPropertyName("type")] - public string Type => "string"; + public abstract string Type { get; set; } - /// Gets or sets a title for the string. + /// Gets or sets a title for the schema. [JsonPropertyName("title")] public string? Title { get; set; } - /// Gets or sets a description for the string. + /// Gets or sets a description for the schema. [JsonPropertyName("description")] public string? Description { get; set; } + /// + /// Provides a for . + /// + [EditorBrowsable(EditorBrowsableState.Never)] + public class Converter : JsonConverter + { + /// + public override PrimitiveSchemaDefinition? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + if (reader.TokenType == JsonTokenType.Null) + { + return null; + } + + if (reader.TokenType != JsonTokenType.StartObject) + { + throw new JsonException(); + } + + string? type = null; + string? title = null; + string? description = null; + int? minLength = null; + int? maxLength = null; + string? format = null; + double? minimum = null; + double? maximum = null; + bool? defaultBool = null; + IList? enumValues = null; + IList? enumNames = null; + + while (reader.Read() && reader.TokenType != JsonTokenType.EndObject) + { + if (reader.TokenType != JsonTokenType.PropertyName) + { + continue; + } + + string? propertyName = reader.GetString(); + bool success = reader.Read(); + Debug.Assert(success, "STJ must have buffered the entire object for us."); + + switch (propertyName) + { + case "type": + type = reader.GetString(); + break; + + case "title": + title = reader.GetString(); + break; + + case "description": + description = reader.GetString(); + break; + + case "minLength": + minLength = reader.GetInt32(); + break; + + case "maxLength": + maxLength = reader.GetInt32(); + break; + + case "format": + format = reader.GetString(); + break; + + case "minimum": + minimum = reader.GetDouble(); + break; + + case "maximum": + maximum = reader.GetDouble(); + break; + + case "default": + defaultBool = reader.GetBoolean(); + break; + + case "enum": + enumValues = JsonSerializer.Deserialize(ref reader, McpJsonUtilities.JsonContext.Default.IListString); + break; + + case "enumNames": + enumNames = JsonSerializer.Deserialize(ref reader, McpJsonUtilities.JsonContext.Default.IListString); + break; + + default: + break; + } + } + + if (type is null) + { + throw new JsonException("The 'type' property is required."); + } + + PrimitiveSchemaDefinition? psd = null; + switch (type) + { + case "string": + if (enumValues is not null) + { + psd = new EnumSchema + { + Enum = enumValues, + EnumNames = enumNames + }; + } + else + { + psd = new StringSchema + { + MinLength = minLength, + MaxLength = maxLength, + Format = format, + }; + } + break; + + case "integer": + case "number": + psd = new NumberSchema + { + Minimum = minimum, + Maximum = maximum, + }; + break; + + case "boolean": + psd = new BooleanSchema + { + Default = defaultBool, + }; + break; + } + + if (psd is not null) + { + psd.Type = type; + psd.Title = title; + psd.Description = description; + } + + return psd; + } + + /// + public override void Write(Utf8JsonWriter writer, PrimitiveSchemaDefinition value, JsonSerializerOptions options) + { + if (value is null) + { + writer.WriteNullValue(); + return; + } + + writer.WriteStartObject(); + + writer.WriteString("type", value.Type); + if (value.Title is not null) + { + writer.WriteString("title", value.Title); + } + if (value.Description is not null) + { + writer.WriteString("description", value.Description); + } + + switch (value) + { + case StringSchema stringSchema: + if (stringSchema.MinLength.HasValue) + { + writer.WriteNumber("minLength", stringSchema.MinLength.Value); + } + if (stringSchema.MaxLength.HasValue) + { + writer.WriteNumber("maxLength", stringSchema.MaxLength.Value); + } + if (stringSchema.Format is not null) + { + writer.WriteString("format", stringSchema.Format); + } + break; + + case NumberSchema numberSchema: + if (numberSchema.Minimum.HasValue) + { + writer.WriteNumber("minimum", numberSchema.Minimum.Value); + } + if (numberSchema.Maximum.HasValue) + { + writer.WriteNumber("maximum", numberSchema.Maximum.Value); + } + break; + + case BooleanSchema booleanSchema: + if (booleanSchema.Default.HasValue) + { + writer.WriteBoolean("default", booleanSchema.Default.Value); + } + break; + + case EnumSchema enumSchema: + if (enumSchema.Enum is not null) + { + writer.WritePropertyName("enum"); + JsonSerializer.Serialize(writer, enumSchema.Enum, McpJsonUtilities.JsonContext.Default.IListString); + } + if (enumSchema.EnumNames is not null) + { + writer.WritePropertyName("enumNames"); + JsonSerializer.Serialize(writer, enumSchema.EnumNames, McpJsonUtilities.JsonContext.Default.IListString); + } + break; + + default: + throw new JsonException($"Unexpected schema type: {value.GetType().Name}"); + } + + writer.WriteEndObject(); + } + } + } + + /// Represents a schema for a string type. + public sealed class StringSchema : PrimitiveSchemaDefinition + { + /// + [JsonPropertyName("type")] + public override string Type + { + get => "string"; + set + { + if (value is not "string") + { + throw new ArgumentException("Type must be 'string'.", nameof(value)); + } + } + } + /// Gets or sets the minimum length for the string. [JsonPropertyName("minLength")] public int? MinLength @@ -139,11 +376,9 @@ public string? Format /// Represents a schema for a number or integer type. public sealed class NumberSchema : PrimitiveSchemaDefinition { - /// Gets the type of the schema. - /// This should be "number" or "integer". - [JsonPropertyName("type")] + /// [field: MaybeNull] - public string Type + public override string Type { get => field ??= "number"; set @@ -157,14 +392,6 @@ public string Type } } - /// Gets or sets a title for the number input. - [JsonPropertyName("title")] - public string? Title { get; set; } - - /// Gets or sets a description for the number input. - [JsonPropertyName("description")] - public string? Description { get; set; } - /// Gets or sets the minimum allowed value. [JsonPropertyName("minimum")] public double? Minimum { get; set; } @@ -177,18 +404,19 @@ public string Type /// Represents a schema for a Boolean type. public sealed class BooleanSchema : PrimitiveSchemaDefinition { - /// Gets the type of the schema. - /// This is always "boolean". + /// [JsonPropertyName("type")] - public string Type => "boolean"; - - /// Gets or sets a title for the Boolean. - [JsonPropertyName("title")] - public string? Title { get; set; } - - /// Gets or sets a description for the Boolean. - [JsonPropertyName("description")] - public string? Description { get; set; } + public override string Type + { + get => "boolean"; + set + { + if (value is not "boolean") + { + throw new ArgumentException("Type must be 'boolean'.", nameof(value)); + } + } + } /// Gets or sets the default value for the Boolean. [JsonPropertyName("default")] @@ -198,18 +426,19 @@ public sealed class BooleanSchema : PrimitiveSchemaDefinition /// Represents a schema for an enum type. public sealed class EnumSchema : PrimitiveSchemaDefinition { - /// Gets the type of the schema. - /// This is always "string". + /// [JsonPropertyName("type")] - public string Type => "string"; - - /// Gets or sets a title for the enum. - [JsonPropertyName("title")] - public string? Title { get; set; } - - /// Gets or sets a description for the enum. - [JsonPropertyName("description")] - public string? Description { get; set; } + public override string Type + { + get => "string"; + set + { + if (value is not "string") + { + throw new ArgumentException("Type must be 'string'.", nameof(value)); + } + } + } /// Gets or sets the list of allowed string values for the enum. [JsonPropertyName("enum")] diff --git a/tests/ModelContextProtocol.Tests/ClientServerTestBase.cs b/tests/ModelContextProtocol.Tests/ClientServerTestBase.cs index 326b235f..ec1c8510 100644 --- a/tests/ModelContextProtocol.Tests/ClientServerTestBase.cs +++ b/tests/ModelContextProtocol.Tests/ClientServerTestBase.cs @@ -62,13 +62,14 @@ public async ValueTask DisposeAsync() Dispose(); } - protected async Task CreateMcpClientForServer() + protected async Task CreateMcpClientForServer(McpClientOptions? clientOptions = null) { return await McpClientFactory.CreateAsync( new StreamClientTransport( serverInput: _clientToServerPipe.Writer.AsStream(), _serverToClientPipe.Reader.AsStream(), LoggerFactory), + clientOptions: clientOptions, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); } diff --git a/tests/ModelContextProtocol.Tests/Protocol/ElicitationTests.cs b/tests/ModelContextProtocol.Tests/Protocol/ElicitationTests.cs new file mode 100644 index 00000000..a5128230 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Protocol/ElicitationTests.cs @@ -0,0 +1,147 @@ +using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using System.Text.Json; + +namespace ModelContextProtocol.Tests.Configuration; + +public partial class ElicitationTests : ClientServerTestBase +{ + public ElicitationTests(ITestOutputHelper testOutputHelper) + : base(testOutputHelper) + { + } + + protected override void ConfigureServices(ServiceCollection services, IMcpServerBuilder mcpServerBuilder) + { + mcpServerBuilder.WithCallToolHandler(async (request, cancellationToken) => + { + Assert.Equal("TestElicitation", request.Params?.Name); + + var result = await request.Server.ElicitAsync( + new() + { + Message = "Please provide more information.", + RequestedSchema = new() + { + Properties = new Dictionary() + { + ["prop1"] = new ElicitRequestParams.StringSchema() + { + Title = "title1", + MinLength = 1, + MaxLength = 100, + }, + ["prop2"] = new ElicitRequestParams.NumberSchema() + { + Description = "description2", + Minimum = 0, + Maximum = 1000, + }, + ["prop3"] = new ElicitRequestParams.BooleanSchema() + { + Title = "title3", + Description = "description4", + Default = true, + }, + ["prop4"] = new ElicitRequestParams.EnumSchema() + { + Enum = ["option1", "option2", "option3"], + EnumNames = ["Name1", "Name2", "Name3"], + }, + }, + }, + }, + CancellationToken.None); + + Assert.Equal("accept", result.Action); + + return new CallToolResult() + { + Content = [new TextContentBlock() { Text = "success" }], + }; + }); + } + + [Fact] + public async Task Can_Elicit_Information() + { + await using IMcpClient client = await CreateMcpClientForServer(new McpClientOptions() + { + Capabilities = new() + { + Elicitation = new() + { + ElicitationHandler = async (request, cancellationtoken) => + { + Assert.NotNull(request); + Assert.Equal("Please provide more information.", request.Message); + Assert.Equal(4, request.RequestedSchema.Properties.Count); + + foreach (var entry in request.RequestedSchema.Properties) + { + switch (entry.Key) + { + case "prop1": + var primitiveString = Assert.IsType(entry.Value); + Assert.Equal("title1", primitiveString.Title); + Assert.Equal(1, primitiveString.MinLength); + Assert.Equal(100, primitiveString.MaxLength); + break; + + case "prop2": + var primitiveNumber = Assert.IsType(entry.Value); + Assert.Equal("description2", primitiveNumber.Description); + Assert.Equal(0, primitiveNumber.Minimum); + Assert.Equal(1000, primitiveNumber.Maximum); + break; + + case "prop3": + var primitiveBool = Assert.IsType(entry.Value); + Assert.Equal("title3", primitiveBool.Title); + Assert.Equal("description4", primitiveBool.Description); + Assert.True(primitiveBool.Default); + break; + + case "prop4": + var primitiveEnum = Assert.IsType(entry.Value); + Assert.Equal(["option1", "option2", "option3"], primitiveEnum.Enum); + Assert.Equal(["Name1", "Name2", "Name3"], primitiveEnum.EnumNames); + break; + + default: + Assert.Fail($"Unknown property: {entry.Key}"); + break; + } + } + + return new ElicitResult + { + Action = "accept", + Content = new Dictionary + { + ["prop1"] = (JsonElement)JsonSerializer.Deserialize(""" + "string result" + """, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonElement)))!, + ["prop2"] = (JsonElement)JsonSerializer.Deserialize(""" + 42 + """, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonElement)))!, + ["prop3"] = (JsonElement)JsonSerializer.Deserialize(""" + true + """, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonElement)))!, + ["prop4"] = (JsonElement)JsonSerializer.Deserialize(""" + "option2" + """, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonElement)))!, + }, + }; + }, + }, + }, + }); + + var result = await client.CallToolAsync("TestElicitation", cancellationToken: TestContext.Current.CancellationToken); + + Assert.Equal("success", (result.Content[0] as TextContentBlock)?.Text); + } +} \ No newline at end of file