Skip to content

Commit 58a71d7

Browse files
Merge pull request #232 from ibuildthecloud/context
feat: add context and export context tool fields
2 parents f1fd439 + eb701e5 commit 58a71d7

File tree

12 files changed

+303
-43
lines changed

12 files changed

+303
-43
lines changed

pkg/engine/engine.go

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"encoding/json"
66
"fmt"
7+
"strings"
78
"sync"
89
"sync/atomic"
910

@@ -53,11 +54,17 @@ type CallResult struct {
5354
}
5455

5556
type Context struct {
56-
ID string
57-
Ctx context.Context
58-
Parent *Context
59-
Program *types.Program
60-
Tool types.Tool
57+
ID string
58+
Ctx context.Context
59+
Parent *Context
60+
Program *types.Program
61+
Tool types.Tool
62+
InputContext []InputContext
63+
}
64+
65+
type InputContext struct {
66+
ToolID string `json:"toolID,omitempty"`
67+
Content string `json:"content,omitempty"`
6168
}
6269

6370
func (c *Context) ParentID() string {
@@ -77,9 +84,10 @@ func (c *Context) MarshalJSON() ([]byte, error) {
7784
parentID = c.Parent.ID
7885
}
7986
return json.Marshal(map[string]any{
80-
"id": c.ID,
81-
"parentID": parentID,
82-
"tool": c.Tool,
87+
"id": c.ID,
88+
"parentID": parentID,
89+
"tool": c.Tool,
90+
"inputContext": c.InputContext,
8391
})
8492
}
8593

@@ -155,10 +163,20 @@ func (e *Engine) Start(ctx Context, input string) (*Return, error) {
155163
return nil, err
156164
}
157165

166+
var instructions []string
167+
168+
for _, context := range ctx.InputContext {
169+
instructions = append(instructions, context.Content)
170+
}
171+
158172
if tool.Instructions != "" {
173+
instructions = append(instructions, tool.Instructions)
174+
}
175+
176+
if len(instructions) > 0 {
159177
completion.Messages = append(completion.Messages, types.CompletionMessage{
160178
Role: types.CompletionMessageRoleTypeSystem,
161-
Content: types.Text(tool.Instructions),
179+
Content: types.Text(strings.Join(instructions, "\n")),
162180
})
163181
}
164182

pkg/loader/loader.go

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"io/fs"
1111
"os"
1212
"path/filepath"
13+
"slices"
1314
"strings"
1415

