diff --git a/pkg/engine/engine.go b/pkg/engine/engine.go index ebc83454..24ee9302 100644 --- a/pkg/engine/engine.go +++ b/pkg/engine/engine.go @@ -7,21 +7,12 @@ import ( "sync" "sync/atomic" - "github.com/gptscript-ai/gptscript/pkg/system" "github.com/gptscript-ai/gptscript/pkg/types" "github.com/gptscript-ai/gptscript/pkg/version" ) var completionID int64 -type ErrToolNotFound struct { - ToolName string -} - -func (e *ErrToolNotFound) Error() string { - return fmt.Sprintf("tool not found: %s", e.ToolName) -} - type Model interface { Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) } @@ -62,12 +53,11 @@ type CallResult struct { } type Context struct { - ID string - Ctx context.Context - Parent *Context - Program *types.Program - Tool types.Tool - toolNames map[string]struct{} + ID string + Ctx context.Context + Parent *Context + Program *types.Program + Tool types.Tool } func (c *Context) ParentID() string { @@ -119,65 +109,6 @@ func (c *Context) SubCall(ctx context.Context, toolID, callID string) (Context, }, nil } -func (c *Context) getTool(parent types.Tool, name string) (types.Tool, error) { - toolID, ok := parent.ToolMapping[name] - if !ok { - return types.Tool{}, &ErrToolNotFound{ - ToolName: name, - } - } - tool, ok := c.Program.ToolSet[toolID] - if !ok { - return types.Tool{}, &ErrToolNotFound{ - ToolName: name, - } - } - return tool, nil -} - -func (c *Context) appendTool(completion *types.CompletionRequest, parentTool types.Tool, subToolName string) error { - subTool, err := c.getTool(parentTool, subToolName) - if err != nil { - return err - } - - args := subTool.Parameters.Arguments - if args == nil && !subTool.IsCommand() { - args = &system.DefaultToolSchema - } - - for _, existingTool := range completion.Tools { - if existingTool.Function.ToolID == subTool.ID { - return nil - } - } - - if c.toolNames == nil { - c.toolNames = map[string]struct{}{} - } - - if subTool.Instructions == "" { - log.Debugf("Skipping zero instruction tool %s (%s)", subToolName, subTool.ID) - } else { - completion.Tools = append(completion.Tools, types.CompletionTool{ - Function: types.CompletionFunctionDefinition{ - ToolID: subTool.ID, - Name: PickToolName(subToolName, c.toolNames), - Description: subTool.Parameters.Description, - Parameters: args, - }, - }) - } - - for _, export := range subTool.Export { - if err := c.appendTool(completion, subTool, export); err != nil { - return err - } - } - - return nil -} - func (e *Engine) Start(ctx Context, input string) (*Return, error) { tool := ctx.Tool @@ -207,10 +138,10 @@ func (e *Engine) Start(ctx Context, input string) (*Return, error) { InternalSystemPrompt: tool.Parameters.InternalPrompt, } - for _, subToolName := range tool.Parameters.Tools { - if err := ctx.appendTool(&completion, ctx.Tool, subToolName); err != nil { - return nil, err - } + var err error + completion.Tools, err = tool.GetCompletionTools(*ctx.Program) + if err != nil { + return nil, err } if tool.Instructions != "" { diff --git a/pkg/loader/loader.go b/pkg/loader/loader.go index 41248957..57d1866e 100644 --- a/pkg/loader/loader.go +++ b/pkg/loader/loader.go @@ -17,7 +17,6 @@ import ( "github.com/getkin/kin-openapi/openapi3" "github.com/gptscript-ai/gptscript/pkg/assemble" "github.com/gptscript-ai/gptscript/pkg/builtin" - "github.com/gptscript-ai/gptscript/pkg/engine" "github.com/gptscript-ai/gptscript/pkg/parser" "github.com/gptscript-ai/gptscript/pkg/types" "gopkg.in/yaml.v3" @@ -109,7 +108,7 @@ func loadProgram(data []byte, into *types.Program, targetToolName string) (types tool, ok := into.ToolSet[tool.LocalTools[targetToolName]] if !ok { - return tool, &engine.ErrToolNotFound{ + return tool, &types.ErrToolNotFound{ ToolName: targetToolName, } } diff --git a/pkg/tests/runner_test.go b/pkg/tests/runner_test.go index 6e911977..539d0c76 100644 --- a/pkg/tests/runner_test.go +++ b/pkg/tests/runner_test.go @@ -3,7 +3,6 @@ package tests import ( "testing" - "github.com/gptscript-ai/gptscript/pkg/engine" "github.com/gptscript-ai/gptscript/pkg/tests/tester" "github.com/gptscript-ai/gptscript/pkg/types" "github.com/stretchr/testify/assert" @@ -15,7 +14,7 @@ func TestCwd(t *testing.T) { runner.RespondWith(tester.Result{ Func: types.CompletionFunctionCall{ - Name: engine.ToolNormalizer("./subtool/test.gpt"), + Name: types.ToolNormalizer("./subtool/test.gpt"), }, }) runner.RespondWith(tester.Result{ diff --git a/pkg/types/completion.go b/pkg/types/completion.go index e370e9cb..e4bb92a5 100644 --- a/pkg/types/completion.go +++ b/pkg/types/completion.go @@ -28,7 +28,6 @@ type CompletionFunctionDefinition struct { ToolID string `json:"toolID,omitempty"` Name string `json:"name"` Description string `json:"description,omitempty"` - Domain string `json:"domain,omitempty"` Parameters *openapi3.Schema `json:"parameters"` } diff --git a/pkg/types/log.go b/pkg/types/log.go new file mode 100644 index 00000000..ba3ff8c5 --- /dev/null +++ b/pkg/types/log.go @@ -0,0 +1,5 @@ +package types + +import "github.com/gptscript-ai/gptscript/pkg/mvl" + +var log = mvl.Package() diff --git a/pkg/types/tool.go b/pkg/types/tool.go index 777e7953..ebd5812e 100644 --- a/pkg/types/tool.go +++ b/pkg/types/tool.go @@ -7,6 +7,7 @@ import ( "strings" "github.com/getkin/kin-openapi/openapi3" + "github.com/gptscript-ai/gptscript/pkg/system" "golang.org/x/exp/maps" ) @@ -16,6 +17,14 @@ const ( CommandPrefix = "#!" ) +type ErrToolNotFound struct { + ToolName string +} + +func (e *ErrToolNotFound) Error() string { + return fmt.Sprintf("tool not found: %s", e.ToolName) +} + type ToolSet map[string]Tool type Program struct { @@ -24,6 +33,17 @@ type Program struct { ToolSet ToolSet `json:"toolSet,omitempty"` } +func (p Program) GetCompletionTools() (result []CompletionTool, err error) { + return Tool{ + Parameters: Parameters{ + Tools: []string{"main"}, + }, + ToolMapping: map[string]string{ + "main": p.EntryToolID, + }, + }.GetCompletionTools(p) +} + func (p Program) TopLevelTools() (result []Tool) { for _, tool := range p.ToolSet[p.EntryToolID].LocalTools { result = append(result, p.ToolSet[tool]) @@ -124,6 +144,75 @@ func (t Tool) String() string { return buf.String() } +func (t Tool) GetCompletionTools(prg Program) (result []CompletionTool, err error) { + toolNames := map[string]struct{}{} + + for _, subToolName := range t.Parameters.Tools { + result, err = appendTool(result, prg, t, subToolName, toolNames) + if err != nil { + return nil, err + } + } + + return result, nil +} + +func getTool(prg Program, parent Tool, name string) (Tool, error) { + toolID, ok := parent.ToolMapping[name] + if !ok { + return Tool{}, &ErrToolNotFound{ + ToolName: name, + } + } + tool, ok := prg.ToolSet[toolID] + if !ok { + return Tool{}, &ErrToolNotFound{ + ToolName: name, + } + } + return tool, nil +} + +func appendTool(completionTools []CompletionTool, prg Program, parentTool Tool, subToolName string, toolNames map[string]struct{}) ([]CompletionTool, error) { + subTool, err := getTool(prg, parentTool, subToolName) + if err != nil { + return nil, err + } + + args := subTool.Parameters.Arguments + if args == nil && !subTool.IsCommand() { + args = &system.DefaultToolSchema + } + + for _, existingTool := range completionTools { + if existingTool.Function.ToolID == subTool.ID { + return completionTools, nil + } + } + + if subTool.Instructions == "" { + log.Debugf("Skipping zero instruction tool %s (%s)", subToolName, subTool.ID) + } else { + completionTools = append(completionTools, CompletionTool{ + Function: CompletionFunctionDefinition{ + ToolID: subTool.ID, + Name: PickToolName(subToolName, toolNames), + Description: subTool.Parameters.Description, + Parameters: args, + }, + }) + } + + for _, export := range subTool.Export { + completionTools, err = appendTool(completionTools, prg, subTool, export, toolNames) + if err != nil { + return nil, err + } + } + + return completionTools, nil +} + type Repo struct { // VCS The VCS type, such as "git" VCS string diff --git a/pkg/engine/toolname.go b/pkg/types/toolname.go similarity index 98% rename from pkg/engine/toolname.go rename to pkg/types/toolname.go index 69357c50..8f46dd19 100644 --- a/pkg/engine/toolname.go +++ b/pkg/types/toolname.go @@ -1,4 +1,4 @@ -package engine +package types import ( "path/filepath"