Skip to content

feat: add context and export context tool fields #232

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 27 additions & 9 deletions pkg/engine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"strings"
"sync"
"sync/atomic"

Expand Down Expand Up @@ -53,11 +54,17 @@ type CallResult struct {
}

type Context struct {
ID string
Ctx context.Context
Parent *Context
Program *types.Program
Tool types.Tool
ID string
Ctx context.Context
Parent *Context
Program *types.Program
Tool types.Tool
InputContext []InputContext
}

type InputContext struct {
ToolID string `json:"toolID,omitempty"`
Content string `json:"content,omitempty"`
}

func (c *Context) ParentID() string {
Expand All @@ -77,9 +84,10 @@ func (c *Context) MarshalJSON() ([]byte, error) {
parentID = c.Parent.ID
}
return json.Marshal(map[string]any{
"id": c.ID,
"parentID": parentID,
"tool": c.Tool,
"id": c.ID,
"parentID": parentID,
"tool": c.Tool,
"inputContext": c.InputContext,
})
}

Expand Down Expand Up @@ -155,10 +163,20 @@ func (e *Engine) Start(ctx Context, input string) (*Return, error) {
return nil, err
}

var instructions []string

for _, context := range ctx.InputContext {
instructions = append(instructions, context.Content)
}

if tool.Instructions != "" {
instructions = append(instructions, tool.Instructions)
}

if len(instructions) > 0 {
completion.Messages = append(completion.Messages, types.CompletionMessage{
Role: types.CompletionMessageRoleTypeSystem,
Content: types.Text(tool.Instructions),
Content: types.Text(strings.Join(instructions, "\n")),
})
}

Expand Down
24 changes: 15 additions & 9 deletions pkg/loader/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"io/fs"
"os"
"path/filepath"
"slices"
"strings"

