Skip to content

Add export parameter #113

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 70 additions & 32 deletions pkg/engine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,21 +45,23 @@ type Return struct {
}

type Call struct {
ToolName string `json:"toolName,omitempty"`
Input string `json:"input,omitempty"`
ToolID string `json:"toolID,omitempty"`
Input string `json:"input,omitempty"`
}

type CallResult struct {
ID string `json:"id,omitempty"`
ToolID string `json:"toolID,omitempty"`
CallID string `json:"callID,omitempty"`
Result string `json:"result,omitempty"`
}

type Context struct {
ID string
Ctx context.Context
Parent *Context
Program *types.Program
Tool types.Tool
ID string
Ctx context.Context
Parent *Context
Program *types.Program
Tool types.Tool
toolNames map[string]struct{}
}

func (c *Context) ParentID() string {
Expand Down Expand Up @@ -97,10 +99,10 @@ func NewContext(ctx context.Context, prg *types.Program) Context {
return callCtx
}

func (c *Context) SubCall(ctx context.Context, toolName, callID string) (Context, error) {
tool, err := c.getTool(toolName)
if err != nil {
return Context{}, err
func (c *Context) SubCall(ctx context.Context, toolID, callID string) (Context, error) {
tool, ok := c.Program.ToolSet[toolID]
if !ok {
return Context{}, fmt.Errorf("failed to file tool for id [%s]", toolID)
}
return Context{
ID: callID,
Expand All @@ -111,8 +113,8 @@ func (c *Context) SubCall(ctx context.Context, toolName, callID string) (Context
}, nil
}

func (c *Context) getTool(name string) (types.Tool, error) {
toolID, ok := c.Tool.ToolMapping[name]
func (c *Context) getTool(parent types.Tool, name string) (types.Tool, error) {
toolID, ok := parent.ToolMapping[name]
if !ok {
return types.Tool{}, &ErrToolNotFound{
ToolName: name,
Expand All @@ -127,6 +129,45 @@ func (c *Context) getTool(name string) (types.Tool, error) {
return tool, nil
}

func (c *Context) appendTool(completion *types.CompletionRequest, parentTool types.Tool, subToolName string) error {
subTool, err := c.getTool(parentTool, subToolName)
if err != nil {
return err
}

args := subTool.Parameters.Arguments
if args == nil && !subTool.IsCommand() {
args = &system.DefaultToolSchema
}

for _, existingTool := range completion.Tools {
if existingTool.Function.ToolID == subTool.ID {
return nil
}
}

if c.toolNames == nil {
c.toolNames = map[string]struct{}{}
}

completion.Tools = append(completion.Tools, types.CompletionTool{
Function: types.CompletionFunctionDefinition{
ToolID: subTool.ID,
Name: PickToolName(subToolName, c.toolNames),
Description: subTool.Parameters.Description,
Parameters: args,
},
})

for _, export := range subTool.Export {
if err := c.appendTool(completion, subTool, export); err != nil {
return err
}
}

return nil
}

func (e *Engine) Start(ctx Context, input string) (*Return, error) {
tool := ctx.Tool

Expand Down Expand Up @@ -155,21 +196,9 @@ func (e *Engine) Start(ctx Context, input string) (*Return, error) {
}

for _, subToolName := range tool.Parameters.Tools {
subTool, err := ctx.getTool(subToolName)
if err != nil {
if err := ctx.appendTool(&completion, ctx.Tool, subToolName); err != nil {
return nil, err
}
args := subTool.Parameters.Arguments
if args == nil && !subTool.IsCommand() {
args = &system.DefaultToolSchema
}
completion.Tools = append(completion.Tools, types.CompletionTool{
Function: types.CompletionFunctionDefinition{
Name: subToolName,
Description: subTool.Parameters.Description,
Parameters: args,
},
})
}

if tool.Instructions != "" {
Expand Down Expand Up @@ -225,10 +254,19 @@ func (e *Engine) complete(ctx context.Context, state *State) (*Return, error) {
state.Pending = map[string]types.CompletionToolCall{}
for _, content := range resp.Content {
if content.ToolCall != nil {
var toolID string
for _, tool := range state.Completion.Tools {
if tool.Function.Name == content.ToolCall.Function.Name {
toolID = tool.Function.ToolID
}
}
if toolID == "" {
return nil, fmt.Errorf("failed to find tool id for tool %s in tool_call result", content.ToolCall.Function.Name)
}
state.Pending[content.ToolCall.ID] = *content.ToolCall
ret.Calls[content.ToolCall.ID] = Call{
ToolName: content.ToolCall.Function.Name,
Input: content.ToolCall.Function.Arguments,
ToolID: toolID,
Input: content.ToolCall.Function.Arguments,
}
} else {
cp := content.Text
Expand All @@ -247,7 +285,7 @@ func (e *Engine) Continue(ctx context.Context, state *State, results ...CallResu
}

for _, result := range results {
state.Results[result.ID] = result
state.Results[result.CallID] = result
}

ret := Return{
Expand All @@ -262,8 +300,8 @@ func (e *Engine) Continue(ctx context.Context, state *State, results ...CallResu
for id, pending := range state.Pending {
if _, ok := state.Results[id]; !ok {
ret.Calls[id] = Call{
ToolName: pending.Function.Name,
Input: pending.Function.Arguments,
ToolID: state.Completion.Tools[*pending.Index].Function.ToolID,
Input: pending.Function.Arguments,
}
}
}
Expand Down
53 changes: 53 additions & 0 deletions pkg/engine/toolname.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package engine

import (
"crypto/md5"
"encoding/hex"
"path/filepath"
"regexp"
"strings"

"github.com/gptscript-ai/gptscript/pkg/system"
)

var (
validToolName = regexp.MustCompile("^[a-zA-Z0-9_-]{1,64}$")
invalidChars = regexp.MustCompile("[^a-zA-Z0-9_-]+")
)

func ToolNormalizer(tool string) string {
parts := strings.Split(tool, "/")
tool = parts[len(parts)-1]
if strings.HasSuffix(tool, system.Suffix) {
tool = strings.TrimSuffix(tool, filepath.Ext(tool))
}

if validToolName.MatchString(tool) {
return tool
}

name := invalidChars.ReplaceAllString(tool, "-")
if len(name) > 55 {
name = name[:55]
}

hash := md5.Sum([]byte(tool))
hexed := hex.EncodeToString(hash[:])

return name + "-" + hexed[:8]
}

func PickToolName(toolName string, existing map[string]struct{}) string {
if toolName == "" {
toolName = "external"
}

for {
testName := ToolNormalizer(toolName)
if _, ok := existing[testName]; !ok {
existing[testName] = struct{}{}
return testName
}
toolName += "0"
}
}
108 changes: 26 additions & 82 deletions pkg/loader/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package loader
import (
"bytes"
"context"
"crypto/md5"
"crypto/sha256"
"encoding/hex"
"encoding/json"
Expand All @@ -13,14 +12,12 @@ import (
"io/fs"
"os"
"path/filepath"
"regexp"
"strings"

"github.com/gptscript-ai/gptscript/pkg/assemble"
"github.com/gptscript-ai/gptscript/pkg/builtin"
"github.com/gptscript-ai/gptscript/pkg/engine"
"github.com/gptscript-ai/gptscript/pkg/parser"
"github.com/gptscript-ai/gptscript/pkg/system"
"github.com/gptscript-ai/gptscript/pkg/types"
)

Expand Down Expand Up @@ -182,48 +179,6 @@ func readTool(ctx context.Context, prg *types.Program, base *source, targetToolN
return link(ctx, prg, base, mainTool, localTools)
}

var (
validToolName = regexp.MustCompile("^[a-zA-Z0-9_-]{1,64}$")
invalidChars = regexp.MustCompile("[^a-zA-Z0-9_-]+")
)

func ToolNormalizer(tool string) string {
parts := strings.Split(tool, "/")
tool = parts[len(parts)-1]
if strings.HasSuffix(tool, system.Suffix) {
tool = strings.TrimSuffix(tool, filepath.Ext(tool))
}

if validToolName.MatchString(tool) {
return tool
}

name := invalidChars.ReplaceAllString(tool, "-")
if len(name) > 55 {
name = name[:55]
}

hash := md5.Sum([]byte(tool))
hexed := hex.EncodeToString(hash[:])

return name + "-" + hexed[:8]
}

func pickToolName(toolName string, existing map[string]struct{}) string {
if toolName == "" {
toolName = "external"
}

for {
testName := ToolNormalizer(toolName)
if _, ok := existing[testName]; !ok {
existing[testName] = struct{}{}
return testName
}
toolName += "0"
}
}

func link(ctx context.Context, prg *types.Program, base *source, tool types.Tool, localTools types.ToolSet) (types.Tool, error) {
if existing, ok := prg.ToolSet[tool.ID]; ok {
return existing, nil
Expand All @@ -240,50 +195,39 @@ func link(ctx context.Context, prg *types.Program, base *source, tool types.Tool
// The below is done in two loops so that local names stay as the tool names
// and don't get mangled by external references

for _, targetToolName := range tool.Parameters.Tools {
for _, targetToolName := range append(tool.Parameters.Tools, tool.Parameters.Export...) {
localTool, ok := localTools[targetToolName]
if !ok {
continue
}
if ok {
var linkedTool types.Tool
if existing, ok := prg.ToolSet[localTool.ID]; ok {
linkedTool = existing
} else {
var err error
linkedTool, err = link(ctx, prg, base, localTool, localTools)
if err != nil {
return types.Tool{}, fmt.Errorf("failed linking %s at %s: %w", targetToolName, base, err)
}
}

var linkedTool types.Tool
if existing, ok := prg.ToolSet[localTool.ID]; ok {
linkedTool = existing
tool.ToolMapping[targetToolName] = linkedTool.ID
toolNames[targetToolName] = struct{}{}
} else {
var err error
linkedTool, err = link(ctx, prg, base, localTool, localTools)
if err != nil {
return types.Tool{}, fmt.Errorf("failed linking %s at %s: %w", targetToolName, base, err)
subTool, toolName, ok := strings.Cut(targetToolName, " from ")
if ok {
toolName = strings.TrimSpace(toolName)
subTool = strings.TrimSpace(subTool)
} else {
toolName = targetToolName
subTool = ""
}
}

tool.ToolMapping[targetToolName] = linkedTool.ID
toolNames[targetToolName] = struct{}{}
}

for i, targetToolName := range tool.Parameters.Tools {
_, ok := localTools[targetToolName]
if ok {
continue
}

subTool, toolName, ok := strings.Cut(targetToolName, " from ")
if ok {
toolName = strings.TrimSpace(toolName)
subTool = strings.TrimSpace(subTool)
} else {
toolName = targetToolName
subTool = ""
}
resolvedTool, err := resolve(ctx, prg, base, toolName, subTool)
if err != nil {
return types.Tool{}, fmt.Errorf("failed resolving %s at %s: %w", targetToolName, base, err)
}

resolvedTool, err := resolve(ctx, prg, base, toolName, subTool)
if err != nil {
return types.Tool{}, fmt.Errorf("failed resolving %s at %s: %w", targetToolName, base, err)
tool.ToolMapping[targetToolName] = resolvedTool.ID
}

newToolName := pickToolName(toolName, toolNames)
tool.ToolMapping[newToolName] = resolvedTool.ID
tool.Parameters.Tools[i] = newToolName
}

for _, localTool := range localTools {
Expand Down
Loading