1516
"github.com/getkin/kin-openapi/openapi3"
@@ -195,7 +196,10 @@ func link(ctx context.Context, prg *types.Program, base *source, tool types.Tool
195196
// The below is done in two loops so that local names stay as the tool names
196197
// and don't get mangled by external references
197198

198-
for _, targetToolName := range append(tool.Parameters.Tools, tool.Parameters.Export...) {
199+
for _, targetToolName := range slices.Concat(tool.Parameters.Tools,
200+
tool.Parameters.Export,
201+
tool.Parameters.ExportContext,
202+
tool.Parameters.Context) {
199203
localTool, ok := localTools[targetToolName]
200204
if ok {
201205
var linkedTool types.Tool
@@ -301,15 +305,17 @@ func input(ctx context.Context, base *source, name string) (*source, error) {
301305
}
302306

303307
func SplitToolRef(targetToolName string) (toolName, subTool string) {
304-
subTool, toolName, ok := strings.Cut(strings.ReplaceAll(targetToolName, "\t", " "), " from ")
305-
if ok {
306-
toolName = strings.TrimSpace(toolName)
307-
subTool = strings.TrimSpace(subTool)
308-
} else {
309-
toolName = targetToolName
310-
subTool = ""
308+
var (
309+
fields = strings.Fields(targetToolName)
310+
idx = slices.Index(fields, "from")
311+
)
312+
313+
if idx == -1 {
314+
return strings.TrimSpace(targetToolName), ""
311315
}
312-
return
316+
317+
return strings.Join(fields[idx+1:], " "),
318+
strings.Join(fields[:idx], " ")
313319
}
314320

315321
func isOpenAPI(data []byte) bool {

pkg/monitor/display.go

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -204,17 +204,17 @@ func (d *display) Event(event runner.Event) {
204204
"parentID", currentCall.ParentID,
205205
"toolID", currentCall.ToolID)
206206

207-
prettyID, ok := d.callIDMap[currentCall.ID]
207+
_, ok := d.callIDMap[currentCall.ID]
208208
if !ok {
209-
prettyID = fmt.Sprint(atomic.AddInt64(&prettyIDCounter, 1))
209+
prettyID := fmt.Sprint(atomic.AddInt64(&prettyIDCounter, 1))
210210
d.callIDMap[currentCall.ID] = prettyID
211211
}
212212

213213
callName := callName{
214-
prettyID: prettyID,
215-
call: &currentCall,
216-
prg: d.dump.Program,
217-
calls: d.dump.Calls,
214+
prettyIDMap: d.callIDMap,
215+
call: &currentCall,
216+
prg: d.dump.Program,
217+
calls: d.dump.Calls,
218218
}
219219

220220
switch event.Type {
@@ -327,10 +327,10 @@ func (j jsonDump) String() string {
327327
}
328328

329329
type callName struct {
330-
prettyID string
331-
call *call
332-
prg *types.Program
333-
calls []call
330+
prettyIDMap map[string]string
331+
call *call
332+
prg *types.Program
333+
calls []call
334334
}
335335

336336
func (c callName) String() string {
@@ -346,7 +346,7 @@ func (c callName) String() string {
346346
name = tool.Source.Location
347347
}
348348
if currentCall.ID != "1" {
349-
name += "(" + c.prettyID + ")"
349+
name += "(" + c.prettyIDMap[currentCall.ID] + ")"
350350
}
351351
msg = append(msg, name)
352352
found := false

pkg/parser/parser.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@ func isParam(line string, tool *types.Tool) (_ bool, err error) {
9393
tool.Parameters.Export = append(tool.Parameters.Export, csv(strings.ToLower(value))...)
9494
case "tool", "tools":
9595
tool.Parameters.Tools = append(tool.Parameters.Tools, csv(strings.ToLower(value))...)
96+
case "exportcontext":
97+
tool.Parameters.ExportContext = append(tool.Parameters.ExportContext, csv(strings.ToLower(value))...)
98+
case "context":
99+
tool.Parameters.Context = append(tool.Parameters.Context, csv(strings.ToLower(value))...)
96100
case "args", "arg", "param", "params", "parameters", "parameter":
97101
if err := addArg(value, tool); err != nil {
98102
return false, err

pkg/runner/runner.go

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,35 @@ var (
116116
EventTypeCallFinish = EventType("callFinish")
117117
)
118118

119+
func (r *Runner) getContext(callCtx engine.Context, monitor Monitor, env []string) (result []engine.InputContext, _ error) {
120+
toolIDs, err := callCtx.Program.GetContextToolIDs(callCtx.Tool.ID)
121+
if err != nil {
122+
return nil, err
123+
}
124+
125+
for _, toolID := range toolIDs {
126+
content, err := r.subCall(callCtx.Ctx, callCtx, monitor, env, toolID, "", "")
127+
if err != nil {
128+
return nil, err
129+
}
130+
result = append(result, engine.InputContext{
131+
ToolID: toolID,
132+
Content: content,
133+
})
134+
}
135+
return result, nil
136+
}
137+
119138
func (r *Runner) call(callCtx engine.Context, monitor Monitor, env []string, input string) (string, error) {
120139
progress, progressClose := streamProgress(&callCtx, monitor)
121140
defer progressClose()
122141

142+
var err error
143+
callCtx.InputContext, err = r.getContext(callCtx, monitor, env)
144+
if err != nil {
145+
return "", err
146+
}
147+
123148
e := engine.Engine{
124149
Model: r.c,
125150
RuntimeManager: r.runtimeManager,
@@ -221,6 +246,15 @@ func streamProgress(callCtx *engine.Context, monitor Monitor) (chan<- types.Comp
221246
}
222247
}
223248

249+
func (r *Runner) subCall(ctx context.Context, parentContext engine.Context, monitor Monitor, env []string, toolID, input, callID string) (string, error) {
250+
callCtx, err := parentContext.SubCall(ctx, toolID, callID)
251+
if err != nil {
252+
return "", err
253+
}
254+
255+
return r.call(callCtx, monitor, env, input)
256+
}
257+
224258
func (r *Runner) subCalls(callCtx engine.Context, monitor Monitor, env []string, lastReturn *engine.Return) (callResults []engine.CallResult, _ error) {
225259
var (
226260
resultLock sync.Mutex
@@ -229,12 +263,7 @@ func (r *Runner) subCalls(callCtx engine.Context, monitor Monitor, env []string,
229263
eg, subCtx := errgroup.WithContext(callCtx.Ctx)
230264
for id, call := range lastReturn.Calls {
231265
eg.Go(func() error {
232-
callCtx, err := callCtx.SubCall(subCtx, call.ToolID, id)
233-
if err != nil {
234-
return err
235-
}
236-
237-
result, err := r.call(callCtx, monitor, env, call.Input)
266+
result, err := r.subCall(subCtx, callCtx, monitor, env, call.ToolID, call.Input, id)
238267
if err != nil {
239268
return err
240269
}

pkg/server/server.go

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -87,14 +87,6 @@ var (
8787

8888
type execKey struct{}
8989

90-
func ContextWithNewID(ctx context.Context) context.Context {
91-
return context.WithValue(ctx, execKey{}, fmt.Sprint(atomic.AddInt64(&execID, 1)))
92-
}
93-
94-
func IDFromContext(ctx context.Context) string {
95-
return ctx.Value(execKey{}).(string)
96-
}
97-
9890
func (s *Server) Close() {
9991
s.runner.Close()
10092
}

pkg/tests/runner_test.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,18 @@ import (
99
"github.com/stretchr/testify/require"
1010
)
1111

12+
func TestExportContext(t *testing.T) {
13+
runner := tester.NewRunner(t)
14+
x := runner.RunDefault()
15+
assert.Equal(t, "TEST RESULT CALL: 1", x)
16+
}
17+
18+
func TestContext(t *testing.T) {
19+
runner := tester.NewRunner(t)
20+
x := runner.RunDefault()
21+
assert.Equal(t, "TEST RESULT CALL: 1", x)
22+
}
23+
1224
func TestCwd(t *testing.T) {
1325
runner := tester.NewRunner(t)
1426

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
`{
2+
"Model": "gpt-4-turbo-preview",
3+
"InternalSystemPrompt": null,
4+
"Tools": null,
5+
"Messages": [
6+
{
7+
"role": "system",
8+
"content": [
9+
{
10+
"text": "this is from context\n\nThis is from tool"
11+
}
12+
]
13+
}
14+
],
15+
"MaxTokens": 0,
16+
"Temperature": null,
17+
"JSONResponse": false,
18+
"Grammar": "",
19+
"Cache": null
20+
}`
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
context: fromcontext
2+
3+
This is from tool
4+
---
5+
name: fromcontext
6+
7+
#!/bin/bash
8+
echo this is from context
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
`{
2+
"Model": "gpt-4-turbo-preview",
3+
"InternalSystemPrompt": null,
4+
"Tools": [
5+
{
6+
"function": {
7+
"toolID": "testdata/TestExportContext/test.gpt:21",
8+
"name": "subtool",
9+
"parameters": {
10+
"properties": {
11+
"defaultPromptParameter": {
12+
"description": "Prompt to send to the tool or assistant. This may be instructions or question.",
13+
"type": "string"
14+
}
15+
},
16+
"required": [
17+
"defaultPromptParameter"
18+
],
19+
"type": "object"
20+
}
21+
}
22+
},
23+
{
24+
"function": {
25+
"toolID": "testdata/TestExportContext/test.gpt:14",
26+
"name": "sampletool",
27+
"description": "sample",
28+
"parameters": {
29+
"properties": {
30+
"foo": {
31+
"description": "foo description",
32+
"type": "string"
33+
}
34+
},
35+
"type": "object"
36+
}
37+
}
38+
}
39+
],
40+
"Messages": [
41+
{
42+
"role": "system",
43+
"content": [
44+
{
45+
"text": "this is from external context\n\nthis is from context\n\nThis is from tool"
46+
}
47+
]
48+
}
49+
],
50+
"MaxTokens": 0,
51+
"Temperature": null,
52+
"JSONResponse": false,
53+
"Grammar": "",
54+
"Cache": null
55+
}`

0 commit comments

Comments
 (0)