diff --git a/pkg/engine/engine.go b/pkg/engine/engine.go index c123fae6..d71951fa 100644 --- a/pkg/engine/engine.go +++ b/pkg/engine/engine.go @@ -32,6 +32,7 @@ type Engine struct { } type State struct { + Input string `json:"input,omitempty"` Completion types.CompletionRequest `json:"completion,omitempty"` Pending map[string]types.CompletionToolCall `json:"pending,omitempty"` Results map[string]CallResult `json:"results,omitempty"` @@ -169,9 +170,15 @@ func (c *Context) WrappedContext() context.Context { return context.WithValue(c.Ctx, engineContext{}, c) } -func (e *Engine) Start(ctx Context, input string) (*Return, error) { +func (e *Engine) Start(ctx Context, input string) (ret *Return, _ error) { tool := ctx.Tool + defer func() { + if ret != nil && ret.State != nil { + ret.State.Input = input + } + }() + if tool.IsCommand() { if tool.IsHTTP() { return e.runHTTP(ctx.Ctx, ctx.Program, tool, input) @@ -321,6 +328,7 @@ func (e *Engine) Continue(ctx Context, state *State, results ...CallResult) (*Re var added bool state = &State{ + Input: state.Input, Completion: state.Completion, Pending: state.Pending, Results: map[string]CallResult{}, diff --git a/pkg/loader/loader.go b/pkg/loader/loader.go index 544cf676..e5a9bb1e 100644 --- a/pkg/loader/loader.go +++ b/pkg/loader/loader.go @@ -99,7 +99,7 @@ func loadProgram(data []byte, into *types.Program, targetToolName string) (types return tool, nil } - tool, ok := into.ToolSet[tool.LocalTools[targetToolName]] + tool, ok := into.ToolSet[tool.LocalTools[strings.ToLower(targetToolName)]] if !ok { return tool, &types.ErrToolNotFound{ ToolName: targetToolName, @@ -217,7 +217,8 @@ func link(ctx context.Context, prg *types.Program, base *source, tool types.Tool tool.Parameters.ExportContext, tool.Parameters.Context, tool.Parameters.Credentials) { - localTool, ok := localTools[targetToolName] + noArgs, _ := types.SplitArg(targetToolName) + localTool, ok := localTools[strings.ToLower(noArgs)] if ok { var linkedTool types.Tool if existing, ok := prg.ToolSet[localTool.ID]; ok { @@ -244,7 +245,7 @@ func link(ctx context.Context, prg *types.Program, base *source, tool types.Tool } for _, localTool := range localTools { - tool.LocalTools[localTool.Parameters.Name] = localTool.ID + tool.LocalTools[strings.ToLower(localTool.Parameters.Name)] = localTool.ID } tool = builtin.SetDefaults(tool) @@ -327,6 +328,10 @@ func SplitToolRef(targetToolName string) (toolName, subTool string) { idx = slices.Index(fields, "from") ) + defer func() { + toolName, _ = types.SplitArg(toolName) + }() + if idx == -1 { return strings.TrimSpace(targetToolName), "" } diff --git a/pkg/loader/loader_test.go b/pkg/loader/loader_test.go index c8e44647..0ce0ce8f 100644 --- a/pkg/loader/loader_test.go +++ b/pkg/loader/loader_test.go @@ -97,3 +97,11 @@ func TestHelloWorld(t *testing.T) { } }`).Equal(t, toString(prg)) } + +func TestParse(t *testing.T) { + tool, subTool := SplitToolRef("a from b with x") + autogold.Expect([]string{"b", "a"}).Equal(t, []string{tool, subTool}) + + tool, subTool = SplitToolRef("a with x") + autogold.Expect([]string{"a", ""}).Equal(t, []string{tool, subTool}) +} diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index 80e3a3cf..f4227ea0 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -233,14 +233,98 @@ var ( EventTypeCallFinish = EventType("callFinish") ) -func (r *Runner) getContext(callCtx engine.Context, monitor Monitor, env []string) (result []engine.InputContext, _ error) { - toolIDs, err := callCtx.Program.GetContextToolIDs(callCtx.Tool.ID) +func getContextInput(prg *types.Program, ref types.ToolReference, input string) (string, error) { + if ref.Arg == "" { + return "", nil + } + + targetArgs := prg.ToolSet[ref.ToolID].Arguments + targetKeys := map[string]string{} + + if targetArgs == nil { + return "", nil + } + + for targetKey := range targetArgs.Properties { + targetKeys[strings.ToLower(targetKey)] = targetKey + } + + inputMap := map[string]interface{}{} + outputMap := map[string]interface{}{} + + _ = json.Unmarshal([]byte(input), &inputMap) + + fields := strings.Fields(ref.Arg) + + for i := 0; i < len(fields); i++ { + field := fields[i] + if field == "and" { + continue + } + if field == "as" { + i++ + continue + } + + var ( + keyName string + val any + ) + + if strings.HasPrefix(field, "$") { + key := strings.TrimPrefix(field, "$") + key = strings.TrimPrefix(key, "{") + key = strings.TrimSuffix(key, "}") + val = inputMap[key] + } else { + val = field + } + + if len(fields) > i+1 && fields[i+1] == "as" { + keyName = strings.ToLower(fields[i+2]) + } + + if len(targetKeys) == 0 { + return "", fmt.Errorf("can not assign arg to context because target tool [%s] has no defined args", ref.ToolID) + } + + if keyName == "" { + if len(targetKeys) != 1 { + return "", fmt.Errorf("can not assign arg to context because target tool [%s] has does not have one args. You must use \"as\" syntax to map the arg to a key %v", ref.ToolID, targetKeys) + } + for k := range targetKeys { + keyName = k + } + } + + if targetKey, ok := targetKeys[strings.ToLower(keyName)]; ok { + outputMap[targetKey] = val + } else { + return "", fmt.Errorf("can not assign arg to context because target tool [%s] has does not args [%s]", ref.ToolID, keyName) + } + } + + if len(outputMap) == 0 { + return "", nil + } + + output, err := json.Marshal(outputMap) + return string(output), err +} + +func (r *Runner) getContext(callCtx engine.Context, monitor Monitor, env []string, input string) (result []engine.InputContext, _ error) { + toolRefs, err := callCtx.Program.GetContextToolRefs(callCtx.Tool.ID) if err != nil { return nil, err } - for _, toolID := range toolIDs { - content, err := r.subCall(callCtx.Ctx, callCtx, monitor, env, toolID, "", "", engine.ContextToolCategory) + for _, toolRef := range toolRefs { + contextInput, err := getContextInput(callCtx.Program, toolRef, input) + if err != nil { + return nil, err + } + + content, err := r.subCall(callCtx.Ctx, callCtx, monitor, env, toolRef.ToolID, contextInput, "", engine.ContextToolCategory) if err != nil { return nil, err } @@ -248,7 +332,7 @@ func (r *Runner) getContext(callCtx engine.Context, monitor Monitor, env []strin return nil, fmt.Errorf("context tool can not result in a chat continuation") } result = append(result, engine.InputContext{ - ToolID: toolID, + ToolID: toolRef.ToolID, Content: *content.Result, }) } @@ -278,7 +362,7 @@ func (r *Runner) start(callCtx engine.Context, monitor Monitor, env []string, in } var err error - callCtx.InputContext, err = r.getContext(callCtx, monitor, env) + callCtx.InputContext, err = r.getContext(callCtx, monitor, env, input) if err != nil { return nil, err } @@ -361,8 +445,16 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s } } - var err error - callCtx.InputContext, err = r.getContext(callCtx, monitor, env) + var ( + err error + contentInput string + ) + + if state.Continuation != nil && state.Continuation.State != nil { + contentInput = state.Continuation.State.Input + } + + callCtx.InputContext, err = r.getContext(callCtx, monitor, env, contentInput) if err != nil { return nil, err } diff --git a/pkg/tests/runner_test.go b/pkg/tests/runner_test.go index 199668a4..a67491b7 100644 --- a/pkg/tests/runner_test.go +++ b/pkg/tests/runner_test.go @@ -127,6 +127,7 @@ func TestSubChat(t *testing.T) { "state": { "continuation": { "state": { + "input": "Hello", "completion": { "Model": "gpt-4-turbo-preview", "InternalSystemPrompt": null, @@ -249,6 +250,7 @@ func TestSubChat(t *testing.T) { "state": { "continuation": { "state": { + "input": "Hello", "completion": { "Model": "gpt-4-turbo-preview", "InternalSystemPrompt": null, @@ -399,6 +401,7 @@ func TestChat(t *testing.T) { "state": { "continuation": { "state": { + "input": "Hello", "completion": { "Model": "gpt-4-turbo-preview", "InternalSystemPrompt": false, @@ -452,6 +455,7 @@ func TestChat(t *testing.T) { "state": { "continuation": { "state": { + "input": "Hello", "completion": { "Model": "gpt-4-turbo-preview", "InternalSystemPrompt": false, @@ -530,6 +534,15 @@ func TestContext(t *testing.T) { assert.Equal(t, "TEST RESULT CALL: 1", x) } +func TestContextArg(t *testing.T) { + runner := tester.NewRunner(t) + x, err := runner.Run("", `{ +"file": "foo.db" +}`) + require.NoError(t, err) + assert.Equal(t, "TEST RESULT CALL: 1", x) +} + func TestCwd(t *testing.T) { runner := tester.NewRunner(t) diff --git a/pkg/tests/testdata/TestContextArg/call1.golden b/pkg/tests/testdata/TestContextArg/call1.golden new file mode 100644 index 00000000..f0207cbf --- /dev/null +++ b/pkg/tests/testdata/TestContextArg/call1.golden @@ -0,0 +1,28 @@ +`{ + "Model": "gpt-4-turbo-preview", + "InternalSystemPrompt": null, + "Tools": null, + "Messages": [ + { + "role": "system", + "content": [ + { + "text": "this is from context -- foo.db\n\nthis is from other context foo.db and then\n\nthis is from other context and then foo.db\n\nThis is from tool" + } + ] + }, + { + "role": "user", + "content": [ + { + "text": "{\n\"file\": \"foo.db\"\n}" + } + ] + } + ], + "MaxTokens": 0, + "Temperature": null, + "JSONResponse": false, + "Grammar": "", + "Cache": null +}` diff --git a/pkg/tests/testdata/TestContextArg/other.gpt b/pkg/tests/testdata/TestContextArg/other.gpt new file mode 100644 index 00000000..b1acd66a --- /dev/null +++ b/pkg/tests/testdata/TestContextArg/other.gpt @@ -0,0 +1,6 @@ +name: fromcontext +args: first: an arg +args: second: an arg + +#!/bin/bash +echo this is from other context ${first} and then ${second} \ No newline at end of file diff --git a/pkg/tests/testdata/TestContextArg/test.gpt b/pkg/tests/testdata/TestContextArg/test.gpt new file mode 100644 index 00000000..9569aaf9 --- /dev/null +++ b/pkg/tests/testdata/TestContextArg/test.gpt @@ -0,0 +1,12 @@ +context: fromcontext with ${file} +context: fromcontext from other.gpt with ${file} as first +context: fromcontext from other.gpt with ${file} as second +arg: file: something + +This is from tool +--- +name: fromcontext +args: first: an arg + +#!/bin/bash +echo this is from context -- ${first} \ No newline at end of file diff --git a/pkg/tests/testdata/TestDualSubChat/step1.golden b/pkg/tests/testdata/TestDualSubChat/step1.golden index 503f946a..798bfe74 100644 --- a/pkg/tests/testdata/TestDualSubChat/step1.golden +++ b/pkg/tests/testdata/TestDualSubChat/step1.golden @@ -5,6 +5,7 @@ "state": { "continuation": { "state": { + "input": "User 1", "completion": { "Model": "gpt-4-turbo-preview", "InternalSystemPrompt": null, @@ -110,6 +111,7 @@ "state": { "continuation": { "state": { + "input": "Input to chatbot1", "completion": { "Model": "gpt-4-turbo-preview", "InternalSystemPrompt": false, @@ -175,6 +177,7 @@ "state": { "continuation": { "state": { + "input": "Input to chatbot2", "completion": { "Model": "gpt-4-turbo-preview", "InternalSystemPrompt": false, diff --git a/pkg/tests/testdata/TestDualSubChat/step2.golden b/pkg/tests/testdata/TestDualSubChat/step2.golden index 5fe0a132..9d4e8738 100644 --- a/pkg/tests/testdata/TestDualSubChat/step2.golden +++ b/pkg/tests/testdata/TestDualSubChat/step2.golden @@ -5,6 +5,7 @@ "state": { "continuation": { "state": { + "input": "User 1", "completion": { "Model": "gpt-4-turbo-preview", "InternalSystemPrompt": null, @@ -117,6 +118,7 @@ "state": { "continuation": { "state": { + "input": "Input to chatbot2", "completion": { "Model": "gpt-4-turbo-preview", "InternalSystemPrompt": false, diff --git a/pkg/tests/testdata/TestDualSubChat/step3.golden b/pkg/tests/testdata/TestDualSubChat/step3.golden index d7732186..1a0e393d 100644 --- a/pkg/tests/testdata/TestDualSubChat/step3.golden +++ b/pkg/tests/testdata/TestDualSubChat/step3.golden @@ -5,6 +5,7 @@ "state": { "continuation": { "state": { + "input": "User 1", "completion": { "Model": "gpt-4-turbo-preview", "InternalSystemPrompt": null, @@ -117,6 +118,7 @@ "state": { "continuation": { "state": { + "input": "Input to chatbot2", "completion": { "Model": "gpt-4-turbo-preview", "InternalSystemPrompt": false, diff --git a/pkg/types/tool.go b/pkg/types/tool.go index f619799f..3b43b8d6 100644 --- a/pkg/types/tool.go +++ b/pkg/types/tool.go @@ -3,6 +3,7 @@ package types import ( "context" "fmt" + "slices" "sort" "strings" @@ -54,38 +55,61 @@ func (p Program) ChatName() string { return p.Name } -func (p Program) GetContextToolIDs(toolID string) (result []string, _ error) { - seen := map[string]struct{}{} +type ToolReference struct { + Reference string + Arg string + ToolID string +} + +func (p Program) GetContextToolRefs(toolID string) (result []ToolReference, _ error) { + seen := map[struct { + toolID string + arg string + }]struct{}{} tool := p.ToolSet[toolID] - subToolIDs, err := tool.GetToolIDsFromNames(tool.Tools) + subToolRefs, err := tool.GetToolRefsFromNames(tool.Tools) if err != nil { return nil, err } - for _, subToolID := range subToolIDs { - subTool := p.ToolSet[subToolID] - exportContextToolIDs, err := subTool.GetToolIDsFromNames(subTool.ExportContext) + for _, subToolRef := range subToolRefs { + subTool := p.ToolSet[subToolRef.ToolID] + exportContextToolRefs, err := subTool.GetToolRefsFromNames(subTool.ExportContext) if err != nil { return nil, err } - for _, exportContextToolID := range exportContextToolIDs { - if _, ok := seen[exportContextToolID]; !ok { - seen[exportContextToolID] = struct{}{} - result = append(result, exportContextToolID) + for _, exportContextToolRef := range exportContextToolRefs { + key := struct { + toolID string + arg string + }{ + toolID: exportContextToolRef.ToolID, + arg: exportContextToolRef.Arg, + } + if _, ok := seen[key]; !ok { + seen[key] = struct{}{} + result = append(result, exportContextToolRef) } } } - contextToolIDs, err := p.ToolSet[toolID].GetToolIDsFromNames(p.ToolSet[toolID].Context) + contextToolRefs, err := p.ToolSet[toolID].GetToolRefsFromNames(p.ToolSet[toolID].Context) if err != nil { return nil, err } - for _, contextToolID := range contextToolIDs { - if _, ok := seen[contextToolID]; !ok { - seen[contextToolID] = struct{}{} - result = append(result, contextToolID) + for _, contextToolRef := range contextToolRefs { + key := struct { + toolID string + arg string + }{ + toolID: contextToolRef.ToolID, + arg: contextToolRef.Arg, + } + if _, ok := seen[key]; !ok { + seen[key] = struct{}{} + result = append(result, contextToolRef) } } @@ -155,13 +179,32 @@ type Tool struct { WorkingDir string `json:"workingDir,omitempty"` } -func (t Tool) GetToolIDsFromNames(names []string) (result []string, _ error) { +func SplitArg(hasArg string) (prefix, arg string) { + var ( + fields = strings.Fields(hasArg) + idx = slices.Index(fields, "with") + ) + + if idx == -1 { + return strings.TrimSpace(hasArg), "" + } + + return strings.Join(fields[:idx], " "), + strings.Join(fields[idx+1:], " ") +} + +func (t Tool) GetToolRefsFromNames(names []string) (result []ToolReference, _ error) { for _, toolName := range names { toolID, ok := t.ToolMapping[toolName] if !ok { return nil, NewErrToolNotFound(toolName) } - result = append(result, toolID) + _, arg := SplitArg(toolName) + result = append(result, ToolReference{ + Arg: arg, + Reference: toolName, + ToolID: toolID, + }) } return }