From bf43e687efe6bc0f166205bd19ea759bef06f3ff Mon Sep 17 00:00:00 2001 From: Darren Shepherd Date: Thu, 21 Mar 2024 15:48:52 -0700 Subject: [PATCH] feat: add support for 3rd party model shims --- go.mod | 2 +- pkg/cache/cache.go | 4 + pkg/cli/gptscript.go | 111 +++++++------------------ pkg/engine/cmd.go | 8 +- pkg/engine/http.go | 6 ++ pkg/env/env.go | 4 + pkg/gptscript/gptscript.go | 108 +++++++++++++++++++++++++ pkg/llm/registry.go | 44 +++++----- pkg/loader/loader.go | 22 ++--- pkg/openai/client.go | 21 ++++- pkg/remote/remote.go | 129 ++++++++++++++++++++++++++++++ pkg/repos/runtimes/env.old/env.go | 42 ++++++++++ pkg/server/server.go | 32 +++++--- pkg/types/tool.go | 12 +++ 14 files changed, 415 insertions(+), 130 deletions(-) create mode 100644 pkg/gptscript/gptscript.go create mode 100644 pkg/remote/remote.go create mode 100644 pkg/repos/runtimes/env.old/env.go diff --git a/go.mod b/go.mod index 3441a39c..9ad35830 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,7 @@ require ( github.com/sirupsen/logrus v1.9.3 github.com/spf13/cobra v1.8.0 github.com/stretchr/testify v1.8.4 + golang.org/x/exp v0.0.0-20240103183307-be819d1f06fc golang.org/x/sync v0.6.0 golang.org/x/term v0.16.0 ) @@ -66,7 +67,6 @@ require ( github.com/therootcompany/xz v1.0.1 // indirect github.com/ulikunitz/xz v0.5.10 // indirect go4.org v0.0.0-20200411211856-f5505b9728dd // indirect - golang.org/x/exp v0.0.0-20240103183307-be819d1f06fc // indirect golang.org/x/mod v0.15.0 // indirect golang.org/x/net v0.20.0 // indirect golang.org/x/sys v0.16.0 // indirect diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go index af006060..f86a7f82 100644 --- a/pkg/cache/cache.go +++ b/pkg/cache/cache.go @@ -58,6 +58,10 @@ func New(opts ...Options) (*Client, error) { }, nil } +func (c *Client) CacheDir() string { + return c.dir +} + func (c *Client) Store(key string, content []byte) error { if c == nil || c.noop { return nil diff --git a/pkg/cli/gptscript.go b/pkg/cli/gptscript.go index e2f548c1..f09f661c 100644 --- a/pkg/cli/gptscript.go +++ b/pkg/cli/gptscript.go @@ -1,7 +1,6 @@ package cli import ( - "context" "fmt" "io" "os" @@ -12,15 +11,12 @@ import ( "github.com/gptscript-ai/gptscript/pkg/builtin" "github.com/gptscript-ai/gptscript/pkg/cache" "github.com/gptscript-ai/gptscript/pkg/confirm" - "github.com/gptscript-ai/gptscript/pkg/engine" + "github.com/gptscript-ai/gptscript/pkg/gptscript" "github.com/gptscript-ai/gptscript/pkg/input" - "github.com/gptscript-ai/gptscript/pkg/llm" "github.com/gptscript-ai/gptscript/pkg/loader" "github.com/gptscript-ai/gptscript/pkg/monitor" "github.com/gptscript-ai/gptscript/pkg/mvl" "github.com/gptscript-ai/gptscript/pkg/openai" - "github.com/gptscript-ai/gptscript/pkg/repos/runtimes" - "github.com/gptscript-ai/gptscript/pkg/runner" "github.com/gptscript-ai/gptscript/pkg/server" "github.com/gptscript-ai/gptscript/pkg/types" "github.com/gptscript-ai/gptscript/pkg/version" @@ -50,8 +46,6 @@ type GPTScript struct { Server bool `usage:"Start server"` ListenAddress string `usage:"Server listen address" default:"127.0.0.1:9090"` Chdir string `usage:"Change current working directory" short:"C"` - - _client llm.Client `usage:"-"` } func New() *cobra.Command { @@ -78,33 +72,6 @@ func (r *GPTScript) Customize(cmd *cobra.Command) { } } -func (r *GPTScript) getClient(ctx context.Context) (llm.Client, error) { - if r._client != nil { - return r._client, nil - } - - cacheClient, err := cache.New(cache.Options(r.CacheOptions)) - if err != nil { - return nil, err - } - - oaClient, err := openai.NewClient(openai.Options(r.OpenAIOptions), openai.Options{ - Cache: cacheClient, - }) - if err != nil { - return nil, err - } - - registry := llm.NewRegistry() - - if err := registry.AddClient(ctx, oaClient); err != nil { - return nil, err - } - - r._client = registry - return r._client, nil -} - func (r *GPTScript) listTools() error { var lines []string for _, tool := range builtin.ListTools() { @@ -114,24 +81,6 @@ func (r *GPTScript) listTools() error { return nil } -func (r *GPTScript) listModels(ctx context.Context) error { - c, err := r.getClient(ctx) - if err != nil { - return err - } - - models, err := c.ListModels(ctx) - if err != nil { - return err - } - - for _, model := range models { - fmt.Println(model) - } - - return nil -} - func (r *GPTScript) Pre(*cobra.Command, []string) error { // chdir as soon as possible if r.Chdir != "" { @@ -164,37 +113,50 @@ func (r *GPTScript) Pre(*cobra.Command, []string) error { } func (r *GPTScript) Run(cmd *cobra.Command, args []string) error { - defer engine.CloseDaemons() - - if r.ListModels { - return r.listModels(cmd.Context()) - } - - if r.ListTools { - return r.listTools() + gptOpt := gptscript.Options{ + Cache: cache.Options(r.CacheOptions), + OpenAI: openai.Options(r.OpenAIOptions), + Monitor: monitor.Options(r.DisplayOptions), + Quiet: r.Quiet, + Env: os.Environ(), } if r.Server { - c, err := r.getClient(cmd.Context()) - if err != nil { - return err - } - s, err := server.New(c, server.Options{ + s, err := server.New(&server.Options{ ListenAddress: r.ListenAddress, + GPTScript: gptOpt, }) if err != nil { return err } + defer s.Close() return s.Start(cmd.Context()) } + gptScript, err := gptscript.New(&gptOpt) + if err != nil { + return err + } + defer gptScript.Close() + + if r.ListModels { + models, err := gptScript.ListModels(cmd.Context()) + if err != nil { + return err + } + fmt.Println(strings.Join(models, "\n")) + } + + if r.ListTools { + return r.listTools() + } + if len(args) == 0 { return cmd.Help() } var ( prg types.Program - err error ) if args[0] == "-" { @@ -227,21 +189,6 @@ func (r *GPTScript) Run(cmd *cobra.Command, args []string) error { return assemble.Assemble(prg, out) } - client, err := r.getClient(cmd.Context()) - if err != nil { - return err - } - - runner, err := runner.New(client, runner.Options{ - MonitorFactory: monitor.NewConsole(monitor.Options(r.DisplayOptions), monitor.Options{ - DisplayProgress: !*r.Quiet, - }), - RuntimeManager: runtimes.Default(cache.Complete(cache.Options(r.CacheOptions)).CacheDir), - }) - if err != nil { - return err - } - toolInput, err := input.FromCLI(r.Input, args) if err != nil { return err @@ -251,7 +198,7 @@ func (r *GPTScript) Run(cmd *cobra.Command, args []string) error { if r.Confirm { ctx = confirm.WithConfirm(ctx, confirm.TextPrompt{}) } - s, err := runner.Run(ctx, prg, os.Environ(), toolInput) + s, err := gptScript.Run(ctx, prg, os.Environ(), toolInput) if err != nil { return err } diff --git a/pkg/engine/cmd.go b/pkg/engine/cmd.go index 5dc099a7..c2dbfd8a 100644 --- a/pkg/engine/cmd.go +++ b/pkg/engine/cmd.go @@ -109,13 +109,13 @@ var ignoreENV = map[string]struct{}{ "GPTSCRIPT_TOOL_DIR": {}, } -func appendEnv(env []string, k, v string) []string { - for _, k := range []string{k, strings.ToUpper(strings.ReplaceAll(k, "-", "_"))} { +func appendEnv(envs []string, k, v string) []string { + for _, k := range []string{k, env.ToEnvLike(k)} { if _, ignore := ignoreENV[k]; !ignore { - env = append(env, k+"="+v) + envs = append(envs, k+"="+v) } } - return env + return envs } func appendInputAsEnv(env []string, input string) []string { diff --git a/pkg/engine/http.go b/pkg/engine/http.go index c4ef9153..8d53d283 100644 --- a/pkg/engine/http.go +++ b/pkg/engine/http.go @@ -55,6 +55,12 @@ func (e *Engine) runHTTP(ctx context.Context, prg *types.Program, tool types.Too toolURL = parsed.String() } + if tool.Blocking { + return &Return{ + Result: &toolURL, + }, nil + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, toolURL, strings.NewReader(input)) if err != nil { return nil, err diff --git a/pkg/env/env.go b/pkg/env/env.go index f1b71c5b..fcad5836 100644 --- a/pkg/env/env.go +++ b/pkg/env/env.go @@ -12,6 +12,10 @@ func execEquals(bin, check string) bool { bin == check+".exe" } +func ToEnvLike(v string) string { + return strings.ToUpper(strings.ReplaceAll(v, "-", "_")) +} + func Matches(cmd []string, bin string) bool { switch len(cmd) { case 0: diff --git a/pkg/gptscript/gptscript.go b/pkg/gptscript/gptscript.go new file mode 100644 index 00000000..6e593880 --- /dev/null +++ b/pkg/gptscript/gptscript.go @@ -0,0 +1,108 @@ +package gptscript + +import ( + "context" + "os" + + "github.com/gptscript-ai/gptscript/pkg/cache" + "github.com/gptscript-ai/gptscript/pkg/engine" + "github.com/gptscript-ai/gptscript/pkg/llm" + "github.com/gptscript-ai/gptscript/pkg/monitor" + "github.com/gptscript-ai/gptscript/pkg/openai" + "github.com/gptscript-ai/gptscript/pkg/remote" + "github.com/gptscript-ai/gptscript/pkg/repos/runtimes" + "github.com/gptscript-ai/gptscript/pkg/runner" + "github.com/gptscript-ai/gptscript/pkg/types" +) + +type GPTScript struct { + Registry *llm.Registry + Runner *runner.Runner +} + +type Options struct { + Cache cache.Options + OpenAI openai.Options + Monitor monitor.Options + Runner runner.Options + Quiet *bool `usage:"No output logging (set --quiet=false to force on even when there is no TTY)" short:"q"` + Env []string `usage:"-"` +} + +func complete(opts *Options) (result *Options) { + result = opts + if result == nil { + result = &Options{} + } + if result.Quiet == nil { + result.Quiet = new(bool) + } + if len(result.Env) == 0 { + result.Env = os.Environ() + } + return +} + +func New(opts *Options) (*GPTScript, error) { + opts = complete(opts) + + registry := llm.NewRegistry() + + cacheClient, err := cache.New(opts.Cache) + if err != nil { + return nil, err + } + + oAIClient, err := openai.NewClient(append([]openai.Options{opts.OpenAI}, openai.Options{ + Cache: cacheClient, + })...) + if err != nil { + return nil, err + } + + if err := registry.AddClient(oAIClient); err != nil { + return nil, err + } + + if opts.Runner.MonitorFactory == nil { + opts.Runner.MonitorFactory = monitor.NewConsole(append([]monitor.Options{opts.Monitor}, monitor.Options{ + DisplayProgress: !*opts.Quiet, + })...) + } + + if opts.Runner.RuntimeManager == nil { + opts.Runner.RuntimeManager = runtimes.Default(cacheClient.CacheDir()) + } + + runner, err := runner.New(registry, opts.Runner) + if err != nil { + return nil, err + } + + remoteClient := remote.New(runner, opts.Env, cacheClient) + + if err := registry.AddClient(remoteClient); err != nil { + return nil, err + } + + return &GPTScript{ + Registry: registry, + Runner: runner, + }, nil +} + +func (g *GPTScript) Run(ctx context.Context, prg types.Program, envs []string, input string) (string, error) { + return g.Runner.Run(ctx, prg, envs, input) +} + +func (g *GPTScript) Close() { + engine.CloseDaemons() +} + +func (g *GPTScript) GetModel() engine.Model { + return g.Registry +} + +func (g *GPTScript) ListModels(ctx context.Context) ([]string, error) { + return g.Registry.ListModels(ctx) +} diff --git a/pkg/llm/registry.go b/pkg/llm/registry.go index 20a3ae3b..166d2197 100644 --- a/pkg/llm/registry.go +++ b/pkg/llm/registry.go @@ -2,6 +2,7 @@ package llm import ( "context" + "errors" "fmt" "sort" @@ -11,32 +12,29 @@ import ( type Client interface { Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) ListModels(ctx context.Context) (result []string, _ error) + Supports(ctx context.Context, modelName string) (bool, error) } type Registry struct { - clientsByModel map[string]Client + clients []Client } func NewRegistry() *Registry { - return &Registry{ - clientsByModel: map[string]Client{}, - } + return &Registry{} } -func (r *Registry) AddClient(ctx context.Context, client Client) error { - models, err := client.ListModels(ctx) - if err != nil { - return err - } - for _, model := range models { - r.clientsByModel[model] = client - } +func (r *Registry) AddClient(client Client) error { + r.clients = append(r.clients, client) return nil } -func (r *Registry) ListModels(_ context.Context) (result []string, _ error) { - for k := range r.clientsByModel { - result = append(result, k) +func (r *Registry) ListModels(ctx context.Context) (result []string, _ error) { + for _, v := range r.clients { + models, err := v.ListModels(ctx) + if err != nil { + return nil, err + } + result = append(result, models...) } sort.Strings(result) return result, nil @@ -46,9 +44,17 @@ func (r *Registry) Call(ctx context.Context, messageRequest types.CompletionRequ if messageRequest.Model == "" { return nil, fmt.Errorf("model is required") } - client, ok := r.clientsByModel[messageRequest.Model] - if !ok { - return nil, fmt.Errorf("model not found: %s", messageRequest.Model) + var errs []error + for _, client := range r.clients { + ok, err := client.Supports(ctx, messageRequest.Model) + if err != nil { + errs = append(errs, err) + } else if ok { + return client.Call(ctx, messageRequest, status) + } + } + if len(errs) == 0 { + return nil, fmt.Errorf("failed to find a model provider for model [%s]", messageRequest.Model) } - return client.Call(ctx, messageRequest, status) + return nil, errors.Join(errs...) } diff --git a/pkg/loader/loader.go b/pkg/loader/loader.go index cfd45ef3..e851cac4 100644 --- a/pkg/loader/loader.go +++ b/pkg/loader/loader.go @@ -200,15 +200,7 @@ func link(ctx context.Context, prg *types.Program, base *source, tool types.Tool tool.ToolMapping[targetToolName] = linkedTool.ID toolNames[targetToolName] = struct{}{} } else { - subTool, toolName, ok := strings.Cut(targetToolName, " from ") - if ok { - toolName = strings.TrimSpace(toolName) - subTool = strings.TrimSpace(subTool) - } else { - toolName = targetToolName - subTool = "" - } - + toolName, subTool := SplitToolRef(targetToolName) resolvedTool, err := resolve(ctx, prg, base, toolName, subTool) if err != nil { return types.Tool{}, fmt.Errorf("failed resolving %s at %s: %w", targetToolName, base, err) @@ -292,3 +284,15 @@ func input(ctx context.Context, base *source, name string) (*source, error) { return nil, fmt.Errorf("can not load tools path=%s name=%s", base.Path, name) } + +func SplitToolRef(targetToolName string) (toolName, subTool string) { + subTool, toolName, ok := strings.Cut(targetToolName, " from ") + if ok { + toolName = strings.TrimSpace(toolName) + subTool = strings.TrimSpace(subTool) + } else { + toolName = targetToolName + subTool = "" + } + return +} diff --git a/pkg/openai/client.go b/pkg/openai/client.go index 22f640b3..3aeace0c 100644 --- a/pkg/openai/client.go +++ b/pkg/openai/client.go @@ -119,6 +119,14 @@ func NewClient(opts ...Options) (*Client, error) { }, nil } +func (c *Client) Supports(ctx context.Context, modelName string) (bool, error) { + models, err := c.ListModels(ctx) + if err != nil { + return false, err + } + return slices.Contains(models, modelName), nil +} + func (c *Client) ListModels(ctx context.Context) (result []string, _ error) { models, err := c.c.ListModels(ctx) if err != nil { @@ -335,6 +343,13 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques result = appendMessage(result, response) } + for i, content := range result.Content { + if content.ToolCall != nil && content.ToolCall.ID == "" { + content.ToolCall.ID = "call_" + hash.ID(content.ToolCall.Function.Name, content.ToolCall.Function.Arguments)[:8] + result.Content[i] = content + } + } + status <- types.CompletionStatus{ CompletionID: id, Chunks: response, @@ -354,10 +369,10 @@ func appendMessage(msg types.CompletionMessage, response openai.ChatCompletionSt msg.Role = types.CompletionMessageRoleType(override(string(msg.Role), delta.Role)) for _, tool := range delta.ToolCalls { - if tool.Index == nil { - continue + idx := 0 + if tool.Index != nil { + idx = *tool.Index } - idx := *tool.Index for len(msg.Content)-1 < idx { msg.Content = append(msg.Content, types.ContentPart{ ToolCall: &types.CompletionToolCall{ diff --git a/pkg/remote/remote.go b/pkg/remote/remote.go new file mode 100644 index 00000000..adccf5ad --- /dev/null +++ b/pkg/remote/remote.go @@ -0,0 +1,129 @@ +package remote + +import ( + "context" + "fmt" + "slices" + "sort" + "strings" + "sync" + + "github.com/gptscript-ai/gptscript/pkg/cache" + "github.com/gptscript-ai/gptscript/pkg/loader" + "github.com/gptscript-ai/gptscript/pkg/openai" + "github.com/gptscript-ai/gptscript/pkg/runner" + "github.com/gptscript-ai/gptscript/pkg/types" + "golang.org/x/exp/maps" +) + +type Client struct { + clientsLock sync.Mutex + cache *cache.Client + clients map[string]*openai.Client + models map[string]*openai.Client + runner *runner.Runner + envs []string +} + +func New(r *runner.Runner, envs []string, cache *cache.Client) *Client { + return &Client{ + cache: cache, + runner: r, + envs: envs, + } +} + +func (c *Client) Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) { + c.clientsLock.Lock() + client, ok := c.models[messageRequest.Model] + c.clientsLock.Unlock() + + if !ok { + return nil, fmt.Errorf("failed to find remote model %s", messageRequest.Model) + } + + _, modelName := loader.SplitToolRef(messageRequest.Model) + messageRequest.Model = modelName + return client.Call(ctx, messageRequest, status) +} + +func (c *Client) ListModels(_ context.Context) (result []string, _ error) { + c.clientsLock.Lock() + defer c.clientsLock.Unlock() + + keys := maps.Keys(c.models) + sort.Strings(keys) + return keys, nil +} + +func (c *Client) Supports(ctx context.Context, modelName string) (bool, error) { + toolName, modelNameSuffix := loader.SplitToolRef(modelName) + if modelNameSuffix == "" { + return false, nil + } + + client, err := c.load(ctx, toolName) + if err != nil { + return false, err + } + + models, err := client.ListModels(ctx) + if err != nil { + return false, err + } + + if !slices.Contains(models, modelNameSuffix) { + return false, fmt.Errorf("Failed in find model [%s], supported [%s]", modelNameSuffix, strings.Join(models, ", ")) + } + + c.clientsLock.Lock() + defer c.clientsLock.Unlock() + + if c.models == nil { + c.models = map[string]*openai.Client{} + } + + c.models[modelName] = client + return true, nil +} + +func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, error) { + c.clientsLock.Lock() + defer c.clientsLock.Unlock() + + client, ok := c.clients[toolName] + if ok { + return client, nil + } + + prg, err := loader.Program(ctx, toolName, "") + if err != nil { + return nil, err + } + + url, err := c.runner.Run(ctx, prg.SetBlocking(), c.envs, "") + if err != nil { + return nil, err + } + + if strings.HasSuffix(url, "/") { + url += "v1" + } else { + url += "/v1" + } + + client, err = openai.NewClient(openai.Options{ + BaseURL: url, + Cache: c.cache, + }) + if err != nil { + return nil, err + } + + if c.clients == nil { + c.clients = make(map[string]*openai.Client) + } + + c.clients[toolName] = client + return client, nil +} diff --git a/pkg/repos/runtimes/env.old/env.go b/pkg/repos/runtimes/env.old/env.go new file mode 100644 index 00000000..8c5c0ae3 --- /dev/null +++ b/pkg/repos/runtimes/env.old/env.go @@ -0,0 +1,42 @@ +package env + +import ( + "fmt" + "os" + "strings" +) + +func execEquals(bin, check string) bool { + return bin == check || + bin == check+".exe" +} + +func Matches(cmd []string, bin string) bool { + switch len(cmd) { + case 0: + return false + case 1: + return execEquals(cmd[0], bin) + } + if cmd[0] == bin { + return true + } + if cmd[0] == "/usr/bin/env" || cmd[0] == "/bin/env" { + return execEquals(cmd[1], bin) + } + return false +} + +func AppendPath(env []string, binPath string) []string { + var newEnv []string + for _, path := range env { + for _, prefix := range []string{"PATH=", "Path="} { + v, ok := strings.CutPrefix(path, prefix) + if ok { + newEnv = append(newEnv, fmt.Sprintf(prefix+"%s%s%s", + binPath, string(os.PathListSeparator), v)) + } + } + } + return newEnv +} diff --git a/pkg/server/server.go b/pkg/server/server.go index d24a8d06..a8ba9619 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -17,7 +17,7 @@ import ( "github.com/acorn-io/broadcaster" "github.com/gptscript-ai/gptscript/pkg/builtin" "github.com/gptscript-ai/gptscript/pkg/cache" - "github.com/gptscript-ai/gptscript/pkg/engine" + "github.com/gptscript-ai/gptscript/pkg/gptscript" "github.com/gptscript-ai/gptscript/pkg/loader" "github.com/gptscript-ai/gptscript/pkg/runner" "github.com/gptscript-ai/gptscript/pkg/system" @@ -29,25 +29,29 @@ import ( type Options struct { ListenAddress string + GPTScript gptscript.Options } -func complete(opts []Options) (result Options) { - for _, opt := range opts { - result.ListenAddress = types.FirstSet(opt.ListenAddress, result.ListenAddress) +func complete(opts *Options) (result *Options) { + result = opts + if result == nil { + result = &Options{} } + + result.ListenAddress = types.FirstSet(result.ListenAddress, result.ListenAddress) if result.ListenAddress == "" { result.ListenAddress = "127.0.0.1:9090" } + return } -func New(model engine.Model, opts ...Options) (*Server, error) { +func New(opts *Options) (*Server, error) { events := broadcaster.New[Event]() + opts = complete(opts) + opts.GPTScript.Runner.MonitorFactory = NewSessionFactory(events) - opt := complete(opts) - r, err := runner.New(model, runner.Options{ - MonitorFactory: NewSessionFactory(events), - }) + g, err := gptscript.New(&opts.GPTScript) if err != nil { return nil, err } @@ -55,8 +59,8 @@ func New(model engine.Model, opts ...Options) (*Server, error) { return &Server{ melody: melody.New(), events: events, - runner: r, - listenAddress: opt.ListenAddress, + runner: g, + listenAddress: opts.ListenAddress, }, nil } @@ -72,7 +76,7 @@ type Event struct { type Server struct { ctx context.Context melody *melody.Melody - runner *runner.Runner + runner *gptscript.GPTScript events *broadcaster.Broadcaster[Event] listenAddress string } @@ -91,6 +95,10 @@ func IDFromContext(ctx context.Context) string { return ctx.Value(execKey{}).(string) } +func (s *Server) Close() { + s.runner.Close() +} + func (s *Server) list(rw http.ResponseWriter, req *http.Request) { rw.Header().Set("Content-Type", "application/json") enc := json.NewEncoder(rw) diff --git a/pkg/types/tool.go b/pkg/types/tool.go index 00d1ec2f..7b58d21f 100644 --- a/pkg/types/tool.go +++ b/pkg/types/tool.go @@ -5,6 +5,8 @@ import ( "fmt" "sort" "strings" + + "golang.org/x/exp/maps" ) const ( @@ -21,6 +23,15 @@ type Program struct { Exports map[string]string `json:"exports,omitempty"` } +func (p Program) SetBlocking() Program { + tool := p.ToolSet[p.EntryToolID] + tool.Blocking = true + tools := maps.Clone(p.ToolSet) + tools[p.EntryToolID] = tool + p.ToolSet = tools + return p +} + type BuiltinFunc func(ctx context.Context, env []string, input string) (string, error) type Parameters struct { @@ -35,6 +46,7 @@ type Parameters struct { Arguments *JSONSchema `json:"arguments,omitempty"` Tools []string `json:"tools,omitempty"` Export []string `json:"export,omitempty"` + Blocking bool `json:"-"` } type Tool struct {