Skip to content

Commit 70529b7

Browse files
feat: add global model name and global tools fields
1 parent 2558f65 commit 70529b7

File tree

4 files changed

+140
-20
lines changed

4 files changed

+140
-20
lines changed

pkg/loader/loader.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,9 @@ func readTool(ctx context.Context, prg *types.Program, base *source, targetToolN
133133

134134
// If we didn't get any tools from trying to parse it as OpenAPI, try to parse it as a GPTScript
135135
if len(tools) == 0 {
136-
tools, err = parser.Parse(bytes.NewReader(data))
136+
tools, err = parser.Parse(bytes.NewReader(data), parser.Options{
137+
AssignGlobals: true,
138+
})
137139
if err != nil {
138140
return types.Tool{}, err
139141
}

pkg/parser/parser.go

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"fmt"
66
"io"
77
"regexp"
8+
"slices"
89
"strconv"
910
"strings"
1011

@@ -83,6 +84,8 @@ func isParam(line string, tool *types.Tool) (_ bool, err error) {
8384
tool.Parameters.ModelProvider = true
8485
case "model", "modelname":
8586
tool.Parameters.ModelName = value
87+
case "globalmodel", "globalmodelname":
88+
tool.Parameters.GlobalModelName = value
8689
case "description":
8790
tool.Parameters.Description = value
8891
case "internalprompt":
@@ -95,6 +98,8 @@ func isParam(line string, tool *types.Tool) (_ bool, err error) {
9598
tool.Parameters.Export = append(tool.Parameters.Export, csv(strings.ToLower(value))...)
9699
case "tool", "tools":
97100
tool.Parameters.Tools = append(tool.Parameters.Tools, csv(strings.ToLower(value))...)
101+
case "globaltool", "globaltools":
102+
tool.Parameters.GlobalTools = append(tool.Parameters.GlobalTools, csv(strings.ToLower(value))...)
98103
case "exportcontext":
99104
tool.Parameters.ExportContext = append(tool.Parameters.ExportContext, csv(strings.ToLower(value))...)
100105
case "context":
@@ -168,13 +173,76 @@ type context struct {
168173

169174
func (c *context) finish(tools *[]types.Tool) {
170175
c.tool.Instructions = strings.TrimSpace(strings.Join(c.instructions, ""))
171-
if c.tool.Instructions != "" || c.tool.Parameters.Name != "" || len(c.tool.Export) > 0 || len(c.tool.Tools) > 0 {
176+
if c.tool.Instructions != "" || c.tool.Parameters.Name != "" ||
177+
len(c.tool.Export) > 0 || len(c.tool.Tools) > 0 ||
178+
c.tool.GlobalModelName != "" ||
179+
len(c.tool.GlobalTools) > 0 {
172180
*tools = append(*tools, c.tool)
173181
}
174182
*c = context{}
175183
}
176184

177-
func Parse(input io.Reader) ([]types.Tool, error) {
185+
type Options struct {
186+
AssignGlobals bool
187+
}
188+
189+
func complete(opts ...Options) (result Options) {
190+
for _, opt := range opts {
191+
result.AssignGlobals = types.FirstSet(result.AssignGlobals, opt.AssignGlobals)
192+
}
193+
return
194+
}
195+
196+
func Parse(input io.Reader, opts ...Options) ([]types.Tool, error) {
197+
tools, err := parse(input)
198+
if err != nil {
199+
return nil, err
200+
}
201+
202+
opt := complete(opts...)
203+
204+
if !opt.AssignGlobals {
205+
return tools, nil
206+
}
207+
208+
var (
209+
globalModel string
210+
seenGlobalTools = map[string]struct{}{}
211+
globalTools []string
212+
)
213+
214+
for _, tool := range tools {
215+
if tool.GlobalModelName != "" {
216+
if globalModel != "" {
217+
return nil, fmt.Errorf("global model name defined multiple times")
218+
}
219+
globalModel = tool.GlobalModelName
220+
}
221+
for _, globalTool := range tool.GlobalTools {
222+
if _, ok := seenGlobalTools[globalTool]; ok {
223+
continue
224+
}
225+
seenGlobalTools[globalTool] = struct{}{}
226+
globalTools = append(globalTools, globalTool)
227+
}
228+
}
229+
230+
for i, tool := range tools {
231+
if globalModel != "" && tool.ModelName == "" {
232+
tool.ModelName = globalModel
233+
}
234+
for _, globalTool := range globalTools {
235+
if !slices.Contains(tool.Tools, globalTool) {
236+
tool.Tools = append(tool.Tools, globalTool)
237+
}
238+
}
239+
tools[i] = tool
240+
}
241+
242+
return tools, nil
243+
}
244+
245+
func parse(input io.Reader) ([]types.Tool, error) {
178246
scan := bufio.NewScanner(input)
179247

180248
var (

pkg/parser/parser_test.go

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,49 @@ import (
99
"github.com/stretchr/testify/require"
1010
)
1111

12-
func TestParse(t *testing.T) {
12+
func TestParseGlobals(t *testing.T) {
13+
var input = `
14+
global tools: foo, bar
15+
global model: the model
16+
---
17+
name: bar
18+
tools: bar
19+
`
20+
out, err := Parse(strings.NewReader(input), Options{
21+
AssignGlobals: true,
22+
})
23+
require.NoError(t, err)
24+
autogold.Expect([]types.Tool{
25+
{
26+
Parameters: types.Parameters{
27+
ModelName: "the model",
28+
Tools: []string{
29+
"foo",
30+
"bar",
31+
},
32+
GlobalTools: []string{
33+
"foo",
34+
"bar",
35+
},
36+
GlobalModelName: "the model",
37+
},
38+
Source: types.ToolSource{LineNo: 1},
39+
},
40+
{
41+
Parameters: types.Parameters{
42+
Name: "bar",
43+
ModelName: "the model",
44+
Tools: []string{
45+
"bar",
46+
"foo",
47+
},
48+
},
49+
Source: types.ToolSource{LineNo: 5},
50+
},
51+
}).Equal(t, out)
52+
}
53+
54+
func TestParseSkip(t *testing.T) {
1355
var input = `
1456
first
1557
---

pkg/types/tool.go

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -107,22 +107,24 @@ func (p Program) SetBlocking() Program {
107107
type BuiltinFunc func(ctx context.Context, env []string, input string) (string, error)
108108

109109
type Parameters struct {
110-
Name string `json:"name,omitempty"`
111-
Description string `json:"description,omitempty"`
112-
MaxTokens int `json:"maxTokens,omitempty"`
113-
ModelName string `json:"modelName,omitempty"`
114-
ModelProvider bool `json:"modelProvider,omitempty"`
115-
JSONResponse bool `json:"jsonResponse,omitempty"`
116-
Temperature *float32 `json:"temperature,omitempty"`
117-
Cache *bool `json:"cache,omitempty"`
118-
InternalPrompt *bool `json:"internalPrompt"`
119-
Arguments *openapi3.Schema `json:"arguments,omitempty"`
120-
Tools []string `json:"tools,omitempty"`
121-
Context []string `json:"context,omitempty"`
122-
ExportContext []string `json:"exportContext,omitempty"`
123-
Export []string `json:"export,omitempty"`
124-
Credentials []string `json:"credentials,omitempty"`
125-
Blocking bool `json:"-"`
110+
Name string `json:"name,omitempty"`
111+
Description string `json:"description,omitempty"`
112+
MaxTokens int `json:"maxTokens,omitempty"`
113+
ModelName string `json:"modelName,omitempty"`
114+
ModelProvider bool `json:"modelProvider,omitempty"`
115+
JSONResponse bool `json:"jsonResponse,omitempty"`
116+
Temperature *float32 `json:"temperature,omitempty"`
117+
Cache *bool `json:"cache,omitempty"`
118+
InternalPrompt *bool `json:"internalPrompt"`
119+
Arguments *openapi3.Schema `json:"arguments,omitempty"`
120+
Tools []string `json:"tools,omitempty"`
121+
GlobalTools []string `json:"globalTools,omitempty"`
122+
GlobalModelName string `json:"globalModelName,omitempty"`
123+
Context []string `json:"context,omitempty"`
124+
ExportContext []string `json:"exportContext,omitempty"`
125+
Export []string `json:"export,omitempty"`
126+
Credentials []string `json:"credentials,omitempty"`
127+
Blocking bool `json:"-"`
126128
}
127129

128130
type Tool struct {
@@ -150,6 +152,12 @@ func (t Tool) GetToolIDsFromNames(names []string) (result []string, _ error) {
150152

151153
func (t Tool) String() string {
152154
buf := &strings.Builder{}
155+
if t.Parameters.GlobalModelName != "" {
156+
_, _ = fmt.Fprintf(buf, "Global Model Name: %s\n", t.Parameters.GlobalModelName)
157+
}
158+
if len(t.Parameters.GlobalTools) != 0 {
159+
_, _ = fmt.Fprintf(buf, "Global Tools: %s\n", strings.Join(t.Parameters.GlobalTools, ", "))
160+
}
153161
if t.Parameters.Name != "" {
154162
_, _ = fmt.Fprintf(buf, "Name: %s\n", t.Parameters.Name)
155163
}

0 commit comments

Comments
 (0)