From 3a0689e40a4359a7613812fc8b0bd84d34cf57e2 Mon Sep 17 00:00:00 2001 From: Darren Shepherd Date: Sat, 25 May 2024 22:20:21 -0700 Subject: [PATCH] feat: add support for wildcard subtool names --- pkg/engine/engine.go | 11 ++- pkg/engine/http.go | 6 +- pkg/loader/loader.go | 82 +++++++++++------- pkg/loader/loader_test.go | 7 +- pkg/runner/runner.go | 8 +- pkg/tests/runner_test.go | 85 +++++++++++++++++++ pkg/tests/testdata/TestAsterick/call1.golden | 52 ++++++++++++ pkg/tests/testdata/TestAsterick/other.gpt | 14 ++++ pkg/tests/testdata/TestAsterick/test.gpt | 3 + pkg/types/tool.go | 87 +++++++++++++++----- pkg/types/toolname.go | 2 +- 11 files changed, 295 insertions(+), 62 deletions(-) create mode 100644 pkg/tests/testdata/TestAsterick/call1.golden create mode 100644 pkg/tests/testdata/TestAsterick/other.gpt create mode 100644 pkg/tests/testdata/TestAsterick/test.gpt diff --git a/pkg/engine/engine.go b/pkg/engine/engine.go index 18d8a717..a6eb9bdb 100644 --- a/pkg/engine/engine.go +++ b/pkg/engine/engine.go @@ -109,10 +109,13 @@ func (c *Context) ParentID() string { func (c *Context) GetCallContext() *CallContext { var toolName string if c.Parent != nil { - for name, id := range c.Parent.Tool.ToolMapping { - if id == c.Tool.ID { - toolName = name - break + outer: + for name, refs := range c.Parent.Tool.ToolMapping { + for _, ref := range refs { + if ref.ToolID == c.Tool.ID { + toolName = name + break outer + } } } } diff --git a/pkg/engine/http.go b/pkg/engine/http.go index 8d53d283..94f741dc 100644 --- a/pkg/engine/http.go +++ b/pkg/engine/http.go @@ -35,11 +35,11 @@ func (e *Engine) runHTTP(ctx context.Context, prg *types.Program, tool types.Too if strings.HasSuffix(parsed.Hostname(), DaemonURLSuffix) { referencedToolName := strings.TrimSuffix(parsed.Hostname(), DaemonURLSuffix) - referencedToolID, ok := tool.ToolMapping[referencedToolName] - if !ok { + referencedToolRefs, ok := tool.ToolMapping[referencedToolName] + if !ok || len(referencedToolRefs) != 1 { return nil, fmt.Errorf("invalid reference [%s] to tool [%s] from [%s], missing \"tools: %s\" parameter", toolURL, referencedToolName, tool.Source, referencedToolName) } - referencedTool, ok := prg.ToolSet[referencedToolID] + referencedTool, ok := prg.ToolSet[referencedToolRefs[0].ToolID] if !ok { return nil, fmt.Errorf("failed to find tool [%s] for [%s]", referencedToolName, parsed.Hostname()) } diff --git a/pkg/loader/loader.go b/pkg/loader/loader.go index 8fe3a96f..ad80555d 100644 --- a/pkg/loader/loader.go +++ b/pkg/loader/loader.go @@ -181,11 +181,15 @@ func loadOpenAPI(prg *types.Program, data []byte) *openapi3.T { return openAPIDocument } -func readTool(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, targetToolName string) (types.Tool, error) { +func readTool(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, targetToolName string) ([]types.Tool, error) { data := base.Content if bytes.HasPrefix(data, assemble.Header) { - return loadProgram(data, prg, targetToolName) + tool, err := loadProgram(data, prg, targetToolName) + if err != nil { + return nil, err + } + return []types.Tool{tool}, nil } var ( @@ -200,7 +204,7 @@ func readTool(ctx context.Context, cache *cache.Client, prg *types.Program, base tools, err = getOpenAPITools(openAPIDocument, "") } if err != nil { - return types.Tool{}, fmt.Errorf("error parsing OpenAPI definition: %w", err) + return nil, fmt.Errorf("error parsing OpenAPI definition: %w", err) } } @@ -222,17 +226,17 @@ func readTool(ctx context.Context, cache *cache.Client, prg *types.Program, base AssignGlobals: true, }) if err != nil { - return types.Tool{}, err + return nil, err } } if len(tools) == 0 { - return types.Tool{}, fmt.Errorf("no tools found in %s", base) + return nil, fmt.Errorf("no tools found in %s", base) } var ( - localTools = types.ToolSet{} - mainTool types.Tool + localTools = types.ToolSet{} + targetTools []types.Tool ) for i, tool := range tools { @@ -243,28 +247,38 @@ func readTool(ctx context.Context, cache *cache.Client, prg *types.Program, base // Probably a better way to come up with an ID tool.ID = tool.Source.Location + ":" + tool.Name - if i == 0 { - mainTool = tool + if i == 0 && targetToolName == "" { + targetTools = append(targetTools, tool) } if i != 0 && tool.Parameters.Name == "" { - return types.Tool{}, parser.NewErrLine(tool.Source.Location, tool.Source.LineNo, fmt.Errorf("only the first tool in a file can have no name")) + return nil, parser.NewErrLine(tool.Source.Location, tool.Source.LineNo, fmt.Errorf("only the first tool in a file can have no name")) } if i != 0 && tool.Parameters.GlobalModelName != "" { - return types.Tool{}, parser.NewErrLine(tool.Source.Location, tool.Source.LineNo, fmt.Errorf("only the first tool in a file can have global model name")) + return nil, parser.NewErrLine(tool.Source.Location, tool.Source.LineNo, fmt.Errorf("only the first tool in a file can have global model name")) } if i != 0 && len(tool.Parameters.GlobalTools) > 0 { - return types.Tool{}, parser.NewErrLine(tool.Source.Location, tool.Source.LineNo, fmt.Errorf("only the first tool in a file can have global tools")) + return nil, parser.NewErrLine(tool.Source.Location, tool.Source.LineNo, fmt.Errorf("only the first tool in a file can have global tools")) } - if targetToolName != "" && strings.EqualFold(tool.Parameters.Name, targetToolName) { - mainTool = tool + if targetToolName != "" && tool.Parameters.Name != "" { + if strings.EqualFold(tool.Parameters.Name, targetToolName) { + targetTools = append(targetTools, tool) + } else if strings.Contains(targetToolName, "*") { + match, err := filepath.Match(strings.ToLower(targetToolName), strings.ToLower(tool.Parameters.Name)) + if err != nil { + return nil, parser.NewErrLine(tool.Source.Location, tool.Source.LineNo, err) + } + if match { + targetTools = append(targetTools, tool) + } + } } if existing, ok := localTools[strings.ToLower(tool.Parameters.Name)]; ok { - return types.Tool{}, parser.NewErrLine(tool.Source.Location, tool.Source.LineNo, + return nil, parser.NewErrLine(tool.Source.Location, tool.Source.LineNo, fmt.Errorf("duplicate tool name [%s] in %s found at lines %d and %d", tool.Parameters.Name, tool.Source.Location, tool.Source.LineNo, existing.Source.LineNo)) } @@ -272,7 +286,18 @@ func readTool(ctx context.Context, cache *cache.Client, prg *types.Program, base localTools[strings.ToLower(tool.Parameters.Name)] = tool } - return link(ctx, cache, prg, base, mainTool, localTools) + return linkAll(ctx, cache, prg, base, targetTools, localTools) +} + +func linkAll(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, tools []types.Tool, localTools types.ToolSet) (result []types.Tool, _ error) { + for _, tool := range tools { + tool, err := link(ctx, cache, prg, base, tool, localTools) + if err != nil { + return nil, err + } + result = append(result, tool) + } + return } func link(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, tool types.Tool, localTools types.ToolSet) (types.Tool, error) { @@ -280,7 +305,7 @@ func link(ctx context.Context, cache *cache.Client, prg *types.Program, base *so return existing, nil } - tool.ToolMapping = map[string]string{} + tool.ToolMapping = map[string][]types.ToolReference{} tool.LocalTools = map[string]string{} toolNames := map[string]struct{}{} @@ -310,16 +335,17 @@ func link(ctx context.Context, cache *cache.Client, prg *types.Program, base *so } } - tool.ToolMapping[targetToolName] = linkedTool.ID + tool.AddToolMapping(targetToolName, linkedTool) toolNames[targetToolName] = struct{}{} } else { toolName, subTool := types.SplitToolRef(targetToolName) - resolvedTool, err := resolve(ctx, cache, prg, base, toolName, subTool) + resolvedTools, err := resolve(ctx, cache, prg, base, toolName, subTool) if err != nil { return types.Tool{}, fmt.Errorf("failed resolving %s from %s: %w", targetToolName, base, err) } - - tool.ToolMapping[targetToolName] = resolvedTool.ID + for _, resolvedTool := range resolvedTools { + tool.AddToolMapping(targetToolName, resolvedTool) + } } } @@ -345,14 +371,14 @@ func ProgramFromSource(ctx context.Context, content, subToolName string, opts .. prg := types.Program{ ToolSet: types.ToolSet{}, } - tool, err := readTool(ctx, opt.Cache, &prg, &source{ + tools, err := readTool(ctx, opt.Cache, &prg, &source{ Content: []byte(content), Location: "inline", }, subToolName) if err != nil { return types.Program{}, err } - prg.EntryToolID = tool.ID + prg.EntryToolID = tools[0].ID return prg, nil } @@ -385,26 +411,26 @@ func Program(ctx context.Context, name, subToolName string, opts ...Options) (ty Name: name, ToolSet: types.ToolSet{}, } - tool, err := resolve(ctx, opt.Cache, &prg, &source{}, name, subToolName) + tools, err := resolve(ctx, opt.Cache, &prg, &source{}, name, subToolName) if err != nil { return types.Program{}, err } - prg.EntryToolID = tool.ID + prg.EntryToolID = tools[0].ID return prg, nil } -func resolve(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, name, subTool string) (types.Tool, error) { +func resolve(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, name, subTool string) ([]types.Tool, error) { if subTool == "" { t, ok := builtin.Builtin(name) if ok { prg.ToolSet[t.ID] = t - return t, nil + return []types.Tool{t}, nil } } s, err := input(ctx, cache, base, name) if err != nil { - return types.Tool{}, err + return nil, err } return readTool(ctx, cache, prg, s, subTool) diff --git a/pkg/loader/loader_test.go b/pkg/loader/loader_test.go index b5ae6843..b6a0c4d2 100644 --- a/pkg/loader/loader_test.go +++ b/pkg/loader/loader_test.go @@ -109,7 +109,12 @@ func TestHelloWorld(t *testing.T) { "instructions": "call bob", "id": "https://raw.githubusercontent.com/ibuildthecloud/test/bafe5a62174e8a0ea162277dcfe3a2ddb7eea928/example/sub/tool.gpt:", "toolMapping": { - "../bob.gpt": "https://raw.githubusercontent.com/ibuildthecloud/test/bafe5a62174e8a0ea162277dcfe3a2ddb7eea928/example/bob.gpt:" + "../bob.gpt": [ + { + "reference": "../bob.gpt", + "toolID": "https://raw.githubusercontent.com/ibuildthecloud/test/bafe5a62174e8a0ea162277dcfe3a2ddb7eea928/example/bob.gpt:" + } + ] }, "localTools": { "": "https://raw.githubusercontent.com/ibuildthecloud/test/bafe5a62174e8a0ea162277dcfe3a2ddb7eea928/example/sub/tool.gpt:" diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index 8a06525c..00ba524d 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -834,12 +834,12 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env // If the credential doesn't already exist in the store, run the credential tool in order to get the value, // and save it in the store. if !exists { - credToolID, ok := callCtx.Tool.ToolMapping[credToolName] - if !ok { + credToolRefs, ok := callCtx.Tool.ToolMapping[credToolName] + if !ok || len(credToolRefs) != 1 { return nil, fmt.Errorf("failed to find ID for tool %s", credToolName) } - subCtx, err := callCtx.SubCall(callCtx.Ctx, credToolID, "", engine.CredentialToolCategory) // leaving callID as "" will cause it to be set by the engine + subCtx, err := callCtx.SubCall(callCtx.Ctx, credToolRefs[0].ToolID, "", engine.CredentialToolCategory) // leaving callID as "" will cause it to be set by the engine if err != nil { return nil, fmt.Errorf("failed to create subcall context for tool %s: %w", credToolName, err) } @@ -874,7 +874,7 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env } // Only store the credential if the tool is on GitHub, and the credential is non-empty. - if isGitHubTool(credToolName) && callCtx.Program.ToolSet[credToolID].Source.Repo != nil { + if isGitHubTool(credToolName) && callCtx.Program.ToolSet[credToolRefs[0].ToolID].Source.Repo != nil { if isEmpty { log.Warnf("Not saving empty credential for tool %s", credToolName) } else if err := store.Add(*cred); err != nil { diff --git a/pkg/tests/runner_test.go b/pkg/tests/runner_test.go index 7856deff..8cf5dfa5 100644 --- a/pkg/tests/runner_test.go +++ b/pkg/tests/runner_test.go @@ -20,6 +20,91 @@ func toJSONString(t *testing.T, v interface{}) string { return string(x) } +func TestAsterick(t *testing.T) { + r := tester.NewRunner(t) + p, err := r.Load("") + require.NoError(t, err) + autogold.Expect(`{ + "name": "testdata/TestAsterick/test.gpt", + "entryToolId": "testdata/TestAsterick/test.gpt:", + "toolSet": { + "testdata/TestAsterick/other.gpt:a": { + "name": "a", + "modelName": "gpt-4o", + "internalPrompt": null, + "instructions": "a", + "id": "testdata/TestAsterick/other.gpt:a", + "localTools": { + "a": "testdata/TestAsterick/other.gpt:a", + "afoo": "testdata/TestAsterick/other.gpt:afoo", + "foo": "testdata/TestAsterick/other.gpt:foo", + "fooa": "testdata/TestAsterick/other.gpt:fooa", + "fooafoo": "testdata/TestAsterick/other.gpt:fooafoo" + }, + "source": { + "location": "testdata/TestAsterick/other.gpt", + "lineNo": 10 + }, + "workingDir": "testdata/TestAsterick" + }, + "testdata/TestAsterick/other.gpt:afoo": { + "name": "afoo", + "modelName": "gpt-4o", + "internalPrompt": null, + "instructions": "afoo", + "id": "testdata/TestAsterick/other.gpt:afoo", + "localTools": { + "a": "testdata/TestAsterick/other.gpt:a", + "afoo": "testdata/TestAsterick/other.gpt:afoo", + "foo": "testdata/TestAsterick/other.gpt:foo", + "fooa": "testdata/TestAsterick/other.gpt:fooa", + "fooafoo": "testdata/TestAsterick/other.gpt:fooafoo" + }, + "source": { + "location": "testdata/TestAsterick/other.gpt", + "lineNo": 4 + }, + "workingDir": "testdata/TestAsterick" + }, + "testdata/TestAsterick/test.gpt:": { + "modelName": "gpt-4o", + "internalPrompt": null, + "tools": [ + "a* from ./other.gpt" + ], + "instructions": "Ask Bob how he is doing and let me know exactly what he said.", + "id": "testdata/TestAsterick/test.gpt:", + "toolMapping": { + "a* from ./other.gpt": [ + { + "reference": "afoo from ./other.gpt", + "toolID": "testdata/TestAsterick/other.gpt:afoo" + }, + { + "reference": "a from ./other.gpt", + "toolID": "testdata/TestAsterick/other.gpt:a" + } + ] + }, + "localTools": { + "": "testdata/TestAsterick/test.gpt:" + }, + "source": { + "location": "testdata/TestAsterick/test.gpt", + "lineNo": 1 + }, + "workingDir": "testdata/TestAsterick" + } + } +}`).Equal(t, toJSONString(t, p)) + + r.RespondWith(tester.Result{ + Text: "hi", + }) + _, err = r.Run("", "") + require.NoError(t, err) +} + func TestDualSubChat(t *testing.T) { r := tester.NewRunner(t) r.RespondWith(tester.Result{ diff --git a/pkg/tests/testdata/TestAsterick/call1.golden b/pkg/tests/testdata/TestAsterick/call1.golden new file mode 100644 index 00000000..c530c3ba --- /dev/null +++ b/pkg/tests/testdata/TestAsterick/call1.golden @@ -0,0 +1,52 @@ +`{ + "model": "gpt-4o", + "tools": [ + { + "function": { + "toolID": "testdata/TestAsterick/other.gpt:afoo", + "name": "afoo", + "parameters": { + "properties": { + "defaultPromptParameter": { + "description": "Prompt to send to the tool or assistant. This may be instructions or question.", + "type": "string" + } + }, + "required": [ + "defaultPromptParameter" + ], + "type": "object" + } + } + }, + { + "function": { + "toolID": "testdata/TestAsterick/other.gpt:a", + "name": "a", + "parameters": { + "properties": { + "defaultPromptParameter": { + "description": "Prompt to send to the tool or assistant. This may be instructions or question.", + "type": "string" + } + }, + "required": [ + "defaultPromptParameter" + ], + "type": "object" + } + } + } + ], + "messages": [ + { + "role": "system", + "content": [ + { + "text": "Ask Bob how he is doing and let me know exactly what he said." + } + ], + "usage": {} + } + ] +}` diff --git a/pkg/tests/testdata/TestAsterick/other.gpt b/pkg/tests/testdata/TestAsterick/other.gpt new file mode 100644 index 00000000..af19c81f --- /dev/null +++ b/pkg/tests/testdata/TestAsterick/other.gpt @@ -0,0 +1,14 @@ +name: foo +foo +--- +name: afoo +afoo +--- +name: fooa +fooa +--- +name: a +a +--- +name: fooafoo +fooafoo \ No newline at end of file diff --git a/pkg/tests/testdata/TestAsterick/test.gpt b/pkg/tests/testdata/TestAsterick/test.gpt new file mode 100644 index 00000000..c3128850 --- /dev/null +++ b/pkg/tests/testdata/TestAsterick/test.gpt @@ -0,0 +1,3 @@ +tools: a* from ./other.gpt + +Ask Bob how he is doing and let me know exactly what he said. \ No newline at end of file diff --git a/pkg/types/tool.go b/pkg/types/tool.go index feb3ab09..98bcce64 100644 --- a/pkg/types/tool.go +++ b/pkg/types/tool.go @@ -57,10 +57,10 @@ func (p Program) ChatName() string { } type ToolReference struct { - Named string - Reference string - Arg string - ToolID string + Named string `json:"named,omitempty"` + Reference string `json:"reference,omitempty"` + Arg string `json:"arg,omitempty"` + ToolID string `json:"toolID,omitempty"` } func (p Program) GetContextToolRefs(toolID string) ([]ToolReference, error) { @@ -72,15 +72,22 @@ func (p Program) GetCompletionTools() (result []CompletionTool, err error) { Parameters: Parameters{ Tools: []string{"main"}, }, - ToolMapping: map[string]string{ - "main": p.EntryToolID, + ToolMapping: map[string][]ToolReference{ + "main": { + { + Reference: "main", + ToolID: p.EntryToolID, + }, + }, }, }.GetCompletionTools(p) } func (p Program) TopLevelTools() (result []Tool) { for _, tool := range p.ToolSet[p.EntryToolID].LocalTools { - result = append(result, p.ToolSet[tool]) + if target, ok := p.ToolSet[tool]; ok { + result = append(result, target) + } } return } @@ -122,12 +129,46 @@ type Tool struct { Parameters `json:",inline"` Instructions string `json:"instructions,omitempty"` - ID string `json:"id,omitempty"` - ToolMapping map[string]string `json:"toolMapping,omitempty"` - LocalTools map[string]string `json:"localTools,omitempty"` - BuiltinFunc BuiltinFunc `json:"-"` - Source ToolSource `json:"source,omitempty"` - WorkingDir string `json:"workingDir,omitempty"` + ID string `json:"id,omitempty"` + ToolMapping map[string][]ToolReference `json:"toolMapping,omitempty"` + LocalTools map[string]string `json:"localTools,omitempty"` + BuiltinFunc BuiltinFunc `json:"-"` + Source ToolSource `json:"source,omitempty"` + WorkingDir string `json:"workingDir,omitempty"` +} + +func IsMatch(subTool string) bool { + return strings.ContainsAny(subTool, "*?[") +} + +func (t *Tool) AddToolMapping(name string, tool Tool) { + if t.ToolMapping == nil { + t.ToolMapping = map[string][]ToolReference{} + } + + ref := name + _, subTool := SplitToolRef(name) + if IsMatch(subTool) && tool.Name != "" { + ref = strings.Replace(ref, subTool, tool.Name, 1) + } + + if existing, ok := t.ToolMapping[name]; ok { + var found bool + for _, toolRef := range existing { + if toolRef.ToolID == tool.ID && toolRef.Reference == ref { + found = true + break + } + } + if found { + return + } + } + + t.ToolMapping[name] = append(t.ToolMapping[name], ToolReference{ + Reference: ref, + ToolID: tool.ID, + }) } func SplitArg(hasArg string) (prefix, arg string) { @@ -151,21 +192,25 @@ func SplitArg(hasArg string) (prefix, arg string) { func (t Tool) GetToolRefsFromNames(names []string) (result []ToolReference, _ error) { for _, toolName := range names { - toolID, ok := t.ToolMapping[toolName] - if !ok { + toolRefs, ok := t.ToolMapping[toolName] + if !ok || len(toolRefs) == 0 { return nil, NewErrToolNotFound(toolName) } _, arg := SplitArg(toolName) named, ok := strings.CutPrefix(arg, "as ") if !ok { named = "" + } else if len(toolRefs) > 1 { + return nil, fmt.Errorf("can not combine 'as' syntax with wildcard: %s", toolName) + } + for _, toolRef := range toolRefs { + result = append(result, ToolReference{ + Named: named, + Arg: arg, + Reference: toolRef.Reference, + ToolID: toolRef.ToolID, + }) } - result = append(result, ToolReference{ - Named: named, - Arg: arg, - Reference: toolName, - ToolID: toolID, - }) } return } diff --git a/pkg/types/toolname.go b/pkg/types/toolname.go index c3a0efc8..4d794c37 100644 --- a/pkg/types/toolname.go +++ b/pkg/types/toolname.go @@ -86,8 +86,8 @@ func PickToolName(toolName string, existing map[string]struct{}) string { toolName = "external" } + testName := ToolNormalizer(toolName) for { - testName := ToolNormalizer(toolName) if _, ok := existing[testName]; !ok { existing[testName] = struct{}{} return testName