Skip to content

Commit b480cc0

Browse files
Merge pull request #113 from ibuildthecloud/main
Add export parameter
2 parents a1410a3 + 86f8f5c commit b480cc0

File tree

19 files changed

+449
-128
lines changed

19 files changed

+449
-128
lines changed

pkg/engine/engine.go

Lines changed: 70 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -45,21 +45,23 @@ type Return struct {
4545
}
4646

4747
type Call struct {
48-
ToolName string `json:"toolName,omitempty"`
49-
Input string `json:"input,omitempty"`
48+
ToolID string `json:"toolID,omitempty"`
49+
Input string `json:"input,omitempty"`
5050
}
5151

5252
type CallResult struct {
53-
ID string `json:"id,omitempty"`
53+
ToolID string `json:"toolID,omitempty"`
54+
CallID string `json:"callID,omitempty"`
5455
Result string `json:"result,omitempty"`
5556
}
5657

5758
type Context struct {
58-
ID string
59-
Ctx context.Context
60-
Parent *Context
61-
Program *types.Program
62-
Tool types.Tool
59+
ID string
60+
Ctx context.Context
61+
Parent *Context
62+
Program *types.Program
63+
Tool types.Tool
64+
toolNames map[string]struct{}
6365
}
6466

6567
func (c *Context) ParentID() string {
@@ -97,10 +99,10 @@ func NewContext(ctx context.Context, prg *types.Program) Context {
9799
return callCtx
98100
}
99101

100-
func (c *Context) SubCall(ctx context.Context, toolName, callID string) (Context, error) {
101-
tool, err := c.getTool(toolName)
102-
if err != nil {
103-
return Context{}, err
102+
func (c *Context) SubCall(ctx context.Context, toolID, callID string) (Context, error) {
103+
tool, ok := c.Program.ToolSet[toolID]
104+
if !ok {
105+
return Context{}, fmt.Errorf("failed to file tool for id [%s]", toolID)
104106
}
105107
return Context{
106108
ID: callID,
@@ -111,8 +113,8 @@ func (c *Context) SubCall(ctx context.Context, toolName, callID string) (Context
111113
}, nil
112114
}
113115

114-
func (c *Context) getTool(name string) (types.Tool, error) {
115-
toolID, ok := c.Tool.ToolMapping[name]
116+
func (c *Context) getTool(parent types.Tool, name string) (types.Tool, error) {
117+
toolID, ok := parent.ToolMapping[name]
116118
if !ok {
117119
return types.Tool{}, &ErrToolNotFound{
118120
ToolName: name,
@@ -127,6 +129,45 @@ func (c *Context) getTool(name string) (types.Tool, error) {
127129
return tool, nil
128130
}
129131

132+
func (c *Context) appendTool(completion *types.CompletionRequest, parentTool types.Tool, subToolName string) error {
133+
subTool, err := c.getTool(parentTool, subToolName)
134+
if err != nil {
135+
return err
136+
}
137+
138+
args := subTool.Parameters.Arguments
139+
if args == nil && !subTool.IsCommand() {
140+
args = &system.DefaultToolSchema
141+
}
142+
143+
for _, existingTool := range completion.Tools {
144+
if existingTool.Function.ToolID == subTool.ID {
145+
return nil
146+
}
147+
}
148+
149+
if c.toolNames == nil {
150+
c.toolNames = map[string]struct{}{}
151+
}
152+
153+
completion.Tools = append(completion.Tools, types.CompletionTool{
154+
Function: types.CompletionFunctionDefinition{
155+
ToolID: subTool.ID,
156+
Name: PickToolName(subToolName, c.toolNames),
157+
Description: subTool.Parameters.Description,
158+
Parameters: args,
159+
},
160+
})
161+
162+
for _, export := range subTool.Export {
163+
if err := c.appendTool(completion, subTool, export); err != nil {
164+
return err
165+
}
166+
}
167+
168+
return nil
169+
}
170+
130171
func (e *Engine) Start(ctx Context, input string) (*Return, error) {
131172
tool := ctx.Tool
132173

@@ -155,21 +196,9 @@ func (e *Engine) Start(ctx Context, input string) (*Return, error) {
155196
}
156197

157198
for _, subToolName := range tool.Parameters.Tools {
158-
subTool, err := ctx.getTool(subToolName)
159-
if err != nil {
199+
if err := ctx.appendTool(&completion, ctx.Tool, subToolName); err != nil {
160200
return nil, err
161201
}
162-
args := subTool.Parameters.Arguments
163-
if args == nil && !subTool.IsCommand() {
164-
args = &system.DefaultToolSchema
165-
}
166-
completion.Tools = append(completion.Tools, types.CompletionTool{
167-
Function: types.CompletionFunctionDefinition{
168-
Name: subToolName,
169-
Description: subTool.Parameters.Description,
170-
Parameters: args,
171-
},
172-
})
173202
}
174203

175204
if tool.Instructions != "" {
@@ -225,10 +254,19 @@ func (e *Engine) complete(ctx context.Context, state *State) (*Return, error) {
225254
state.Pending = map[string]types.CompletionToolCall{}
226255
for _, content := range resp.Content {
227256
if content.ToolCall != nil {
257+
var toolID string
258+
for _, tool := range state.Completion.Tools {
259+
if tool.Function.Name == content.ToolCall.Function.Name {
260+
toolID = tool.Function.ToolID
261+
}
262+
}
263+
if toolID == "" {
264+
return nil, fmt.Errorf("failed to find tool id for tool %s in tool_call result", content.ToolCall.Function.Name)
265+
}
228266
state.Pending[content.ToolCall.ID] = *content.ToolCall
229267
ret.Calls[content.ToolCall.ID] = Call{
230-
ToolName: content.ToolCall.Function.Name,
231-
Input: content.ToolCall.Function.Arguments,
268+
ToolID: toolID,
269+
Input: content.ToolCall.Function.Arguments,
232270
}
233271
} else {
234272
cp := content.Text
@@ -247,7 +285,7 @@ func (e *Engine) Continue(ctx context.Context, state *State, results ...CallResu
247285
}
248286

249287
for _, result := range results {
250-
state.Results[result.ID] = result
288+
state.Results[result.CallID] = result
251289
}
252290

253291
ret := Return{
@@ -262,8 +300,8 @@ func (e *Engine) Continue(ctx context.Context, state *State, results ...CallResu
262300
for id, pending := range state.Pending {
263301
if _, ok := state.Results[id]; !ok {
264302
ret.Calls[id] = Call{
265-
ToolName: pending.Function.Name,
266-
Input: pending.Function.Arguments,
303+
ToolID: state.Completion.Tools[*pending.Index].Function.ToolID,
304+
Input: pending.Function.Arguments,
267305
}
268306
}
269307
}

pkg/engine/toolname.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
package engine
2+
3+
import (
4+
"crypto/md5"
5+
"encoding/hex"
6+
"path/filepath"
7+
"regexp"
8+
"strings"
9+
10+
"github.com/gptscript-ai/gptscript/pkg/system"
11+
)
12+
13+
var (
14+
validToolName = regexp.MustCompile("^[a-zA-Z0-9_-]{1,64}$")
15+
invalidChars = regexp.MustCompile("[^a-zA-Z0-9_-]+")
16+
)
17+
18+
func ToolNormalizer(tool string) string {
19+
parts := strings.Split(tool, "/")
20+
tool = parts[len(parts)-1]
21+
if strings.HasSuffix(tool, system.Suffix) {
22+
tool = strings.TrimSuffix(tool, filepath.Ext(tool))
23+
}
24+
25+
if validToolName.MatchString(tool) {
26+
return tool
27+
}
28+
29+
name := invalidChars.ReplaceAllString(tool, "-")
30+
if len(name) > 55 {
31+
name = name[:55]
32+
}
33+
34+
hash := md5.Sum([]byte(tool))
35+
hexed := hex.EncodeToString(hash[:])
36+
37+
return name + "-" + hexed[:8]
38+
}
39+
40+
func PickToolName(toolName string, existing map[string]struct{}) string {
41+
if toolName == "" {
42+
toolName = "external"
43+
}
44+
45+
for {
46+
testName := ToolNormalizer(toolName)
47+
if _, ok := existing[testName]; !ok {
48+
existing[testName] = struct{}{}
49+
return testName
50+
}
51+
toolName += "0"
52+
}
53+
}

pkg/loader/loader.go

Lines changed: 26 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package loader
33
import (
44
"bytes"
55
"context"
6-
"crypto/md5"
76
"crypto/sha256"
87
"encoding/hex"
98
"encoding/json"
@@ -13,14 +12,12 @@ import (
1312
"io/fs"
1413
"os"
1514
"path/filepath"
16-
"regexp"
1715
"strings"
1816

1917
"github.com/gptscript-ai/gptscript/pkg/assemble"
2018
"github.com/gptscript-ai/gptscript/pkg/builtin"
2119
"github.com/gptscript-ai/gptscript/pkg/engine"
2220
"github.com/gptscript-ai/gptscript/pkg/parser"
23-
"github.com/gptscript-ai/gptscript/pkg/system"
2421
"github.com/gptscript-ai/gptscript/pkg/types"
2522
)
2623

@@ -182,48 +179,6 @@ func readTool(ctx context.Context, prg *types.Program, base *source, targetToolN
182179
return link(ctx, prg, base, mainTool, localTools)
183180
}
184181

185-
var (
186-
validToolName = regexp.MustCompile("^[a-zA-Z0-9_-]{1,64}$")
187-
invalidChars = regexp.MustCompile("[^a-zA-Z0-9_-]+")
188-
)
189-
190-
func ToolNormalizer(tool string) string {
191-
parts := strings.Split(tool, "/")
192-
tool = parts[len(parts)-1]
193-
if strings.HasSuffix(tool, system.Suffix) {
194-
tool = strings.TrimSuffix(tool, filepath.Ext(tool))
195-
}
196-
197-
if validToolName.MatchString(tool) {
198-
return tool
199-
}
200-
201-
name := invalidChars.ReplaceAllString(tool, "-")
202-
if len(name) > 55 {
203-
name = name[:55]
204-
}
205-
206-
hash := md5.Sum([]byte(tool))
207-
hexed := hex.EncodeToString(hash[:])
208-
209-
return name + "-" + hexed[:8]
210-
}
211-
212-
func pickToolName(toolName string, existing map[string]struct{}) string {
213-
if toolName == "" {
214-
toolName = "external"
215-
}
216-
217-
for {
218-
testName := ToolNormalizer(toolName)
219-
if _, ok := existing[testName]; !ok {
220-
existing[testName] = struct{}{}
221-
return testName
222-
}
223-
toolName += "0"
224-
}
225-
}
226-
227182
func link(ctx context.Context, prg *types.Program, base *source, tool types.Tool, localTools types.ToolSet) (types.Tool, error) {
228183
if existing, ok := prg.ToolSet[tool.ID]; ok {
229184
return existing, nil
@@ -240,50 +195,39 @@ func link(ctx context.Context, prg *types.Program, base *source, tool types.Tool
240195
// The below is done in two loops so that local names stay as the tool names
241196
// and don't get mangled by external references
242197

243-
for _, targetToolName := range tool.Parameters.Tools {
198+
for _, targetToolName := range append(tool.Parameters.Tools, tool.Parameters.Export...) {
244199
localTool, ok := localTools[targetToolName]
245-
if !ok {
246-
continue
247-
}
200+
if ok {
201+
var linkedTool types.Tool
202+
if existing, ok := prg.ToolSet[localTool.ID]; ok {
203+
linkedTool = existing
204+
} else {
205+
var err error
206+
linkedTool, err = link(ctx, prg, base, localTool, localTools)
207+
if err != nil {
208+
return types.Tool{}, fmt.Errorf("failed linking %s at %s: %w", targetToolName, base, err)
209+
}
210+
}
248211

249-
var linkedTool types.Tool
250-
if existing, ok := prg.ToolSet[localTool.ID]; ok {
251-
linkedTool = existing
212+
tool.ToolMapping[targetToolName] = linkedTool.ID
213+
toolNames[targetToolName] = struct{}{}
252214
} else {
253-
var err error
254-
linkedTool, err = link(ctx, prg, base, localTool, localTools)
255-
if err != nil {
256-
return types.Tool{}, fmt.Errorf("failed linking %s at %s: %w", targetToolName, base, err)
215+
subTool, toolName, ok := strings.Cut(targetToolName, " from ")
216+
if ok {
217+
toolName = strings.TrimSpace(toolName)
218+
subTool = strings.TrimSpace(subTool)
219+
} else {
220+
toolName = targetToolName
221+
subTool = ""
257222
}
258-
}
259-
260-
tool.ToolMapping[targetToolName] = linkedTool.ID
261-
toolNames[targetToolName] = struct{}{}
262-
}
263-
264-
for i, targetToolName := range tool.Parameters.Tools {
265-
_, ok := localTools[targetToolName]
266-
if ok {
267-
continue
268-
}
269223

270-
subTool, toolName, ok := strings.Cut(targetToolName, " from ")
271-
if ok {
272-
toolName = strings.TrimSpace(toolName)
273-
subTool = strings.TrimSpace(subTool)
274-
} else {
275-
toolName = targetToolName
276-
subTool = ""
277-
}
224+
resolvedTool, err := resolve(ctx, prg, base, toolName, subTool)
225+
if err != nil {
226+
return types.Tool{}, fmt.Errorf("failed resolving %s at %s: %w", targetToolName, base, err)
227+
}
278228

279-
resolvedTool, err := resolve(ctx, prg, base, toolName, subTool)
280-
if err != nil {
281-
return types.Tool{}, fmt.Errorf("failed resolving %s at %s: %w", targetToolName, base, err)
229+
tool.ToolMapping[targetToolName] = resolvedTool.ID
282230
}
283-
284-
newToolName := pickToolName(toolName, toolNames)
285-
tool.ToolMapping[newToolName] = resolvedTool.ID
286-
tool.Parameters.Tools[i] = newToolName
287231
}
288232

289233
for _, localTool := range localTools {

0 commit comments

Comments
 (0)