From eb701e5ae559bdf96f43243f6a2d81295cd0b7e6 Mon Sep 17 00:00:00 2001 From: Darren Shepherd Date: Wed, 10 Apr 2024 16:25:09 -0700 Subject: [PATCH] feat: add context and export context tool fields --- pkg/engine/engine.go | 36 ++++++-- pkg/loader/loader.go | 24 ++++-- pkg/monitor/display.go | 22 ++--- pkg/parser/parser.go | 4 + pkg/runner/runner.go | 41 +++++++-- pkg/server/server.go | 8 -- pkg/tests/runner_test.go | 12 +++ pkg/tests/testdata/TestContext/call1.golden | 20 +++++ pkg/tests/testdata/TestContext/test.gpt | 8 ++ .../testdata/TestExportContext/call1.golden | 55 ++++++++++++ pkg/tests/testdata/TestExportContext/test.gpt | 30 +++++++ pkg/types/tool.go | 86 +++++++++++++++++++ 12 files changed, 303 insertions(+), 43 deletions(-) create mode 100644 pkg/tests/testdata/TestContext/call1.golden create mode 100644 pkg/tests/testdata/TestContext/test.gpt create mode 100644 pkg/tests/testdata/TestExportContext/call1.golden create mode 100644 pkg/tests/testdata/TestExportContext/test.gpt diff --git a/pkg/engine/engine.go b/pkg/engine/engine.go index 7bad5222..fda1f06a 100644 --- a/pkg/engine/engine.go +++ b/pkg/engine/engine.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "strings" "sync" "sync/atomic" @@ -53,11 +54,17 @@ type CallResult struct { } type Context struct { - ID string - Ctx context.Context - Parent *Context - Program *types.Program - Tool types.Tool + ID string + Ctx context.Context + Parent *Context + Program *types.Program + Tool types.Tool + InputContext []InputContext +} + +type InputContext struct { + ToolID string `json:"toolID,omitempty"` + Content string `json:"content,omitempty"` } func (c *Context) ParentID() string { @@ -77,9 +84,10 @@ func (c *Context) MarshalJSON() ([]byte, error) { parentID = c.Parent.ID } return json.Marshal(map[string]any{ - "id": c.ID, - "parentID": parentID, - "tool": c.Tool, + "id": c.ID, + "parentID": parentID, + "tool": c.Tool, + "inputContext": c.InputContext, }) } @@ -155,10 +163,20 @@ func (e *Engine) Start(ctx Context, input string) (*Return, error) { return nil, err } + var instructions []string + + for _, context := range ctx.InputContext { + instructions = append(instructions, context.Content) + } + if tool.Instructions != "" { + instructions = append(instructions, tool.Instructions) + } + + if len(instructions) > 0 { completion.Messages = append(completion.Messages, types.CompletionMessage{ Role: types.CompletionMessageRoleTypeSystem, - Content: types.Text(tool.Instructions), + Content: types.Text(strings.Join(instructions, "\n")), }) } diff --git a/pkg/loader/loader.go b/pkg/loader/loader.go index 15ad3774..eb06a88e 100644 --- a/pkg/loader/loader.go +++ b/pkg/loader/loader.go @@ -10,6 +10,7 @@ import ( "io/fs" "os" "path/filepath" + "slices" "strings" "github.com/getkin/kin-openapi/openapi3" @@ -195,7 +196,10 @@ func link(ctx context.Context, prg *types.Program, base *source, tool types.Tool // The below is done in two loops so that local names stay as the tool names // and don't get mangled by external references - for _, targetToolName := range append(tool.Parameters.Tools, tool.Parameters.Export...) { + for _, targetToolName := range slices.Concat(tool.Parameters.Tools, + tool.Parameters.Export, + tool.Parameters.ExportContext, + tool.Parameters.Context) { localTool, ok := localTools[targetToolName] if ok { var linkedTool types.Tool @@ -301,15 +305,17 @@ func input(ctx context.Context, base *source, name string) (*source, error) { } func SplitToolRef(targetToolName string) (toolName, subTool string) { - subTool, toolName, ok := strings.Cut(strings.ReplaceAll(targetToolName, "\t", " "), " from ") - if ok { - toolName = strings.TrimSpace(toolName) - subTool = strings.TrimSpace(subTool) - } else { - toolName = targetToolName - subTool = "" + var ( + fields = strings.Fields(targetToolName) + idx = slices.Index(fields, "from") + ) + + if idx == -1 { + return strings.TrimSpace(targetToolName), "" } - return + + return strings.Join(fields[idx+1:], " "), + strings.Join(fields[:idx], " ") } func isOpenAPI(data []byte) bool { diff --git a/pkg/monitor/display.go b/pkg/monitor/display.go index c315633c..5aaffbb0 100644 --- a/pkg/monitor/display.go +++ b/pkg/monitor/display.go @@ -204,17 +204,17 @@ func (d *display) Event(event runner.Event) { "parentID", currentCall.ParentID, "toolID", currentCall.ToolID) - prettyID, ok := d.callIDMap[currentCall.ID] + _, ok := d.callIDMap[currentCall.ID] if !ok { - prettyID = fmt.Sprint(atomic.AddInt64(&prettyIDCounter, 1)) + prettyID := fmt.Sprint(atomic.AddInt64(&prettyIDCounter, 1)) d.callIDMap[currentCall.ID] = prettyID } callName := callName{ - prettyID: prettyID, - call: ¤tCall, - prg: d.dump.Program, - calls: d.dump.Calls, + prettyIDMap: d.callIDMap, + call: ¤tCall, + prg: d.dump.Program, + calls: d.dump.Calls, } switch event.Type { @@ -327,10 +327,10 @@ func (j jsonDump) String() string { } type callName struct { - prettyID string - call *call - prg *types.Program - calls []call + prettyIDMap map[string]string + call *call + prg *types.Program + calls []call } func (c callName) String() string { @@ -346,7 +346,7 @@ func (c callName) String() string { name = tool.Source.Location } if currentCall.ID != "1" { - name += "(" + c.prettyID + ")" + name += "(" + c.prettyIDMap[currentCall.ID] + ")" } msg = append(msg, name) found := false diff --git a/pkg/parser/parser.go b/pkg/parser/parser.go index de890fc1..9557b340 100644 --- a/pkg/parser/parser.go +++ b/pkg/parser/parser.go @@ -93,6 +93,10 @@ func isParam(line string, tool *types.Tool) (_ bool, err error) { tool.Parameters.Export = append(tool.Parameters.Export, csv(strings.ToLower(value))...) case "tool", "tools": tool.Parameters.Tools = append(tool.Parameters.Tools, csv(strings.ToLower(value))...) + case "exportcontext": + tool.Parameters.ExportContext = append(tool.Parameters.ExportContext, csv(strings.ToLower(value))...) + case "context": + tool.Parameters.Context = append(tool.Parameters.Context, csv(strings.ToLower(value))...) case "args", "arg", "param", "params", "parameters", "parameter": if err := addArg(value, tool); err != nil { return false, err diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index 10990954..9e6edbd9 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -116,10 +116,35 @@ var ( EventTypeCallFinish = EventType("callFinish") ) +func (r *Runner) getContext(callCtx engine.Context, monitor Monitor, env []string) (result []engine.InputContext, _ error) { + toolIDs, err := callCtx.Program.GetContextToolIDs(callCtx.Tool.ID) + if err != nil { + return nil, err + } + + for _, toolID := range toolIDs { + content, err := r.subCall(callCtx.Ctx, callCtx, monitor, env, toolID, "", "") + if err != nil { + return nil, err + } + result = append(result, engine.InputContext{ + ToolID: toolID, + Content: content, + }) + } + return result, nil +} + func (r *Runner) call(callCtx engine.Context, monitor Monitor, env []string, input string) (string, error) { progress, progressClose := streamProgress(&callCtx, monitor) defer progressClose() + var err error + callCtx.InputContext, err = r.getContext(callCtx, monitor, env) + if err != nil { + return "", err + } + e := engine.Engine{ Model: r.c, RuntimeManager: r.runtimeManager, @@ -221,6 +246,15 @@ func streamProgress(callCtx *engine.Context, monitor Monitor) (chan<- types.Comp } } +func (r *Runner) subCall(ctx context.Context, parentContext engine.Context, monitor Monitor, env []string, toolID, input, callID string) (string, error) { + callCtx, err := parentContext.SubCall(ctx, toolID, callID) + if err != nil { + return "", err + } + + return r.call(callCtx, monitor, env, input) +} + func (r *Runner) subCalls(callCtx engine.Context, monitor Monitor, env []string, lastReturn *engine.Return) (callResults []engine.CallResult, _ error) { var ( resultLock sync.Mutex @@ -229,12 +263,7 @@ func (r *Runner) subCalls(callCtx engine.Context, monitor Monitor, env []string, eg, subCtx := errgroup.WithContext(callCtx.Ctx) for id, call := range lastReturn.Calls { eg.Go(func() error { - callCtx, err := callCtx.SubCall(subCtx, call.ToolID, id) - if err != nil { - return err - } - - result, err := r.call(callCtx, monitor, env, call.Input) + result, err := r.subCall(subCtx, callCtx, monitor, env, call.ToolID, call.Input, id) if err != nil { return err } diff --git a/pkg/server/server.go b/pkg/server/server.go index a8ba9619..6686e380 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -87,14 +87,6 @@ var ( type execKey struct{} -func ContextWithNewID(ctx context.Context) context.Context { - return context.WithValue(ctx, execKey{}, fmt.Sprint(atomic.AddInt64(&execID, 1))) -} - -func IDFromContext(ctx context.Context) string { - return ctx.Value(execKey{}).(string) -} - func (s *Server) Close() { s.runner.Close() } diff --git a/pkg/tests/runner_test.go b/pkg/tests/runner_test.go index 539d0c76..a9662e79 100644 --- a/pkg/tests/runner_test.go +++ b/pkg/tests/runner_test.go @@ -9,6 +9,18 @@ import ( "github.com/stretchr/testify/require" ) +func TestExportContext(t *testing.T) { + runner := tester.NewRunner(t) + x := runner.RunDefault() + assert.Equal(t, "TEST RESULT CALL: 1", x) +} + +func TestContext(t *testing.T) { + runner := tester.NewRunner(t) + x := runner.RunDefault() + assert.Equal(t, "TEST RESULT CALL: 1", x) +} + func TestCwd(t *testing.T) { runner := tester.NewRunner(t) diff --git a/pkg/tests/testdata/TestContext/call1.golden b/pkg/tests/testdata/TestContext/call1.golden new file mode 100644 index 00000000..916953f5 --- /dev/null +++ b/pkg/tests/testdata/TestContext/call1.golden @@ -0,0 +1,20 @@ +`{ + "Model": "gpt-4-turbo-preview", + "InternalSystemPrompt": null, + "Tools": null, + "Messages": [ + { + "role": "system", + "content": [ + { + "text": "this is from context\n\nThis is from tool" + } + ] + } + ], + "MaxTokens": 0, + "Temperature": null, + "JSONResponse": false, + "Grammar": "", + "Cache": null +}` diff --git a/pkg/tests/testdata/TestContext/test.gpt b/pkg/tests/testdata/TestContext/test.gpt new file mode 100644 index 00000000..1b276653 --- /dev/null +++ b/pkg/tests/testdata/TestContext/test.gpt @@ -0,0 +1,8 @@ +context: fromcontext + +This is from tool +--- +name: fromcontext + +#!/bin/bash +echo this is from context \ No newline at end of file diff --git a/pkg/tests/testdata/TestExportContext/call1.golden b/pkg/tests/testdata/TestExportContext/call1.golden new file mode 100644 index 00000000..f03d45ea --- /dev/null +++ b/pkg/tests/testdata/TestExportContext/call1.golden @@ -0,0 +1,55 @@ +`{ + "Model": "gpt-4-turbo-preview", + "InternalSystemPrompt": null, + "Tools": [ + { + "function": { + "toolID": "testdata/TestExportContext/test.gpt:21", + "name": "subtool", + "parameters": { + "properties": { + "defaultPromptParameter": { + "description": "Prompt to send to the tool or assistant. This may be instructions or question.", + "type": "string" + } + }, + "required": [ + "defaultPromptParameter" + ], + "type": "object" + } + } + }, + { + "function": { + "toolID": "testdata/TestExportContext/test.gpt:14", + "name": "sampletool", + "description": "sample", + "parameters": { + "properties": { + "foo": { + "description": "foo description", + "type": "string" + } + }, + "type": "object" + } + } + } + ], + "Messages": [ + { + "role": "system", + "content": [ + { + "text": "this is from external context\n\nthis is from context\n\nThis is from tool" + } + ] + } + ], + "MaxTokens": 0, + "Temperature": null, + "JSONResponse": false, + "Grammar": "", + "Cache": null +}` diff --git a/pkg/tests/testdata/TestExportContext/test.gpt b/pkg/tests/testdata/TestExportContext/test.gpt new file mode 100644 index 00000000..f015b223 --- /dev/null +++ b/pkg/tests/testdata/TestExportContext/test.gpt @@ -0,0 +1,30 @@ +tools: subtool +context: fromcontext + +This is from tool + +--- +name: fromcontext +export: sampletool + +#!/bin/bash +echo this is from context + +--- +name: sampletool +description: sample +args: foo: foo description + +Dummy body + +--- +name: subtool +export context: fromexportcontext + +Dummy body + +--- +name: fromexportcontext + +#!/bin/bash +echo this is from external context diff --git a/pkg/types/tool.go b/pkg/types/tool.go index ebd5812e..f7469be9 100644 --- a/pkg/types/tool.go +++ b/pkg/types/tool.go @@ -21,6 +21,12 @@ type ErrToolNotFound struct { ToolName string } +func NewErrToolNotFound(toolName string) *ErrToolNotFound { + return &ErrToolNotFound{ + ToolName: toolName, + } +} + func (e *ErrToolNotFound) Error() string { return fmt.Sprintf("tool not found: %s", e.ToolName) } @@ -33,6 +39,44 @@ type Program struct { ToolSet ToolSet `json:"toolSet,omitempty"` } +func (p Program) GetContextToolIDs(toolID string) (result []string, _ error) { + seen := map[string]struct{}{} + tool := p.ToolSet[toolID] + + subToolIDs, err := tool.GetToolIDsFromNames(tool.Tools) + if err != nil { + return nil, err + } + + for _, subToolID := range subToolIDs { + subTool := p.ToolSet[subToolID] + exportContextToolIDs, err := subTool.GetToolIDsFromNames(subTool.ExportContext) + if err != nil { + return nil, err + } + for _, exportContextToolID := range exportContextToolIDs { + if _, ok := seen[exportContextToolID]; !ok { + seen[exportContextToolID] = struct{}{} + result = append(result, exportContextToolID) + } + } + } + + contextToolIDs, err := p.ToolSet[toolID].GetToolIDsFromNames(p.ToolSet[toolID].Context) + if err != nil { + return nil, err + } + + for _, contextToolID := range contextToolIDs { + if _, ok := seen[contextToolID]; !ok { + seen[contextToolID] = struct{}{} + result = append(result, contextToolID) + } + } + + return +} + func (p Program) GetCompletionTools() (result []CompletionTool, err error) { return Tool{ Parameters: Parameters{ @@ -74,6 +118,8 @@ type Parameters struct { InternalPrompt *bool `json:"internalPrompt"` Arguments *openapi3.Schema `json:"arguments,omitempty"` Tools []string `json:"tools,omitempty"` + Context []string `json:"context,omitempty"` + ExportContext []string `json:"exportContext,omitempty"` Export []string `json:"export,omitempty"` Blocking bool `json:"-"` } @@ -90,6 +136,17 @@ type Tool struct { WorkingDir string `json:"workingDir,omitempty"` } +func (t Tool) GetToolIDsFromNames(names []string) (result []string, _ error) { + for _, toolName := range names { + toolID, ok := t.ToolMapping[toolName] + if !ok { + return nil, NewErrToolNotFound(toolName) + } + result = append(result, toolID) + } + return +} + func (t Tool) String() string { buf := &strings.Builder{} if t.Parameters.Name != "" { @@ -104,6 +161,12 @@ func (t Tool) String() string { if len(t.Parameters.Export) != 0 { _, _ = fmt.Fprintf(buf, "Export: %s\n", strings.Join(t.Parameters.Export, ", ")) } + if len(t.Parameters.ExportContext) != 0 { + _, _ = fmt.Fprintf(buf, "Export Context: %s\n", strings.Join(t.Parameters.ExportContext, ", ")) + } + if len(t.Parameters.Context) != 0 { + _, _ = fmt.Fprintf(buf, "Context: %s\n", strings.Join(t.Parameters.Context, ", ")) + } if t.Parameters.MaxTokens != 0 { _, _ = fmt.Fprintf(buf, "Max Tokens: %d\n", t.Parameters.MaxTokens) } @@ -154,6 +217,13 @@ func (t Tool) GetCompletionTools(prg Program) (result []CompletionTool, err erro } } + for _, subToolName := range t.Parameters.Context { + result, err = appendExports(result, prg, t, subToolName, toolNames) + if err != nil { + return nil, err + } + } + return result, nil } @@ -173,6 +243,22 @@ func getTool(prg Program, parent Tool, name string) (Tool, error) { return tool, nil } +func appendExports(completionTools []CompletionTool, prg Program, parentTool Tool, subToolName string, toolNames map[string]struct{}) ([]CompletionTool, error) { + subTool, err := getTool(prg, parentTool, subToolName) + if err != nil { + return nil, err + } + + for _, export := range subTool.Export { + completionTools, err = appendTool(completionTools, prg, subTool, export, toolNames) + if err != nil { + return nil, err + } + } + + return completionTools, nil +} + func appendTool(completionTools []CompletionTool, prg Program, parentTool Tool, subToolName string, toolNames map[string]struct{}) ([]CompletionTool, error) { subTool, err := getTool(prg, parentTool, subToolName) if err != nil {