diff --git a/pkg/loader/loader.go b/pkg/loader/loader.go index 3ed7725b..aaebca80 100644 --- a/pkg/loader/loader.go +++ b/pkg/loader/loader.go @@ -133,7 +133,9 @@ func readTool(ctx context.Context, prg *types.Program, base *source, targetToolN // If we didn't get any tools from trying to parse it as OpenAPI, try to parse it as a GPTScript if len(tools) == 0 { - tools, err = parser.Parse(bytes.NewReader(data)) + tools, err = parser.Parse(bytes.NewReader(data), parser.Options{ + AssignGlobals: true, + }) if err != nil { return types.Tool{}, err } diff --git a/pkg/parser/parser.go b/pkg/parser/parser.go index 5ab5295d..4f3415cd 100644 --- a/pkg/parser/parser.go +++ b/pkg/parser/parser.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "regexp" + "slices" "strconv" "strings" @@ -83,6 +84,8 @@ func isParam(line string, tool *types.Tool) (_ bool, err error) { tool.Parameters.ModelProvider = true case "model", "modelname": tool.Parameters.ModelName = value + case "globalmodel", "globalmodelname": + tool.Parameters.GlobalModelName = value case "description": tool.Parameters.Description = value case "internalprompt": @@ -95,6 +98,8 @@ func isParam(line string, tool *types.Tool) (_ bool, err error) { tool.Parameters.Export = append(tool.Parameters.Export, csv(strings.ToLower(value))...) case "tool", "tools": tool.Parameters.Tools = append(tool.Parameters.Tools, csv(strings.ToLower(value))...) + case "globaltool", "globaltools": + tool.Parameters.GlobalTools = append(tool.Parameters.GlobalTools, csv(strings.ToLower(value))...) case "exportcontext": tool.Parameters.ExportContext = append(tool.Parameters.ExportContext, csv(strings.ToLower(value))...) case "context": @@ -168,13 +173,76 @@ type context struct { func (c *context) finish(tools *[]types.Tool) { c.tool.Instructions = strings.TrimSpace(strings.Join(c.instructions, "")) - if c.tool.Instructions != "" || c.tool.Parameters.Name != "" || len(c.tool.Export) > 0 || len(c.tool.Tools) > 0 { + if c.tool.Instructions != "" || c.tool.Parameters.Name != "" || + len(c.tool.Export) > 0 || len(c.tool.Tools) > 0 || + c.tool.GlobalModelName != "" || + len(c.tool.GlobalTools) > 0 { *tools = append(*tools, c.tool) } *c = context{} } -func Parse(input io.Reader) ([]types.Tool, error) { +type Options struct { + AssignGlobals bool +} + +func complete(opts ...Options) (result Options) { + for _, opt := range opts { + result.AssignGlobals = types.FirstSet(result.AssignGlobals, opt.AssignGlobals) + } + return +} + +func Parse(input io.Reader, opts ...Options) ([]types.Tool, error) { + tools, err := parse(input) + if err != nil { + return nil, err + } + + opt := complete(opts...) + + if !opt.AssignGlobals { + return tools, nil + } + + var ( + globalModel string + seenGlobalTools = map[string]struct{}{} + globalTools []string + ) + + for _, tool := range tools { + if tool.GlobalModelName != "" { + if globalModel != "" { + return nil, fmt.Errorf("global model name defined multiple times") + } + globalModel = tool.GlobalModelName + } + for _, globalTool := range tool.GlobalTools { + if _, ok := seenGlobalTools[globalTool]; ok { + continue + } + seenGlobalTools[globalTool] = struct{}{} + globalTools = append(globalTools, globalTool) + } + } + + for i, tool := range tools { + if globalModel != "" && tool.ModelName == "" { + tool.ModelName = globalModel + } + for _, globalTool := range globalTools { + if !slices.Contains(tool.Tools, globalTool) { + tool.Tools = append(tool.Tools, globalTool) + } + } + tools[i] = tool + } + + return tools, nil +} + +func parse(input io.Reader) ([]types.Tool, error) { scan := bufio.NewScanner(input) var ( diff --git a/pkg/parser/parser_test.go b/pkg/parser/parser_test.go index 02a61fd3..0127b566 100644 --- a/pkg/parser/parser_test.go +++ b/pkg/parser/parser_test.go @@ -9,7 +9,49 @@ import ( "github.com/stretchr/testify/require" ) -func TestParse(t *testing.T) { +func TestParseGlobals(t *testing.T) { + var input = ` +global tools: foo, bar +global model: the model +--- +name: bar +tools: bar +` + out, err := Parse(strings.NewReader(input), Options{ + AssignGlobals: true, + }) + require.NoError(t, err) + autogold.Expect([]types.Tool{ + { + Parameters: types.Parameters{ + ModelName: "the model", + Tools: []string{ + "foo", + "bar", + }, + GlobalTools: []string{ + "foo", + "bar", + }, + GlobalModelName: "the model", + }, + Source: types.ToolSource{LineNo: 1}, + }, + { + Parameters: types.Parameters{ + Name: "bar", + ModelName: "the model", + Tools: []string{ + "bar", + "foo", + }, + }, + Source: types.ToolSource{LineNo: 5}, + }, + }).Equal(t, out) +} + +func TestParseSkip(t *testing.T) { var input = ` first --- diff --git a/pkg/types/tool.go b/pkg/types/tool.go index d1553827..21008aff 100644 --- a/pkg/types/tool.go +++ b/pkg/types/tool.go @@ -107,22 +107,24 @@ func (p Program) SetBlocking() Program { type BuiltinFunc func(ctx context.Context, env []string, input string) (string, error) type Parameters struct { - Name string `json:"name,omitempty"` - Description string `json:"description,omitempty"` - MaxTokens int `json:"maxTokens,omitempty"` - ModelName string `json:"modelName,omitempty"` - ModelProvider bool `json:"modelProvider,omitempty"` - JSONResponse bool `json:"jsonResponse,omitempty"` - Temperature *float32 `json:"temperature,omitempty"` - Cache *bool `json:"cache,omitempty"` - InternalPrompt *bool `json:"internalPrompt"` - Arguments *openapi3.Schema `json:"arguments,omitempty"` - Tools []string `json:"tools,omitempty"` - Context []string `json:"context,omitempty"` - ExportContext []string `json:"exportContext,omitempty"` - Export []string `json:"export,omitempty"` - Credentials []string `json:"credentials,omitempty"` - Blocking bool `json:"-"` + Name string `json:"name,omitempty"` + Description string `json:"description,omitempty"` + MaxTokens int `json:"maxTokens,omitempty"` + ModelName string `json:"modelName,omitempty"` + ModelProvider bool `json:"modelProvider,omitempty"` + JSONResponse bool `json:"jsonResponse,omitempty"` + Temperature *float32 `json:"temperature,omitempty"` + Cache *bool `json:"cache,omitempty"` + InternalPrompt *bool `json:"internalPrompt"` + Arguments *openapi3.Schema `json:"arguments,omitempty"` + Tools []string `json:"tools,omitempty"` + GlobalTools []string `json:"globalTools,omitempty"` + GlobalModelName string `json:"globalModelName,omitempty"` + Context []string `json:"context,omitempty"` + ExportContext []string `json:"exportContext,omitempty"` + Export []string `json:"export,omitempty"` + Credentials []string `json:"credentials,omitempty"` + Blocking bool `json:"-"` } type Tool struct { @@ -150,6 +152,12 @@ func (t Tool) GetToolIDsFromNames(names []string) (result []string, _ error) { func (t Tool) String() string { buf := &strings.Builder{} + if t.Parameters.GlobalModelName != "" { + _, _ = fmt.Fprintf(buf, "Global Model Name: %s\n", t.Parameters.GlobalModelName) + } + if len(t.Parameters.GlobalTools) != 0 { + _, _ = fmt.Fprintf(buf, "Global Tools: %s\n", strings.Join(t.Parameters.GlobalTools, ", ")) + } if t.Parameters.Name != "" { _, _ = fmt.Fprintf(buf, "Name: %s\n", t.Parameters.Name) }