"github.com/getkin/kin-openapi/openapi3"
Expand Down Expand Up @@ -195,7 +196,10 @@ func link(ctx context.Context, prg *types.Program, base *source, tool types.Tool
// The below is done in two loops so that local names stay as the tool names
// and don't get mangled by external references

for _, targetToolName := range append(tool.Parameters.Tools, tool.Parameters.Export...) {
for _, targetToolName := range slices.Concat(tool.Parameters.Tools,
tool.Parameters.Export,
tool.Parameters.ExportContext,
tool.Parameters.Context) {
localTool, ok := localTools[targetToolName]
if ok {
var linkedTool types.Tool
Expand Down Expand Up @@ -301,15 +305,17 @@ func input(ctx context.Context, base *source, name string) (*source, error) {
}

func SplitToolRef(targetToolName string) (toolName, subTool string) {
subTool, toolName, ok := strings.Cut(strings.ReplaceAll(targetToolName, "\t", " "), " from ")
if ok {
toolName = strings.TrimSpace(toolName)
subTool = strings.TrimSpace(subTool)
} else {
toolName = targetToolName
subTool = ""
var (
fields = strings.Fields(targetToolName)
idx = slices.Index(fields, "from")
)

if idx == -1 {
return strings.TrimSpace(targetToolName), ""
}
return

return strings.Join(fields[idx+1:], " "),
strings.Join(fields[:idx], " ")
}

func isOpenAPI(data []byte) bool {
Expand Down
22 changes: 11 additions & 11 deletions pkg/monitor/display.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,17 +204,17 @@ func (d *display) Event(event runner.Event) {
"parentID", currentCall.ParentID,
"toolID", currentCall.ToolID)

prettyID, ok := d.callIDMap[currentCall.ID]
_, ok := d.callIDMap[currentCall.ID]
if !ok {
prettyID = fmt.Sprint(atomic.AddInt64(&prettyIDCounter, 1))
prettyID := fmt.Sprint(atomic.AddInt64(&prettyIDCounter, 1))
d.callIDMap[currentCall.ID] = prettyID
}

callName := callName{
prettyID: prettyID,
call: &currentCall,
prg: d.dump.Program,
calls: d.dump.Calls,
prettyIDMap: d.callIDMap,
call: &currentCall,
prg: d.dump.Program,
calls: d.dump.Calls,
}

switch event.Type {
Expand Down Expand Up @@ -327,10 +327,10 @@ func (j jsonDump) String() string {
}

type callName struct {
prettyID string
call *call
prg *types.Program
calls []call
prettyIDMap map[string]string
call *call
prg *types.Program
calls []call
}

func (c callName) String() string {
Expand All @@ -346,7 +346,7 @@ func (c callName) String() string {
name = tool.Source.Location
}
if currentCall.ID != "1" {
name += "(" + c.prettyID + ")"
name += "(" + c.prettyIDMap[currentCall.ID] + ")"
}
msg = append(msg, name)
found := false
Expand Down
4 changes: 4 additions & 0 deletions pkg/parser/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ func isParam(line string, tool *types.Tool) (_ bool, err error) {
tool.Parameters.Export = append(tool.Parameters.Export, csv(strings.ToLower(value))...)
case "tool", "tools":
tool.Parameters.Tools = append(tool.Parameters.Tools, csv(strings.ToLower(value))...)
case "exportcontext":
tool.Parameters.ExportContext = append(tool.Parameters.ExportContext, csv(strings.ToLower(value))...)
case "context":
tool.Parameters.Context = append(tool.Parameters.Context, csv(strings.ToLower(value))...)
case "args", "arg", "param", "params", "parameters", "parameter":
if err := addArg(value, tool); err != nil {
return false, err
Expand Down
41 changes: 35 additions & 6 deletions pkg/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,35 @@ var (
EventTypeCallFinish = EventType("callFinish")
)

func (r *Runner) getContext(callCtx engine.Context, monitor Monitor, env []string) (result []engine.InputContext, _ error) {
toolIDs, err := callCtx.Program.GetContextToolIDs(callCtx.Tool.ID)
if err != nil {
return nil, err
}

for _, toolID := range toolIDs {
content, err := r.subCall(callCtx.Ctx, callCtx, monitor, env, toolID, "", "")
if err != nil {
return nil, err
}
result = append(result, engine.InputContext{
ToolID: toolID,
Content: content,
})
}
return result, nil
}

func (r *Runner) call(callCtx engine.Context, monitor Monitor, env []string, input string) (string, error) {
progress, progressClose := streamProgress(&callCtx, monitor)
defer progressClose()

var err error
callCtx.InputContext, err = r.getContext(callCtx, monitor, env)
if err != nil {
return "", err
}

e := engine.Engine{
Model: r.c,
RuntimeManager: r.runtimeManager,
Expand Down Expand Up @@ -221,6 +246,15 @@ func streamProgress(callCtx *engine.Context, monitor Monitor) (chan<- types.Comp
}
}

func (r *Runner) subCall(ctx context.Context, parentContext engine.Context, monitor Monitor, env []string, toolID, input, callID string) (string, error) {
callCtx, err := parentContext.SubCall(ctx, toolID, callID)
if err != nil {
return "", err
}

return r.call(callCtx, monitor, env, input)
}

func (r *Runner) subCalls(callCtx engine.Context, monitor Monitor, env []string, lastReturn *engine.Return) (callResults []engine.CallResult, _ error) {
var (
resultLock sync.Mutex
Expand All @@ -229,12 +263,7 @@ func (r *Runner) subCalls(callCtx engine.Context, monitor Monitor, env []string,
eg, subCtx := errgroup.WithContext(callCtx.Ctx)
for id, call := range lastReturn.Calls {
eg.Go(func() error {
callCtx, err := callCtx.SubCall(subCtx, call.ToolID, id)
if err != nil {
return err
}

result, err := r.call(callCtx, monitor, env, call.Input)
result, err := r.subCall(subCtx, callCtx, monitor, env, call.ToolID, call.Input, id)
if err != nil {
return err
}
Expand Down
8 changes: 0 additions & 8 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,6 @@ var (

type execKey struct{}

func ContextWithNewID(ctx context.Context) context.Context {
return context.WithValue(ctx, execKey{}, fmt.Sprint(atomic.AddInt64(&execID, 1)))
}

func IDFromContext(ctx context.Context) string {
return ctx.Value(execKey{}).(string)
}

func (s *Server) Close() {
s.runner.Close()
}
Expand Down
12 changes: 12 additions & 0 deletions pkg/tests/runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,18 @@ import (
"github.com/stretchr/testify/require"
)

func TestExportContext(t *testing.T) {
runner := tester.NewRunner(t)
x := runner.RunDefault()
assert.Equal(t, "TEST RESULT CALL: 1", x)
}

func TestContext(t *testing.T) {
runner := tester.NewRunner(t)
x := runner.RunDefault()
assert.Equal(t, "TEST RESULT CALL: 1", x)
}

func TestCwd(t *testing.T) {
runner := tester.NewRunner(t)

Expand Down
20 changes: 20 additions & 0 deletions pkg/tests/testdata/TestContext/call1.golden
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
`{
"Model": "gpt-4-turbo-preview",
"InternalSystemPrompt": null,
"Tools": null,
"Messages": [
{
"role": "system",
"content": [
{
"text": "this is from context\n\nThis is from tool"
}
]
}
],
"MaxTokens": 0,
"Temperature": null,
"JSONResponse": false,
"Grammar": "",
"Cache": null
}`
8 changes: 8 additions & 0 deletions pkg/tests/testdata/TestContext/test.gpt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
context: fromcontext

This is from tool
---
name: fromcontext

#!/bin/bash
echo this is from context
55 changes: 55 additions & 0 deletions pkg/tests/testdata/TestExportContext/call1.golden
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
`{
"Model": "gpt-4-turbo-preview",
"InternalSystemPrompt": null,
"Tools": [
{
"function": {
"toolID": "testdata/TestExportContext/test.gpt:21",
"name": "subtool",
"parameters": {
"properties": {
"defaultPromptParameter": {
"description": "Prompt to send to the tool or assistant. This may be instructions or question.",
"type": "string"
}
},
"required": [
"defaultPromptParameter"
],
"type": "object"
}
}
},
{
"function": {
"toolID": "testdata/TestExportContext/test.gpt:14",
"name": "sampletool",
"description": "sample",
"parameters": {
"properties": {
"foo": {
"description": "foo description",
"type": "string"
}
},
"type": "object"
}
}
}
],
"Messages": [
{
"role": "system",
"content": [
{
"text": "this is from external context\n\nthis is from context\n\nThis is from tool"
}
]
}
],
"MaxTokens": 0,
"Temperature": null,
"JSONResponse": false,
"Grammar": "",
"Cache": null
}`
30 changes: 30 additions & 0 deletions pkg/tests/testdata/TestExportContext/test.gpt
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
tools: subtool
context: fromcontext

This is from tool

---
name: fromcontext
export: sampletool

#!/bin/bash
echo this is from context

---
name: sampletool
description: sample
args: foo: foo description

Dummy body

---
name: subtool
export context: fromexportcontext

Dummy body

---
name: fromexportcontext

#!/bin/bash
echo this is from external context
Loading