From 90bd245374235022f8e2623f3b0f6a5933b73be6 Mon Sep 17 00:00:00 2001 From: Darren Shepherd Date: Sat, 18 May 2024 20:32:27 -0700 Subject: [PATCH] chore: assign provider tool category when launch provider --- pkg/engine/engine.go | 14 ++++++++++++-- pkg/remote/remote.go | 3 ++- pkg/runner/runner.go | 2 +- pkg/types/completion.go | 2 +- 4 files changed, 16 insertions(+), 5 deletions(-) diff --git a/pkg/engine/engine.go b/pkg/engine/engine.go index a88e8415..da662a10 100644 --- a/pkg/engine/engine.go +++ b/pkg/engine/engine.go @@ -77,6 +77,7 @@ type Context struct { type ToolCategory string const ( + ProviderToolCategory ToolCategory = "provider" CredentialToolCategory ToolCategory = "credential" ContextToolCategory ToolCategory = "context" NoCategory ToolCategory = "" @@ -120,11 +121,20 @@ func (c *Context) MarshalJSON() ([]byte, error) { return json.Marshal(c.GetCallContext()) } +type toolCategoryKey struct{} + +func WithToolCategory(ctx context.Context, toolCategory ToolCategory) context.Context { + return context.WithValue(ctx, toolCategoryKey{}, toolCategory) +} + func NewContext(ctx context.Context, prg *types.Program) Context { + category, _ := ctx.Value(toolCategoryKey{}).(ToolCategory) + callCtx := Context{ commonContext: commonContext{ - ID: counter.Next(), - Tool: prg.ToolSet[prg.EntryToolID], + ID: counter.Next(), + Tool: prg.ToolSet[prg.EntryToolID], + ToolCategory: category, }, Ctx: ctx, Program: prg, diff --git a/pkg/remote/remote.go b/pkg/remote/remote.go index 68072c71..e8fc93dd 100644 --- a/pkg/remote/remote.go +++ b/pkg/remote/remote.go @@ -10,6 +10,7 @@ import ( "sync" "github.com/gptscript-ai/gptscript/pkg/cache" + "github.com/gptscript-ai/gptscript/pkg/engine" env2 "github.com/gptscript-ai/gptscript/pkg/env" "github.com/gptscript-ai/gptscript/pkg/loader" "github.com/gptscript-ai/gptscript/pkg/mvl" @@ -144,7 +145,7 @@ func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, err return nil, err } - url, err := c.runner.Run(ctx, prg.SetBlocking(), c.envs, "") + url, err := c.runner.Run(engine.WithToolCategory(ctx, engine.ProviderToolCategory), prg.SetBlocking(), c.envs, "") if err != nil { return nil, err } diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index 234f308c..d4a5516b 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -527,7 +527,7 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s err error ) - state, callResults, err = r.subCalls(callCtx, monitor, env, state, engine.NoCategory) + state, callResults, err = r.subCalls(callCtx, monitor, env, state, callCtx.ToolCategory) if errMessage := (*builtin.ErrChatFinish)(nil); errors.As(err, &errMessage) && callCtx.Tool.Chat { return &State{ Result: &errMessage.Message, diff --git a/pkg/types/completion.go b/pkg/types/completion.go index 2a1ca268..14c7e987 100644 --- a/pkg/types/completion.go +++ b/pkg/types/completion.go @@ -98,7 +98,7 @@ func (in CompletionMessage) String() string { } buf.WriteString(content.Text) if content.ToolCall != nil { - buf.WriteString(fmt.Sprintf("tool call %s -> %s", color.GreenString(content.ToolCall.Function.Name), content.ToolCall.Function.Arguments)) + buf.WriteString(fmt.Sprintf(" %s -> %s", color.GreenString(content.ToolCall.Function.Name), content.ToolCall.Function.Arguments)) } } return buf.String()