diff --git a/pkg/builtin/builtin.go b/pkg/builtin/builtin.go index bda5c33a..8946c76b 100644 --- a/pkg/builtin/builtin.go +++ b/pkg/builtin/builtin.go @@ -265,7 +265,7 @@ func Builtin(name string) (types.Tool, bool) { return SetDefaults(t), ok } -func SysFind(ctx context.Context, env []string, input string) (string, error) { +func SysFind(_ context.Context, _ []string, input string) (string, error) { var result []string var params struct { Pattern string `json:"pattern,omitempty"` @@ -306,7 +306,7 @@ func SysFind(ctx context.Context, env []string, input string) (string, error) { return strings.Join(result, "\n"), nil } -func SysExec(ctx context.Context, env []string, input string) (string, error) { +func SysExec(_ context.Context, env []string, input string) (string, error) { var params struct { Command string `json:"command,omitempty"` Directory string `json:"directory,omitempty"` @@ -412,7 +412,7 @@ func SysRead(_ context.Context, _ []string, input string) (string, error) { return string(data), nil } -func SysWrite(ctx context.Context, _ []string, input string) (string, error) { +func SysWrite(_ context.Context, _ []string, input string) (string, error) { var params struct { Filename string `json:"filename,omitempty"` Content string `json:"content,omitempty"` @@ -444,7 +444,7 @@ func SysWrite(ctx context.Context, _ []string, input string) (string, error) { return fmt.Sprintf("Wrote (%d) bytes to file %s", len(data), file), nil } -func SysAppend(ctx context.Context, env []string, input string) (string, error) { +func SysAppend(_ context.Context, _ []string, input string) (string, error) { var params struct { Filename string `json:"filename,omitempty"` Content string `json:"content,omitempty"` @@ -490,7 +490,7 @@ func fixQueries(u string) string { return url.String() } -func SysHTTPGet(ctx context.Context, env []string, input string) (_ string, err error) { +func SysHTTPGet(_ context.Context, _ []string, input string) (_ string, err error) { var params struct { URL string `json:"url,omitempty"` } @@ -534,7 +534,7 @@ func SysHTTPHtml2Text(ctx context.Context, env []string, input string) (string, }) } -func SysHTTPPost(ctx context.Context, env []string, input string) (_ string, err error) { +func SysHTTPPost(ctx context.Context, _ []string, input string) (_ string, err error) { var params struct { URL string `json:"url,omitempty"` Content string `json:"content,omitempty"` @@ -570,7 +570,7 @@ func SysHTTPPost(ctx context.Context, env []string, input string) (_ string, err return fmt.Sprintf("Wrote %d to %s", len([]byte(params.Content)), params.URL), nil } -func SysGetenv(ctx context.Context, env []string, input string) (string, error) { +func SysGetenv(_ context.Context, env []string, input string) (string, error) { var params struct { Name string `json:"name,omitempty"` } @@ -636,7 +636,7 @@ func writeHistory(ctx *engine.Context) (result []engine.ChatHistoryCall) { return } -func SysChatFinish(ctx context.Context, env []string, input string) (string, error) { +func SysChatFinish(_ context.Context, _ []string, input string) (string, error) { var params struct { Message string `json:"return,omitempty"` } @@ -650,7 +650,7 @@ func SysChatFinish(ctx context.Context, env []string, input string) (string, err } } -func SysAbort(ctx context.Context, env []string, input string) (string, error) { +func SysAbort(_ context.Context, _ []string, input string) (string, error) { var params struct { Message string `json:"message,omitempty"` } @@ -660,7 +660,7 @@ func SysAbort(ctx context.Context, env []string, input string) (string, error) { return "", fmt.Errorf("ABORT: %s", params.Message) } -func SysRemove(ctx context.Context, env []string, input string) (string, error) { +func SysRemove(_ context.Context, _ []string, input string) (string, error) { var params struct { Location string `json:"location,omitempty"` } @@ -679,7 +679,7 @@ func SysRemove(ctx context.Context, env []string, input string) (string, error) return fmt.Sprintf("Removed file: %s", params.Location), nil } -func SysStat(ctx context.Context, env []string, input string) (string, error) { +func SysStat(_ context.Context, _ []string, input string) (string, error) { var params struct { Filepath string `json:"filepath,omitempty"` } @@ -699,7 +699,7 @@ func SysStat(ctx context.Context, env []string, input string) (string, error) { return fmt.Sprintf("%s %s mode: %s, size: %d bytes, modtime: %s", title, params.Filepath, stat.Mode().String(), stat.Size(), stat.ModTime().String()), nil } -func SysDownload(ctx context.Context, env []string, input string) (_ string, err error) { +func SysDownload(_ context.Context, env []string, input string) (_ string, err error) { var params struct { URL string `json:"url,omitempty"` Location string `json:"location,omitempty"` @@ -772,12 +772,8 @@ func SysDownload(ctx context.Context, env []string, input string) (_ string, err return fmt.Sprintf("Downloaded %s to %s", params.URL, params.Location), nil } -func sysPromptHTTP(ctx context.Context, url, message string, fields []string, sensitive bool) (_ string, err error) { - data, err := json.Marshal(map[string]any{ - "message": message, - "fields": fields, - "sensitive": sensitive, - }) +func sysPromptHTTP(ctx context.Context, url string, prompt types.Prompt) (_ string, err error) { + data, err := json.Marshal(prompt) if err != nil { return "", err } @@ -792,7 +788,7 @@ func sysPromptHTTP(ctx context.Context, url, message string, fields []string, se if err != nil { return "", err } - resp.Body.Close() + defer resp.Body.Close() if resp.StatusCode != 200 { return "", fmt.Errorf("invalid status code [%d], expected 200", resp.StatusCode) @@ -813,8 +809,13 @@ func SysPrompt(ctx context.Context, envs []string, input string) (_ string, err } for _, env := range envs { - if url, ok := strings.CutPrefix(env, "GPTSCRIPT_PROMPT_URL="); ok { - return sysPromptHTTP(ctx, url, params.Message, strings.Split(params.Fields, ","), params.Sensitive == "true") + if url, ok := strings.CutPrefix(env, types.PromptURLEnvVar+"="); ok { + httpPrompt := types.Prompt{ + Message: params.Message, + Fields: strings.Split(params.Fields, ","), + Sensitive: params.Sensitive == "true", + } + return sysPromptHTTP(ctx, url, httpPrompt) } } @@ -844,6 +845,6 @@ func SysPrompt(ctx context.Context, envs []string, input string) (_ string, err return string(resultsStr), nil } -func SysTimeNow(ctx context.Context, env []string, input string) (string, error) { +func SysTimeNow(context.Context, []string, string) (string, error) { return time.Now().Format(time.RFC3339), nil } diff --git a/pkg/sdkserver/confirm.go b/pkg/sdkserver/confirm.go index 2ed34bb3..05ca6e5d 100644 --- a/pkg/sdkserver/confirm.go +++ b/pkg/sdkserver/confirm.go @@ -39,14 +39,16 @@ func (s *server) authorize(ctx engine.Context, input string) (runner.AuthorizerR s.lock.Unlock() }(ctx.ID) - s.events.C <- gserver.Event{ - Event: runner.Event{ - Time: time.Now(), - CallContext: ctx.GetCallContext(), - Type: CallConfirm, + s.events.C <- event{ + Event: gserver.Event{ + Event: runner.Event{ + Time: time.Now(), + CallContext: ctx.GetCallContext(), + Type: CallConfirm, + }, + Input: input, + RunID: runID, }, - Input: input, - RunID: runID, } // Wait for the confirmation to come through. diff --git a/pkg/sdkserver/monitor.go b/pkg/sdkserver/monitor.go new file mode 100644 index 00000000..c0aa6090 --- /dev/null +++ b/pkg/sdkserver/monitor.go @@ -0,0 +1,94 @@ +package sdkserver + +import ( + "context" + "sync" + "time" + + "github.com/acorn-io/broadcaster" + "github.com/gptscript-ai/gptscript/pkg/runner" + gserver "github.com/gptscript-ai/gptscript/pkg/server" + "github.com/gptscript-ai/gptscript/pkg/types" +) + +type SessionFactory struct { + events *broadcaster.Broadcaster[event] +} + +func NewSessionFactory(events *broadcaster.Broadcaster[event]) *SessionFactory { + return &SessionFactory{ + events: events, + } +} + +func (s SessionFactory) Start(ctx context.Context, prg *types.Program, env []string, input string) (runner.Monitor, error) { + id := gserver.RunIDFromContext(ctx) + + s.events.C <- event{ + Event: gserver.Event{ + Event: runner.Event{ + Time: time.Now(), + Type: runner.EventTypeRunStart, + }, + RunID: id, + Program: prg, + }, + } + + return &Session{ + id: id, + prj: prg, + env: env, + input: input, + events: s.events, + }, nil +} + +type Session struct { + id string + prj *types.Program + env []string + input string + events *broadcaster.Broadcaster[event] + runLock sync.Mutex +} + +func (s *Session) Event(e runner.Event) { + s.runLock.Lock() + defer s.runLock.Unlock() + s.events.C <- event{ + Event: gserver.Event{ + Event: e, + RunID: s.id, + Input: s.input, + }, + } +} + +func (s *Session) Stop(output string, err error) { + e := event{ + Event: gserver.Event{ + Event: runner.Event{ + Time: time.Now(), + Type: runner.EventTypeRunFinish, + }, + RunID: s.id, + Input: s.input, + Output: output, + }, + } + if err != nil { + e.Err = err.Error() + } + + s.runLock.Lock() + defer s.runLock.Unlock() + s.events.C <- e +} + +func (s *Session) Pause() func() { + s.runLock.Lock() + return func() { + s.runLock.Unlock() + } +} diff --git a/pkg/sdkserver/prompt.go b/pkg/sdkserver/prompt.go new file mode 100644 index 00000000..a5b15162 --- /dev/null +++ b/pkg/sdkserver/prompt.go @@ -0,0 +1,111 @@ +package sdkserver + +import ( + "encoding/json" + "fmt" + "net/http" + "time" + + gcontext "github.com/gptscript-ai/gptscript/pkg/context" + "github.com/gptscript-ai/gptscript/pkg/mvl" + "github.com/gptscript-ai/gptscript/pkg/runner" + gserver "github.com/gptscript-ai/gptscript/pkg/server" + "github.com/gptscript-ai/gptscript/pkg/types" +) + +func (s *server) promptResponse(w http.ResponseWriter, r *http.Request) { + logger := gcontext.GetLogger(r.Context()) + id := r.PathValue("id") + + s.lock.RLock() + promptChan := s.waitingToPrompt[id] + s.lock.RUnlock() + + if promptChan == nil { + writeError(logger, w, http.StatusNotFound, fmt.Errorf("no prompt found with id %q", id)) + return + } + + var promptResponse map[string]string + if err := json.NewDecoder(r.Body).Decode(&promptResponse); err != nil { + writeError(logger, w, http.StatusBadRequest, fmt.Errorf("failed to decode request body: %w", err)) + return + } + + // Don't block here because, if the prompter is no longer waiting on this then it will never unblock. + select { + case promptChan <- promptResponse: + w.WriteHeader(http.StatusAccepted) + default: + w.WriteHeader(http.StatusConflict) + } +} + +func (s *server) prompt(w http.ResponseWriter, r *http.Request) { + logger := gcontext.GetLogger(r.Context()) + id := r.PathValue("id") + + s.lock.RLock() + promptChan := s.waitingToPrompt[id] + s.lock.RUnlock() + + if promptChan != nil { + writeError(logger, w, http.StatusBadRequest, fmt.Errorf("prompt called multiple times for same ID: %s", id)) + return + } + + var prompt types.Prompt + if err := json.NewDecoder(r.Body).Decode(&prompt); err != nil { + writeError(logger, w, http.StatusBadRequest, fmt.Errorf("failed to decode request body: %v", err)) + return + } + + s.lock.Lock() + promptChan = make(chan map[string]string) + s.waitingToPrompt[id] = promptChan + s.lock.Unlock() + defer func(id string) { + s.lock.Lock() + delete(s.waitingToPrompt, id) + s.lock.Unlock() + }(id) + + s.events.C <- event{ + Prompt: types.Prompt{ + Message: prompt.Message, + Fields: prompt.Fields, + Sensitive: prompt.Sensitive, + }, + Event: gserver.Event{ + RunID: id, + Event: runner.Event{ + Type: Prompt, + Time: time.Now(), + }, + }, + } + + // Wait for the prompt response to come through. + select { + case <-r.Context().Done(): + writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("context canceled: %v", r.Context().Err())) + return + case promptResponse := <-promptChan: + writePromptResponse(logger, w, http.StatusOK, promptResponse) + } +} + +func writePromptResponse(logger mvl.Logger, w http.ResponseWriter, code int, resp any) { + b, err := json.Marshal(resp) + if err != nil { + logger.Errorf("failed to marshal response: %v", err) + w.WriteHeader(http.StatusInternalServerError) + } else { + w.WriteHeader(code) + } + + _, err = w.Write(b) + if err != nil { + logger.Errorf("failed to write response: %v", err) + } +} diff --git a/pkg/sdkserver/routes.go b/pkg/sdkserver/routes.go index 913e5a77..decdda0b 100644 --- a/pkg/sdkserver/routes.go +++ b/pkg/sdkserver/routes.go @@ -26,11 +26,13 @@ import ( const toolRunTimeout = 15 * time.Minute type server struct { - client *gptscript.GPTScript - events *broadcaster.Broadcaster[gserver.Event] + address string + client *gptscript.GPTScript + events *broadcaster.Broadcaster[event] lock sync.RWMutex waitingToConfirm map[string]chan runner.AuthorizerResponse + waitingToPrompt map[string]chan map[string]string } func (s *server) addRoutes(mux *http.ServeMux) { @@ -52,6 +54,8 @@ func (s *server) addRoutes(mux *http.ServeMux) { mux.HandleFunc("POST /fmt", s.fmtDocument) mux.HandleFunc("POST /confirm/{id}", s.confirm) + mux.HandleFunc("POST /prompt/{id}", s.prompt) + mux.HandleFunc("POST /prompt-response/{id}", s.promptResponse) } // health just provides an endpoint for checking whether the server is running and accessible. @@ -148,7 +152,9 @@ func (s *server) execHandler(w http.ResponseWriter, r *http.Request) { return } - ctx, cancel := context.WithTimeout(gserver.ContextWithNewRunID(r.Context()), toolRunTimeout) + ctx := gserver.ContextWithNewRunID(r.Context()) + runID := gserver.RunIDFromContext(ctx) + ctx, cancel := context.WithTimeout(ctx, toolRunTimeout) defer cancel() // Ensure chat state is not empty. @@ -156,6 +162,9 @@ func (s *server) execHandler(w http.ResponseWriter, r *http.Request) { reqObject.ChatState = "null" } + // Append a prompt URL for this run. + reqObject.Env = append(reqObject.Env, fmt.Sprintf("%s=http://%s/prompt/%s", types.PromptURLEnvVar, s.address, runID)) + logger.Debugf("executing tool: %+v", reqObject) var ( def fmt.Stringer = &reqObject.ToolDef @@ -170,12 +179,12 @@ func (s *server) execHandler(w http.ResponseWriter, r *http.Request) { opts := &gptscript.Options{ Cache: reqObject.Options, - Env: reqObject.Env, + Env: append(os.Environ(), reqObject.Env...), Workspace: reqObject.Workspace, CredentialContext: reqObject.CredentialContext, Runner: runner.Options{ // Set the monitor factory so that we can get events from the server. - MonitorFactory: gserver.NewSessionFactory(s.events), + MonitorFactory: NewSessionFactory(s.events), }, } diff --git a/pkg/sdkserver/run.go b/pkg/sdkserver/run.go index 97194f87..33740899 100644 --- a/pkg/sdkserver/run.go +++ b/pkg/sdkserver/run.go @@ -51,7 +51,7 @@ func (s *server) execAndStream(ctx context.Context, programLoader loaderFunc, lo // processEventStreamOutput will stream the events of the tool to the response as server sent events. // If an error occurs, then an event with the error will also be sent. -func processEventStreamOutput(ctx context.Context, logger mvl.Logger, w http.ResponseWriter, id string, events <-chan gserver.Event, output <-chan runner.ChatResponse, errChan chan error) { +func processEventStreamOutput(ctx context.Context, logger mvl.Logger, w http.ResponseWriter, id string, events <-chan event, output <-chan runner.ChatResponse, errChan chan error) { run := newRun(id) setStreamingHeaders(w) @@ -82,7 +82,7 @@ func processEventStreamOutput(ctx context.Context, logger mvl.Logger, w http.Res } // streamEvents will stream the events of the tool to the response as server sent events. -func streamEvents(ctx context.Context, logger mvl.Logger, w http.ResponseWriter, run *runInfo, events <-chan gserver.Event) { +func streamEvents(ctx context.Context, logger mvl.Logger, w http.ResponseWriter, run *runInfo, events <-chan event) { logger.Debugf("receiving events") for { select { diff --git a/pkg/sdkserver/server.go b/pkg/sdkserver/server.go index 2a134af7..1e1cacc1 100644 --- a/pkg/sdkserver/server.go +++ b/pkg/sdkserver/server.go @@ -16,7 +16,6 @@ import ( "github.com/gptscript-ai/gptscript/pkg/gptscript" "github.com/gptscript-ai/gptscript/pkg/mvl" "github.com/gptscript-ai/gptscript/pkg/runner" - gserver "github.com/gptscript-ai/gptscript/pkg/server" "github.com/rs/cors" ) @@ -42,8 +41,8 @@ func Start(ctx context.Context, opts Options) error { mvl.SetDebug() } - events := broadcaster.New[gserver.Event]() - opts.Options.Runner.MonitorFactory = gserver.NewSessionFactory(events) + events := broadcaster.New[event]() + opts.Options.Runner.MonitorFactory = NewSessionFactory(events) go events.Start(ctx) g, err := gptscript.New(&opts.Options) @@ -52,9 +51,11 @@ func Start(ctx context.Context, opts Options) error { } s := &server{ + address: opts.ListenAddress, client: g, events: events, waitingToConfirm: make(map[string]chan runner.AuthorizerResponse), + waitingToPrompt: make(map[string]chan map[string]string), } defer s.Close() diff --git a/pkg/sdkserver/types.go b/pkg/sdkserver/types.go index 9f4f5e31..5b98e07b 100644 --- a/pkg/sdkserver/types.go +++ b/pkg/sdkserver/types.go @@ -22,6 +22,7 @@ const ( Error runState = "error" CallConfirm runner.EventType = "callConfirm" + Prompt runner.EventType = "prompt" ) type toolOrFileRequest struct { @@ -89,20 +90,26 @@ func newRun(id string) *runInfo { type runEvent struct { runInfo `json:",inline"` - - Type runner.EventType `json:"type"` + Type runner.EventType `json:"type"` } -func (r *runInfo) process(event gserver.Event) map[string]any { - switch event.Type { +func (r *runInfo) process(e event) map[string]any { + switch e.Type { + case Prompt: + return map[string]any{"prompt": prompt{ + Prompt: e.Prompt, + ID: e.RunID, + Type: e.Type, + Time: e.Time, + }} case runner.EventTypeRunStart: - r.Start = event.Time - r.Program = *event.Program + r.Start = e.Time + r.Program = *e.Program r.State = Running case runner.EventTypeRunFinish: - r.End = event.Time - r.Output = event.Output - r.Error = event.Err + r.End = e.Time + r.Output = e.Output + r.Error = e.Err if r.Error != "" { r.State = Error } else { @@ -110,42 +117,42 @@ func (r *runInfo) process(event gserver.Event) map[string]any { } } - if event.CallContext == nil || event.CallContext.ID == "" { + if e.CallContext == nil || e.CallContext.ID == "" { return map[string]any{"run": runEvent{ runInfo: *r, - Type: event.Type, + Type: e.Type, }} } - call := r.Calls[event.CallContext.ID] - call.CallContext = *event.CallContext - call.Type = event.Type + call := r.Calls[e.CallContext.ID] + call.CallContext = *e.CallContext + call.Type = e.Type - switch event.Type { + switch e.Type { case runner.EventTypeCallStart: - call.Start = event.Time - call.Input = event.Content + call.Start = e.Time + call.Input = e.Content case runner.EventTypeCallSubCalls: - call.setSubCalls(event.ToolSubCalls) + call.setSubCalls(e.ToolSubCalls) case runner.EventTypeCallProgress: - call.setOutput(event.Content) + call.setOutput(e.Content) case runner.EventTypeCallFinish: - call.End = event.Time - call.setOutput(event.Content) + call.End = e.Time + call.setOutput(e.Content) case runner.EventTypeChat: - if event.ChatRequest != nil { - call.LLMRequest = event.ChatRequest + if e.ChatRequest != nil { + call.LLMRequest = e.ChatRequest } - if event.ChatResponse != nil { - call.LLMResponse = event.ChatResponse + if e.ChatResponse != nil { + call.LLMResponse = e.ChatResponse } } - r.Calls[event.CallContext.ID] = call + r.Calls[e.CallContext.ID] = call return map[string]any{"call": call} } @@ -192,3 +199,15 @@ type output struct { Content string `json:"content"` SubCalls map[string]engine.Call `json:"subCalls"` } + +type event struct { + gserver.Event `json:",inline"` + types.Prompt `json:",inline"` +} + +type prompt struct { + types.Prompt `json:",inline"` + ID string `json:"id,omitempty"` + Type runner.EventType `json:"type,omitempty"` + Time time.Time `json:"time,omitempty"` +} diff --git a/pkg/types/prompt.go b/pkg/types/prompt.go new file mode 100644 index 00000000..ec38ea34 --- /dev/null +++ b/pkg/types/prompt.go @@ -0,0 +1,9 @@ +package types + +const PromptURLEnvVar = "GPTSCRIPT_PROMPT_URL" + +type Prompt struct { + Message string `json:"message,omitempty"` + Fields []string `json:"fields,omitempty"` + Sensitive bool `json:"sensitive,omitempty"` +}