diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go new file mode 100644 index 00000000..8f8d4544 --- /dev/null +++ b/pkg/auth/auth.go @@ -0,0 +1,64 @@ +package auth + +import ( + "fmt" + "path/filepath" + "strings" + + "github.com/AlecAivazis/survey/v2" + "github.com/gptscript-ai/gptscript/pkg/builtin" + "github.com/gptscript-ai/gptscript/pkg/context" + "github.com/gptscript-ai/gptscript/pkg/engine" + "github.com/gptscript-ai/gptscript/pkg/runner" +) + +func Authorize(ctx engine.Context, input string) (runner.AuthorizerResponse, error) { + defer context.GetPauseFuncFromCtx(ctx.Ctx)()() + + if !ctx.Tool.IsCommand() { + return runner.AuthorizerResponse{ + Accept: true, + }, nil + } + + var ( + result bool + loc = ctx.Tool.Source.Location + interpreter = strings.Split(ctx.Tool.Instructions, "\n")[0][2:] + ) + + if _, ok := builtin.SafeTools[interpreter]; ok { + return runner.AuthorizerResponse{ + Accept: true, + }, nil + } + + if ctx.Tool.Source.Repo != nil { + loc = ctx.Tool.Source.Repo.Root + loc = strings.TrimPrefix(loc, "https://") + loc = strings.TrimSuffix(loc, ".git") + loc = filepath.Join(loc, ctx.Tool.Source.Repo.Path, ctx.Tool.Source.Repo.Name) + } + + if ctx.Tool.BuiltinFunc != nil { + loc = "Builtin" + } + + err := survey.AskOne(&survey.Confirm{ + Help: fmt.Sprintf("The full source of the tools is as follows:\n\n%s", ctx.Tool.String()), + Default: true, + Message: fmt.Sprintf(`Description: %s + Interpreter: %s + Source: %s + Input: %s +Allow the above tool to execute?`, ctx.Tool.Description, interpreter, loc, strings.TrimSpace(input)), + }, &result) + if err != nil { + return runner.AuthorizerResponse{}, err + } + + return runner.AuthorizerResponse{ + Accept: result, + Message: "Request denied, blocking execution.", + }, nil +} diff --git a/pkg/builtin/builtin.go b/pkg/builtin/builtin.go index 592f873f..5eb7be3c 100644 --- a/pkg/builtin/builtin.go +++ b/pkg/builtin/builtin.go @@ -20,11 +20,17 @@ import ( "github.com/AlecAivazis/survey/v2" "github.com/BurntSushi/locker" "github.com/google/shlex" - "github.com/gptscript-ai/gptscript/pkg/confirm" "github.com/gptscript-ai/gptscript/pkg/types" "github.com/jaytaylor/html2text" ) +var SafeTools = map[string]struct{}{ + "sys.echo": {}, + "sys.time.now": {}, + "sys.prompt": {}, + "sys.chat.finish": {}, +} + var tools = map[string]types.Tool{ "sys.time.now": { Parameters: types.Parameters{ @@ -278,10 +284,6 @@ func SysExec(ctx context.Context, env []string, input string) (string, error) { log.Debugf("Running %s in %s", params.Command, params.Directory) - if err := confirm.Promptf(ctx, "Run command: %s", params.Command); err != nil { - return "", err - } - var cmd *exec.Cmd if runtime.GOOS == "windows" { @@ -404,12 +406,6 @@ func SysWrite(ctx context.Context, _ []string, input string) (string, error) { } } - if _, err := os.Stat(file); err == nil { - if err := confirm.Promptf(ctx, "Overwrite: %s", params.Filename); err != nil { - return "", err - } - } - data := []byte(params.Content) log.Debugf("Wrote %d bytes to file %s", len(data), file) @@ -429,12 +425,6 @@ func SysAppend(ctx context.Context, env []string, input string) (string, error) locker.Lock(params.Filename) defer locker.Unlock(params.Filename) - if _, err := os.Stat(params.Filename); err == nil { - if err := confirm.Promptf(ctx, "Write to existing file: %s.", params.Filename); err != nil { - return "", err - } - } - f, err := os.OpenFile(params.Filename, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0644) if err != nil { return "", err @@ -609,10 +599,6 @@ func SysRemove(ctx context.Context, env []string, input string) (string, error) return "", err } - if err := confirm.Promptf(ctx, "Remove: %s", params.Location); err != nil { - return "", err - } - // Lock the file to prevent concurrent writes from other tool calls. locker.Lock(params.Location) defer locker.Unlock(params.Location) diff --git a/pkg/cli/eval.go b/pkg/cli/eval.go index a51b7f3b..addab89d 100644 --- a/pkg/cli/eval.go +++ b/pkg/cli/eval.go @@ -72,12 +72,12 @@ func (e *Eval) Run(cmd *cobra.Command, args []string) error { } if e.Chat { - return chat.Start(e.gptscript.NewRunContext(cmd), nil, runner, func() (types.Program, error) { + return chat.Start(cmd.Context(), nil, runner, func() (types.Program, error) { return prg, nil }, os.Environ(), toolInput) } - toolOutput, err := runner.Run(e.gptscript.NewRunContext(cmd), prg, os.Environ(), toolInput) + toolOutput, err := runner.Run(cmd.Context(), prg, os.Environ(), toolInput) if err != nil { return err } diff --git a/pkg/cli/gptscript.go b/pkg/cli/gptscript.go index 6fbc0bc2..38502ceb 100644 --- a/pkg/cli/gptscript.go +++ b/pkg/cli/gptscript.go @@ -13,10 +13,10 @@ import ( "github.com/acorn-io/cmd" "github.com/fatih/color" "github.com/gptscript-ai/gptscript/pkg/assemble" + "github.com/gptscript-ai/gptscript/pkg/auth" "github.com/gptscript-ai/gptscript/pkg/builtin" "github.com/gptscript-ai/gptscript/pkg/cache" "github.com/gptscript-ai/gptscript/pkg/chat" - "github.com/gptscript-ai/gptscript/pkg/confirm" "github.com/gptscript-ai/gptscript/pkg/gptscript" "github.com/gptscript-ai/gptscript/pkg/input" "github.com/gptscript-ai/gptscript/pkg/loader" @@ -117,14 +117,6 @@ func New() *cobra.Command { return command } -func (r *GPTScript) NewRunContext(cmd *cobra.Command) context.Context { - ctx := cmd.Context() - if r.Confirm { - ctx = confirm.WithConfirm(ctx, confirm.TextPrompt{}) - } - return ctx -} - func (r *GPTScript) NewGPTScriptOpts() (gptscript.Options, error) { opts := gptscript.Options{ Cache: cache.Options(r.CacheOptions), @@ -140,6 +132,10 @@ func (r *GPTScript) NewGPTScriptOpts() (gptscript.Options, error) { Workspace: r.Workspace, } + if r.Confirm { + opts.Runner.Authorizer = auth.Authorize + } + if r.Ports != "" { start, end, _ := strings.Cut(r.Ports, "-") startNum, err := strconv.ParseInt(strings.TrimSpace(start), 10, 64) @@ -388,7 +384,7 @@ func (r *GPTScript) Run(cmd *cobra.Command, args []string) (retErr error) { } if r.ChatState != "" { - resp, err := gptScript.Chat(r.NewRunContext(cmd), r.ChatState, prg, os.Environ(), toolInput) + resp, err := gptScript.Chat(cmd.Context(), r.ChatState, prg, os.Environ(), toolInput) if err != nil { return err } @@ -400,12 +396,12 @@ func (r *GPTScript) Run(cmd *cobra.Command, args []string) (retErr error) { } if prg.IsChat() || r.ForceChat { - return chat.Start(r.NewRunContext(cmd), nil, gptScript, func() (types.Program, error) { + return chat.Start(cmd.Context(), nil, gptScript, func() (types.Program, error) { return r.readProgram(ctx, gptScript, args) }, os.Environ(), toolInput) } - s, err := gptScript.Run(r.NewRunContext(cmd), prg, os.Environ(), toolInput) + s, err := gptScript.Run(cmd.Context(), prg, os.Environ(), toolInput) if err != nil { return err } diff --git a/pkg/confirm/confirm.go b/pkg/confirm/confirm.go deleted file mode 100644 index a494aded..00000000 --- a/pkg/confirm/confirm.go +++ /dev/null @@ -1,45 +0,0 @@ -package confirm - -import ( - "context" - "errors" - "fmt" - - "github.com/AlecAivazis/survey/v2" -) - -type Confirm interface { - Confirm(ctx context.Context, prompt string) error -} - -type confirmer struct{} - -func WithConfirm(ctx context.Context, c Confirm) context.Context { - return context.WithValue(ctx, confirmer{}, c) -} - -func Promptf(ctx context.Context, fmtString string, args ...any) error { - c, ok := ctx.Value(confirmer{}).(Confirm) - if !ok { - return nil - } - return c.Confirm(ctx, fmt.Sprintf(fmtString, args...)) -} - -type TextPrompt struct { -} - -func (t TextPrompt) Confirm(_ context.Context, prompt string) error { - var result bool - err := survey.AskOne(&survey.Confirm{ - Message: prompt, - Default: false, - }, &result) - if err != nil { - return err - } - if !result { - return errors.New("abort") - } - return nil -} diff --git a/pkg/mvl/log.go b/pkg/mvl/log.go index f35e023e..1523982f 100644 --- a/pkg/mvl/log.go +++ b/pkg/mvl/log.go @@ -43,7 +43,7 @@ func (f formatter) Format(entry *logrus.Entry) ([]byte, error) { } d, _ := json.Marshal(i) i = string(d) - i = strings.TrimSpace(i[1 : len(i)-2]) + i = strings.TrimSpace(i[1 : len(i)-1]) if addDot { i += "..." } diff --git a/pkg/openai/client.go b/pkg/openai/client.go index c1787497..486d58c3 100644 --- a/pkg/openai/client.go +++ b/pkg/openai/client.go @@ -435,7 +435,9 @@ func appendMessage(msg types.CompletionMessage, response openai.ChatCompletionSt tc.ToolCall.Index = tool.Index } tc.ToolCall.ID = override(tc.ToolCall.ID, tool.ID) - tc.ToolCall.Function.Name += tool.Function.Name + if tc.ToolCall.Function.Name != tool.Function.Name { + tc.ToolCall.Function.Name += tool.Function.Name + } tc.ToolCall.Function.Arguments += tool.Function.Arguments msg.Content[idx] = tc diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index d4a5516b..42966493 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -36,6 +36,20 @@ type Options struct { EndPort int64 `usage:"-"` CredentialOverride string `usage:"-"` Sequential bool `usage:"-"` + Authorizer AuthorizerFunc `usage:"-"` +} + +type AuthorizerResponse struct { + Accept bool + Message string +} + +type AuthorizerFunc func(ctx engine.Context, input string) (AuthorizerResponse, error) + +func DefaultAuthorizer(_ engine.Context, _ string) (AuthorizerResponse, error) { + return AuthorizerResponse{ + Accept: true, + }, nil } func complete(opts ...Options) (result Options) { @@ -46,6 +60,9 @@ func complete(opts ...Options) (result Options) { result.EndPort = types.FirstSet(opt.EndPort, result.EndPort) result.CredentialOverride = types.FirstSet(opt.CredentialOverride, result.CredentialOverride) result.Sequential = types.FirstSet(opt.Sequential, result.Sequential) + if opt.Authorizer != nil { + result.Authorizer = opt.Authorizer + } } if result.MonitorFactory == nil { result.MonitorFactory = noopFactory{} @@ -56,11 +73,15 @@ func complete(opts ...Options) (result Options) { if result.StartPort == 0 { result.StartPort = result.EndPort } + if result.Authorizer == nil { + result.Authorizer = DefaultAuthorizer + } return } type Runner struct { c engine.Model + auth AuthorizerFunc factory MonitorFactory runtimeManager engine.RuntimeManager ports engine.Ports @@ -81,6 +102,7 @@ func New(client engine.Model, credCtx string, opts ...Options) (*Runner, error) credMutex: sync.Mutex{}, credOverrides: opt.CredentialOverride, sequential: opt.Sequential, + auth: opt.Authorizer, } if opt.StartPort != 0 { @@ -405,6 +427,20 @@ func (r *Runner) start(callCtx engine.Context, state *State, monitor Monitor, en callCtx.Ctx = context2.AddPauseFuncToCtx(callCtx.Ctx, monitor.Pause) + authResp, err := r.auth(callCtx, input) + if err != nil { + return nil, err + } + + if !authResp.Accept { + msg := fmt.Sprintf("[AUTHORIZATION ERROR]: %s", authResp.Message) + return &State{ + Continuation: &engine.Return{ + Result: &msg, + }, + }, nil + } + ret, err := e.Start(callCtx, input) if err != nil { return nil, err