Skip to content

chore: add program.GetCompletionTools() #214

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 9 additions & 78 deletions pkg/engine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,12 @@ import (
"sync"
"sync/atomic"

"github.com/gptscript-ai/gptscript/pkg/system"
"github.com/gptscript-ai/gptscript/pkg/types"
"github.com/gptscript-ai/gptscript/pkg/version"
)

var completionID int64

type ErrToolNotFound struct {
ToolName string
}

func (e *ErrToolNotFound) Error() string {
return fmt.Sprintf("tool not found: %s", e.ToolName)
}

type Model interface {
Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error)
}
Expand Down Expand Up @@ -62,12 +53,11 @@ type CallResult struct {
}

type Context struct {
ID string
Ctx context.Context
Parent *Context
Program *types.Program
Tool types.Tool
toolNames map[string]struct{}
ID string
Ctx context.Context
Parent *Context
Program *types.Program
Tool types.Tool
}

func (c *Context) ParentID() string {
Expand Down Expand Up @@ -119,65 +109,6 @@ func (c *Context) SubCall(ctx context.Context, toolID, callID string) (Context,
}, nil
}

func (c *Context) getTool(parent types.Tool, name string) (types.Tool, error) {
toolID, ok := parent.ToolMapping[name]
if !ok {
return types.Tool{}, &ErrToolNotFound{
ToolName: name,
}
}
tool, ok := c.Program.ToolSet[toolID]
if !ok {
return types.Tool{}, &ErrToolNotFound{
ToolName: name,
}
}
return tool, nil
}

func (c *Context) appendTool(completion *types.CompletionRequest, parentTool types.Tool, subToolName string) error {
subTool, err := c.getTool(parentTool, subToolName)
if err != nil {
return err
}

args := subTool.Parameters.Arguments
if args == nil && !subTool.IsCommand() {
args = &system.DefaultToolSchema
}

for _, existingTool := range completion.Tools {
if existingTool.Function.ToolID == subTool.ID {
return nil
}
}

if c.toolNames == nil {
c.toolNames = map[string]struct{}{}
}

if subTool.Instructions == "" {
log.Debugf("Skipping zero instruction tool %s (%s)", subToolName, subTool.ID)
} else {
completion.Tools = append(completion.Tools, types.CompletionTool{
Function: types.CompletionFunctionDefinition{
ToolID: subTool.ID,
Name: PickToolName(subToolName, c.toolNames),
Description: subTool.Parameters.Description,
Parameters: args,
},
})
}

for _, export := range subTool.Export {
if err := c.appendTool(completion, subTool, export); err != nil {
return err
}
}

return nil
}

