diff --git a/pkg/tests/runner_test.go b/pkg/tests/runner_test.go index 9941aad0..5f05d6ec 100644 --- a/pkg/tests/runner_test.go +++ b/pkg/tests/runner_test.go @@ -678,6 +678,13 @@ func TestContextArg(t *testing.T) { assert.Equal(t, "TEST RESULT CALL: 1", x) } +func TestToolAs(t *testing.T) { + runner := tester.NewRunner(t) + x, err := runner.Run("", `{}`) + require.NoError(t, err) + assert.Equal(t, "TEST RESULT CALL: 1", x) +} + func TestCwd(t *testing.T) { runner := tester.NewRunner(t) diff --git a/pkg/tests/testdata/TestToolAs/call1.golden b/pkg/tests/testdata/TestToolAs/call1.golden new file mode 100644 index 00000000..390abe3b --- /dev/null +++ b/pkg/tests/testdata/TestToolAs/call1.golden @@ -0,0 +1,67 @@ +`{ + "Model": "gpt-4-turbo", + "InternalSystemPrompt": null, + "Tools": [ + { + "function": { + "toolID": "testdata/TestToolAs/test.gpt:6", + "name": "local", + "parameters": { + "properties": { + "defaultPromptParameter": { + "description": "Prompt to send to the tool or assistant. This may be instructions or question.", + "type": "string" + } + }, + "required": [ + "defaultPromptParameter" + ], + "type": "object" + } + } + }, + { + "function": { + "toolID": "testdata/TestToolAs/other.gpt:1", + "name": "remote", + "parameters": { + "properties": { + "defaultPromptParameter": { + "description": "Prompt to send to the tool or assistant. This may be instructions or question.", + "type": "string" + } + }, + "required": [ + "defaultPromptParameter" + ], + "type": "object" + } + } + } + ], + "Messages": [ + { + "role": "system", + "content": [ + { + "text": "A tool" + } + ], + "usage": {} + }, + { + "role": "user", + "content": [ + { + "text": "{}" + } + ], + "usage": {} + } + ], + "MaxTokens": 0, + "Temperature": null, + "JSONResponse": false, + "Grammar": "", + "Cache": null +}` diff --git a/pkg/tests/testdata/TestToolAs/other.gpt b/pkg/tests/testdata/TestToolAs/other.gpt new file mode 100644 index 00000000..a0a4c910 --- /dev/null +++ b/pkg/tests/testdata/TestToolAs/other.gpt @@ -0,0 +1 @@ +other file \ No newline at end of file diff --git a/pkg/tests/testdata/TestToolAs/test.gpt b/pkg/tests/testdata/TestToolAs/test.gpt new file mode 100644 index 00000000..0417f4a4 --- /dev/null +++ b/pkg/tests/testdata/TestToolAs/test.gpt @@ -0,0 +1,8 @@ +tools: infile as local, ./other.gpt as remote + +A tool + +--- +name: infile + +infile tool \ No newline at end of file diff --git a/pkg/types/tool.go b/pkg/types/tool.go index e698bdc4..8085db0c 100644 --- a/pkg/types/tool.go +++ b/pkg/types/tool.go @@ -57,6 +57,7 @@ func (p Program) ChatName() string { } type ToolReference struct { + Named string Reference string Arg string ToolID string @@ -184,9 +185,14 @@ func SplitArg(hasArg string) (prefix, arg string) { var ( fields = strings.Fields(hasArg) idx = slices.Index(fields, "with") + asIdx = slices.Index(fields, "as") ) if idx == -1 { + if asIdx != -1 { + return strings.Join(fields[:asIdx], " "), + strings.Join(fields[asIdx:], " ") + } return strings.TrimSpace(hasArg), "" } @@ -201,7 +207,12 @@ func (t Tool) GetToolRefsFromNames(names []string) (result []ToolReference, _ er return nil, NewErrToolNotFound(toolName) } _, arg := SplitArg(toolName) + named, ok := strings.CutPrefix(arg, "as ") + if !ok { + named = "" + } result = append(result, ToolReference{ + Named: named, Arg: arg, Reference: toolName, ToolID: toolID, @@ -287,8 +298,13 @@ func (t Tool) String() string { func (t Tool) GetCompletionTools(prg Program) (result []CompletionTool, err error) { toolNames := map[string]struct{}{} - for _, subToolName := range t.Parameters.Tools { - result, err = appendTool(result, prg, t, subToolName, toolNames) + subToolRefs, err := t.GetToolRefsFromNames(t.Parameters.Tools) + if err != nil { + return nil, err + } + + for _, subToolRef := range subToolRefs { + result, err = appendTool(result, prg, t, subToolRef.Reference, toolNames, subToolRef.Named) if err != nil { return nil, err } @@ -327,7 +343,7 @@ func appendExports(completionTools []CompletionTool, prg Program, parentTool Too } for _, export := range subTool.Export { - completionTools, err = appendTool(completionTools, prg, subTool, export, toolNames) + completionTools, err = appendTool(completionTools, prg, subTool, export, toolNames, "") if err != nil { return nil, err } @@ -336,7 +352,7 @@ func appendExports(completionTools []CompletionTool, prg Program, parentTool Too return completionTools, nil } -func appendTool(completionTools []CompletionTool, prg Program, parentTool Tool, subToolName string, toolNames map[string]struct{}) ([]CompletionTool, error) { +func appendTool(completionTools []CompletionTool, prg Program, parentTool Tool, subToolName string, toolNames map[string]struct{}, asName string) ([]CompletionTool, error) { subTool, err := getTool(prg, parentTool, subToolName) if err != nil { return nil, err @@ -356,10 +372,14 @@ func appendTool(completionTools []CompletionTool, prg Program, parentTool Tool, if subTool.Instructions == "" { log.Debugf("Skipping zero instruction tool %s (%s)", subToolName, subTool.ID) } else { + name := subToolName + if asName != "" { + name = asName + } completionTools = append(completionTools, CompletionTool{ Function: CompletionFunctionDefinition{ ToolID: subTool.ID, - Name: PickToolName(subToolName, toolNames), + Name: PickToolName(name, toolNames), Description: subTool.Parameters.Description, Parameters: args, }, @@ -367,7 +387,7 @@ func appendTool(completionTools []CompletionTool, prg Program, parentTool Tool, } for _, export := range subTool.Export { - completionTools, err = appendTool(completionTools, prg, subTool, export, toolNames) + completionTools, err = appendTool(completionTools, prg, subTool, export, toolNames, "") if err != nil { return nil, err } diff --git a/pkg/types/toolname_test.go b/pkg/types/toolname_test.go index da3b3a0e..6d276931 100644 --- a/pkg/types/toolname_test.go +++ b/pkg/types/toolname_test.go @@ -20,6 +20,12 @@ func TestParse(t *testing.T) { tool, subTool := SplitToolRef("a from b with x") autogold.Expect([]string{"b", "a"}).Equal(t, []string{tool, subTool}) + tool, subTool = SplitToolRef("a from b with x as other") + autogold.Expect([]string{"b", "a"}).Equal(t, []string{tool, subTool}) + tool, subTool = SplitToolRef("a with x") autogold.Expect([]string{"a", ""}).Equal(t, []string{tool, subTool}) + + tool, subTool = SplitToolRef("a with x as other") + autogold.Expect([]string{"a", ""}).Equal(t, []string{tool, subTool}) }