From 0f71912de2a2a682c3809f1f03935edafc4fca8d Mon Sep 17 00:00:00 2001 From: Darren Shepherd Date: Mon, 3 Jun 2024 17:35:04 -0700 Subject: [PATCH] chore: move prompt to always http based --- pkg/builtin/builtin.go | 78 +--------------------------- pkg/engine/cmd.go | 7 --- pkg/gptscript/gptscript.go | 20 +++++++- pkg/monitor/display.go | 11 +++- pkg/monitor/fd.go | 4 ++ pkg/prompt/prompt.go | 102 +++++++++++++++++++++++++++++++++++++ pkg/prompt/server.go | 76 +++++++++++++++++++++++++++ pkg/runner/monitor.go | 4 ++ pkg/runner/runner.go | 1 + pkg/sdkserver/monitor.go | 4 ++ pkg/server/server.go | 4 ++ pkg/types/prompt.go | 5 +- 12 files changed, 229 insertions(+), 87 deletions(-) create mode 100644 pkg/prompt/prompt.go create mode 100644 pkg/prompt/server.go diff --git a/pkg/builtin/builtin.go b/pkg/builtin/builtin.go index 8946c76b..a4aeb74e 100644 --- a/pkg/builtin/builtin.go +++ b/pkg/builtin/builtin.go @@ -1,7 +1,6 @@ package builtin import ( - "bytes" "context" "encoding/json" "errors" @@ -18,9 +17,9 @@ import ( "strings" "time" - "github.com/AlecAivazis/survey/v2" "github.com/BurntSushi/locker" "github.com/gptscript-ai/gptscript/pkg/engine" + "github.com/gptscript-ai/gptscript/pkg/prompt" "github.com/gptscript-ai/gptscript/pkg/types" "github.com/jaytaylor/html2text" ) @@ -216,7 +215,7 @@ var tools = map[string]types.Tool{ "sensitive", "(true or false) Whether the input should be hidden", ), }, - BuiltinFunc: SysPrompt, + BuiltinFunc: prompt.SysPrompt, }, }, "sys.chat.history": { @@ -772,79 +771,6 @@ func SysDownload(_ context.Context, env []string, input string) (_ string, err e return fmt.Sprintf("Downloaded %s to %s", params.URL, params.Location), nil } -func sysPromptHTTP(ctx context.Context, url string, prompt types.Prompt) (_ string, err error) { - data, err := json.Marshal(prompt) - if err != nil { - return "", err - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(data)) - if err != nil { - return "", err - } - req.Header.Set("Content-Type", "application/json") - - resp, err := http.DefaultClient.Do(req) - if err != nil { - return "", err - } - defer resp.Body.Close() - - if resp.StatusCode != 200 { - return "", fmt.Errorf("invalid status code [%d], expected 200", resp.StatusCode) - } - - data, err = io.ReadAll(resp.Body) - return string(data), err -} - -func SysPrompt(ctx context.Context, envs []string, input string) (_ string, err error) { - var params struct { - Message string `json:"message,omitempty"` - Fields string `json:"fields,omitempty"` - Sensitive string `json:"sensitive,omitempty"` - } - if err := json.Unmarshal([]byte(input), ¶ms); err != nil { - return "", err - } - - for _, env := range envs { - if url, ok := strings.CutPrefix(env, types.PromptURLEnvVar+"="); ok { - httpPrompt := types.Prompt{ - Message: params.Message, - Fields: strings.Split(params.Fields, ","), - Sensitive: params.Sensitive == "true", - } - return sysPromptHTTP(ctx, url, httpPrompt) - } - } - - if params.Message != "" { - _, _ = fmt.Fprintln(os.Stderr, params.Message) - } - - results := map[string]string{} - for _, f := range strings.Split(params.Fields, ",") { - var value string - if params.Sensitive == "true" { - err = survey.AskOne(&survey.Password{Message: f}, &value, survey.WithStdio(os.Stdin, os.Stderr, os.Stderr)) - } else { - err = survey.AskOne(&survey.Input{Message: f}, &value, survey.WithStdio(os.Stdin, os.Stderr, os.Stderr)) - } - if err != nil { - return "", err - } - results[f] = value - } - - resultsStr, err := json.Marshal(results) - if err != nil { - return "", err - } - - return string(resultsStr), nil -} - func SysTimeNow(context.Context, []string, string) (string, error) { return time.Now().Format(time.RFC3339), nil } diff --git a/pkg/engine/cmd.go b/pkg/engine/cmd.go index 680fdffd..1f775b74 100644 --- a/pkg/engine/cmd.go +++ b/pkg/engine/cmd.go @@ -14,7 +14,6 @@ import ( "strings" "github.com/google/shlex" - context2 "github.com/gptscript-ai/gptscript/pkg/context" "github.com/gptscript-ai/gptscript/pkg/counter" "github.com/gptscript-ai/gptscript/pkg/env" "github.com/gptscript-ai/gptscript/pkg/types" @@ -73,12 +72,6 @@ func (e *Engine) runCommand(ctx Context, tool types.Tool, input string, toolCate cmd.Stderr = io.MultiWriter(all, os.Stderr) cmd.Stdout = io.MultiWriter(all, output) - if toolCategory == CredentialToolCategory { - pause := context2.GetPauseFuncFromCtx(ctx.Ctx) - unpause := pause() - defer unpause() - } - if err := cmd.Run(); err != nil { if toolCategory == NoCategory { return fmt.Sprintf("ERROR: got (%v) while running tool, OUTPUT: %s", err, all), nil diff --git a/pkg/gptscript/gptscript.go b/pkg/gptscript/gptscript.go index d079e845..a92de044 100644 --- a/pkg/gptscript/gptscript.go +++ b/pkg/gptscript/gptscript.go @@ -5,15 +5,18 @@ import ( "fmt" "os" "path/filepath" + "slices" "github.com/gptscript-ai/gptscript/pkg/builtin" "github.com/gptscript-ai/gptscript/pkg/cache" + context2 "github.com/gptscript-ai/gptscript/pkg/context" "github.com/gptscript-ai/gptscript/pkg/engine" "github.com/gptscript-ai/gptscript/pkg/hash" "github.com/gptscript-ai/gptscript/pkg/llm" "github.com/gptscript-ai/gptscript/pkg/monitor" "github.com/gptscript-ai/gptscript/pkg/mvl" "github.com/gptscript-ai/gptscript/pkg/openai" + "github.com/gptscript-ai/gptscript/pkg/prompt" "github.com/gptscript-ai/gptscript/pkg/remote" "github.com/gptscript-ai/gptscript/pkg/repos/runtimes" "github.com/gptscript-ai/gptscript/pkg/runner" @@ -28,6 +31,8 @@ type GPTScript struct { Cache *cache.Client WorkspacePath string DeleteWorkspaceOnClose bool + extraEnv []string + close func() } type Options struct { @@ -96,12 +101,21 @@ func New(opts *Options) (*GPTScript, error) { return nil, err } + ctx, closeServer := context.WithCancel(context2.AddPauseFuncToCtx(context.Background(), opts.Runner.MonitorFactory.Pause)) + extraEnv, err := prompt.NewServer(ctx, opts.Env) + if err != nil { + closeServer() + return nil, err + } + return &GPTScript{ Registry: registry, Runner: runner, Cache: cacheClient, WorkspacePath: opts.Workspace, DeleteWorkspaceOnClose: opts.Workspace == "", + extraEnv: extraEnv, + close: closeServer, }, nil } @@ -122,10 +136,10 @@ func (g *GPTScript) getEnv(env []string) ([]string, error) { if err := os.MkdirAll(g.WorkspacePath, 0700); err != nil { return nil, err } - return append([]string{ + return slices.Concat(g.extraEnv, []string{ fmt.Sprintf("GPTSCRIPT_WORKSPACE_DIR=%s", g.WorkspacePath), fmt.Sprintf("GPTSCRIPT_WORKSPACE_ID=%s", hash.ID(g.WorkspacePath)), - }, env...), nil + }, env), nil } func (g *GPTScript) Chat(ctx context.Context, prevState runner.ChatState, prg types.Program, envs []string, input string) (runner.ChatResponse, error) { @@ -153,6 +167,8 @@ func (g *GPTScript) Close(closeDaemons bool) { } } + g.close() + if closeDaemons { engine.CloseDaemons() } diff --git a/pkg/monitor/display.go b/pkg/monitor/display.go index 2ddb236e..bbdfc8ee 100644 --- a/pkg/monitor/display.go +++ b/pkg/monitor/display.go @@ -38,6 +38,7 @@ type Console struct { dumpState string displayProgress bool printMessages bool + callLock sync.Mutex } var ( @@ -47,6 +48,7 @@ var ( func (c *Console) Start(_ context.Context, prg *types.Program, _ []string, input string) (runner.Monitor, error) { id := counter.Next() mon := newDisplay(c.dumpState, c.displayProgress, c.printMessages) + mon.callLock = &c.callLock mon.dump.ID = fmt.Sprint(id) mon.dump.Program = prg mon.dump.Input = input @@ -55,13 +57,20 @@ func (c *Console) Start(_ context.Context, prg *types.Program, _ []string, input return mon, nil } +func (c *Console) Pause() func() { + c.callLock.Lock() + return func() { + c.callLock.Unlock() + } +} + type display struct { dump dump printMessages bool livePrinter *livePrinter dumpState string callIDMap map[string]string - callLock sync.Mutex + callLock *sync.Mutex usage types.Usage } diff --git a/pkg/monitor/fd.go b/pkg/monitor/fd.go index 08b73ed3..8cfeeede 100644 --- a/pkg/monitor/fd.go +++ b/pkg/monitor/fd.go @@ -70,6 +70,10 @@ func (s *fileFactory) Start(_ context.Context, prg *types.Program, env []string, return fd, nil } +func (s *fileFactory) Pause() func() { + return func() {} +} + func (s *fileFactory) close() { s.lock.Lock() defer s.lock.Unlock() diff --git a/pkg/prompt/prompt.go b/pkg/prompt/prompt.go new file mode 100644 index 00000000..047a6abc --- /dev/null +++ b/pkg/prompt/prompt.go @@ -0,0 +1,102 @@ +package prompt + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "strings" + + "github.com/AlecAivazis/survey/v2" + context2 "github.com/gptscript-ai/gptscript/pkg/context" + "github.com/gptscript-ai/gptscript/pkg/types" +) + +func sysPromptHTTP(ctx context.Context, envs []string, url string, prompt types.Prompt) (_ string, err error) { + data, err := json.Marshal(prompt) + if err != nil { + return "", err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(data)) + if err != nil { + return "", err + } + req.Header.Set("Content-Type", "application/json") + + for _, env := range envs { + if _, v, ok := strings.Cut(env, types.PromptTokenEnvVar+"="); ok && v != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", v)) + break + } + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + return "", fmt.Errorf("invalid status code [%d], expected 200", resp.StatusCode) + } + + data, err = io.ReadAll(resp.Body) + return string(data), err +} + +func SysPrompt(ctx context.Context, envs []string, input string) (_ string, err error) { + var params struct { + Message string `json:"message,omitempty"` + Fields string `json:"fields,omitempty"` + Sensitive string `json:"sensitive,omitempty"` + } + if err := json.Unmarshal([]byte(input), ¶ms); err != nil { + return "", err + } + + for _, env := range envs { + if url, ok := strings.CutPrefix(env, types.PromptURLEnvVar+"="); ok { + httpPrompt := types.Prompt{ + Message: params.Message, + Fields: strings.Split(params.Fields, ","), + Sensitive: params.Sensitive == "true", + } + return sysPromptHTTP(ctx, envs, url, httpPrompt) + } + } + + return "", fmt.Errorf("no prompt server found, can not continue") +} + +func sysPrompt(ctx context.Context, req types.Prompt) (_ string, err error) { + defer context2.GetPauseFuncFromCtx(ctx)()() + + if req.Message != "" { + _, _ = fmt.Fprintln(os.Stderr, req.Message) + } + + results := map[string]string{} + for _, f := range req.Fields { + var value string + if req.Sensitive { + err = survey.AskOne(&survey.Password{Message: f}, &value, survey.WithStdio(os.Stdin, os.Stderr, os.Stderr)) + } else { + err = survey.AskOne(&survey.Input{Message: f}, &value, survey.WithStdio(os.Stdin, os.Stderr, os.Stderr)) + } + if err != nil { + return "", err + } + results[f] = value + } + + resultsStr, err := json.Marshal(results) + if err != nil { + return "", err + } + + return string(resultsStr), nil +} diff --git a/pkg/prompt/server.go b/pkg/prompt/server.go new file mode 100644 index 00000000..36fa44e3 --- /dev/null +++ b/pkg/prompt/server.go @@ -0,0 +1,76 @@ +package prompt + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log" + "net" + "net/http" + "strings" + + "github.com/google/uuid" + "github.com/gptscript-ai/gptscript/pkg/types" +) + +func NewServer(ctx context.Context, envs []string) ([]string, error) { + for _, env := range envs { + _, v, ok := strings.Cut(env, types.PromptTokenEnvVar+"=") + if ok && v != "" { + return nil, nil + } + } + + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return nil, err + } + + token := uuid.NewString() + s := http.Server{ + BaseContext: func(_ net.Listener) context.Context { + return ctx + }, + Handler: http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + if r.Header.Get("Authorization") != "Bearer "+token { + rw.WriteHeader(http.StatusUnauthorized) + _, _ = rw.Write([]byte("Unauthorized")) + return + } + + var req types.Prompt + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + rw.WriteHeader(http.StatusBadRequest) + _, _ = rw.Write([]byte(err.Error())) + return + } + + resp, err := sysPrompt(r.Context(), req) + if err != nil { + rw.WriteHeader(http.StatusInternalServerError) + _, _ = rw.Write([]byte(err.Error())) + return + } + + rw.Header().Set("Content-Type", "application/json") + rw.WriteHeader(http.StatusOK) + _, _ = rw.Write([]byte(resp)) + }), + } + + context.AfterFunc(ctx, func() { + _ = s.Shutdown(context.Background()) + }) + + go func() { + if err := s.Serve(l); err != nil && !errors.Is(err, http.ErrServerClosed) { + log.Fatalf("failed to run prompt server: %v", err) + } + }() + + return []string{ + fmt.Sprintf("%s=http://%s", types.PromptURLEnvVar, l.Addr().String()), + fmt.Sprintf("%s=%s", types.PromptTokenEnvVar, token), + }, nil +} diff --git a/pkg/runner/monitor.go b/pkg/runner/monitor.go index 87543eda..48e64b0b 100644 --- a/pkg/runner/monitor.go +++ b/pkg/runner/monitor.go @@ -13,6 +13,10 @@ func (n noopFactory) Start(context.Context, *types.Program, []string, string) (M return noopMonitor{}, nil } +func (n noopFactory) Pause() func() { + return func() {} +} + type noopMonitor struct { } diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index fbcbdfa5..57ea7b38 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -21,6 +21,7 @@ import ( type MonitorFactory interface { Start(ctx context.Context, prg *types.Program, env []string, input string) (Monitor, error) + Pause() func() } type Monitor interface { diff --git a/pkg/sdkserver/monitor.go b/pkg/sdkserver/monitor.go index c0aa6090..5b06771c 100644 --- a/pkg/sdkserver/monitor.go +++ b/pkg/sdkserver/monitor.go @@ -44,6 +44,10 @@ func (s SessionFactory) Start(ctx context.Context, prg *types.Program, env []str }, nil } +func (s SessionFactory) Pause() func() { + return func() {} +} + type Session struct { id string prj *types.Program diff --git a/pkg/server/server.go b/pkg/server/server.go index d1605886..9f687783 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -324,6 +324,10 @@ func (s SessionFactory) Start(ctx context.Context, prg *types.Program, env []str }, nil } +func (s SessionFactory) Pause() func() { + return func() {} +} + type Session struct { id string prj *types.Program diff --git a/pkg/types/prompt.go b/pkg/types/prompt.go index ec38ea34..ea17c11c 100644 --- a/pkg/types/prompt.go +++ b/pkg/types/prompt.go @@ -1,6 +1,9 @@ package types -const PromptURLEnvVar = "GPTSCRIPT_PROMPT_URL" +const ( + PromptURLEnvVar = "GPTSCRIPT_PROMPT_URL" + PromptTokenEnvVar = "GPTSCRIPT_PROMPT_TOKEN" +) type Prompt struct { Message string `json:"message,omitempty"`