diff --git a/pkg/loader/loader.go b/pkg/loader/loader.go index 2159e0b3..e9bbd1bf 100644 --- a/pkg/loader/loader.go +++ b/pkg/loader/loader.go @@ -242,7 +242,7 @@ func link(ctx context.Context, cache *cache.Client, prg *types.Program, base *so tool.ToolMapping[targetToolName] = linkedTool.ID toolNames[targetToolName] = struct{}{} } else { - toolName, subTool := SplitToolRef(targetToolName) + toolName, subTool := types.SplitToolRef(targetToolName) resolvedTool, err := resolve(ctx, cache, prg, base, toolName, subTool) if err != nil { return types.Tool{}, fmt.Errorf("failed resolving %s at %s: %w", targetToolName, base, err) @@ -295,7 +295,7 @@ func Program(ctx context.Context, name, subToolName string, opts ...Options) (ty opt := complete(opts...) if subToolName == "" { - name, subToolName = SplitToolRef(name) + name, subToolName = types.SplitToolRef(name) } prg := types.Program{ Name: name, @@ -346,24 +346,6 @@ func input(ctx context.Context, cache *cache.Client, base *source, name string) return nil, fmt.Errorf("can not load tools path=%s name=%s", base.Path, name) } -func SplitToolRef(targetToolName string) (toolName, subTool string) { - var ( - fields = strings.Fields(targetToolName) - idx = slices.Index(fields, "from") - ) - - defer func() { - toolName, _ = types.SplitArg(toolName) - }() - - if idx == -1 { - return strings.TrimSpace(targetToolName), "" - } - - return strings.Join(fields[idx+1:], " "), - strings.Join(fields[:idx], " ") -} - func isOpenAPI(data []byte) bool { var fragment struct { Paths map[string]any `json:"paths,omitempty"` diff --git a/pkg/loader/loader_test.go b/pkg/loader/loader_test.go index 6c3bb17b..8b306011 100644 --- a/pkg/loader/loader_test.go +++ b/pkg/loader/loader_test.go @@ -99,11 +99,3 @@ func TestHelloWorld(t *testing.T) { } }`, "MODEL", openai.DefaultModel)).Equal(t, toString(prg)) } - -func TestParse(t *testing.T) { - tool, subTool := SplitToolRef("a from b with x") - autogold.Expect([]string{"b", "a"}).Equal(t, []string{tool, subTool}) - - tool, subTool = SplitToolRef("a with x") - autogold.Expect([]string{"a", ""}).Equal(t, []string{tool, subTool}) -} diff --git a/pkg/remote/remote.go b/pkg/remote/remote.go index 35131efa..68072c71 100644 --- a/pkg/remote/remote.go +++ b/pkg/remote/remote.go @@ -46,7 +46,7 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques return nil, fmt.Errorf("failed to find remote model %s", messageRequest.Model) } - _, modelName := loader.SplitToolRef(messageRequest.Model) + _, modelName := types.SplitToolRef(messageRequest.Model) messageRequest.Model = modelName return client.Call(ctx, messageRequest, status) } @@ -71,7 +71,7 @@ func (c *Client) ListModels(ctx context.Context, providers ...string) (result [] } func (c *Client) Supports(ctx context.Context, modelName string) (bool, error) { - toolName, modelNameSuffix := loader.SplitToolRef(modelName) + toolName, modelNameSuffix := types.SplitToolRef(modelName) if modelNameSuffix == "" { return false, nil } diff --git a/pkg/types/toolname.go b/pkg/types/toolname.go index 390d2913..cd08e8e6 100644 --- a/pkg/types/toolname.go +++ b/pkg/types/toolname.go @@ -3,6 +3,7 @@ package types import ( "path/filepath" "regexp" + "slices" "strings" "github.com/gptscript-ai/gptscript/pkg/system" @@ -14,7 +15,13 @@ var ( ) func ToolNormalizer(tool string) string { - parts := strings.Split(tool, "/") + _, subTool := SplitToolRef(tool) + lastTool := tool + if subTool != "" { + lastTool = subTool + } + + parts := strings.Split(lastTool, "/") tool = parts[len(parts)-1] if strings.HasSuffix(tool, system.Suffix) { tool = strings.TrimSuffix(tool, filepath.Ext(tool)) @@ -43,6 +50,24 @@ func ToolNormalizer(tool string) string { return strings.Join(result, "") } +func SplitToolRef(targetToolName string) (toolName, subTool string) { + var ( + fields = strings.Fields(targetToolName) + idx = slices.Index(fields, "from") + ) + + defer func() { + toolName, _ = SplitArg(toolName) + }() + + if idx == -1 { + return strings.TrimSpace(targetToolName), "" + } + + return strings.Join(fields[idx+1:], " "), + strings.Join(fields[:idx], " ") +} + func PickToolName(toolName string, existing map[string]struct{}) string { if toolName == "" { toolName = "external" diff --git a/pkg/types/toolname_test.go b/pkg/types/toolname_test.go index fe1d2524..da3b3a0e 100644 --- a/pkg/types/toolname_test.go +++ b/pkg/types/toolname_test.go @@ -10,4 +10,16 @@ func TestToolNormalizer(t *testing.T) { autogold.Expect("bobTool").Equal(t, ToolNormalizer("bob-tool")) autogold.Expect("bobTool").Equal(t, ToolNormalizer("bob_tool")) autogold.Expect("bobTool").Equal(t, ToolNormalizer("BOB tOOL")) + autogold.Expect("barList").Equal(t, ToolNormalizer("bar_list from ./foo.yaml")) + autogold.Expect("barList").Equal(t, ToolNormalizer("bar_list from ./foo.gpt")) + autogold.Expect("write").Equal(t, ToolNormalizer("sys.write")) + autogold.Expect("gpt4VVision").Equal(t, ToolNormalizer("github.com/gptscript-ai/gpt4-v-vision")) +} + +func TestParse(t *testing.T) { + tool, subTool := SplitToolRef("a from b with x") + autogold.Expect([]string{"b", "a"}).Equal(t, []string{tool, subTool}) + + tool, subTool = SplitToolRef("a with x") + autogold.Expect([]string{"a", ""}).Equal(t, []string{tool, subTool}) }