diff --git a/.vscode/launch.json b/.vscode/launch.json index cc84991c..669016b3 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -15,12 +15,12 @@ ] }, { - "name": "Launch Server", + "name": "Clicky Serves", "type": "go", "request": "launch", "mode": "debug", "program": "main.go", - "args": ["--server"] + "args": ["--debug", "--listen-address", "127.0.0.1:63774", "sys.sdkserver"] } ] } diff --git a/go.mod b/go.mod index bc80c47e..15f88d5f 100644 --- a/go.mod +++ b/go.mod @@ -15,7 +15,7 @@ require ( github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 github.com/google/uuid v1.6.0 github.com/gptscript-ai/broadcaster v0.0.0-20240625175512-c43682019b86 - github.com/gptscript-ai/chat-completion-client v0.0.0-20250128181713-57857b74f9f1 + github.com/gptscript-ai/chat-completion-client v0.0.0-20250224164718-139cb4507b1d github.com/gptscript-ai/cmd v0.0.0-20240802230653-326b7baf6fcb github.com/gptscript-ai/go-gptscript v0.9.6-0.20250204133419-744b25b84a61 github.com/gptscript-ai/tui v0.0.0-20250204145344-33cd15de4cee diff --git a/go.sum b/go.sum index 7ed757bd..07d8d500 100644 --- a/go.sum +++ b/go.sum @@ -197,8 +197,8 @@ github.com/gorilla/css v1.0.0 h1:BQqNyPTi50JCFMTw/b67hByjMVXZRwGha6wxVGkeihY= github.com/gorilla/css v1.0.0/go.mod h1:Dn721qIggHpt4+EFCcTLTU/vk5ySda2ReITrtgBl60c= github.com/gptscript-ai/broadcaster v0.0.0-20240625175512-c43682019b86 h1:m9yLtIEd0z1ia8qFjq3u0Ozb6QKwidyL856JLJp6nbA= github.com/gptscript-ai/broadcaster v0.0.0-20240625175512-c43682019b86/go.mod h1:lK3K5EZx4dyT24UG3yCt0wmspkYqrj4D/8kxdN3relk= -github.com/gptscript-ai/chat-completion-client v0.0.0-20250128181713-57857b74f9f1 h1:D8VmhL68Fm6YI7fue4wkzd1TqODn//LtcJtPvWk8BQ8= -github.com/gptscript-ai/chat-completion-client v0.0.0-20250128181713-57857b74f9f1/go.mod h1:7P/o6/IWa1KqsntVf68hSnLKuu3+xuqm6lYhch1w4jo= +github.com/gptscript-ai/chat-completion-client v0.0.0-20250224164718-139cb4507b1d h1:p5uqZufDIMQzAALblZFkr8fwbnZbFXbBCR1ZMAFylXk= +github.com/gptscript-ai/chat-completion-client v0.0.0-20250224164718-139cb4507b1d/go.mod h1:7P/o6/IWa1KqsntVf68hSnLKuu3+xuqm6lYhch1w4jo= github.com/gptscript-ai/cmd v0.0.0-20240802230653-326b7baf6fcb h1:ky2J2CzBOskC7Jgm2VJAQi2x3p7FVGa+2/PcywkFJuc= github.com/gptscript-ai/cmd v0.0.0-20240802230653-326b7baf6fcb/go.mod h1:DJAo1xTht1LDkNYFNydVjTHd576TC7MlpsVRl3oloVw= github.com/gptscript-ai/go-gptscript v0.9.6-0.20250204133419-744b25b84a61 h1:QxLjsLOYlsVLPwuRkP0Q8EcAoZT1s8vU2ZBSX0+R6CI= diff --git a/pkg/builtin/builtin.go b/pkg/builtin/builtin.go index ccbda66a..d14fe7c7 100644 --- a/pkg/builtin/builtin.go +++ b/pkg/builtin/builtin.go @@ -269,18 +269,14 @@ func ListTools() (result []types.Tool) { sort.Strings(keys) for _, key := range keys { - t, _ := Builtin(key) + t, _ := DefaultModel(key, "") result = append(result, t) } return } -func Builtin(name string) (types.Tool, bool) { - return BuiltinWithDefaultModel(name, "") -} - -func BuiltinWithDefaultModel(name, defaultModel string) (types.Tool, bool) { +func DefaultModel(name, defaultModel string) (types.Tool, bool) { // Legacy syntax not used anymore name = strings.TrimSuffix(name, "?") t, ok := tools[name] @@ -332,7 +328,7 @@ func SysFind(_ context.Context, _ []string, input string, _ chan<- string) (stri return strings.Join(result, "\n"), nil } -func SysExec(_ context.Context, env []string, input string, progress chan<- string) (string, error) { +func SysExec(ctx context.Context, env []string, input string, progress chan<- string) (string, error) { var params struct { Command string `json:"command,omitempty"` Directory string `json:"directory,omitempty"` @@ -345,14 +341,20 @@ func SysExec(_ context.Context, env []string, input string, progress chan<- stri params.Directory = "." } + commandCtx, _ := engine.FromContext(ctx) + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + commandCtx.OnUserCancel(ctx, cancel) + log.Debugf("Running %s in %s", params.Command, params.Directory) var cmd *exec.Cmd - if runtime.GOOS == "windows" { - cmd = exec.Command("cmd.exe", "/c", params.Command) + cmd = exec.CommandContext(ctx, "cmd.exe", "/c", params.Command) } else { - cmd = exec.Command("/bin/sh", "-c", params.Command) + cmd = exec.CommandContext(ctx, "/bin/sh", "-c", params.Command) } var ( @@ -371,7 +373,8 @@ func SysExec(_ context.Context, env []string, input string, progress chan<- stri cmd.Dir = params.Directory cmd.Stdout = combined cmd.Stderr = combined - if err := cmd.Run(); err != nil { + if err := cmd.Run(); err != nil && (ctx.Err() == nil || commandCtx.Ctx.Err() != nil) { + // If the command failed and the context hasn't been canceled, then return the error. return fmt.Sprintf("ERROR: %s\nOUTPUT:\n%s", err, &out), nil } return out.String(), nil @@ -420,7 +423,6 @@ func getWorkspaceEnvFileContents(envs []string) ([]string, error) { } return envContents, nil - } func getWorkspaceDir(envs []string) (string, error) { @@ -665,6 +667,7 @@ func DiscardProgress() (progress chan<- string, closeFunc func()) { ch := make(chan string) go func() { for range ch { + continue } }() return ch, func() { diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go index 90e8ee10..e5b4494e 100644 --- a/pkg/cache/cache.go +++ b/pkg/cache/cache.go @@ -105,6 +105,13 @@ func (c *Client) Store(ctx context.Context, key, value any) error { return nil } + select { + // If the context has been canceled, then don't try to save. + case <-ctx.Done(): + return nil + default: + } + if c.noop || IsNoCache(ctx) { keyValue, err := c.cacheKey(key) if err == nil { diff --git a/pkg/chat/chat.go b/pkg/chat/chat.go index 1e1fe63f..e36f107b 100644 --- a/pkg/chat/chat.go +++ b/pkg/chat/chat.go @@ -17,7 +17,7 @@ type Prompter interface { } type Chatter interface { - Chat(ctx context.Context, prevState runner.ChatState, prg types.Program, env []string, input string) (resp runner.ChatResponse, err error) + Chat(ctx context.Context, prevState runner.ChatState, prg types.Program, env []string, input string, opts runner.RunOptions) (resp runner.ChatResponse, err error) } type GetProgram func() (types.Program, error) @@ -74,7 +74,7 @@ func Start(ctx context.Context, prevState runner.ChatState, chatter Chatter, prg } } - resp, err = chatter.Chat(ctx, prevState, prog, env, input) + resp, err = chatter.Chat(ctx, prevState, prog, env, input, runner.RunOptions{}) if err != nil { return err } diff --git a/pkg/cli/eval.go b/pkg/cli/eval.go index c649a505..4afdf112 100644 --- a/pkg/cli/eval.go +++ b/pkg/cli/eval.go @@ -10,6 +10,7 @@ import ( "github.com/gptscript-ai/gptscript/pkg/gptscript" "github.com/gptscript-ai/gptscript/pkg/input" "github.com/gptscript-ai/gptscript/pkg/loader" + "github.com/gptscript-ai/gptscript/pkg/runner" "github.com/gptscript-ai/gptscript/pkg/types" "github.com/spf13/cobra" ) @@ -56,13 +57,13 @@ func (e *Eval) Run(cmd *cobra.Command, args []string) error { return err } - runner, err := gptscript.New(cmd.Context(), opts) + g, err := gptscript.New(cmd.Context(), opts) if err != nil { return err } prg, err := loader.ProgramFromSource(cmd.Context(), tool.String(), "", loader.Options{ - Cache: runner.Cache, + Cache: g.Cache, }) if err != nil { return err @@ -74,14 +75,14 @@ func (e *Eval) Run(cmd *cobra.Command, args []string) error { } if e.Chat { - return chat.Start(cmd.Context(), nil, runner, func() (types.Program, error) { + return chat.Start(cmd.Context(), nil, g, func() (types.Program, error) { return loader.ProgramFromSource(cmd.Context(), tool.String(), "", loader.Options{ - Cache: runner.Cache, + Cache: g.Cache, }) }, os.Environ(), toolInput, "") } - toolOutput, err := runner.Run(cmd.Context(), prg, opts.Env, toolInput) + toolOutput, err := g.Run(cmd.Context(), prg, opts.Env, toolInput, runner.RunOptions{}) if err != nil { return err } diff --git a/pkg/cli/gptscript.go b/pkg/cli/gptscript.go index 4bd04509..16f9152d 100644 --- a/pkg/cli/gptscript.go +++ b/pkg/cli/gptscript.go @@ -469,7 +469,7 @@ func (r *GPTScript) Run(cmd *cobra.Command, args []string) (retErr error) { // This chat in a stateless mode if r.SaveChatStateFile == "-" || r.SaveChatStateFile == "stdout" { - resp, err := gptScript.Chat(cmd.Context(), chatState, prg, gptOpt.Env, toolInput) + resp, err := gptScript.Chat(cmd.Context(), chatState, prg, gptOpt.Env, toolInput, runner.RunOptions{}) if err != nil { return err } @@ -511,7 +511,7 @@ func (r *GPTScript) Run(cmd *cobra.Command, args []string) (retErr error) { gptScript.ExtraEnv = nil } - s, err := gptScript.Run(cmd.Context(), prg, gptOpt.Env, toolInput) + s, err := gptScript.Run(cmd.Context(), prg, gptOpt.Env, toolInput, runner.RunOptions{}) if err != nil { return err } diff --git a/pkg/engine/cmd.go b/pkg/engine/cmd.go index 010c1ace..368b1c98 100644 --- a/pkg/engine/cmd.go +++ b/pkg/engine/cmd.go @@ -119,10 +119,14 @@ func (e *Engine) runCommand(ctx Context, tool types.Tool, input string, toolCate instructions = append(instructions, inputContext.Content) } - var extraEnv = []string{ + extraEnv := []string{ strings.TrimSpace("GPTSCRIPT_CONTEXT=" + strings.Join(instructions, "\n")), } - cmd, stop, err := e.newCommand(ctx.Ctx, extraEnv, tool, input, true) + + commandCtx, cancel := context.WithCancel(ctx.Ctx) + defer cancel() + + cmd, stop, err := e.newCommand(commandCtx, extraEnv, tool, input, true) if err != nil { if toolCategory == NoCategory && ctx.Parent != nil { return fmt.Sprintf("ERROR: got (%v) while parsing command", err), nil @@ -155,18 +159,22 @@ func (e *Engine) runCommand(ctx Context, tool types.Tool, input string, toolCate cmd.Stdout = io.MultiWriter(stdout, stdoutAndErr, progressOut) cmd.Stderr = io.MultiWriter(stdoutAndErr, progressOut, os.Stderr) result = stdout + defer func() { + combinedOutput = stdoutAndErr.String() + }() + + ctx.OnUserCancel(commandCtx, cancel) - if err := cmd.Run(); err != nil { + if err := cmd.Run(); err != nil && (commandCtx.Err() == nil || ctx.Ctx.Err() != nil) { + // If the command failed and the context hasn't been canceled, then return the error. if toolCategory == NoCategory && ctx.Parent != nil { // If this is a sub-call, then don't return the error; return the error as a message so that the LLM can retry. return fmt.Sprintf("ERROR: got (%v) while running tool, OUTPUT: %s", err, stdoutAndErr), nil } log.Errorf("failed to run tool [%s] cmd %v: %v", tool.Parameters.Name, cmd.Args, err) - combinedOutput = stdoutAndErr.String() return "", fmt.Errorf("ERROR: %s: %w", stdoutAndErr, err) } - combinedOutput = stdoutAndErr.String() return result.String(), IsChatFinishMessage(result.String()) } diff --git a/pkg/engine/daemon.go b/pkg/engine/daemon.go index b7877da3..6f991be0 100644 --- a/pkg/engine/daemon.go +++ b/pkg/engine/daemon.go @@ -229,7 +229,7 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) { return url, fmt.Errorf("timeout waiting for 200 response from GET %s", url) } -func (e *Engine) runDaemon(ctx context.Context, prg *types.Program, tool types.Tool, input string) (cmdRet *Return, cmdErr error) { +func (e *Engine) runDaemon(ctx Context, tool types.Tool, input string) (cmdRet *Return, cmdErr error) { url, err := e.startDaemon(tool) if err != nil { return nil, err @@ -238,5 +238,5 @@ func (e *Engine) runDaemon(ctx context.Context, prg *types.Program, tool types.T tool.Instructions = strings.Join(append([]string{ types.CommandPrefix + url, }, strings.Split(tool.Instructions, "\n")[1:]...), "\n") - return e.runHTTP(ctx, prg, tool, input) + return e.runHTTP(ctx, tool, input) } diff --git a/pkg/engine/engine.go b/pkg/engine/engine.go index 7b0d86d0..778b1e7e 100644 --- a/pkg/engine/engine.go +++ b/pkg/engine/engine.go @@ -92,7 +92,8 @@ type Context struct { Engine *Engine Program *types.Program // Input is saved only so that we can render display text, don't use otherwise - Input string + Input string + userCancel <-chan struct{} } type ChatHistory struct { @@ -188,6 +189,18 @@ func (c *Context) MarshalJSON() ([]byte, error) { return json.Marshal(c.GetCallContext()) } +func (c *Context) OnUserCancel(ctx context.Context, cancel func()) { + go func() { + select { + case <-ctx.Done(): + // If the context is canceled, then nothing to do. + case <-c.userCancel: + // If the user is requesting a cancel, then cancel the context. + cancel() + } + }() +} + type toolCategoryKey struct{} func WithToolCategory(ctx context.Context, toolCategory ToolCategory) context.Context { @@ -199,7 +212,7 @@ func ToolCategoryFromContext(ctx context.Context) ToolCategory { return category } -func NewContext(ctx context.Context, prg *types.Program, input string) (Context, error) { +func NewContext(ctx context.Context, prg *types.Program, input string, userCancel <-chan struct{}) (Context, error) { category := ToolCategoryFromContext(ctx) callCtx := Context{ @@ -208,9 +221,10 @@ func NewContext(ctx context.Context, prg *types.Program, input string) (Context, Tool: prg.ToolSet[prg.EntryToolID], ToolCategory: category, }, - Ctx: ctx, - Program: prg, - Input: input, + Ctx: ctx, + Program: prg, + Input: input, + userCancel: userCancel, } agentGroup, err := callCtx.Tool.GetToolsByType(prg, types.ToolTypeAgent) @@ -251,6 +265,7 @@ func (c *Context) SubCallContext(ctx context.Context, input, toolID, callID stri Program: c.Program, CurrentReturn: c.CurrentReturn, Input: input, + userCancel: c.userCancel, }, nil } @@ -292,32 +307,37 @@ func populateMessageParams(ctx Context, completion *types.CompletionRequest, too func (e *Engine) runCommandTools(ctx Context, tool types.Tool, input string) (*Return, error) { if tool.IsHTTP() { - return e.runHTTP(ctx.Ctx, ctx.Program, tool, input) + return e.runHTTP(ctx, tool, input) } else if tool.IsDaemon() { - return e.runDaemon(ctx.Ctx, ctx.Program, tool, input) + return e.runDaemon(ctx, tool, input) } else if tool.IsOpenAPI() { - return e.runOpenAPI(tool, input) + return e.runOpenAPI(ctx, tool, input) } else if tool.IsEcho() { return e.runEcho(tool) } else if tool.IsCall() { return e.runCall(ctx, tool, input) } s, err := e.runCommand(ctx, tool, input, ctx.ToolCategory) - if err != nil { - return nil, err - } return &Return{ Result: &s, - }, nil + }, err } -func (e *Engine) Start(ctx Context, input string) (ret *Return, _ error) { +func (e *Engine) Start(ctx Context, input string) (ret *Return, err error) { tool := ctx.Tool defer func() { if ret != nil && ret.State != nil { ret.State.Input = input } + select { + case <-ctx.userCancel: + if ret.Result == nil { + ret.Result = new(string) + } + *ret.Result += "\n\nABORTED BY USER" + default: + } }() if tool.IsCommand() { @@ -344,7 +364,7 @@ func (e *Engine) Start(ctx Context, input string) (ret *Return, _ error) { }) } - return e.complete(ctx.Ctx, &State{ + return e.complete(ctx, &State{ Completion: completion, }) } @@ -376,7 +396,7 @@ func addUpdateSystem(ctx Context, tool types.Tool, msgs []types.CompletionMessag return append([]types.CompletionMessage{msg}, msgs...) } -func (e *Engine) complete(ctx context.Context, state *State) (*Return, error) { +func (e *Engine) complete(ctx Context, state *State) (*Return, error) { var ( progress = make(chan types.CompletionStatus) ret = Return{ @@ -429,7 +449,7 @@ func (e *Engine) complete(ctx context.Context, state *State) (*Return, error) { return &ret, nil } - resp, err := e.Model.Call(ctx, state.Completion, e.Env, progress) + resp, err := e.Model.Call(ctx.WrappedContext(e), state.Completion, e.Env, progress) if err != nil { return nil, fmt.Errorf("failed calling model for completion: %w", err) } @@ -474,7 +494,17 @@ func (e *Engine) complete(ctx context.Context, state *State) (*Return, error) { return &ret, nil } -func (e *Engine) Continue(ctx Context, state *State, results ...CallResult) (*Return, error) { +func (e *Engine) Continue(ctx Context, state *State, results ...CallResult) (ret *Return, _ error) { + defer func() { + select { + case <-ctx.userCancel: + if ret.Result == nil { + ret.Result = new(string) + } + *ret.Result += "\n\nABORTED BY USER" + default: + } + }() if ctx.Tool.IsCommand() { var input string if len(results) == 1 { @@ -508,7 +538,7 @@ func (e *Engine) Continue(ctx Context, state *State, results ...CallResult) (*Re } } - ret := Return{ + ret = &Return{ State: state, Calls: map[string]Call{}, } @@ -524,7 +554,7 @@ func (e *Engine) Continue(ctx Context, state *State, results ...CallResult) (*Re if len(ret.Calls) > 0 { // Outstanding tool calls still pending - return &ret, nil + return ret, nil } for _, content := range state.Completion.Messages[len(state.Completion.Messages)-1].Content { @@ -559,5 +589,5 @@ func (e *Engine) Continue(ctx Context, state *State, results ...CallResult) (*Re return nil, err } - return e.complete(ctx.Ctx, state) + return e.complete(ctx, state) } diff --git a/pkg/engine/http.go b/pkg/engine/http.go index f301f978..9e59b70a 100644 --- a/pkg/engine/http.go +++ b/pkg/engine/http.go @@ -1,7 +1,6 @@ package engine import ( - "context" "encoding/json" "fmt" "io" @@ -17,7 +16,7 @@ import ( const DaemonURLSuffix = ".daemon.gptscript.local" -func (e *Engine) runHTTP(ctx context.Context, prg *types.Program, tool types.Tool, input string) (cmdRet *Return, cmdErr error) { +func (e *Engine) runHTTP(ctx Context, tool types.Tool, input string) (cmdRet *Return, cmdErr error) { envMap := map[string]string{} for _, env := range appendInputAsEnv(nil, input) { @@ -47,7 +46,7 @@ func (e *Engine) runHTTP(ctx context.Context, prg *types.Program, tool types.Too if !ok || len(referencedToolRefs) != 1 { return nil, fmt.Errorf("invalid reference [%s] to tool [%s] from [%s], missing \"tools: %s\" parameter", toolURL, referencedToolName, tool.Source, referencedToolName) } - referencedTool, ok := prg.ToolSet[referencedToolRefs[0].ToolID] + referencedTool, ok := ctx.Program.ToolSet[referencedToolRefs[0].ToolID] if !ok { return nil, fmt.Errorf("failed to find tool [%s] for [%s]", referencedToolName, parsed.Hostname()) } @@ -81,7 +80,7 @@ func (e *Engine) runHTTP(ctx context.Context, prg *types.Program, tool types.Too input = body } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, toolURL, strings.NewReader(input)) + req, err := http.NewRequestWithContext(ctx.Ctx, http.MethodPost, toolURL, strings.NewReader(input)) if err != nil { return nil, err } @@ -121,6 +120,13 @@ func (e *Engine) runHTTP(ctx context.Context, prg *types.Program, tool types.Too req.Header.Set("Content-Type", "text/plain") } + // If the user canceled the run, then don't make the request. + select { + case <-ctx.userCancel: + return &Return{}, nil + default: + } + resp, err := http.DefaultClient.Do(req) if err != nil { return nil, err diff --git a/pkg/engine/openapi.go b/pkg/engine/openapi.go index a9a1a644..2e79bc38 100644 --- a/pkg/engine/openapi.go +++ b/pkg/engine/openapi.go @@ -145,7 +145,7 @@ func (e *Engine) runOpenAPIRevamp(tool types.Tool, input string) (*Return, error } res = &Return{ - Result: ptr(result), + Result: &result, } } @@ -156,7 +156,7 @@ func (e *Engine) runOpenAPIRevamp(tool types.Tool, input string) (*Return, error // The tool itself will have instructions regarding the HTTP request that needs to be made. // The tools Instructions field will be in the format "#!sys.openapi '{Instructions JSON}'", // where {Instructions JSON} is a JSON string of type OpenAPIInstructions. -func (e *Engine) runOpenAPI(tool types.Tool, input string) (*Return, error) { +func (e *Engine) runOpenAPI(ctx Context, tool types.Tool, input string) (*Return, error) { if os.Getenv("GPTSCRIPT_OPENAPI_REVAMP") == "true" { return e.runOpenAPIRevamp(tool, input) } @@ -266,6 +266,13 @@ func (e *Engine) runOpenAPI(tool types.Tool, input string) (*Return, error) { req.Body = io.NopCloser(&body) } + // If the user canceled the run, then don't make the request. + select { + case <-ctx.userCancel: + return &Return{}, nil + default: + } + // Make the request resp, err := http.DefaultClient.Do(req) if err != nil { diff --git a/pkg/gptscript/gptscript.go b/pkg/gptscript/gptscript.go index 4669e5ab..f92f9324 100644 --- a/pkg/gptscript/gptscript.go +++ b/pkg/gptscript/gptscript.go @@ -244,22 +244,22 @@ func makeAbsolute(path string) (string, error) { return filepath.Abs(path) } -func (g *GPTScript) Chat(ctx context.Context, prevState runner.ChatState, prg types.Program, envs []string, input string) (runner.ChatResponse, error) { +func (g *GPTScript) Chat(ctx context.Context, prevState runner.ChatState, prg types.Program, envs []string, input string, opts runner.RunOptions) (runner.ChatResponse, error) { envs, err := g.getEnv(envs) if err != nil { return runner.ChatResponse{}, err } - return g.Runner.Chat(ctx, prevState, prg, envs, input) + return g.Runner.Chat(ctx, prevState, prg, envs, input, opts) } -func (g *GPTScript) Run(ctx context.Context, prg types.Program, envs []string, input string) (string, error) { +func (g *GPTScript) Run(ctx context.Context, prg types.Program, envs []string, input string, opts runner.RunOptions) (string, error) { envs, err := g.getEnv(envs) if err != nil { return "", err } - return g.Runner.Run(ctx, prg, envs, input) + return g.Runner.Run(ctx, prg, envs, input, opts) } func (g *GPTScript) Close(closeDaemons bool) { @@ -319,7 +319,7 @@ func (s *simpleRunner) Load(ctx context.Context, toolName string) (prg types.Pro } func (s *simpleRunner) Run(ctx context.Context, prg types.Program, input string) (output string, err error) { - return s.runner.Run(ctx, prg, s.env, input) + return s.runner.Run(ctx, prg, s.env, input, runner.RunOptions{}) } type noopModel struct { diff --git a/pkg/loader/loader.go b/pkg/loader/loader.go index 5a907f5b..902d0ed9 100644 --- a/pkg/loader/loader.go +++ b/pkg/loader/loader.go @@ -141,7 +141,7 @@ func loadProgram(data []byte, into *types.Program, targetToolName, defaultModel into.ToolSet = make(map[string]types.Tool, len(ext.ToolSet)) for k, v := range ext.ToolSet { - if builtinTool, ok := builtin.BuiltinWithDefaultModel(k, defaultModel); ok { + if builtinTool, ok := builtin.DefaultModel(k, defaultModel); ok { v = builtinTool } into.ToolSet[k] = v @@ -471,7 +471,7 @@ func Program(ctx context.Context, name, subToolName string, opts ...Options) (ty func resolve(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, name, subTool, defaultModel string) ([]types.Tool, error) { if subTool == "" { - t, ok := builtin.BuiltinWithDefaultModel(name, defaultModel) + t, ok := builtin.DefaultModel(name, defaultModel) if ok { prg.ToolSet[t.ID] = t return []types.Tool{t}, nil diff --git a/pkg/openai/client.go b/pkg/openai/client.go index db911962..ec6b2668 100644 --- a/pkg/openai/client.go +++ b/pkg/openai/client.go @@ -15,6 +15,7 @@ import ( "github.com/gptscript-ai/gptscript/pkg/cache" "github.com/gptscript-ai/gptscript/pkg/counter" "github.com/gptscript-ai/gptscript/pkg/credentials" + "github.com/gptscript-ai/gptscript/pkg/engine" "github.com/gptscript-ai/gptscript/pkg/hash" "github.com/gptscript-ai/gptscript/pkg/mvl" "github.com/gptscript-ai/gptscript/pkg/prompt" @@ -583,10 +584,21 @@ func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest, slog.Debug("calling openai", "message", request.Messages) + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + engineCtx, ok := engine.FromContext(ctx) + if ok { + engineCtx.OnUserCancel(ctx, cancel) + } + if !streamResponse { request.StreamOptions = nil resp, err := c.c.CreateChatCompletion(ctx, request, headers, retryOpts...) if err != nil { + if errors.Is(err, context.Canceled) { + err = nil + } return types.CompletionMessage{}, err } return appendMessage(types.CompletionMessage{}, openai.ChatCompletionStreamResponse{ @@ -612,6 +624,9 @@ func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest, stream, err := c.c.CreateChatCompletionStream(ctx, request, headers, retryOpts...) if err != nil { + if errors.Is(err, context.Canceled) { + err = nil + } return types.CompletionMessage{}, err } defer stream.Close() @@ -619,11 +634,12 @@ func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest, var ( partialMessage types.CompletionMessage start = time.Now() - last []string ) for { response, err := stream.Recv() - if err == io.EOF { + if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) { + // If the stream is finished, either because we got an EOF or the context was canceled, + // then we're done. The cache won't save the response if the context was canceled. return partialMessage, c.cache.Store(ctx, c.cacheKey(request), partialMessage) } else if err != nil { return types.CompletionMessage{}, err @@ -631,7 +647,6 @@ func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest, partialMessage = appendMessage(partialMessage, response) if partial != nil { if time.Since(start) > 100*time.Millisecond { - last = last[:0] partial <- types.CompletionStatus{ CompletionID: transactionID, PartialResponse: &partialMessage, diff --git a/pkg/remote/remote.go b/pkg/remote/remote.go index 93f612ef..441a01dd 100644 --- a/pkg/remote/remote.go +++ b/pkg/remote/remote.go @@ -166,7 +166,7 @@ func (c *Client) load(ctx context.Context, toolName string, env ...string) (*ope return nil, err } - url, err := c.runner.Run(engine.WithToolCategory(ctx, engine.ProviderToolCategory), prg.SetBlocking(), c.envs, "") + url, err := c.runner.Run(engine.WithToolCategory(ctx, engine.ProviderToolCategory), prg.SetBlocking(), c.envs, "", runner.RunOptions{}) if err != nil { return nil, err } diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index e2699cf6..df3ef172 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -39,6 +39,10 @@ type Options struct { Authorizer AuthorizerFunc `usage:"-"` } +type RunOptions struct { + UserCancel <-chan struct{} +} + type AuthorizerResponse struct { Accept bool Message string @@ -130,7 +134,7 @@ type ChatResponse struct { type ChatState interface{} -func (r *Runner) Chat(ctx context.Context, prevState ChatState, prg types.Program, env []string, input string) (resp ChatResponse, err error) { +func (r *Runner) Chat(ctx context.Context, prevState ChatState, prg types.Program, env []string, input string, opts RunOptions) (resp ChatResponse, err error) { var state *State defer func() { @@ -167,7 +171,7 @@ func (r *Runner) Chat(ctx context.Context, prevState ChatState, prg types.Progra monitor.Stop(ctx, resp.Content, err) }() - callCtx, err := engine.NewContext(ctx, &prg, input) + callCtx, err := engine.NewContext(ctx, &prg, input, opts.UserCancel) if err != nil { return resp, err } @@ -210,8 +214,8 @@ func (r *Runner) Chat(ctx context.Context, prevState ChatState, prg types.Progra }, nil } -func (r *Runner) Run(ctx context.Context, prg types.Program, env []string, input string) (output string, err error) { - resp, err := r.Chat(ctx, nil, prg, env, input) +func (r *Runner) Run(ctx context.Context, prg types.Program, env []string, input string, opts RunOptions) (output string, err error) { + resp, err := r.Chat(ctx, nil, prg, env, input, opts) if err != nil { return "", err } @@ -651,8 +655,11 @@ func (r *Runner) newDispatcher(ctx context.Context) dispatcher { return newParallelDispatcher(ctx) } -func (r *Runner) subCalls(callCtx engine.Context, monitor Monitor, env []string, state *State, toolCategory engine.ToolCategory) (_ *State, callResults []SubCallResult, _ error) { - var resultLock sync.Mutex +func (r *Runner) subCalls(callCtx engine.Context, monitor Monitor, env []string, state *State, toolCategory engine.ToolCategory) (*State, []SubCallResult, error) { + var ( + resultLock sync.Mutex + callResults []SubCallResult + ) if state.Continuation != nil { callCtx.LastReturn = state.Continuation @@ -666,8 +673,6 @@ func (r *Runner) subCalls(callCtx engine.Context, monitor Monitor, env []string, for _, subCall := range state.SubCalls { if subCall.CallID == state.SubCallID { found = true - subState := *subCall.State - subState.ResumeInput = state.ResumeInput result, err := r.subCallResume(callCtx.Ctx, callCtx, monitor, env, subCall.ToolID, subCall.CallID, subCall.State.WithResumeInput(state.ResumeInput), toolCategory) if err != nil { return nil, nil, err diff --git a/pkg/sdkserver/datasets.go b/pkg/sdkserver/datasets.go index c4178801..b923490b 100644 --- a/pkg/sdkserver/datasets.go +++ b/pkg/sdkserver/datasets.go @@ -9,6 +9,7 @@ import ( gcontext "github.com/gptscript-ai/gptscript/pkg/context" "github.com/gptscript-ai/gptscript/pkg/gptscript" "github.com/gptscript-ai/gptscript/pkg/loader" + "github.com/gptscript-ai/gptscript/pkg/runner" ) func (s *server) getDatasetTool(req datasetRequest) string { @@ -79,7 +80,7 @@ func (s *server) listDatasets(w http.ResponseWriter, r *http.Request) { return } - result, err := g.Run(r.Context(), prg, s.getServerToolsEnv(req.Env), req.Input) + result, err := g.Run(r.Context(), prg, s.getServerToolsEnv(req.Env), req.Input, runner.RunOptions{}) if err != nil { writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err)) return @@ -147,7 +148,7 @@ func (s *server) addDatasetElements(w http.ResponseWriter, r *http.Request) { return } - result, err := g.Run(r.Context(), prg, s.getServerToolsEnv(req.Env), req.Input) + result, err := g.Run(r.Context(), prg, s.getServerToolsEnv(req.Env), req.Input, runner.RunOptions{}) if err != nil { writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err)) return @@ -207,7 +208,7 @@ func (s *server) listDatasetElements(w http.ResponseWriter, r *http.Request) { return } - result, err := g.Run(r.Context(), prg, s.getServerToolsEnv(req.Env), req.Input) + result, err := g.Run(r.Context(), prg, s.getServerToolsEnv(req.Env), req.Input, runner.RunOptions{}) if err != nil { writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err)) return @@ -270,7 +271,7 @@ func (s *server) getDatasetElement(w http.ResponseWriter, r *http.Request) { return } - result, err := g.Run(r.Context(), prg, s.getServerToolsEnv(req.Env), req.Input) + result, err := g.Run(r.Context(), prg, s.getServerToolsEnv(req.Env), req.Input, runner.RunOptions{}) if err != nil { writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err)) return diff --git a/pkg/sdkserver/routes.go b/pkg/sdkserver/routes.go index 1431b73b..d520e97a 100644 --- a/pkg/sdkserver/routes.go +++ b/pkg/sdkserver/routes.go @@ -36,6 +36,9 @@ type server struct { lock sync.RWMutex waitingToConfirm map[string]chan runner.AuthorizerResponse waitingToPrompt map[string]chan map[string]string + + runningLock sync.Mutex + running map[string]chan struct{} } func (s *server) addRoutes(mux *http.ServeMux) { @@ -52,6 +55,7 @@ func (s *server) addRoutes(mux *http.ServeMux) { mux.HandleFunc("POST /run", s.execHandler) mux.HandleFunc("POST /evaluate", s.execHandler) + mux.HandleFunc("POST /abort/{run_id}", s.abort) mux.HandleFunc("POST /load", s.load) @@ -164,6 +168,17 @@ func (s *server) execHandler(w http.ResponseWriter, r *http.Request) { ctx := gserver.ContextWithNewRunID(r.Context()) runID := gserver.RunIDFromContext(ctx) + cancel := make(chan struct{}) + s.runningLock.Lock() + s.running[runID] = cancel + s.runningLock.Unlock() + + defer func() { + s.runningLock.Lock() + delete(s.running, runID) + s.runningLock.Unlock() + close(cancel) + }() // Ensure chat state is not empty. if reqObject.ChatState == "" { @@ -214,7 +229,30 @@ func (s *server) execHandler(w http.ResponseWriter, r *http.Request) { opts.Runner.Authorizer = s.authorize } - s.execAndStream(ctx, programLoader, logger, w, opts, reqObject.ChatState, reqObject.Input, reqObject.SubTool, def) + s.execAndStream(ctx, programLoader, logger, w, opts, reqObject.ChatState, reqObject.Input, reqObject.SubTool, def, cancel) +} + +// abort will abort the run in a way such that the chat state will be returned. +func (s *server) abort(w http.ResponseWriter, r *http.Request) { + logger := gcontext.GetLogger(r.Context()) + runID := r.PathValue("run_id") + if runID == "" { + writeError(logger, w, http.StatusBadRequest, fmt.Errorf("run_id is required")) + return + } + + s.runningLock.Lock() + cancel := s.running[runID] + delete(s.running, runID) + s.runningLock.Unlock() + + if cancel == nil { + writeResponse(logger, w, "run not found") + return + } + + close(cancel) + writeResponse(logger, w, "run aborted") } // load will load the file and return the corresponding Program. diff --git a/pkg/sdkserver/run.go b/pkg/sdkserver/run.go index 1c0f7c4b..93c9996b 100644 --- a/pkg/sdkserver/run.go +++ b/pkg/sdkserver/run.go @@ -24,7 +24,7 @@ func loaderWithLocation(f loaderFunc, loc string) loaderFunc { } } -func (s *server) execAndStream(ctx context.Context, programLoader loaderFunc, logger mvl.Logger, w http.ResponseWriter, opts gptscript.Options, chatState, input, subTool string, toolDef fmt.Stringer) { +func (s *server) execAndStream(ctx context.Context, programLoader loaderFunc, logger mvl.Logger, w http.ResponseWriter, opts gptscript.Options, chatState, input, subTool string, toolDef fmt.Stringer, cancel <-chan struct{}) { g, err := gptscript.New(ctx, s.gptscriptOpts, opts) if err != nil { writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to initialize gptscript: %w", err)) @@ -48,7 +48,9 @@ func (s *server) execAndStream(ctx context.Context, programLoader loaderFunc, lo defer events.Close() go func() { - run, err := g.Chat(ctx, chatState, prg, opts.Env, input) + run, err := g.Chat(ctx, chatState, prg, opts.Env, input, runner.RunOptions{ + UserCancel: cancel, + }) if err != nil { errChan <- err } else { @@ -58,21 +60,19 @@ func (s *server) execAndStream(ctx context.Context, programLoader loaderFunc, lo close(programOutput) }() - processEventStreamOutput(ctx, logger, w, gserver.RunIDFromContext(ctx), events.C, programOutput, errChan) + processEventStreamOutput(logger, w, gserver.RunIDFromContext(ctx), events.C, programOutput, errChan) } // processEventStreamOutput will stream the events of the tool to the response as server sent events. // If an error occurs, then an event with the error will also be sent. -func processEventStreamOutput(ctx context.Context, logger mvl.Logger, w http.ResponseWriter, id string, events <-chan event, output <-chan runner.ChatResponse, errChan chan error) { +func processEventStreamOutput(logger mvl.Logger, w http.ResponseWriter, id string, events <-chan event, output <-chan runner.ChatResponse, errChan chan error) { run := newRun(id) setStreamingHeaders(w) - streamEvents(ctx, logger, w, run, events) + streamEvents(logger, w, run, events) - var out runner.ChatResponse select { - case <-ctx.Done(): - case out = <-output: + case out := <-output: run.processStdout(out) writeServerSentEvent(logger, w, map[string]any{ @@ -85,47 +85,27 @@ func processEventStreamOutput(ctx context.Context, logger mvl.Logger, w http.Res } // Now that we have received all events, send the DONE event. - _, err := w.Write([]byte("data: [DONE]\n\n")) - if err == nil { - if f, ok := w.(http.Flusher); ok { - f.Flush() - } - } + writeServerSentEvent(logger, w, "[DONE]") logger.Debugf("wrote DONE event") } // streamEvents will stream the events of the tool to the response as server sent events. -func streamEvents(ctx context.Context, logger mvl.Logger, w http.ResponseWriter, run *runInfo, events <-chan event) { +func streamEvents(logger mvl.Logger, w http.ResponseWriter, run *runInfo, events <-chan event) { logger.Debugf("receiving events") - for { - select { - case <-ctx.Done(): - logger.Debugf("context canceled while receiving events") - go func() { - //nolint:revive - for range events { - } - }() - return - case e, ok := <-events: - if ok && e.RunID != run.ID { - continue - } - - if !ok { - logger.Debugf("done receiving events") - return - } - - writeServerSentEvent(logger, w, run.process(e)) - - if e.Type == runner.EventTypeRunFinish { - logger.Debugf("finished receiving events") - return - } + for e := range events { + if e.RunID != run.ID { + continue + } + + writeServerSentEvent(logger, w, run.process(e)) + + if e.Type == runner.EventTypeRunFinish { + break } } + + logger.Debugf("done receiving events") } func writeResponse(logger mvl.Logger, w http.ResponseWriter, v any) { diff --git a/pkg/sdkserver/server.go b/pkg/sdkserver/server.go index 79d6daf7..41066d30 100644 --- a/pkg/sdkserver/server.go +++ b/pkg/sdkserver/server.go @@ -118,6 +118,7 @@ func run(ctx context.Context, listener net.Listener, opts Options) error { runtimeManager: runtimes.Default(opts.Options.Cache.CacheDir, opts.SystemToolsDir), waitingToConfirm: make(map[string]chan runner.AuthorizerResponse), waitingToPrompt: make(map[string]chan map[string]string), + running: make(map[string]chan struct{}), } defer s.close() diff --git a/pkg/sdkserver/workspaces.go b/pkg/sdkserver/workspaces.go index 162853f7..f1846051 100644 --- a/pkg/sdkserver/workspaces.go +++ b/pkg/sdkserver/workspaces.go @@ -7,6 +7,7 @@ import ( gcontext "github.com/gptscript-ai/gptscript/pkg/context" "github.com/gptscript-ai/gptscript/pkg/loader" + "github.com/gptscript-ai/gptscript/pkg/runner" ) func (s *server) getWorkspaceTool(req workspaceCommonRequest) string { @@ -65,6 +66,7 @@ func (s *server) createWorkspace(w http.ResponseWriter, r *http.Request) { prg, s.getServerToolsEnv(reqObject.Env), string(b), + runner.RunOptions{}, ) if err != nil { writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err)) @@ -100,6 +102,7 @@ func (s *server) deleteWorkspace(w http.ResponseWriter, r *http.Request) { `{"workspace_id": "%s"}`, reqObject.ID, ), + runner.RunOptions{}, ) if err != nil { writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err)) @@ -137,6 +140,7 @@ func (s *server) listWorkspaceContents(w http.ResponseWriter, r *http.Request) { `{"workspace_id": "%s", "ls_prefix": "%s"}`, reqObject.ID, reqObject.Prefix, ), + runner.RunOptions{}, ) if err != nil { writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err)) @@ -173,6 +177,7 @@ func (s *server) removeAllWithPrefixInWorkspace(w http.ResponseWriter, r *http.R `{"workspace_id": "%s", "prefix": "%s"}`, reqObject.ID, reqObject.Prefix, ), + runner.RunOptions{}, ) if err != nil { writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err)) @@ -212,6 +217,7 @@ func (s *server) writeFileInWorkspace(w http.ResponseWriter, r *http.Request) { `{"workspace_id": "%s", "file_path": "%s", "body": "%s", "create_revision": %t, "latest_revision_id": "%s"}`, reqObject.ID, reqObject.FilePath, reqObject.Contents, reqObject.CreateRevision == nil || *reqObject.CreateRevision, reqObject.LatestRevisionID, ), + runner.RunOptions{}, ) if err != nil { writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err)) @@ -248,6 +254,7 @@ func (s *server) removeFileInWorkspace(w http.ResponseWriter, r *http.Request) { `{"workspace_id": "%s", "file_path": "%s"}`, reqObject.ID, reqObject.FilePath, ), + runner.RunOptions{}, ) if err != nil { writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err)) @@ -284,6 +291,7 @@ func (s *server) readFileInWorkspace(w http.ResponseWriter, r *http.Request) { `{"workspace_id": "%s", "file_path": "%s"}`, reqObject.ID, reqObject.FilePath, ), + runner.RunOptions{}, ) if err != nil { writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err)) @@ -315,6 +323,7 @@ func (s *server) readFileWithRevisionInWorkspace(w http.ResponseWriter, r *http. `{"workspace_id": "%s", "file_path": "%s", "with_latest_revision_id": "true"}`, reqObject.ID, reqObject.FilePath, ), + runner.RunOptions{}, ) if err != nil { writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err)) @@ -352,6 +361,7 @@ func (s *server) statFileInWorkspace(w http.ResponseWriter, r *http.Request) { `{"workspace_id": "%s", "file_path": "%s", "with_latest_revision_id": "%v"}`, reqObject.ID, reqObject.FilePath, reqObject.WithLatestRevisionID, ), + runner.RunOptions{}, ) if err != nil { writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err)) @@ -387,6 +397,7 @@ func (s *server) listRevisions(w http.ResponseWriter, r *http.Request) { `{"workspace_id": "%s", "file_path": "%s"}`, reqObject.ID, reqObject.FilePath, ), + runner.RunOptions{}, ) if err != nil { writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err)) @@ -424,6 +435,7 @@ func (s *server) getRevisionForFileInWorkspace(w http.ResponseWriter, r *http.Re `{"workspace_id": "%s", "file_path": "%s", "revision_id": "%s"}`, reqObject.ID, reqObject.FilePath, reqObject.RevisionID, ), + runner.RunOptions{}, ) if err != nil { writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err)) @@ -461,6 +473,7 @@ func (s *server) deleteRevisionForFileInWorkspace(w http.ResponseWriter, r *http `{"workspace_id": "%s", "file_path": "%s", "revision_id": "%s"}`, reqObject.ID, reqObject.FilePath, reqObject.RevisionID, ), + runner.RunOptions{}, ) if err != nil { writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err)) diff --git a/pkg/tests/runner2_test.go b/pkg/tests/runner2_test.go index 165f86c8..f5de8e10 100644 --- a/pkg/tests/runner2_test.go +++ b/pkg/tests/runner2_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/gptscript-ai/gptscript/pkg/loader" + "github.com/gptscript-ai/gptscript/pkg/runner" "github.com/gptscript-ai/gptscript/pkg/tests/tester" "github.com/hexops/autogold/v2" "github.com/stretchr/testify/require" @@ -28,10 +29,10 @@ echo This is the input: ${GPTSCRIPT_INPUT} `, "") require.NoError(t, err) - resp, err := r.Chat(context.Background(), nil, prg, nil, "input 1") + resp, err := r.Chat(context.Background(), nil, prg, nil, "input 1", runner.RunOptions{}) r.AssertStep(t, resp, err) - resp, err = r.Chat(context.Background(), resp.State, prg, nil, "input 2") + resp, err = r.Chat(context.Background(), resp.State, prg, nil, "input 2", runner.RunOptions{}) r.AssertStep(t, resp, err) } @@ -54,7 +55,7 @@ name: realcontext Yo dawg`, "") require.NoError(t, err) - resp, err := r.Chat(context.Background(), nil, prg, nil, "input 1") + resp, err := r.Chat(context.Background(), nil, prg, nil, "input 1", runner.RunOptions{}) r.AssertStep(t, resp, err) } @@ -76,9 +77,9 @@ echo ${FOO}:${INPUT} `, "") require.NoError(t, err) - resp, err := r.Chat(context.Background(), nil, prg, nil, `{"foo":"123"}`) + resp, err := r.Chat(context.Background(), nil, prg, nil, `{"foo":"123"}`, runner.RunOptions{}) r.AssertStep(t, resp, err) - resp, err = r.Chat(context.Background(), nil, prg, nil, `"foo":"123"}`) + resp, err = r.Chat(context.Background(), nil, prg, nil, `"foo":"123"}`, runner.RunOptions{}) r.AssertStep(t, resp, err) } @@ -110,7 +111,7 @@ echo '{"env": {"CRED2": "that also worked"}}' `, "") require.NoError(t, err) - resp, err := r.Chat(context.Background(), nil, prg, nil, "") + resp, err := r.Chat(context.Background(), nil, prg, nil, "", runner.RunOptions{}) r.AssertStep(t, resp, err) } @@ -144,7 +145,7 @@ echo "${GPTSCRIPT_INPUT}" `, "") require.NoError(t, err) - resp, err := r.Chat(context.Background(), nil, prg, nil, `{"foo":"baz", "start": true}`) + resp, err := r.Chat(context.Background(), nil, prg, nil, `{"foo":"baz", "start": true}`, runner.RunOptions{}) r.AssertStep(t, resp, err) data := map[string]any{} diff --git a/pkg/tests/runner_test.go b/pkg/tests/runner_test.go index 18871ed6..ce3cebe6 100644 --- a/pkg/tests/runner_test.go +++ b/pkg/tests/runner_test.go @@ -12,6 +12,7 @@ import ( "github.com/gptscript-ai/gptscript/pkg/engine" "github.com/gptscript-ai/gptscript/pkg/loader" + "github.com/gptscript-ai/gptscript/pkg/runner" "github.com/gptscript-ai/gptscript/pkg/tests/tester" "github.com/gptscript-ai/gptscript/pkg/types" "github.com/hexops/autogold/v2" @@ -143,7 +144,7 @@ func TestDualSubChat(t *testing.T) { prg, err := r.Load("") require.NoError(t, err) - resp, err := r.Chat(context.Background(), nil, prg, os.Environ(), "User 1") + resp, err := r.Chat(context.Background(), nil, prg, os.Environ(), "User 1", runner.RunOptions{}) require.NoError(t, err) r.AssertResponded(t) assert.False(t, resp.Done) @@ -157,7 +158,7 @@ func TestDualSubChat(t *testing.T) { }, }) - resp, err = r.Chat(context.Background(), resp.State, prg, os.Environ(), "User 2") + resp, err = r.Chat(context.Background(), resp.State, prg, os.Environ(), "User 2", runner.RunOptions{}) require.NoError(t, err) r.AssertResponded(t) assert.False(t, resp.Done) @@ -168,7 +169,7 @@ func TestDualSubChat(t *testing.T) { Text: "Assistant 3", }) - resp, err = r.Chat(context.Background(), resp.State, prg, os.Environ(), "User 3") + resp, err = r.Chat(context.Background(), resp.State, prg, os.Environ(), "User 3", runner.RunOptions{}) require.NoError(t, err) r.AssertResponded(t) assert.False(t, resp.Done) @@ -184,7 +185,7 @@ func TestDualSubChat(t *testing.T) { Text: "And we're done", }) - resp, err = r.Chat(context.Background(), resp.State, prg, os.Environ(), "User 4") + resp, err = r.Chat(context.Background(), resp.State, prg, os.Environ(), "User 4", runner.RunOptions{}) require.NoError(t, err) r.AssertResponded(t) assert.True(t, resp.Done) @@ -213,7 +214,7 @@ func TestContextSubChat(t *testing.T) { prg, err := r.Load("") require.NoError(t, err) - _, err = r.Chat(context.Background(), nil, prg, os.Environ(), "User 1") + _, err = r.Chat(context.Background(), nil, prg, os.Environ(), "User 1", runner.RunOptions{}) autogold.Expect("invalid state: context tool [testdata/TestContextSubChat/test.gpt:subtool] can not result in a continuation").Equal(t, err.Error()) } @@ -232,7 +233,7 @@ func TestSubChat(t *testing.T) { prg, err := r.Load("") require.NoError(t, err) - resp, err := r.Chat(context.Background(), nil, prg, os.Environ(), "Hello") + resp, err := r.Chat(context.Background(), nil, prg, os.Environ(), "Hello", runner.RunOptions{}) require.NoError(t, err) autogold.Expect(`{ @@ -357,7 +358,7 @@ func TestSubChat(t *testing.T) { } }`).Equal(t, toJSONString(t, resp)) - resp, err = r.Chat(context.Background(), resp.State, prg, os.Environ(), "User 1") + resp, err = r.Chat(context.Background(), resp.State, prg, os.Environ(), "User 1", runner.RunOptions{}) require.NoError(t, err) autogold.Expect(`{ @@ -512,7 +513,7 @@ func TestChat(t *testing.T) { prg, err := r.Load("") require.NoError(t, err) - resp, err := r.Chat(context.Background(), nil, prg, os.Environ(), "Hello") + resp, err := r.Chat(context.Background(), nil, prg, os.Environ(), "Hello", runner.RunOptions{}) require.NoError(t, err) autogold.Expect(`{ @@ -564,7 +565,7 @@ func TestChat(t *testing.T) { } }`).Equal(t, toJSONString(t, resp)) - resp, err = r.Chat(context.Background(), resp.State, prg, os.Environ(), "User 1") + resp, err = r.Chat(context.Background(), resp.State, prg, os.Environ(), "User 1", runner.RunOptions{}) require.NoError(t, err) autogold.Expect(`{ @@ -740,7 +741,7 @@ func TestAgentOnly(t *testing.T) { }, }) - resp, err := r.Chat(context.Background(), nil, prg, nil, "Input 1") + resp, err := r.Chat(context.Background(), nil, prg, nil, "Input 1", runner.RunOptions{}) require.NoError(t, err) r.AssertResponded(t) assert.False(t, resp.Done) @@ -767,7 +768,7 @@ func TestAgents(t *testing.T) { }, }) - resp, err := r.Chat(context.Background(), nil, prg, nil, "Input 1") + resp, err := r.Chat(context.Background(), nil, prg, nil, "Input 1", runner.RunOptions{}) require.NoError(t, err) r.AssertResponded(t) assert.False(t, resp.Done) @@ -785,14 +786,14 @@ func TestInput(t *testing.T) { prg, err := r.Load("") require.NoError(t, err) - resp, err := r.Chat(context.Background(), nil, prg, nil, "You're stupid") + resp, err := r.Chat(context.Background(), nil, prg, nil, "You're stupid", runner.RunOptions{}) require.NoError(t, err) r.AssertResponded(t) assert.False(t, resp.Done) autogold.Expect("TEST RESULT CALL: 1").Equal(t, resp.Content) autogold.ExpectFile(t, toJSONString(t, resp), autogold.Name(t.Name()+"/step1")) - resp, err = r.Chat(context.Background(), resp.State, prg, nil, "You're ugly") + resp, err = r.Chat(context.Background(), resp.State, prg, nil, "You're ugly", runner.RunOptions{}) require.NoError(t, err) r.AssertResponded(t) assert.False(t, resp.Done) @@ -813,7 +814,7 @@ func TestOutput(t *testing.T) { prg, err := r.Load("") require.NoError(t, err) - resp, err := r.Chat(context.Background(), nil, prg, nil, "Input 1") + resp, err := r.Chat(context.Background(), nil, prg, nil, "Input 1", runner.RunOptions{}) require.NoError(t, err) r.AssertResponded(t) assert.False(t, resp.Done) @@ -824,7 +825,7 @@ func TestOutput(t *testing.T) { r.RespondWith(tester.Result{ Text: "Response 2", }) - resp, err = r.Chat(context.Background(), resp.State, prg, nil, "Input 2") + resp, err = r.Chat(context.Background(), resp.State, prg, nil, "Input 2", runner.RunOptions{}) require.NoError(t, err) r.AssertResponded(t) assert.False(t, resp.Done) @@ -837,7 +838,7 @@ func TestOutput(t *testing.T) { Message: "Chat Done", }, }) - resp, err = r.Chat(context.Background(), resp.State, prg, nil, "Input 3") + resp, err = r.Chat(context.Background(), resp.State, prg, nil, "Input 3", runner.RunOptions{}) require.NoError(t, err) r.AssertResponded(t) assert.True(t, resp.Done) @@ -885,7 +886,7 @@ func TestSysContext(t *testing.T) { prg, err := r.Load("") require.NoError(t, err) - resp, err := r.Chat(context.Background(), nil, prg, nil, "input 1") + resp, err := r.Chat(context.Background(), nil, prg, nil, "input 1", runner.RunOptions{}) require.NoError(t, err) r.AssertResponded(t) assert.False(t, resp.Done) @@ -977,7 +978,7 @@ tools: sys.ls, sys.read, sys.write `, "") require.NoError(t, err) - resp, err := r.Chat(context.Background(), nil, prg, nil, "input 1") + resp, err := r.Chat(context.Background(), nil, prg, nil, "input 1", runner.RunOptions{}) require.NoError(t, err) r.AssertResponded(t) assert.False(t, resp.Done) @@ -991,7 +992,7 @@ tools: sys.ls, sys.write `, "") require.NoError(t, err) - resp, err = r.Chat(context.Background(), resp.State, prg, nil, "input 2") + resp, err = r.Chat(context.Background(), resp.State, prg, nil, "input 2", runner.RunOptions{}) require.NoError(t, err) r.AssertResponded(t) assert.False(t, resp.Done) diff --git a/pkg/tests/tester/runner.go b/pkg/tests/tester/runner.go index 44ec4e3c..f59c0b14 100644 --- a/pkg/tests/tester/runner.go +++ b/pkg/tests/tester/runner.go @@ -159,7 +159,7 @@ func (r *Runner) Run(script, input string) (string, error) { return "", err } - return r.Runner.Run(context.Background(), prg, os.Environ(), input) + return r.Runner.Run(context.Background(), prg, os.Environ(), input, runner.RunOptions{}) } func (r *Runner) AssertResponded(t *testing.T) {