From 7c87df07ef9725bf0d0631f0a7eb844b79b16318 Mon Sep 17 00:00:00 2001 From: Grant Linville Date: Wed, 17 Apr 2024 15:02:11 -0400 Subject: [PATCH 1/2] fix: credentials: block output while running credential tools Signed-off-by: Grant Linville --- pkg/builtin/builtin.go | 11 +---------- pkg/engine/cmd.go | 6 +++++- pkg/engine/engine.go | 4 ++-- pkg/runner/runner.go | 9 +++++---- 4 files changed, 13 insertions(+), 17 deletions(-) diff --git a/pkg/builtin/builtin.go b/pkg/builtin/builtin.go index 1ca10067..668faa9e 100644 --- a/pkg/builtin/builtin.go +++ b/pkg/builtin/builtin.go @@ -21,7 +21,6 @@ import ( "github.com/BurntSushi/locker" "github.com/google/shlex" "github.com/gptscript-ai/gptscript/pkg/confirm" - "github.com/gptscript-ai/gptscript/pkg/runner" "github.com/gptscript-ai/gptscript/pkg/types" "github.com/jaytaylor/html2text" ) @@ -647,15 +646,7 @@ func SysDownload(ctx context.Context, env []string, input string) (_ string, err return params.Location, nil } -func SysPrompt(ctx context.Context, _ []string, input string) (_ string, err error) { - monitor := ctx.Value(runner.MonitorKey{}) - if monitor == nil { - return "", errors.New("no monitor in context") - } - - unpause := monitor.(runner.Monitor).Pause() - defer unpause() - +func SysPrompt(_ context.Context, _ []string, input string) (_ string, err error) { var params struct { Message string `json:"message,omitempty"` Fields string `json:"fields,omitempty"` diff --git a/pkg/engine/cmd.go b/pkg/engine/cmd.go index da7b3bca..a1a60dbd 100644 --- a/pkg/engine/cmd.go +++ b/pkg/engine/cmd.go @@ -20,7 +20,7 @@ import ( "github.com/gptscript-ai/gptscript/pkg/version" ) -func (e *Engine) runCommand(ctx context.Context, tool types.Tool, input string) (cmdOut string, cmdErr error) { +func (e *Engine) runCommand(ctx context.Context, tool types.Tool, input string, pauseF func() func()) (cmdOut string, cmdErr error) { id := fmt.Sprint(atomic.AddInt64(&completionID, 1)) defer func() { @@ -64,6 +64,10 @@ func (e *Engine) runCommand(ctx context.Context, tool types.Tool, input string) cmd.Stderr = io.MultiWriter(all, os.Stderr) cmd.Stdout = io.MultiWriter(all, output) + if pauseF != nil { + unpauseF := pauseF() + defer unpauseF() + } if err := cmd.Run(); err != nil { _, _ = os.Stderr.Write(output.Bytes()) log.Errorf("failed to run tool [%s] cmd %v: %v", tool.Parameters.Name, cmd.Args, err) diff --git a/pkg/engine/engine.go b/pkg/engine/engine.go index 09c037a1..03543cb9 100644 --- a/pkg/engine/engine.go +++ b/pkg/engine/engine.go @@ -137,7 +137,7 @@ func (c *Context) WrappedContext() context.Context { return context.WithValue(c.Ctx, engineContext{}, c) } -func (e *Engine) Start(ctx Context, input string) (*Return, error) { +func (e *Engine) Start(ctx Context, input string, pauseF func() func()) (*Return, error) { tool := ctx.Tool if tool.IsCommand() { @@ -148,7 +148,7 @@ func (e *Engine) Start(ctx Context, input string) (*Return, error) { } else if tool.IsOpenAPI() { return e.runOpenAPI(tool, input) } - s, err := e.runCommand(ctx.WrappedContext(), tool, input) + s, err := e.runCommand(ctx.WrappedContext(), tool, input, pauseF) if err != nil { return nil, err } diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index bb665c30..6c841ccc 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -178,12 +178,13 @@ func (r *Runner) call(callCtx engine.Context, monitor Monitor, env []string, inp Content: input, }) - // The sys.prompt tool is a special case where we need to pass the monitor to the builtin function. - if callCtx.Tool.BuiltinFunc != nil && callCtx.Tool.ID == "sys.prompt" { - callCtx.Ctx = context.WithValue(callCtx.Ctx, MonitorKey{}, monitor) + var result *engine.Return + if callCtx.IsCredential { + result, err = e.Start(callCtx, input, monitor.Pause) + } else { + result, err = e.Start(callCtx, input, nil) } - result, err := e.Start(callCtx, input) if err != nil { return "", err } From 6030b323c19d6d71fa924b204eff802b8da3d890 Mon Sep 17 00:00:00 2001 From: Grant Linville Date: Wed, 17 Apr 2024 17:06:45 -0400 Subject: [PATCH 2/2] do this but better Signed-off-by: Grant Linville --- pkg/context/context.go | 19 +++++++++++++++++++ pkg/engine/cmd.go | 11 +++++++---- pkg/engine/engine.go | 4 ++-- pkg/runner/runner.go | 11 +++-------- 4 files changed, 31 insertions(+), 14 deletions(-) create mode 100644 pkg/context/context.go diff --git a/pkg/context/context.go b/pkg/context/context.go new file mode 100644 index 00000000..3ec88c21 --- /dev/null +++ b/pkg/context/context.go @@ -0,0 +1,19 @@ +package context + +import ( + "context" +) + +type pauseKey struct{} + +func AddPauseFuncToCtx(ctx context.Context, pauseF func() func()) context.Context { + return context.WithValue(ctx, pauseKey{}, pauseF) +} + +func GetPauseFuncFromCtx(ctx context.Context) func() func() { + pauseF, ok := ctx.Value(pauseKey{}).(func() func()) + if !ok { + return func() func() { return func() {} } + } + return pauseF +} diff --git a/pkg/engine/cmd.go b/pkg/engine/cmd.go index a1a60dbd..8ceab4a3 100644 --- a/pkg/engine/cmd.go +++ b/pkg/engine/cmd.go @@ -15,12 +15,13 @@ import ( "sync/atomic" "github.com/google/shlex" + context2 "github.com/gptscript-ai/gptscript/pkg/context" "github.com/gptscript-ai/gptscript/pkg/env" "github.com/gptscript-ai/gptscript/pkg/types" "github.com/gptscript-ai/gptscript/pkg/version" ) -func (e *Engine) runCommand(ctx context.Context, tool types.Tool, input string, pauseF func() func()) (cmdOut string, cmdErr error) { +func (e *Engine) runCommand(ctx context.Context, tool types.Tool, input string, isCredential bool) (cmdOut string, cmdErr error) { id := fmt.Sprint(atomic.AddInt64(&completionID, 1)) defer func() { @@ -64,10 +65,12 @@ func (e *Engine) runCommand(ctx context.Context, tool types.Tool, input string, cmd.Stderr = io.MultiWriter(all, os.Stderr) cmd.Stdout = io.MultiWriter(all, output) - if pauseF != nil { - unpauseF := pauseF() - defer unpauseF() + if isCredential { + pause := context2.GetPauseFuncFromCtx(ctx) + unpause := pause() + defer unpause() } + if err := cmd.Run(); err != nil { _, _ = os.Stderr.Write(output.Bytes()) log.Errorf("failed to run tool [%s] cmd %v: %v", tool.Parameters.Name, cmd.Args, err) diff --git a/pkg/engine/engine.go b/pkg/engine/engine.go index 03543cb9..db585e2d 100644 --- a/pkg/engine/engine.go +++ b/pkg/engine/engine.go @@ -137,7 +137,7 @@ func (c *Context) WrappedContext() context.Context { return context.WithValue(c.Ctx, engineContext{}, c) } -func (e *Engine) Start(ctx Context, input string, pauseF func() func()) (*Return, error) { +func (e *Engine) Start(ctx Context, input string) (*Return, error) { tool := ctx.Tool if tool.IsCommand() { @@ -148,7 +148,7 @@ func (e *Engine) Start(ctx Context, input string, pauseF func() func()) (*Return } else if tool.IsOpenAPI() { return e.runOpenAPI(tool, input) } - s, err := e.runCommand(ctx.WrappedContext(), tool, input, pauseF) + s, err := e.runCommand(ctx.WrappedContext(), tool, input, ctx.IsCredential) if err != nil { return nil, err } diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index 6c841ccc..7869515c 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -11,6 +11,7 @@ import ( "time" "github.com/gptscript-ai/gptscript/pkg/config" + context2 "github.com/gptscript-ai/gptscript/pkg/context" "github.com/gptscript-ai/gptscript/pkg/credentials" "github.com/gptscript-ai/gptscript/pkg/engine" "github.com/gptscript-ai/gptscript/pkg/types" @@ -27,8 +28,6 @@ type Monitor interface { Stop(output string, err error) } -type MonitorKey struct{} - type Options struct { MonitorFactory MonitorFactory `usage:"-"` RuntimeManager engine.RuntimeManager `usage:"-"` @@ -178,13 +177,9 @@ func (r *Runner) call(callCtx engine.Context, monitor Monitor, env []string, inp Content: input, }) - var result *engine.Return - if callCtx.IsCredential { - result, err = e.Start(callCtx, input, monitor.Pause) - } else { - result, err = e.Start(callCtx, input, nil) - } + callCtx.Ctx = context2.AddPauseFuncToCtx(callCtx.Ctx, monitor.Pause) + result, err := e.Start(callCtx, input) if err != nil { return "", err }