Skip to content

Commit 981de74

Browse files
Merge pull request #287 from ibuildthecloud/main
feat: add ability to pass args to context tools
2 parents e730bbd + 0c1123c commit 981de74

File tree

12 files changed

+251
-29
lines changed

12 files changed

+251
-29
lines changed

pkg/engine/engine.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ type Engine struct {
3232
}
3333

3434
type State struct {
35+
Input string `json:"input,omitempty"`
3536
Completion types.CompletionRequest `json:"completion,omitempty"`
3637
Pending map[string]types.CompletionToolCall `json:"pending,omitempty"`
3738
Results map[string]CallResult `json:"results,omitempty"`
@@ -169,9 +170,15 @@ func (c *Context) WrappedContext() context.Context {
169170
return context.WithValue(c.Ctx, engineContext{}, c)
170171
}
171172

172-
func (e *Engine) Start(ctx Context, input string) (*Return, error) {
173+
func (e *Engine) Start(ctx Context, input string) (ret *Return, _ error) {
173174
tool := ctx.Tool
174175

176+
defer func() {
177+
if ret != nil && ret.State != nil {
178+
ret.State.Input = input
179+
}
180+
}()
181+
175182
if tool.IsCommand() {
176183
if tool.IsHTTP() {
177184
return e.runHTTP(ctx.Ctx, ctx.Program, tool, input)
@@ -321,6 +328,7 @@ func (e *Engine) Continue(ctx Context, state *State, results ...CallResult) (*Re
321328
var added bool
322329

323330
state = &State{
331+
Input: state.Input,
324332
Completion: state.Completion,
325333
Pending: state.Pending,
326334
Results: map[string]CallResult{},

pkg/loader/loader.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ func loadProgram(data []byte, into *types.Program, targetToolName string) (types
9999
return tool, nil
100100
}
101101

102-
tool, ok := into.ToolSet[tool.LocalTools[targetToolName]]
102+
tool, ok := into.ToolSet[tool.LocalTools[strings.ToLower(targetToolName)]]
103103
if !ok {
104104
return tool, &types.ErrToolNotFound{
105105
ToolName: targetToolName,
@@ -217,7 +217,8 @@ func link(ctx context.Context, prg *types.Program, base *source, tool types.Tool
217217
tool.Parameters.ExportContext,
218218
tool.Parameters.Context,
219219
tool.Parameters.Credentials) {
220-
localTool, ok := localTools[targetToolName]
220+
noArgs, _ := types.SplitArg(targetToolName)
221+
localTool, ok := localTools[strings.ToLower(noArgs)]
221222
if ok {
222223
var linkedTool types.Tool
223224
if existing, ok := prg.ToolSet[localTool.ID]; ok {
@@ -244,7 +245,7 @@ func link(ctx context.Context, prg *types.Program, base *source, tool types.Tool
244245
}
245246

246247
for _, localTool := range localTools {
247-
tool.LocalTools[localTool.Parameters.Name] = localTool.ID
248+
tool.LocalTools[strings.ToLower(localTool.Parameters.Name)] = localTool.ID
248249
}
249250

250251
tool = builtin.SetDefaults(tool)
@@ -327,6 +328,10 @@ func SplitToolRef(targetToolName string) (toolName, subTool string) {
327328
idx = slices.Index(fields, "from")
328329
)
329330

331+
defer func() {
332+
toolName, _ = types.SplitArg(toolName)
333+
}()
334+
330335
if idx == -1 {
331336
return strings.TrimSpace(targetToolName), ""
332337
}

pkg/loader/loader_test.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,11 @@ func TestHelloWorld(t *testing.T) {
9797
}
9898
}`).Equal(t, toString(prg))
9999
}
100+
101+
func TestParse(t *testing.T) {
102+
tool, subTool := SplitToolRef("a from b with x")
103+
autogold.Expect([]string{"b", "a"}).Equal(t, []string{tool, subTool})
104+
105+
tool, subTool = SplitToolRef("a with x")
106+
autogold.Expect([]string{"a", ""}).Equal(t, []string{tool, subTool})
107+
}

pkg/runner/runner.go

Lines changed: 100 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -233,22 +233,106 @@ var (
233233
EventTypeCallFinish = EventType("callFinish")
234234
)
235235

236-
func (r *Runner) getContext(callCtx engine.Context, monitor Monitor, env []string) (result []engine.InputContext, _ error) {
237-
toolIDs, err := callCtx.Program.GetContextToolIDs(callCtx.Tool.ID)
236+
func getContextInput(prg *types.Program, ref types.ToolReference, input string) (string, error) {
237+
if ref.Arg == "" {
238+
return "", nil
239+
}
240+
241+
targetArgs := prg.ToolSet[ref.ToolID].Arguments
242+
targetKeys := map[string]string{}
243+
244+
if targetArgs == nil {
245+
return "", nil
246+
}
247+
248+
for targetKey := range targetArgs.Properties {
249+
targetKeys[strings.ToLower(targetKey)] = targetKey
250+
}
251+
252+
inputMap := map[string]interface{}{}
253+
outputMap := map[string]interface{}{}
254+
255+
_ = json.Unmarshal([]byte(input), &inputMap)
256+
257+
fields := strings.Fields(ref.Arg)
258+
259+
for i := 0; i < len(fields); i++ {
260+
field := fields[i]
261+
if field == "and" {
262+
continue
263+
}
264+
if field == "as" {
265+
i++
266+
continue
267+
}
268+
269+
var (
270+
keyName string
271+
val any
272+
)
273+
274+
if strings.HasPrefix(field, "$") {
275+
key := strings.TrimPrefix(field, "$")
276+
key = strings.TrimPrefix(key, "{")
277+
key = strings.TrimSuffix(key, "}")
278+
val = inputMap[key]
279+
} else {
280+
val = field
281+
}
282+
283+
if len(fields) > i+1 && fields[i+1] == "as" {
284+
keyName = strings.ToLower(fields[i+2])
285+
}
286+
287+
if len(targetKeys) == 0 {
288+
return "", fmt.Errorf("can not assign arg to context because target tool [%s] has no defined args", ref.ToolID)
289+
}
290+
291+
if keyName == "" {
292+
if len(targetKeys) != 1 {
293+
return "", fmt.Errorf("can not assign arg to context because target tool [%s] has does not have one args. You must use \"as\" syntax to map the arg to a key %v", ref.ToolID, targetKeys)
294+
}
295+
for k := range targetKeys {
296+
keyName = k
297+
}
298+
}
299+
300+
if targetKey, ok := targetKeys[strings.ToLower(keyName)]; ok {
301+
outputMap[targetKey] = val
302+
} else {
303+
return "", fmt.Errorf("can not assign arg to context because target tool [%s] has does not args [%s]", ref.ToolID, keyName)
304+
}
305+
}
306+
307+
if len(outputMap) == 0 {
308+
return "", nil
309+
}
310+
311+
output, err := json.Marshal(outputMap)
312+
return string(output), err
313+
}
314+
315+
func (r *Runner) getContext(callCtx engine.Context, monitor Monitor, env []string, input string) (result []engine.InputContext, _ error) {
316+
toolRefs, err := callCtx.Program.GetContextToolRefs(callCtx.Tool.ID)
238317
if err != nil {
239318
return nil, err
240319
}
241320

242-
for _, toolID := range toolIDs {
243-
content, err := r.subCall(callCtx.Ctx, callCtx, monitor, env, toolID, "", "", engine.ContextToolCategory)
321+
for _, toolRef := range toolRefs {
322+
contextInput, err := getContextInput(callCtx.Program, toolRef, input)
323+
if err != nil {
324+
return nil, err
325+
}
326+
327+
content, err := r.subCall(callCtx.Ctx, callCtx, monitor, env, toolRef.ToolID, contextInput, "", engine.ContextToolCategory)
244328
if err != nil {
245329
return nil, err
246330
}
247331
if content.Result == nil {
248332
return nil, fmt.Errorf("context tool can not result in a chat continuation")
249333
}
250334
result = append(result, engine.InputContext{
251-
ToolID: toolID,
335+
ToolID: toolRef.ToolID,
252336
Content: *content.Result,
253337
})
254338
}
@@ -278,7 +362,7 @@ func (r *Runner) start(callCtx engine.Context, monitor Monitor, env []string, in
278362
}
279363

280364
var err error
281-
callCtx.InputContext, err = r.getContext(callCtx, monitor, env)
365+
callCtx.InputContext, err = r.getContext(callCtx, monitor, env, input)
282366
if err != nil {
283367
return nil, err
284368
}
@@ -361,8 +445,16 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s
361445
}
362446
}
363447

364-
var err error
365-
callCtx.InputContext, err = r.getContext(callCtx, monitor, env)
448+
var (
449+
err error
450+
contentInput string
451+
)
452+
453+
if state.Continuation != nil && state.Continuation.State != nil {
454+
contentInput = state.Continuation.State.Input
455+
}
456+
457+
callCtx.InputContext, err = r.getContext(callCtx, monitor, env, contentInput)
366458
if err != nil {
367459
return nil, err
368460
}

pkg/tests/runner_test.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ func TestSubChat(t *testing.T) {
127127
"state": {
128128
"continuation": {
129129
"state": {
130+
"input": "Hello",
130131
"completion": {
131132
"Model": "gpt-4-turbo-preview",
132133
"InternalSystemPrompt": null,
@@ -249,6 +250,7 @@ func TestSubChat(t *testing.T) {
249250
"state": {
250251
"continuation": {
251252
"state": {
253+
"input": "Hello",
252254
"completion": {
253255
"Model": "gpt-4-turbo-preview",
254256
"InternalSystemPrompt": null,
@@ -399,6 +401,7 @@ func TestChat(t *testing.T) {
399401
"state": {
400402
"continuation": {
401403
"state": {
404+
"input": "Hello",
402405
"completion": {
403406
"Model": "gpt-4-turbo-preview",
404407
"InternalSystemPrompt": false,
@@ -452,6 +455,7 @@ func TestChat(t *testing.T) {
452455
"state": {
453456
"continuation": {
454457
"state": {
458+
"input": "Hello",
455459
"completion": {
456460
"Model": "gpt-4-turbo-preview",
457461
"InternalSystemPrompt": false,
@@ -530,6 +534,15 @@ func TestContext(t *testing.T) {
530534
assert.Equal(t, "TEST RESULT CALL: 1", x)
531535
}
532536

537+
func TestContextArg(t *testing.T) {
538+
runner := tester.NewRunner(t)
539+
x, err := runner.Run("", `{
540+
"file": "foo.db"
541+
}`)
542+
require.NoError(t, err)
543+
assert.Equal(t, "TEST RESULT CALL: 1", x)
544+
}
545+
533546
func TestCwd(t *testing.T) {
534547
runner := tester.NewRunner(t)
535548

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
`{
2+
"Model": "gpt-4-turbo-preview",
3+
"InternalSystemPrompt": null,
4+
"Tools": null,
5+
"Messages": [
6+
{
7+
"role": "system",
8+
"content": [
9+
{
10+
"text": "this is from context -- foo.db\n\nthis is from other context foo.db and then\n\nthis is from other context and then foo.db\n\nThis is from tool"
11+
}
12+
]
13+
},
14+
{
15+
"role": "user",
16+
"content": [
17+
{
18+
"text": "{\n\"file\": \"foo.db\"\n}"
19+
}
20+
]
21+
}
22+
],
23+
"MaxTokens": 0,
24+
"Temperature": null,
25+
"JSONResponse": false,
26+
"Grammar": "",
27+
"Cache": null
28+
}`
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
name: fromcontext
2+
args: first: an arg
3+
args: second: an arg
4+
5+
#!/bin/bash
6+
echo this is from other context ${first} and then ${second}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
context: fromcontext with ${file}
2+
context: fromcontext from other.gpt with ${file} as first
3+
context: fromcontext from other.gpt with ${file} as second
4+
arg: file: something
5+
6+
This is from tool
7+
---
8+
name: fromcontext
9+
args: first: an arg
10+
11+
#!/bin/bash
12+
echo this is from context -- ${first}

pkg/tests/testdata/TestDualSubChat/step1.golden

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"state": {
66
"continuation": {
77
"state": {
8+
"input": "User 1",
89
"completion": {
910
"Model": "gpt-4-turbo-preview",
1011
"InternalSystemPrompt": null,
@@ -110,6 +111,7 @@
110111
"state": {
111112
"continuation": {
112113
"state": {
114+
"input": "Input to chatbot1",
113115
"completion": {
114116
"Model": "gpt-4-turbo-preview",
115117
"InternalSystemPrompt": false,
@@ -175,6 +177,7 @@
175177
"state": {
176178
"continuation": {
177179
"state": {
180+
"input": "Input to chatbot2",
178181
"completion": {
179182
"Model": "gpt-4-turbo-preview",
180183
"InternalSystemPrompt": false,

pkg/tests/testdata/TestDualSubChat/step2.golden

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"state": {
66
"continuation": {
77
"state": {
8+
"input": "User 1",
89
"completion": {
910
"Model": "gpt-4-turbo-preview",
1011
"InternalSystemPrompt": null,
@@ -117,6 +118,7 @@
117118
"state": {
118119
"continuation": {
119120
"state": {
121+
"input": "Input to chatbot2",
120122
"completion": {
121123
"Model": "gpt-4-turbo-preview",
122124
"InternalSystemPrompt": false,

0 commit comments

Comments
 (0)