From 5c06cf16b2b1f5d73641404fd80e291f5410af10 Mon Sep 17 00:00:00 2001 From: Nick Hale <4175918+njhale@users.noreply.github.com> Date: Tue, 25 Jun 2024 15:37:02 -0400 Subject: [PATCH] fix: send only one run start/finish event from sdkserver The sdkserver is sending multiple run start and finish events for provider tools which causes stream processing to terminate prematurely. To fix this, only send start and finish events when the tool category isn't set. Signed-off-by: Nick Hale <4175918+njhale@users.noreply.github.com> --- pkg/engine/engine.go | 7 ++++++- pkg/monitor/display.go | 6 ++---- pkg/monitor/fd.go | 2 +- pkg/runner/monitor.go | 8 +++----- pkg/runner/runner.go | 16 +++++----------- pkg/sdkserver/monitor.go | 28 +++++++++++++++++++--------- 6 files changed, 36 insertions(+), 31 deletions(-) diff --git a/pkg/engine/engine.go b/pkg/engine/engine.go index 2c7c729d..5cbc87a8 100644 --- a/pkg/engine/engine.go +++ b/pkg/engine/engine.go @@ -165,8 +165,13 @@ func WithToolCategory(ctx context.Context, toolCategory ToolCategory) context.Co return context.WithValue(ctx, toolCategoryKey{}, toolCategory) } -func NewContext(ctx context.Context, prg *types.Program, input string) (Context, error) { +func ToolCategoryFromContext(ctx context.Context) ToolCategory { category, _ := ctx.Value(toolCategoryKey{}).(ToolCategory) + return category +} + +func NewContext(ctx context.Context, prg *types.Program, input string) (Context, error) { + category := ToolCategoryFromContext(ctx) callCtx := Context{ commonContext: commonContext{ diff --git a/pkg/monitor/display.go b/pkg/monitor/display.go index 167498c7..73a15006 100644 --- a/pkg/monitor/display.go +++ b/pkg/monitor/display.go @@ -38,9 +38,7 @@ type Console struct { callLock sync.Mutex } -var ( - prettyIDCounter int64 -) +var prettyIDCounter int64 func (c *Console) Start(_ context.Context, prg *types.Program, _ []string, input string) (runner.Monitor, error) { id := counter.Next() @@ -290,7 +288,7 @@ func (d *display) Event(event runner.Event) { d.dump.Calls[currentIndex] = currentCall } -func (d *display) Stop(output string, err error) { +func (d *display) Stop(_ context.Context, output string, err error) { d.callLock.Lock() defer d.callLock.Unlock() diff --git a/pkg/monitor/fd.go b/pkg/monitor/fd.go index 8cfeeede..9136936a 100644 --- a/pkg/monitor/fd.go +++ b/pkg/monitor/fd.go @@ -139,7 +139,7 @@ func (f *fd) event(event Event) { } } -func (f *fd) Stop(output string, err error) { +func (f *fd) Stop(_ context.Context, output string, err error) { e := Event{ Event: runner.Event{ Time: time.Now(), diff --git a/pkg/runner/monitor.go b/pkg/runner/monitor.go index 48e64b0b..363ed62b 100644 --- a/pkg/runner/monitor.go +++ b/pkg/runner/monitor.go @@ -6,8 +6,7 @@ import ( "github.com/gptscript-ai/gptscript/pkg/types" ) -type noopFactory struct { -} +type noopFactory struct{} func (n noopFactory) Start(context.Context, *types.Program, []string, string) (Monitor, error) { return noopMonitor{}, nil @@ -17,13 +16,12 @@ func (n noopFactory) Pause() func() { return func() {} } -type noopMonitor struct { -} +type noopMonitor struct{} func (n noopMonitor) Event(Event) { } -func (n noopMonitor) Stop(string, error) {} +func (n noopMonitor) Stop(context.Context, string, error) {} func (n noopMonitor) Pause() func() { return func() {} diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index e4794535..3ca79abc 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -26,7 +26,7 @@ type MonitorFactory interface { type Monitor interface { Event(event Event) Pause() func() - Stop(output string, err error) + Stop(ctx context.Context, output string, err error) } type Options struct { @@ -162,7 +162,7 @@ func (r *Runner) Chat(ctx context.Context, prevState ChatState, prg types.Progra return resp, err } defer func() { - monitor.Stop(resp.Content, err) + monitor.Stop(ctx, resp.Content, err) }() callCtx, err := engine.NewContext(ctx, &prg, input) @@ -425,9 +425,7 @@ func (r *Runner) start(callCtx engine.Context, state *State, monitor Monitor, en } } - var ( - newState *State - ) + var newState *State callCtx.InputContext, newState, err = r.getContext(callCtx, state, monitor, env, input) if err != nil { return nil, err @@ -632,9 +630,7 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s Env: env, } - var ( - contentInput string - ) + var contentInput string if state.Continuation != nil && state.Continuation.State != nil { contentInput = state.Continuation.State.Input @@ -745,9 +741,7 @@ func (r *Runner) newDispatcher(ctx context.Context) dispatcher { } func (r *Runner) subCalls(callCtx engine.Context, monitor Monitor, env []string, state *State, toolCategory engine.ToolCategory) (_ *State, callResults []SubCallResult, _ error) { - var ( - resultLock sync.Mutex - ) + var resultLock sync.Mutex if state.Continuation != nil { callCtx.LastReturn = state.Continuation diff --git a/pkg/sdkserver/monitor.go b/pkg/sdkserver/monitor.go index 3a9a1014..a5b0236b 100644 --- a/pkg/sdkserver/monitor.go +++ b/pkg/sdkserver/monitor.go @@ -6,6 +6,7 @@ import ( "time" "github.com/gptscript-ai/broadcaster" + "github.com/gptscript-ai/gptscript/pkg/engine" "github.com/gptscript-ai/gptscript/pkg/runner" gserver "github.com/gptscript-ai/gptscript/pkg/server" "github.com/gptscript-ai/gptscript/pkg/types" @@ -23,16 +24,19 @@ func NewSessionFactory(events *broadcaster.Broadcaster[event]) *SessionFactory { func (s SessionFactory) Start(ctx context.Context, prg *types.Program, env []string, input string) (runner.Monitor, error) { id := gserver.RunIDFromContext(ctx) + category := engine.ToolCategoryFromContext(ctx) - s.events.C <- event{ - Event: gserver.Event{ - Event: runner.Event{ - Time: time.Now(), - Type: runner.EventTypeRunStart, + if category == engine.NoCategory { + s.events.C <- event{ + Event: gserver.Event{ + Event: runner.Event{ + Time: time.Now(), + Type: runner.EventTypeRunStart, + }, + RunID: id, + Program: prg, }, - RunID: id, - Program: prg, - }, + } } return &Session{ @@ -69,7 +73,13 @@ func (s *Session) Event(e runner.Event) { } } -func (s *Session) Stop(output string, err error) { +func (s *Session) Stop(ctx context.Context, output string, err error) { + category := engine.ToolCategoryFromContext(ctx) + + if category != engine.NoCategory { + return + } + e := event{ Event: gserver.Event{ Event: runner.Event{