Skip to content

feat: add ability to pass args to context tools #287

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion pkg/engine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ type Engine struct {
}

type State struct {
Input string `json:"input,omitempty"`
Completion types.CompletionRequest `json:"completion,omitempty"`
Pending map[string]types.CompletionToolCall `json:"pending,omitempty"`
Results map[string]CallResult `json:"results,omitempty"`
Expand Down Expand Up @@ -169,9 +170,15 @@ func (c *Context) WrappedContext() context.Context {
return context.WithValue(c.Ctx, engineContext{}, c)
}

func (e *Engine) Start(ctx Context, input string) (*Return, error) {
func (e *Engine) Start(ctx Context, input string) (ret *Return, _ error) {
tool := ctx.Tool

defer func() {
if ret != nil && ret.State != nil {
ret.State.Input = input
}
}()

if tool.IsCommand() {
if tool.IsHTTP() {
return e.runHTTP(ctx.Ctx, ctx.Program, tool, input)
Expand Down Expand Up @@ -321,6 +328,7 @@ func (e *Engine) Continue(ctx Context, state *State, results ...CallResult) (*Re
var added bool

state = &State{
Input: state.Input,
Completion: state.Completion,
Pending: state.Pending,
Results: map[string]CallResult{},
Expand Down
11 changes: 8 additions & 3 deletions pkg/loader/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func loadProgram(data []byte, into *types.Program, targetToolName string) (types
return tool, nil
}

tool, ok := into.ToolSet[tool.LocalTools[targetToolName]]
tool, ok := into.ToolSet[tool.LocalTools[strings.ToLower(targetToolName)]]
if !ok {
return tool, &types.ErrToolNotFound{
ToolName: targetToolName,
Expand Down Expand Up @@ -217,7 +217,8 @@ func link(ctx context.Context, prg *types.Program, base *source, tool types.Tool
tool.Parameters.ExportContext,
tool.Parameters.Context,
tool.Parameters.Credentials) {
localTool, ok := localTools[targetToolName]
noArgs, _ := types.SplitArg(targetToolName)
localTool, ok := localTools[strings.ToLower(noArgs)]
if ok {
var linkedTool types.Tool
if existing, ok := prg.ToolSet[localTool.ID]; ok {
Expand All @@ -244,7 +245,7 @@ func link(ctx context.Context, prg *types.Program, base *source, tool types.Tool
}

for _, localTool := range localTools {
tool.LocalTools[localTool.Parameters.Name] = localTool.ID
tool.LocalTools[strings.ToLower(localTool.Parameters.Name)] = localTool.ID
}

tool = builtin.SetDefaults(tool)
Expand Down Expand Up @@ -327,6 +328,10 @@ func SplitToolRef(targetToolName string) (toolName, subTool string) {
idx = slices.Index(fields, "from")
)

defer func() {
toolName, _ = types.SplitArg(toolName)
}()

if idx == -1 {
return strings.TrimSpace(targetToolName), ""
}
Expand Down
8 changes: 8 additions & 0 deletions pkg/loader/loader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,11 @@ func TestHelloWorld(t *testing.T) {
}
}`).Equal(t, toString(prg))
}

func TestParse(t *testing.T) {
tool, subTool := SplitToolRef("a from b with x")
autogold.Expect([]string{"b", "a"}).Equal(t, []string{tool, subTool})

tool, subTool = SplitToolRef("a with x")
autogold.Expect([]string{"a", ""}).Equal(t, []string{tool, subTool})
}
108 changes: 100 additions & 8 deletions pkg/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,22 +233,106 @@ var (
EventTypeCallFinish = EventType("callFinish")
)

func (r *Runner) getContext(callCtx engine.Context, monitor Monitor, env []string) (result []engine.InputContext, _ error) {
toolIDs, err := callCtx.Program.GetContextToolIDs(callCtx.Tool.ID)
func getContextInput(prg *types.Program, ref types.ToolReference, input string) (string, error) {
if ref.Arg == "" {
return "", nil
}

targetArgs := prg.ToolSet[ref.ToolID].Arguments
targetKeys := map[string]string{}

if targetArgs == nil {
return "", nil
}

for targetKey := range targetArgs.Properties {
targetKeys[strings.ToLower(targetKey)] = targetKey
}

inputMap := map[string]interface{}{}
outputMap := map[string]interface{}{}

_ = json.Unmarshal([]byte(input), &inputMap)

fields := strings.Fields(ref.Arg)

for i := 0; i < len(fields); i++ {
field := fields[i]
if field == "and" {
continue
}
if field == "as" {
i++
continue
}

var (
keyName string
val any
)

if strings.HasPrefix(field, "$") {
key := strings.TrimPrefix(field, "$")
key = strings.TrimPrefix(key, "{")
key = strings.TrimSuffix(key, "}")
val = inputMap[key]
} else {
val = field
}

if len(fields) > i+1 && fields[i+1] == "as" {
keyName = strings.ToLower(fields[i+2])
}

if len(targetKeys) == 0 {
return "", fmt.Errorf("can not assign arg to context because target tool [%s] has no defined args", ref.ToolID)
}

if keyName == "" {
if len(targetKeys) != 1 {
return "", fmt.Errorf("can not assign arg to context because target tool [%s] has does not have one args. You must use \"as\" syntax to map the arg to a key %v", ref.ToolID, targetKeys)
}
for k := range targetKeys {
keyName = k
}
}

if targetKey, ok := targetKeys[strings.ToLower(keyName)]; ok {
outputMap[targetKey] = val
} else {
return "", fmt.Errorf("can not assign arg to context because target tool [%s] has does not args [%s]", ref.ToolID, keyName)
}
}

if len(outputMap) == 0 {
return "", nil
}

output, err := json.Marshal(outputMap)
return string(output), err
}

func (r *Runner) getContext(callCtx engine.Context, monitor Monitor, env []string, input string) (result []engine.InputContext, _ error) {
toolRefs, err := callCtx.Program.GetContextToolRefs(callCtx.Tool.ID)
if err != nil {
return nil, err
}

for _, toolID := range toolIDs {
content, err := r.subCall(callCtx.Ctx, callCtx, monitor, env, toolID, "", "", engine.ContextToolCategory)
for _, toolRef := range toolRefs {
contextInput, err := getContextInput(callCtx.Program, toolRef, input)
if err != nil {
return nil, err
}

content, err := r.subCall(callCtx.Ctx, callCtx, monitor, env, toolRef.ToolID, contextInput, "", engine.ContextToolCategory)
if err != nil {
return nil, err
}
if content.Result == nil {
return nil, fmt.Errorf("context tool can not result in a chat continuation")
}
result = append(result, engine.InputContext{
ToolID: toolID,
ToolID: toolRef.ToolID,
Content: *content.Result,
})
}
Expand Down Expand Up @@ -278,7 +362,7 @@ func (r *Runner) start(callCtx engine.Context, monitor Monitor, env []string, in
}

var err error
callCtx.InputContext, err = r.getContext(callCtx, monitor, env)
callCtx.InputContext, err = r.getContext(callCtx, monitor, env, input)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -361,8 +445,16 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s
}
}

var err error
callCtx.InputContext, err = r.getContext(callCtx, monitor, env)
var (
err error
contentInput string
)

if state.Continuation != nil && state.Continuation.State != nil {
contentInput = state.Continuation.State.Input
}

callCtx.InputContext, err = r.getContext(callCtx, monitor, env, contentInput)
if err != nil {
return nil, err
}
Expand Down
13 changes: 13 additions & 0 deletions pkg/tests/runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ func TestSubChat(t *testing.T) {
"state": {
"continuation": {
"state": {
"input": "Hello",
"completion": {
"Model": "gpt-4-turbo-preview",
"InternalSystemPrompt": null,
Expand Down Expand Up @@ -249,6 +250,7 @@ func TestSubChat(t *testing.T) {
"state": {
"continuation": {
"state": {
"input": "Hello",
"completion": {
"Model": "gpt-4-turbo-preview",
"InternalSystemPrompt": null,
Expand Down Expand Up @@ -399,6 +401,7 @@ func TestChat(t *testing.T) {
"state": {
"continuation": {
"state": {
"input": "Hello",
"completion": {
"Model": "gpt-4-turbo-preview",
"InternalSystemPrompt": false,
Expand Down Expand Up @@ -452,6 +455,7 @@ func TestChat(t *testing.T) {
"state": {
"continuation": {
"state": {
"input": "Hello",
"completion": {
"Model": "gpt-4-turbo-preview",
"InternalSystemPrompt": false,
Expand Down Expand Up @@ -530,6 +534,15 @@ func TestContext(t *testing.T) {
assert.Equal(t, "TEST RESULT CALL: 1", x)
}

func TestContextArg(t *testing.T) {
runner := tester.NewRunner(t)
x, err := runner.Run("", `{
"file": "foo.db"
}`)
require.NoError(t, err)
assert.Equal(t, "TEST RESULT CALL: 1", x)
}

func TestCwd(t *testing.T) {
runner := tester.NewRunner(t)

Expand Down
28 changes: 28 additions & 0 deletions pkg/tests/testdata/TestContextArg/call1.golden
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
`{
"Model": "gpt-4-turbo-preview",
"InternalSystemPrompt": null,
"Tools": null,
"Messages": [
{
"role": "system",
"content": [
{
"text": "this is from context -- foo.db\n\nthis is from other context foo.db and then\n\nthis is from other context and then foo.db\n\nThis is from tool"
}
]
},
{
"role": "user",
"content": [
{
"text": "{\n\"file\": \"foo.db\"\n}"
}
]
}
],
"MaxTokens": 0,
"Temperature": null,
"JSONResponse": false,
"Grammar": "",
"Cache": null
}`
6 changes: 6 additions & 0 deletions pkg/tests/testdata/TestContextArg/other.gpt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
name: fromcontext
args: first: an arg
args: second: an arg

#!/bin/bash
echo this is from other context ${first} and then ${second}
12 changes: 12 additions & 0 deletions pkg/tests/testdata/TestContextArg/test.gpt
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
context: fromcontext with ${file}
context: fromcontext from other.gpt with ${file} as first
context: fromcontext from other.gpt with ${file} as second
arg: file: something

This is from tool
---
name: fromcontext
args: first: an arg

#!/bin/bash
echo this is from context -- ${first}
3 changes: 3 additions & 0 deletions pkg/tests/testdata/TestDualSubChat/step1.golden
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"state": {
"continuation": {
"state": {
"input": "User 1",
"completion": {
"Model": "gpt-4-turbo-preview",
"InternalSystemPrompt": null,
Expand Down Expand Up @@ -110,6 +111,7 @@
"state": {
"continuation": {
"state": {
"input": "Input to chatbot1",
"completion": {
"Model": "gpt-4-turbo-preview",
"InternalSystemPrompt": false,
Expand Down Expand Up @@ -175,6 +177,7 @@
"state": {
"continuation": {
"state": {
"input": "Input to chatbot2",
"completion": {
"Model": "gpt-4-turbo-preview",
"InternalSystemPrompt": false,
Expand Down
2 changes: 2 additions & 0 deletions pkg/tests/testdata/TestDualSubChat/step2.golden
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"state": {
"continuation": {
"state": {
"input": "User 1",
"completion": {
"Model": "gpt-4-turbo-preview",
"InternalSystemPrompt": null,
Expand Down Expand Up @@ -117,6 +118,7 @@
"state": {
"continuation": {
"state": {
"input": "Input to chatbot2",
"completion": {
"Model": "gpt-4-turbo-preview",
"InternalSystemPrompt": false,
Expand Down
2 changes: 2 additions & 0 deletions pkg/tests/testdata/TestDualSubChat/step3.golden
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"state": {
"continuation": {
"state": {
"input": "User 1",
"completion": {
"Model": "gpt-4-turbo-preview",
"InternalSystemPrompt": null,
Expand Down Expand Up @@ -117,6 +118,7 @@
"state": {
"continuation": {
"state": {
"input": "Input to chatbot2",
"completion": {
"Model": "gpt-4-turbo-preview",
"InternalSystemPrompt": false,
Expand Down
Loading