Skip to content

chore: add authorization hook #380

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions pkg/auth/auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package auth

import (
"fmt"
"path/filepath"
"strings"

"github.com/AlecAivazis/survey/v2"
"github.com/gptscript-ai/gptscript/pkg/builtin"
"github.com/gptscript-ai/gptscript/pkg/context"
"github.com/gptscript-ai/gptscript/pkg/engine"
"github.com/gptscript-ai/gptscript/pkg/runner"
)

func Authorize(ctx engine.Context, input string) (runner.AuthorizerResponse, error) {
defer context.GetPauseFuncFromCtx(ctx.Ctx)()()

if !ctx.Tool.IsCommand() {
return runner.AuthorizerResponse{
Accept: true,
}, nil
}

var (
result bool
loc = ctx.Tool.Source.Location
interpreter = strings.Split(ctx.Tool.Instructions, "\n")[0][2:]
)

if _, ok := builtin.SafeTools[interpreter]; ok {
return runner.AuthorizerResponse{
Accept: true,
}, nil
}

if ctx.Tool.Source.Repo != nil {
loc = ctx.Tool.Source.Repo.Root
loc = strings.TrimPrefix(loc, "https://")
loc = strings.TrimSuffix(loc, ".git")
loc = filepath.Join(loc, ctx.Tool.Source.Repo.Path, ctx.Tool.Source.Repo.Name)
}

if ctx.Tool.BuiltinFunc != nil {
loc = "Builtin"
}

err := survey.AskOne(&survey.Confirm{
Help: fmt.Sprintf("The full source of the tools is as follows:\n\n%s", ctx.Tool.String()),
Default: true,
Message: fmt.Sprintf(`Description: %s
Interpreter: %s
Source: %s
Input: %s
Allow the above tool to execute?`, ctx.Tool.Description, interpreter, loc, strings.TrimSpace(input)),
}, &result)
if err != nil {
return runner.AuthorizerResponse{}, err
}

return runner.AuthorizerResponse{
Accept: result,
Message: "Request denied, blocking execution.",
}, nil
}
28 changes: 7 additions & 21 deletions pkg/builtin/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,17 @@ import (
"github.com/AlecAivazis/survey/v2"
"github.com/BurntSushi/locker"
"github.com/google/shlex"
"github.com/gptscript-ai/gptscript/pkg/confirm"
"github.com/gptscript-ai/gptscript/pkg/types"
"github.com/jaytaylor/html2text"
)

var SafeTools = map[string]struct{}{
"sys.echo": {},
"sys.time.now": {},
"sys.prompt": {},
"sys.chat.finish": {},
}

