diff --git a/pkg/engine/engine.go b/pkg/engine/engine.go index 2c7c729d..5cbc87a8 100644 --- a/pkg/engine/engine.go +++ b/pkg/engine/engine.go @@ -165,8 +165,13 @@ func WithToolCategory(ctx context.Context, toolCategory ToolCategory) context.Co return context.WithValue(ctx, toolCategoryKey{}, toolCategory) } -func NewContext(ctx context.Context, prg *types.Program, input string) (Context, error) { +func ToolCategoryFromContext(ctx context.Context) ToolCategory { category, _ := ctx.Value(toolCategoryKey{}).(ToolCategory) + return category +} + +func NewContext(ctx context.Context, prg *types.Program, input string) (Context, error) { + category := ToolCategoryFromContext(ctx) callCtx := Context{ commonContext: commonContext{ diff --git a/pkg/monitor/display.go b/pkg/monitor/display.go index 167498c7..73a15006 100644 --- a/pkg/monitor/display.go +++ b/pkg/monitor/display.go @@ -38,9 +38,7 @@ type Console struct { callLock sync.Mutex } -var ( - prettyIDCounter int64 -) +var prettyIDCounter int64 func (c *Console) Start(_ context.Context, prg *types.Program, _ []string, input string) (runner.Monitor, error) { id := counter.Next() @@ -290,7 +288,7 @@ func (d *display) Event(event runner.Event) { d.dump.Calls[currentIndex] = currentCall } -func (d *display) Stop(output string, err error) { +func (d *display) Stop(_ context.Context, output string, err error) { d.callLock.Lock() defer d.callLock.Unlock() diff --git a/pkg/monitor/fd.go b/pkg/monitor/fd.go index 8cfeeede..9136936a 100644 --- a/pkg/monitor/fd.go +++ b/pkg/monitor/fd.go @@ -139,7 +139,7 @@ func (f *fd) event(event Event) { } } -func (f *fd) Stop(output string, err error) { +func (f *fd) Stop(_ context.Context, output string, err error) { e := Event{ Event: runner.Event{ Time: time.Now(), diff --git a/pkg/runner/monitor.go b/pkg/runner/monitor.go index 48e64b0b..363ed62b 100644 --- a/pkg/runner/monitor.go +++ b/pkg/runner/monitor.go @@ -6,8 +6,7 @@ import ( "github.com/gptscript-ai/gptscript/pkg/types" ) -type noopFactory struct { -} +type noopFactory struct{} func (n noopFactory) Start(context.Context, *types.Program, []string, string) (Monitor, error) { return noopMonitor{}, nil @@ -17,13 +16,12 @@ func (n noopFactory) Pause() func() { return func() {} } -type noopMonitor struct { -} +type noopMonitor struct{} func (n noopMonitor) Event(Event) { } -func (n noopMonitor) Stop(string, error) {} +func (n noopMonitor) Stop(context.Context, string, error) {} func (n noopMonitor) Pause() func() { return func() {} diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index e4794535..3ca79abc 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -26,7 +26,7 @@ type MonitorFactory interface { type Monitor interface { Event(event Event) Pause() func() - Stop(output string, err error) + Stop(ctx context.Context, output string, err error) } type Options struct { @@ -162,7 +162,7 @@ func (r *Runner) Chat(ctx context.Context, prevState ChatState, prg types.Progra return resp, err } defer func() { - monitor.Stop(resp.Content, err) + monitor.Stop(ctx, resp.Content, err) }() callCtx, err := engine.NewContext(ctx, &prg, input) @@ -425,9 +425,7 @@ func (r *Runner) start(callCtx engine.Context, state *State, monitor Monitor, en } } - var ( - newState *State - ) + var newState *State callCtx.InputContext, newState, err = r.getContext(callCtx, state, monitor, env, input) if err != nil { return nil, err @@ -632,9 +630,7 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s Env: env, } - var ( - contentInput string - ) + var contentInput string if state.Continuation != nil && state.Continuation.State != nil { contentInput = state.Continuation.State.Input @@ -745,9 +741,7 @@ func (r *Runner) newDispatcher(ctx context.Context) dispatcher { } func (r *Runner) subCalls(callCtx engine.Context, monitor Monitor, env []string, state *State, toolCategory engine.ToolCategory) (_ *State, callResults []SubCallResult, _ error) { - var ( - resultLock sync.Mutex - ) + var resultLock sync.Mutex if state.Continuation != nil { callCtx.LastReturn = state.Continuation diff --git a/pkg/sdkserver/monitor.go b/pkg/sdkserver/monitor.go index 3a9a1014..a5b0236b 100644 --- a/pkg/sdkserver/monitor.go +++ b/pkg/sdkserver/monitor.go @@ -6,6 +6,7 @@ import ( "time" "github.com/gptscript-ai/broadcaster" + "github.com/gptscript-ai/gptscript/pkg/engine" "github.com/gptscript-ai/gptscript/pkg/runner" gserver "github.com/gptscript-ai/gptscript/pkg/server" "github.com/gptscript-ai/gptscript/pkg/types" @@ -23,16 +24,19 @@ func NewSessionFactory(events *broadcaster.Broadcaster[event]) *SessionFactory { func (s SessionFactory) Start(ctx context.Context, prg *types.Program, env []string, input string) (runner.Monitor, error) { id := gserver.RunIDFromContext(ctx) + category := engine.ToolCategoryFromContext(ctx) - s.events.C <- event{ - Event: gserver.Event{ - Event: runner.Event{ - Time: time.Now(), - Type: runner.EventTypeRunStart, + if category == engine.NoCategory { + s.events.C <- event{ + Event: gserver.Event{ + Event: runner.Event{ + Time: time.Now(), + Type: runner.EventTypeRunStart, + }, + RunID: id, + Program: prg, }, - RunID: id, - Program: prg, - }, + } } return &Session{ @@ -69,7 +73,13 @@ func (s *Session) Event(e runner.Event) { } } -func (s *Session) Stop(output string, err error) { +func (s *Session) Stop(ctx context.Context, output string, err error) { + category := engine.ToolCategoryFromContext(ctx) + + if category != engine.NoCategory { + return + } + e := event{ Event: gserver.Event{ Event: runner.Event{