func (e *Engine) Start(ctx Context, input string) (*Return, error) {
tool := ctx.Tool

Expand Down Expand Up @@ -207,10 +138,10 @@ func (e *Engine) Start(ctx Context, input string) (*Return, error) {
InternalSystemPrompt: tool.Parameters.InternalPrompt,
}

for _, subToolName := range tool.Parameters.Tools {
if err := ctx.appendTool(&completion, ctx.Tool, subToolName); err != nil {
return nil, err
}
var err error
completion.Tools, err = tool.GetCompletionTools(*ctx.Program)
if err != nil {
return nil, err
}

if tool.Instructions != "" {
Expand Down
3 changes: 1 addition & 2 deletions pkg/loader/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import (
"github.com/getkin/kin-openapi/openapi3"
"github.com/gptscript-ai/gptscript/pkg/assemble"
"github.com/gptscript-ai/gptscript/pkg/builtin"
"github.com/gptscript-ai/gptscript/pkg/engine"
"github.com/gptscript-ai/gptscript/pkg/parser"
"github.com/gptscript-ai/gptscript/pkg/types"
"gopkg.in/yaml.v3"
Expand Down Expand Up @@ -109,7 +108,7 @@ func loadProgram(data []byte, into *types.Program, targetToolName string) (types

tool, ok := into.ToolSet[tool.LocalTools[targetToolName]]
if !ok {
return tool, &engine.ErrToolNotFound{
return tool, &types.ErrToolNotFound{
ToolName: targetToolName,
}
}
Expand Down
3 changes: 1 addition & 2 deletions pkg/tests/runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package tests
import (
"testing"

"github.com/gptscript-ai/gptscript/pkg/engine"
"github.com/gptscript-ai/gptscript/pkg/tests/tester"
"github.com/gptscript-ai/gptscript/pkg/types"
"github.com/stretchr/testify/assert"
Expand All @@ -15,7 +14,7 @@ func TestCwd(t *testing.T) {

runner.RespondWith(tester.Result{
Func: types.CompletionFunctionCall{
Name: engine.ToolNormalizer("./subtool/test.gpt"),
Name: types.ToolNormalizer("./subtool/test.gpt"),
},
})
runner.RespondWith(tester.Result{
Expand Down
1 change: 0 additions & 1 deletion pkg/types/completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ type CompletionFunctionDefinition struct {
ToolID string `json:"toolID,omitempty"`
Name string `json:"name"`
Description string `json:"description,omitempty"`
Domain string `json:"domain,omitempty"`
Parameters *openapi3.Schema `json:"parameters"`
}

Expand Down
5 changes: 5 additions & 0 deletions pkg/types/log.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package types

import "github.com/gptscript-ai/gptscript/pkg/mvl"

var log = mvl.Package()
89 changes: 89 additions & 0 deletions pkg/types/tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"strings"

"github.com/getkin/kin-openapi/openapi3"
"github.com/gptscript-ai/gptscript/pkg/system"
"golang.org/x/exp/maps"
)

Expand All @@ -16,6 +17,14 @@ const (
CommandPrefix = "#!"
)

type ErrToolNotFound struct {
ToolName string
}

func (e *ErrToolNotFound) Error() string {
return fmt.Sprintf("tool not found: %s", e.ToolName)
}

type ToolSet map[string]Tool

type Program struct {
Expand All @@ -24,6 +33,17 @@ type Program struct {
ToolSet ToolSet `json:"toolSet,omitempty"`
}

func (p Program) GetCompletionTools() (result []CompletionTool, err error) {
return Tool{
Parameters: Parameters{
Tools: []string{"main"},
},
ToolMapping: map[string]string{
"main": p.EntryToolID,
},
}.GetCompletionTools(p)
}

func (p Program) TopLevelTools() (result []Tool) {
for _, tool := range p.ToolSet[p.EntryToolID].LocalTools {
result = append(result, p.ToolSet[tool])
Expand Down Expand Up @@ -124,6 +144,75 @@ func (t Tool) String() string {
return buf.String()
}

func (t Tool) GetCompletionTools(prg Program) (result []CompletionTool, err error) {
toolNames := map[string]struct{}{}

for _, subToolName := range t.Parameters.Tools {
result, err = appendTool(result, prg, t, subToolName, toolNames)
if err != nil {
return nil, err
}
}

return result, nil
}

func getTool(prg Program, parent Tool, name string) (Tool, error) {
toolID, ok := parent.ToolMapping[name]
if !ok {
return Tool{}, &ErrToolNotFound{
ToolName: name,
}
}
tool, ok := prg.ToolSet[toolID]
if !ok {
return Tool{}, &ErrToolNotFound{
ToolName: name,
}
}
return tool, nil
}

func appendTool(completionTools []CompletionTool, prg Program, parentTool Tool, subToolName string, toolNames map[string]struct{}) ([]CompletionTool, error) {
subTool, err := getTool(prg, parentTool, subToolName)
if err != nil {
return nil, err
}

args := subTool.Parameters.Arguments
if args == nil && !subTool.IsCommand() {
args = &system.DefaultToolSchema
}

for _, existingTool := range completionTools {
if existingTool.Function.ToolID == subTool.ID {
return completionTools, nil
}
}

if subTool.Instructions == "" {
log.Debugf("Skipping zero instruction tool %s (%s)", subToolName, subTool.ID)
} else {
completionTools = append(completionTools, CompletionTool{
Function: CompletionFunctionDefinition{
ToolID: subTool.ID,
Name: PickToolName(subToolName, toolNames),
Description: subTool.Parameters.Description,
Parameters: args,
},
})
}

for _, export := range subTool.Export {
completionTools, err = appendTool(completionTools, prg, subTool, export, toolNames)
if err != nil {
return nil, err
}
}

return completionTools, nil
}

type Repo struct {
// VCS The VCS type, such as "git"
VCS string
Expand Down
2 changes: 1 addition & 1 deletion pkg/engine/toolname.go → pkg/types/toolname.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package engine
package types

import (
"path/filepath"
Expand Down