var tools = map[string]types.Tool{
"sys.time.now": {
Parameters: types.Parameters{
Expand Down Expand Up @@ -278,10 +284,6 @@ func SysExec(ctx context.Context, env []string, input string) (string, error) {

log.Debugf("Running %s in %s", params.Command, params.Directory)

if err := confirm.Promptf(ctx, "Run command: %s", params.Command); err != nil {
return "", err
}

var cmd *exec.Cmd

if runtime.GOOS == "windows" {
Expand Down Expand Up @@ -404,12 +406,6 @@ func SysWrite(ctx context.Context, _ []string, input string) (string, error) {
}
}

if _, err := os.Stat(file); err == nil {
if err := confirm.Promptf(ctx, "Overwrite: %s", params.Filename); err != nil {
return "", err
}
}

data := []byte(params.Content)
log.Debugf("Wrote %d bytes to file %s", len(data), file)

Expand All @@ -429,12 +425,6 @@ func SysAppend(ctx context.Context, env []string, input string) (string, error)
locker.Lock(params.Filename)
defer locker.Unlock(params.Filename)

if _, err := os.Stat(params.Filename); err == nil {
if err := confirm.Promptf(ctx, "Write to existing file: %s.", params.Filename); err != nil {
return "", err
}
}

f, err := os.OpenFile(params.Filename, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0644)
if err != nil {
return "", err
Expand Down Expand Up @@ -609,10 +599,6 @@ func SysRemove(ctx context.Context, env []string, input string) (string, error)
return "", err
}

if err := confirm.Promptf(ctx, "Remove: %s", params.Location); err != nil {
return "", err
}

// Lock the file to prevent concurrent writes from other tool calls.
locker.Lock(params.Location)
defer locker.Unlock(params.Location)
Expand Down
4 changes: 2 additions & 2 deletions pkg/cli/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,12 @@ func (e *Eval) Run(cmd *cobra.Command, args []string) error {
}

if e.Chat {
return chat.Start(e.gptscript.NewRunContext(cmd), nil, runner, func() (types.Program, error) {
return chat.Start(cmd.Context(), nil, runner, func() (types.Program, error) {
return prg, nil
}, os.Environ(), toolInput)
}

toolOutput, err := runner.Run(e.gptscript.NewRunContext(cmd), prg, os.Environ(), toolInput)
toolOutput, err := runner.Run(cmd.Context(), prg, os.Environ(), toolInput)
if err != nil {
return err
}
Expand Down
20 changes: 8 additions & 12 deletions pkg/cli/gptscript.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ import (
"github.com/acorn-io/cmd"
"github.com/fatih/color"
"github.com/gptscript-ai/gptscript/pkg/assemble"
"github.com/gptscript-ai/gptscript/pkg/auth"
"github.com/gptscript-ai/gptscript/pkg/builtin"
"github.com/gptscript-ai/gptscript/pkg/cache"
"github.com/gptscript-ai/gptscript/pkg/chat"
"github.com/gptscript-ai/gptscript/pkg/confirm"
"github.com/gptscript-ai/gptscript/pkg/gptscript"
"github.com/gptscript-ai/gptscript/pkg/input"
"github.com/gptscript-ai/gptscript/pkg/loader"
Expand Down Expand Up @@ -117,14 +117,6 @@ func New() *cobra.Command {
return command
}

func (r *GPTScript) NewRunContext(cmd *cobra.Command) context.Context {
ctx := cmd.Context()
if r.Confirm {
ctx = confirm.WithConfirm(ctx, confirm.TextPrompt{})
}
return ctx
}

func (r *GPTScript) NewGPTScriptOpts() (gptscript.Options, error) {
opts := gptscript.Options{
Cache: cache.Options(r.CacheOptions),
Expand All @@ -140,6 +132,10 @@ func (r *GPTScript) NewGPTScriptOpts() (gptscript.Options, error) {
Workspace: r.Workspace,
}

if r.Confirm {
opts.Runner.Authorizer = auth.Authorize
}

if r.Ports != "" {
start, end, _ := strings.Cut(r.Ports, "-")
startNum, err := strconv.ParseInt(strings.TrimSpace(start), 10, 64)
Expand Down Expand Up @@ -388,7 +384,7 @@ func (r *GPTScript) Run(cmd *cobra.Command, args []string) (retErr error) {
}

if r.ChatState != "" {
resp, err := gptScript.Chat(r.NewRunContext(cmd), r.ChatState, prg, os.Environ(), toolInput)
resp, err := gptScript.Chat(cmd.Context(), r.ChatState, prg, os.Environ(), toolInput)
if err != nil {
return err
}
Expand All @@ -400,12 +396,12 @@ func (r *GPTScript) Run(cmd *cobra.Command, args []string) (retErr error) {
}

if prg.IsChat() || r.ForceChat {
return chat.Start(r.NewRunContext(cmd), nil, gptScript, func() (types.Program, error) {
return chat.Start(cmd.Context(), nil, gptScript, func() (types.Program, error) {
return r.readProgram(ctx, gptScript, args)
}, os.Environ(), toolInput)
}

s, err := gptScript.Run(r.NewRunContext(cmd), prg, os.Environ(), toolInput)
s, err := gptScript.Run(cmd.Context(), prg, os.Environ(), toolInput)
if err != nil {
return err
}
Expand Down
45 changes: 0 additions & 45 deletions pkg/confirm/confirm.go

This file was deleted.

2 changes: 1 addition & 1 deletion pkg/mvl/log.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func (f formatter) Format(entry *logrus.Entry) ([]byte, error) {
}
d, _ := json.Marshal(i)
i = string(d)
i = strings.TrimSpace(i[1 : len(i)-2])
i = strings.TrimSpace(i[1 : len(i)-1])
if addDot {
i += "..."
}
Expand Down
4 changes: 3 additions & 1 deletion pkg/openai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,9 @@ func appendMessage(msg types.CompletionMessage, response openai.ChatCompletionSt
tc.ToolCall.Index = tool.Index
}
tc.ToolCall.ID = override(tc.ToolCall.ID, tool.ID)
tc.ToolCall.Function.Name += tool.Function.Name
if tc.ToolCall.Function.Name != tool.Function.Name {
tc.ToolCall.Function.Name += tool.Function.Name
}
tc.ToolCall.Function.Arguments += tool.Function.Arguments

msg.Content[idx] = tc
Expand Down
36 changes: 36 additions & 0 deletions pkg/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,20 @@ type Options struct {
EndPort int64 `usage:"-"`
CredentialOverride string `usage:"-"`
Sequential bool `usage:"-"`
Authorizer AuthorizerFunc `usage:"-"`
}

type AuthorizerResponse struct {
Accept bool
Message string
}

type AuthorizerFunc func(ctx engine.Context, input string) (AuthorizerResponse, error)

func DefaultAuthorizer(_ engine.Context, _ string) (AuthorizerResponse, error) {
return AuthorizerResponse{
Accept: true,
}, nil
}

func complete(opts ...Options) (result Options) {
Expand All @@ -46,6 +60,9 @@ func complete(opts ...Options) (result Options) {
result.EndPort = types.FirstSet(opt.EndPort, result.EndPort)
result.CredentialOverride = types.FirstSet(opt.CredentialOverride, result.CredentialOverride)
result.Sequential = types.FirstSet(opt.Sequential, result.Sequential)
if opt.Authorizer != nil {
result.Authorizer = opt.Authorizer
}
}
if result.MonitorFactory == nil {
result.MonitorFactory = noopFactory{}
Expand All @@ -56,11 +73,15 @@ func complete(opts ...Options) (result Options) {
if result.StartPort == 0 {
result.StartPort = result.EndPort
}
if result.Authorizer == nil {
result.Authorizer = DefaultAuthorizer
}
return
}

type Runner struct {
c engine.Model
auth AuthorizerFunc
factory MonitorFactory
runtimeManager engine.RuntimeManager
ports engine.Ports
Expand All @@ -81,6 +102,7 @@ func New(client engine.Model, credCtx string, opts ...Options) (*Runner, error)
credMutex: sync.Mutex{},
credOverrides: opt.CredentialOverride,
sequential: opt.Sequential,
auth: opt.Authorizer,
}

if opt.StartPort != 0 {
Expand Down Expand Up @@ -405,6 +427,20 @@ func (r *Runner) start(callCtx engine.Context, state *State, monitor Monitor, en

callCtx.Ctx = context2.AddPauseFuncToCtx(callCtx.Ctx, monitor.Pause)

authResp, err := r.auth(callCtx, input)
if err != nil {
return nil, err
}

if !authResp.Accept {
msg := fmt.Sprintf("[AUTHORIZATION ERROR]: %s", authResp.Message)
return &State{
Continuation: &engine.Return{
Result: &msg,
},
}, nil
}

ret, err := e.Start(callCtx, input)
if err != nil {
return nil, err
Expand Down