From eca61d282f95f625a8aea358a97d80acee7d0bf2 Mon Sep 17 00:00:00 2001 From: Darren Shepherd Date: Wed, 29 May 2024 13:13:11 -0700 Subject: [PATCH] feat: add display text to callframe to make it easier on the sdk clients --- pkg/builtin/builtin.go | 26 ++++-------- pkg/builtin/builtin_test.go | 8 ++++ pkg/engine/engine.go | 10 ++++- pkg/runner/runner.go | 33 ++++++++------- pkg/types/tool.go | 15 +++++++ pkg/types/toolstring.go | 82 +++++++++++++++++++++++++++++++++++++ 6 files changed, 138 insertions(+), 36 deletions(-) create mode 100644 pkg/types/toolstring.go diff --git a/pkg/builtin/builtin.go b/pkg/builtin/builtin.go index 7c749c82..85c0bf33 100644 --- a/pkg/builtin/builtin.go +++ b/pkg/builtin/builtin.go @@ -20,18 +20,18 @@ import ( "github.com/AlecAivazis/survey/v2" "github.com/BurntSushi/locker" - "github.com/google/shlex" "github.com/gptscript-ai/gptscript/pkg/engine" "github.com/gptscript-ai/gptscript/pkg/types" "github.com/jaytaylor/html2text" ) var SafeTools = map[string]struct{}{ - "sys.echo": {}, - "sys.time.now": {}, - "sys.prompt": {}, + "sys.abort": {}, "sys.chat.finish": {}, "sys.chat.history": {}, + "sys.echo": {}, + "sys.prompt": {}, + "sys.time.now": {}, } var tools = map[string]types.Tool{ @@ -333,11 +333,7 @@ func SysExec(ctx context.Context, env []string, input string) (string, error) { var cmd *exec.Cmd if runtime.GOOS == "windows" { - args, err := shlex.Split(params.Command) - if err != nil { - return "", fmt.Errorf("parsing command: %w", err) - } - cmd = exec.Command(args[0], args[1:]...) + cmd = exec.Command("cmd.exe", "/c", params.Command) } else { cmd = exec.Command("/bin/sh", "-c", params.Command) } @@ -346,7 +342,7 @@ func SysExec(ctx context.Context, env []string, input string) (string, error) { cmd.Dir = params.Directory out, err := cmd.CombinedOutput() if err != nil { - return string(out), fmt.Errorf("OUTPUT: %s, ERROR: %w", out, err) + return fmt.Sprintf("ERROR: %s\nOUTPUT: %s", err, out), nil } return string(out), nil } @@ -362,10 +358,6 @@ func getWorkspaceDir(envs []string) (string, error) { } func SysLs(_ context.Context, _ []string, input string) (string, error) { - return sysLs("", input) -} - -func sysLs(base, input string) (string, error) { var params struct { Dir string `json:"dir,omitempty"` } @@ -378,10 +370,6 @@ func sysLs(base, input string) (string, error) { dir = "." } - if base != "" { - dir = filepath.Join(base, dir) - } - entries, err := os.ReadDir(dir) if errors.Is(err, fs.ErrNotExist) { return fmt.Sprintf("directory does not exist: %s", params.Dir), nil @@ -772,7 +760,7 @@ func SysDownload(ctx context.Context, env []string, input string) (_ string, err return "", fmt.Errorf("failed copying data from [%s] to [%s]: %w", params.URL, params.Location, err) } - return params.Location, nil + return fmt.Sprintf("Downloaded %s to %s", params.URL, params.Location), nil } func sysPromptHTTP(ctx context.Context, url, message string, fields []string, sensitive bool) (_ string, err error) { diff --git a/pkg/builtin/builtin_test.go b/pkg/builtin/builtin_test.go index f72d5bea..e417ee04 100644 --- a/pkg/builtin/builtin_test.go +++ b/pkg/builtin/builtin_test.go @@ -4,6 +4,7 @@ import ( "context" "testing" + "github.com/gptscript-ai/gptscript/pkg/types" "github.com/hexops/autogold/v2" "github.com/stretchr/testify/require" ) @@ -21,3 +22,10 @@ func TestSysGetenv(t *testing.T) { require.NoError(t, err) autogold.Expect("").Equal(t, v) } + +func TestDisplayCoverage(t *testing.T) { + for _, tool := range ListTools() { + _, err := types.ToSysDisplayString(tool.ID, nil) + require.NoError(t, err) + } +} diff --git a/pkg/engine/engine.go b/pkg/engine/engine.go index 501cbe65..14647c0b 100644 --- a/pkg/engine/engine.go +++ b/pkg/engine/engine.go @@ -64,6 +64,7 @@ type CallContext struct { commonContext `json:",inline"` ToolName string `json:"toolName,omitempty"` ParentID string `json:"parentID,omitempty"` + DisplayText string `json:"displayText,omitempty"` } type Context struct { @@ -72,6 +73,8 @@ type Context struct { Parent *Context LastReturn *Return Program *types.Program + // Input is saved only so that we can render display text, don't use otherwise + Input string } type ChatHistory struct { @@ -123,6 +126,7 @@ func (c *Context) GetCallContext() *CallContext { commonContext: c.commonContext, ParentID: c.ParentID(), ToolName: toolName, + DisplayText: types.ToDisplayText(c.Tool, c.Input), } } @@ -140,7 +144,7 @@ func WithToolCategory(ctx context.Context, toolCategory ToolCategory) context.Co return context.WithValue(ctx, toolCategoryKey{}, toolCategory) } -func NewContext(ctx context.Context, prg *types.Program) Context { +func NewContext(ctx context.Context, prg *types.Program, input string) Context { category, _ := ctx.Value(toolCategoryKey{}).(ToolCategory) callCtx := Context{ @@ -151,11 +155,12 @@ func NewContext(ctx context.Context, prg *types.Program) Context { }, Ctx: ctx, Program: prg, + Input: input, } return callCtx } -func (c *Context) SubCall(ctx context.Context, toolID, callID string, toolCategory ToolCategory) (Context, error) { +func (c *Context) SubCall(ctx context.Context, input, toolID, callID string, toolCategory ToolCategory) (Context, error) { tool, ok := c.Program.ToolSet[toolID] if !ok { return Context{}, fmt.Errorf("failed to file tool for id [%s]", toolID) @@ -174,6 +179,7 @@ func (c *Context) SubCall(ctx context.Context, toolID, callID string, toolCatego Ctx: ctx, Parent: c, Program: c.Program, + Input: input, }, nil } diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index a4877472..fbcbdfa5 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -150,7 +150,7 @@ func (r *Runner) Chat(ctx context.Context, prevState ChatState, prg types.Progra monitor.Stop(resp.Content, err) }() - callCtx := engine.NewContext(ctx, &prg) + callCtx := engine.NewContext(ctx, &prg, input) if state == nil || state.StartContinuation { if state != nil { state = state.WithResumeInput(&input) @@ -423,18 +423,21 @@ func (r *Runner) start(callCtx engine.Context, state *State, monitor Monitor, en callCtx.Ctx = context2.AddPauseFuncToCtx(callCtx.Ctx, monitor.Pause) - authResp, err := r.auth(callCtx, input) - if err != nil { - return nil, err - } + _, safe := builtin.SafeTools[callCtx.Tool.ID] + if callCtx.Tool.IsCommand() && !safe { + authResp, err := r.auth(callCtx, input) + if err != nil { + return nil, err + } - if !authResp.Accept { - msg := fmt.Sprintf("[AUTHORIZATION ERROR]: %s", authResp.Message) - return &State{ - Continuation: &engine.Return{ - Result: &msg, - }, - }, nil + if !authResp.Accept { + msg := fmt.Sprintf("[AUTHORIZATION ERROR]: %s", authResp.Message) + return &State{ + Continuation: &engine.Return{ + Result: &msg, + }, + }, nil + } } ret, err := e.Start(callCtx, input) @@ -671,7 +674,7 @@ func streamProgress(callCtx *engine.Context, monitor Monitor) (chan<- types.Comp } func (r *Runner) subCall(ctx context.Context, parentContext engine.Context, monitor Monitor, env []string, toolID, input, callID string, toolCategory engine.ToolCategory) (*State, error) { - callCtx, err := parentContext.SubCall(ctx, toolID, callID, toolCategory) + callCtx, err := parentContext.SubCall(ctx, input, toolID, callID, toolCategory) if err != nil { return nil, err } @@ -680,7 +683,7 @@ func (r *Runner) subCall(ctx context.Context, parentContext engine.Context, moni } func (r *Runner) subCallResume(ctx context.Context, parentContext engine.Context, monitor Monitor, env []string, toolID, callID string, state *State, toolCategory engine.ToolCategory) (*State, error) { - callCtx, err := parentContext.SubCall(ctx, toolID, callID, toolCategory) + callCtx, err := parentContext.SubCall(ctx, "", toolID, callID, toolCategory) if err != nil { return nil, err } @@ -834,7 +837,7 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env return nil, fmt.Errorf("failed to find ID for tool %s", credToolName) } - subCtx, err := callCtx.SubCall(callCtx.Ctx, credToolRefs[0].ToolID, "", engine.CredentialToolCategory) // leaving callID as "" will cause it to be set by the engine + subCtx, err := callCtx.SubCall(callCtx.Ctx, "", credToolRefs[0].ToolID, "", engine.CredentialToolCategory) // leaving callID as "" will cause it to be set by the engine if err != nil { return nil, fmt.Errorf("failed to create subcall context for tool %s: %w", credToolName, err) } diff --git a/pkg/types/tool.go b/pkg/types/tool.go index c4d20303..612543fe 100644 --- a/pkg/types/tool.go +++ b/pkg/types/tool.go @@ -3,6 +3,7 @@ package types import ( "context" "fmt" + "path/filepath" "slices" "sort" "strings" @@ -453,6 +454,20 @@ func (t ToolSource) String() string { return fmt.Sprintf("%s:%d", t.Location, t.LineNo) } +func (t Tool) GetInterpreter() string { + if !strings.HasPrefix(t.Instructions, CommandPrefix) { + return "" + } + fields := strings.Fields(strings.TrimPrefix(t.Instructions, CommandPrefix)) + for _, field := range fields { + name := filepath.Base(field) + if name != "env" { + return name + } + } + return fields[0] +} + func (t Tool) IsCommand() bool { return strings.HasPrefix(t.Instructions, CommandPrefix) } diff --git a/pkg/types/toolstring.go b/pkg/types/toolstring.go new file mode 100644 index 00000000..3bd6fd57 --- /dev/null +++ b/pkg/types/toolstring.go @@ -0,0 +1,82 @@ +package types + +import ( + "encoding/json" + "fmt" + "path/filepath" + "strings" +) + +func ToDisplayText(tool Tool, input string) string { + interpreter := tool.GetInterpreter() + if interpreter == "" { + return "" + } + + if strings.HasPrefix(interpreter, "sys.") { + data := map[string]string{} + _ = json.Unmarshal([]byte(input), &data) + out, err := ToSysDisplayString(interpreter, data) + if err != nil { + return fmt.Sprintf("Running %s", interpreter) + } + return out + } + + if tool.Source.Repo != nil { + repo := tool.Source.Repo + root := strings.TrimPrefix(repo.Root, "https://") + root = strings.TrimSuffix(root, ".git") + name := repo.Name + if name == "tool.gpt" { + name = "" + } + + return fmt.Sprintf("Running %s from %s", tool.Name, filepath.Join(root, repo.Path, name)) + } + + if tool.Source.Location != "" { + return fmt.Sprintf("Running %s from %s", tool.Name, tool.Source.Location) + } + + return "" +} + +func ToSysDisplayString(id string, args map[string]string) (string, error) { + switch id { + case "sys.append": + return fmt.Sprintf("Appending to file `%s`", args["filename"]), nil + case "sys.download": + if location := args["location"]; location != "" { + return fmt.Sprintf("Downloading `%s` to `%s`", args["url"], location), nil + } else { + return fmt.Sprintf("Downloading `%s` to workspace", args["url"]), nil + } + case "sys.exec": + return fmt.Sprintf("Running `%s`", args["command"]), nil + case "sys.find": + dir := args["directory"] + if dir == "" { + dir = "." + } + return fmt.Sprintf("Finding `%s` in `%s`", args["pattern"], dir), nil + case "sys.http.get": + return fmt.Sprintf("Downloading `%s`", args["url"]), nil + case "sys.http.post": + return fmt.Sprintf("Sending to `%s`", args["url"]), nil + case "sys.http.html2text": + return fmt.Sprintf("Downloading `%s`", args["url"]), nil + case "sys.ls": + return fmt.Sprintf("Listing `%s`", args["dir"]), nil + case "sys.read": + return fmt.Sprintf("Reading `%s`", args["filename"]), nil + case "sys.remove": + return fmt.Sprintf("Removing `%s`", args["location"]), nil + case "sys.write": + return fmt.Sprintf("Writing `%s`", args["filename"]), nil + case "sys.stat", "sys.getenv", "sys.abort", "sys.chat.finish", "sys.chat.history", "sys.echo", "sys.prompt", "sys.time.now": + return "", nil + default: + return "", fmt.Errorf("unknown tool for display string: %s", id) + } +}