diff --git a/pkg/engine/engine.go b/pkg/engine/engine.go index 09c037a1..e8e02b20 100644 --- a/pkg/engine/engine.go +++ b/pkg/engine/engine.go @@ -53,16 +53,20 @@ type CallResult struct { Result string `json:"result,omitempty"` } +type commonContext struct { + ID string `json:"id"` + Tool types.Tool `json:"tool"` + InputContext []InputContext `json:"inputContext"` + // IsCredential indicates that the current call is for a credential tool + IsCredential bool `json:"isCredential"` +} + type Context struct { - ID string + commonContext Ctx context.Context Parent *Context Program *types.Program - Tool types.Tool - InputContext []InputContext CredentialContext string - // IsCredential indicates that the current call is for a credential tool - IsCredential bool } type InputContext struct { @@ -70,6 +74,12 @@ type InputContext struct { Content string `json:"content,omitempty"` } +type BasicContext struct { + commonContext `json:",inline"` + ToolName string `json:"toolName,omitempty"` + ParentID string `json:"parentID,omitempty"` +} + func (c *Context) ParentID() string { if c.Parent == nil { return "" @@ -77,31 +87,42 @@ func (c *Context) ParentID() string { return c.Parent.ID } +func (c *Context) ToBasicContext() *BasicContext { + var toolName string + if c.Parent != nil { + for name, id := range c.Parent.Tool.ToolMapping { + if id == c.Tool.ID { + toolName = name + break + } + } + } + + return &BasicContext{ + commonContext: c.commonContext, + ParentID: c.ParentID(), + ToolName: toolName, + } +} + func (c *Context) UnmarshalJSON([]byte) error { panic("this data struct is circular by design and can not be read from json") } func (c *Context) MarshalJSON() ([]byte, error) { - var parentID string - if c.Parent != nil { - parentID = c.Parent.ID - } - return json.Marshal(map[string]any{ - "id": c.ID, - "parentID": parentID, - "tool": c.Tool, - "inputContext": c.InputContext, - }) + return json.Marshal(c.ToBasicContext()) } var execID int32 func NewContext(ctx context.Context, prg *types.Program) Context { callCtx := Context{ - ID: fmt.Sprint(atomic.AddInt32(&execID, 1)), + commonContext: commonContext{ + ID: fmt.Sprint(atomic.AddInt32(&execID, 1)), + Tool: prg.ToolSet[prg.EntryToolID], + }, Ctx: ctx, Program: prg, - Tool: prg.ToolSet[prg.EntryToolID], } return callCtx } @@ -117,12 +138,14 @@ func (c *Context) SubCall(ctx context.Context, toolID, callID string, isCredenti } return Context{ - ID: callID, - Ctx: ctx, - Parent: c, - Program: c.Program, - Tool: tool, - IsCredential: isCredentialTool, // disallow calls to the LLM if this is a credential tool + commonContext: commonContext{ + ID: callID, + Tool: tool, + IsCredential: isCredentialTool, + }, + Ctx: ctx, + Parent: c, + Program: c.Program, }, nil } diff --git a/pkg/monitor/display.go b/pkg/monitor/display.go index 429b60d2..621d18db 100644 --- a/pkg/monitor/display.go +++ b/pkg/monitor/display.go @@ -194,7 +194,7 @@ func (d *display) Event(event runner.Event) { currentIndex = len(d.dump.Calls) currentCall = call{ ID: event.CallContext.ID, - ParentID: event.CallContext.ParentID(), + ParentID: event.CallContext.ParentID, ToolID: event.CallContext.Tool.ID, } d.dump.Calls = append(d.dump.Calls, currentCall) @@ -213,20 +213,12 @@ func (d *display) Event(event runner.Event) { } callName := callName{ - prettyIDMap: d.callIDMap, - call: ¤tCall, - prg: d.dump.Program, - calls: d.dump.Calls, - credential: event.CallContext.IsCredential, - } - - if event.CallContext.Parent != nil { - for name, id := range event.CallContext.Parent.Tool.ToolMapping { - if id == event.CallContext.Tool.ID { - callName.userSpecifiedToolName = name - break - } - } + prettyIDMap: d.callIDMap, + call: ¤tCall, + prg: d.dump.Program, + calls: d.dump.Calls, + credential: event.CallContext.IsCredential, + userSpecifiedToolName: event.CallContext.ToolName, } switch event.Type { diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index 4353a6c5..ef4c6ddc 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -108,7 +108,7 @@ func (r *Runner) Run(ctx context.Context, prg types.Program, env []string, input type Event struct { Time time.Time `json:"time,omitempty"` - CallContext *engine.Context `json:"callContext,omitempty"` + CallContext *engine.BasicContext `json:"callContext,omitempty"` ToolSubCalls map[string]engine.Call `json:"toolSubCalls,omitempty"` ToolResults int `json:"toolResults,omitempty"` Type EventType `json:"type,omitempty"` @@ -177,7 +177,7 @@ func (r *Runner) call(callCtx engine.Context, monitor Monitor, env []string, inp monitor.Event(Event{ Time: time.Now(), - CallContext: &callCtx, + CallContext: callCtx.ToBasicContext(), Type: EventTypeCallStart, Content: input, }) @@ -197,7 +197,7 @@ func (r *Runner) call(callCtx engine.Context, monitor Monitor, env []string, inp progressClose() monitor.Event(Event{ Time: time.Now(), - CallContext: &callCtx, + CallContext: callCtx.ToBasicContext(), Type: EventTypeCallFinish, Content: *result.Result, }) @@ -210,7 +210,7 @@ func (r *Runner) call(callCtx engine.Context, monitor Monitor, env []string, inp monitor.Event(Event{ Time: time.Now(), - CallContext: &callCtx, + CallContext: callCtx.ToBasicContext(), Type: EventTypeCallSubCalls, ToolSubCalls: result.Calls, }) @@ -222,7 +222,7 @@ func (r *Runner) call(callCtx engine.Context, monitor Monitor, env []string, inp monitor.Event(Event{ Time: time.Now(), - CallContext: &callCtx, + CallContext: callCtx.ToBasicContext(), Type: EventTypeCallContinue, ToolResults: len(callResults), }) @@ -245,7 +245,7 @@ func streamProgress(callCtx *engine.Context, monitor Monitor) (chan<- types.Comp if message := status.PartialResponse; message != nil { monitor.Event(Event{ Time: time.Now(), - CallContext: callCtx, + CallContext: callCtx.ToBasicContext(), Type: EventTypeCallProgress, ChatCompletionID: status.CompletionID, Content: message.String(), @@ -253,7 +253,7 @@ func streamProgress(callCtx *engine.Context, monitor Monitor) (chan<- types.Comp } else { monitor.Event(Event{ Time: time.Now(), - CallContext: callCtx, + CallContext: callCtx.ToBasicContext(), Type: EventTypeChat, ChatCompletionID: status.CompletionID, ChatRequest: status.Request,