diff --git a/go.mod b/go.mod index 96deab9a..ea6f8799 100644 --- a/go.mod +++ b/go.mod @@ -23,7 +23,7 @@ require ( github.com/hexops/autogold/v2 v2.2.1 github.com/hexops/valast v1.4.4 github.com/jaytaylor/html2text v0.0.0-20230321000545-74c2419ad056 - github.com/mark3labs/mcp-go v0.25.0 + github.com/mark3labs/mcp-go v0.30.0 github.com/mholt/archives v0.1.0 github.com/pkoukk/tiktoken-go v0.1.7 github.com/pkoukk/tiktoken-go-loader v0.0.2-0.20240522064338-c17e8bc0f699 diff --git a/go.sum b/go.sum index 02c068cc..e0f89743 100644 --- a/go.sum +++ b/go.sum @@ -270,8 +270,8 @@ github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69 github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= -github.com/mark3labs/mcp-go v0.25.0 h1:UUpcMT3L5hIhuDy7aifj4Bphw4Pfx1Rf8mzMXDe8RQw= -github.com/mark3labs/mcp-go v0.25.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4= +github.com/mark3labs/mcp-go v0.30.0 h1:Taz7fiefkxY/l8jz1nA90V+WdM2eoMtlvwfWforVYbo= +github.com/mark3labs/mcp-go v0.30.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4= github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= diff --git a/pkg/mcp/loader.go b/pkg/mcp/loader.go index d31c6503..f3dffcf6 100644 --- a/pkg/mcp/loader.go +++ b/pkg/mcp/loader.go @@ -15,7 +15,8 @@ import ( "github.com/gptscript-ai/gptscript/pkg/mvl" "github.com/gptscript-ai/gptscript/pkg/types" "github.com/gptscript-ai/gptscript/pkg/version" - "github.com/mark3labs/mcp-go/client" + mcpclient "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/client/transport" "github.com/mark3labs/mcp-go/mcp" ) @@ -36,7 +37,7 @@ type Local struct { type Session struct { ID string InitResult *mcp.InitializeResult - Client client.MCPClient + Client mcpclient.MCPClient Config ServerConfig } @@ -117,7 +118,7 @@ func (l *Local) LoadTools(ctx context.Context, server ServerConfig, toolName str // Reset so we don't start a new MCP server, no reason to if one is already running and the allowed tools change. server.AllowedTools = nil - session, err := l.loadSession(server) + session, err := l.loadSession(server, true) if err != nil { return nil, err } @@ -279,7 +280,7 @@ func (l *Local) sessionToTools(ctx context.Context, session *Session, toolName s return toolDefs, nil } -func (l *Local) loadSession(server ServerConfig) (*Session, error) { +func (l *Local) loadSession(server ServerConfig, tryHTTPStreaming bool) (*Session, error) { id := hash.Digest(server) l.lock.Lock() existing, ok := l.sessions[id] @@ -294,11 +295,11 @@ func (l *Local) loadSession(server ServerConfig) (*Session, error) { } var ( - c *client.Client + c *mcpclient.Client err error ) if server.Command != "" { - c, err = client.NewStdioMCPClient(server.Command, server.Env, server.Args...) + c, err = mcpclient.NewStdioMCPClient(server.Command, server.Env, server.Args...) if err != nil { return nil, fmt.Errorf("failed to create MCP stdio client: %w", err) } @@ -314,7 +315,11 @@ func (l *Local) loadSession(server ServerConfig) (*Session, error) { headers[k] = v } - c, err = client.NewSSEMCPClient(url, client.WithHeaders(headers)) + if tryHTTPStreaming { + c, err = mcpclient.NewStreamableHttpClient(url, transport.WithHTTPHeaders(headers)) + } else { + c, err = mcpclient.NewSSEMCPClient(url, mcpclient.WithHeaders(headers)) + } if err != nil { return nil, fmt.Errorf("failed to create MCP HTTP client: %w", err) } @@ -333,6 +338,13 @@ func (l *Local) loadSession(server ServerConfig) (*Session, error) { initResult, err := c.Initialize(ctx, initRequest) if err != nil { + if server.Command == "" && tryHTTPStreaming { + // The MCP spec indicates that trying to initialize the client for HTTP streaming and checking for an error + // is the recommended way to determine if the server supports HTTP streaming, falling back to SEE. + // Ideally, we can check for a 400-level error, but our client implementation doesn't expose that information. + // Retrying on any error is harmless. + return l.loadSession(server, false) + } return nil, fmt.Errorf("failed to initialize MCP client: %w", err) }