Skip to content

feat: add display text to callframe to make it easier on the sdk clients #395

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 7 additions & 19 deletions pkg/builtin/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,18 @@ import (

"github.com/AlecAivazis/survey/v2"
"github.com/BurntSushi/locker"
"github.com/google/shlex"
"github.com/gptscript-ai/gptscript/pkg/engine"
"github.com/gptscript-ai/gptscript/pkg/types"
"github.com/jaytaylor/html2text"
)

var SafeTools = map[string]struct{}{
"sys.echo": {},
"sys.time.now": {},
"sys.prompt": {},
"sys.abort": {},
"sys.chat.finish": {},
"sys.chat.history": {},
"sys.echo": {},
"sys.prompt": {},
"sys.time.now": {},
}

var tools = map[string]types.Tool{
Expand Down Expand Up @@ -333,11 +333,7 @@ func SysExec(ctx context.Context, env []string, input string) (string, error) {
var cmd *exec.Cmd

if runtime.GOOS == "windows" {
args, err := shlex.Split(params.Command)
if err != nil {
return "", fmt.Errorf("parsing command: %w", err)
}
cmd = exec.Command(args[0], args[1:]...)
cmd = exec.Command("cmd.exe", "/c", params.Command)
} else {
cmd = exec.Command("/bin/sh", "-c", params.Command)
}
Expand All @@ -346,7 +342,7 @@ func SysExec(ctx context.Context, env []string, input string) (string, error) {
cmd.Dir = params.Directory
out, err := cmd.CombinedOutput()
if err != nil {
return string(out), fmt.Errorf("OUTPUT: %s, ERROR: %w", out, err)
return fmt.Sprintf("ERROR: %s\nOUTPUT: %s", err, out), nil
}
return string(out), nil
}
Expand All @@ -362,10 +358,6 @@ func getWorkspaceDir(envs []string) (string, error) {
}

func SysLs(_ context.Context, _ []string, input string) (string, error) {
return sysLs("", input)
}

func sysLs(base, input string) (string, error) {
var params struct {
Dir string `json:"dir,omitempty"`
}
Expand All @@ -378,10 +370,6 @@ func sysLs(base, input string) (string, error) {
dir = "."
}

if base != "" {
dir = filepath.Join(base, dir)
}

entries, err := os.ReadDir(dir)
if errors.Is(err, fs.ErrNotExist) {
return fmt.Sprintf("directory does not exist: %s", params.Dir), nil
Expand Down Expand Up @@ -772,7 +760,7 @@ func SysDownload(ctx context.Context, env []string, input string) (_ string, err
return "", fmt.Errorf("failed copying data from [%s] to [%s]: %w", params.URL, params.Location, err)
}

return params.Location, nil
return fmt.Sprintf("Downloaded %s to %s", params.URL, params.Location), nil
}

func sysPromptHTTP(ctx context.Context, url, message string, fields []string, sensitive bool) (_ string, err error) {
Expand Down
8 changes: 8 additions & 0 deletions pkg/builtin/builtin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"testing"

"github.com/gptscript-ai/gptscript/pkg/types"
"github.com/hexops/autogold/v2"
"github.com/stretchr/testify/require"
)
Expand All @@ -21,3 +22,10 @@ func TestSysGetenv(t *testing.T) {
require.NoError(t, err)
autogold.Expect("").Equal(t, v)
}

func TestDisplayCoverage(t *testing.T) {
for _, tool := range ListTools() {
_, err := types.ToSysDisplayString(tool.ID, nil)
require.NoError(t, err)
}
}
10 changes: 8 additions & 2 deletions pkg/engine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ type CallContext struct {
commonContext `json:",inline"`
ToolName string `json:"toolName,omitempty"`
ParentID string `json:"parentID,omitempty"`
DisplayText string `json:"displayText,omitempty"`
}

type Context struct {
Expand All @@ -72,6 +73,8 @@ type Context struct {
Parent *Context
LastReturn *Return
Program *types.Program
// Input is saved only so that we can render display text, don't use otherwise
Input string
}

type ChatHistory struct {
Expand Down Expand Up @@ -123,6 +126,7 @@ func (c *Context) GetCallContext() *CallContext {
commonContext: c.commonContext,
ParentID: c.ParentID(),
ToolName: toolName,
DisplayText: types.ToDisplayText(c.Tool, c.Input),
}
}

Expand All @@ -140,7 +144,7 @@ func WithToolCategory(ctx context.Context, toolCategory ToolCategory) context.Co
return context.WithValue(ctx, toolCategoryKey{}, toolCategory)
}

func NewContext(ctx context.Context, prg *types.Program) Context {
func NewContext(ctx context.Context, prg *types.Program, input string) Context {
category, _ := ctx.Value(toolCategoryKey{}).(ToolCategory)

callCtx := Context{
Expand All @@ -151,11 +155,12 @@ func NewContext(ctx context.Context, prg *types.Program) Context {
},
Ctx: ctx,
Program: prg,
Input: input,
}
return callCtx
}

func (c *Context) SubCall(ctx context.Context, toolID, callID string, toolCategory ToolCategory) (Context, error) {
func (c *Context) SubCall(ctx context.Context, input, toolID, callID string, toolCategory ToolCategory) (Context, error) {
tool, ok := c.Program.ToolSet[toolID]
if !ok {
return Context{}, fmt.Errorf("failed to file tool for id [%s]", toolID)
Expand All @@ -174,6 +179,7 @@ func (c *Context) SubCall(ctx context.Context, toolID, callID string, toolCatego
Ctx: ctx,
Parent: c,
Program: c.Program,
Input: input,
}, nil
}

Expand Down
33 changes: 18 additions & 15 deletions pkg/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ func (r *Runner) Chat(ctx context.Context, prevState ChatState, prg types.Progra
monitor.Stop(resp.Content, err)
}()

callCtx := engine.NewContext(ctx, &prg)
callCtx := engine.NewContext(ctx, &prg, input)
if state == nil || state.StartContinuation {
if state != nil {
state = state.WithResumeInput(&input)
Expand Down Expand Up @@ -423,18 +423,21 @@ func (r *Runner) start(callCtx engine.Context, state *State, monitor Monitor, en

callCtx.Ctx = context2.AddPauseFuncToCtx(callCtx.Ctx, monitor.Pause)

authResp, err := r.auth(callCtx, input)
if err != nil {
return nil, err
}
_, safe := builtin.SafeTools[callCtx.Tool.ID]
if callCtx.Tool.IsCommand() && !safe {
authResp, err := r.auth(callCtx, input)
if err != nil {
return nil, err
}

if !authResp.Accept {
msg := fmt.Sprintf("[AUTHORIZATION ERROR]: %s", authResp.Message)
return &State{
Continuation: &engine.Return{
Result: &msg,
},
}, nil
if !authResp.Accept {
msg := fmt.Sprintf("[AUTHORIZATION ERROR]: %s", authResp.Message)
return &State{
Continuation: &engine.Return{
Result: &msg,
},
}, nil
}
}

ret, err := e.Start(callCtx, input)
Expand Down Expand Up @@ -671,7 +674,7 @@ func streamProgress(callCtx *engine.Context, monitor Monitor) (chan<- types.Comp
}

func (r *Runner) subCall(ctx context.Context, parentContext engine.Context, monitor Monitor, env []string, toolID, input, callID string, toolCategory engine.ToolCategory) (*State, error) {
callCtx, err := parentContext.SubCall(ctx, toolID, callID, toolCategory)
callCtx, err := parentContext.SubCall(ctx, input, toolID, callID, toolCategory)
if err != nil {
return nil, err
}
Expand All @@ -680,7 +683,7 @@ func (r *Runner) subCall(ctx context.Context, parentContext engine.Context, moni
}

func (r *Runner) subCallResume(ctx context.Context, parentContext engine.Context, monitor Monitor, env []string, toolID, callID string, state *State, toolCategory engine.ToolCategory) (*State, error) {
callCtx, err := parentContext.SubCall(ctx, toolID, callID, toolCategory)
callCtx, err := parentContext.SubCall(ctx, "", toolID, callID, toolCategory)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -834,7 +837,7 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
return nil, fmt.Errorf("failed to find ID for tool %s", credToolName)
}

subCtx, err := callCtx.SubCall(callCtx.Ctx, credToolRefs[0].ToolID, "", engine.CredentialToolCategory) // leaving callID as "" will cause it to be set by the engine
subCtx, err := callCtx.SubCall(callCtx.Ctx, "", credToolRefs[0].ToolID, "", engine.CredentialToolCategory) // leaving callID as "" will cause it to be set by the engine
if err != nil {
return nil, fmt.Errorf("failed to create subcall context for tool %s: %w", credToolName, err)
}
Expand Down
15 changes: 15 additions & 0 deletions pkg/types/tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package types
import (
"context"
"fmt"
"path/filepath"
"slices"
"sort"
"strings"
Expand Down Expand Up @@ -453,6 +454,20 @@ func (t ToolSource) String() string {
return fmt.Sprintf("%s:%d", t.Location, t.LineNo)
}

func (t Tool) GetInterpreter() string {
if !strings.HasPrefix(t.Instructions, CommandPrefix) {
return ""
}
fields := strings.Fields(strings.TrimPrefix(t.Instructions, CommandPrefix))
for _, field := range fields {
name := filepath.Base(field)
if name != "env" {
return name
}
}
return fields[0]
}

func (t Tool) IsCommand() bool {
return strings.HasPrefix(t.Instructions, CommandPrefix)
}
Expand Down
82 changes: 82 additions & 0 deletions pkg/types/toolstring.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package types

import (
"encoding/json"
"fmt"
"path/filepath"
"strings"
)

func ToDisplayText(tool Tool, input string) string {
interpreter := tool.GetInterpreter()
if interpreter == "" {
return ""
}

if strings.HasPrefix(interpreter, "sys.") {
data := map[string]string{}
_ = json.Unmarshal([]byte(input), &data)
out, err := ToSysDisplayString(interpreter, data)
if err != nil {
return fmt.Sprintf("Running %s", interpreter)
}
return out
}

if tool.Source.Repo != nil {
repo := tool.Source.Repo
root := strings.TrimPrefix(repo.Root, "https://")
root = strings.TrimSuffix(root, ".git")
name := repo.Name
if name == "tool.gpt" {
name = ""
}

return fmt.Sprintf("Running %s from %s", tool.Name, filepath.Join(root, repo.Path, name))
}

if tool.Source.Location != "" {
return fmt.Sprintf("Running %s from %s", tool.Name, tool.Source.Location)
}

return ""
}

func ToSysDisplayString(id string, args map[string]string) (string, error) {
switch id {
case "sys.append":
return fmt.Sprintf("Appending to file `%s`", args["filename"]), nil
case "sys.download":
if location := args["location"]; location != "" {
return fmt.Sprintf("Downloading `%s` to `%s`", args["url"], location), nil
} else {
return fmt.Sprintf("Downloading `%s` to workspace", args["url"]), nil
}
case "sys.exec":
return fmt.Sprintf("Running `%s`", args["command"]), nil
case "sys.find":
dir := args["directory"]
if dir == "" {
dir = "."
}
return fmt.Sprintf("Finding `%s` in `%s`", args["pattern"], dir), nil
case "sys.http.get":
return fmt.Sprintf("Downloading `%s`", args["url"]), nil
case "sys.http.post":
return fmt.Sprintf("Sending to `%s`", args["url"]), nil
case "sys.http.html2text":
return fmt.Sprintf("Downloading `%s`", args["url"]), nil
case "sys.ls":
return fmt.Sprintf("Listing `%s`", args["dir"]), nil
case "sys.read":
return fmt.Sprintf("Reading `%s`", args["filename"]), nil
case "sys.remove":
return fmt.Sprintf("Removing `%s`", args["location"]), nil
case "sys.write":
return fmt.Sprintf("Writing `%s`", args["filename"]), nil
case "sys.stat", "sys.getenv", "sys.abort", "sys.chat.finish", "sys.chat.history", "sys.echo", "sys.prompt", "sys.time.now":
return "", nil
default:
return "", fmt.Errorf("unknown tool for display string: %s", id)
}
}