From e404b11e5b2340a1baa45d0717086c6f632540a9 Mon Sep 17 00:00:00 2001 From: Donnie Adams Date: Tue, 20 May 2025 15:26:15 -0400 Subject: [PATCH] enhance: send MCP errors back to the LLM so it can correct if possible Signed-off-by: Donnie Adams --- pkg/engine/cmd.go | 6 +++--- pkg/engine/engine.go | 6 +++--- pkg/mcp/runner.go | 14 +++++++++++--- 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/pkg/engine/cmd.go b/pkg/engine/cmd.go index 5fb340c5..e7671436 100644 --- a/pkg/engine/cmd.go +++ b/pkg/engine/cmd.go @@ -65,7 +65,7 @@ func compressEnv(envs []string) (result []string) { return } -func (e *Engine) runCommand(ctx Context, tool types.Tool, input string, toolCategory ToolCategory) (cmdOut string, cmdErr error) { +func (e *Engine) runCommand(ctx Context, tool types.Tool, input string) (cmdOut string, cmdErr error) { id := counter.Next() var combinedOutput string @@ -128,7 +128,7 @@ func (e *Engine) runCommand(ctx Context, tool types.Tool, input string, toolCate cmd, stop, err := e.newCommand(commandCtx, extraEnv, tool, input, true) if err != nil { - if toolCategory == NoCategory && ctx.Parent != nil { + if ctx.ToolCategory == NoCategory && ctx.Parent != nil { return fmt.Sprintf("ERROR: got (%v) while parsing command", err), nil } return "", fmt.Errorf("got (%v) while parsing command", err) @@ -167,7 +167,7 @@ func (e *Engine) runCommand(ctx Context, tool types.Tool, input string, toolCate if err := cmd.Run(); err != nil && (commandCtx.Err() == nil || ctx.Ctx.Err() != nil) { // If the command failed and the context hasn't been canceled, then return the error. - if toolCategory == NoCategory && ctx.Parent != nil { + if ctx.ToolCategory == NoCategory && ctx.Parent != nil { // If this is a sub-call, then don't return the error; return the error as a message so that the LLM can retry. return fmt.Sprintf("ERROR: got (%v) while running tool, OUTPUT: %s", err, stdoutAndErr), nil } diff --git a/pkg/engine/engine.go b/pkg/engine/engine.go index e3ff930a..c509af9c 100644 --- a/pkg/engine/engine.go +++ b/pkg/engine/engine.go @@ -45,7 +45,7 @@ type Engine struct { } type MCPRunner interface { - Run(ctx context.Context, progress chan<- types.CompletionStatus, tool types.Tool, input string) (string, error) + Run(ctx Context, progress chan<- types.CompletionStatus, tool types.Tool, input string) (string, error) } type State struct { @@ -313,7 +313,7 @@ func populateMessageParams(ctx Context, completion *types.CompletionRequest, too } func (e *Engine) runMCPInvoke(ctx Context, tool types.Tool, input string) (*Return, error) { - output, err := e.MCPRunner.Run(ctx.Ctx, e.Progress, tool, input) + output, err := e.MCPRunner.Run(ctx, e.Progress, tool, input) if err != nil { return nil, fmt.Errorf("failed to run MCP invoke: %w", err) } @@ -335,7 +335,7 @@ func (e *Engine) runCommandTools(ctx Context, tool types.Tool, input string) (*R } else if tool.IsCall() { return e.runCall(ctx, tool, input) } - s, err := e.runCommand(ctx, tool, input, ctx.ToolCategory) + s, err := e.runCommand(ctx, tool, input) return &Return{ Result: &s, }, err diff --git a/pkg/mcp/runner.go b/pkg/mcp/runner.go index 448d58a7..1a275a0c 100644 --- a/pkg/mcp/runner.go +++ b/pkg/mcp/runner.go @@ -1,16 +1,16 @@ package mcp import ( - "context" "encoding/json" "fmt" "strings" + "github.com/gptscript-ai/gptscript/pkg/engine" "github.com/gptscript-ai/gptscript/pkg/types" "github.com/mark3labs/mcp-go/mcp" ) -func (l *Local) Run(ctx context.Context, _ chan<- types.CompletionStatus, tool types.Tool, input string) (string, error) { +func (l *Local) Run(ctx engine.Context, _ chan<- types.CompletionStatus, tool types.Tool, input string) (string, error) { fields := strings.Fields(tool.Instructions) if len(fields) < 2 { return "", fmt.Errorf("invalid mcp call, invalid number of fields in %s", tool.Instructions) @@ -41,8 +41,16 @@ func (l *Local) Run(ctx context.Context, _ chan<- types.CompletionStatus, tool t request.Params.Name = toolName request.Params.Arguments = arguments - result, err := session.Client.CallTool(ctx, request) + result, err := session.Client.CallTool(ctx.Ctx, request) if err != nil { + if ctx.ToolCategory == engine.NoCategory && ctx.Parent != nil { + var output []byte + if result != nil { + output, _ = json.Marshal(result) + } + // If this is a sub-call, then don't return the error; return the error as a message so that the LLM can retry. + return fmt.Sprintf("ERROR: got (%v) while running tool, OUTPUT: %s", err, string(output)), nil + } return "", fmt.Errorf("failed to call tool %s: %w", toolName, err) }