Skip to content

Commit 042673c

Browse files
change: allow chat continuation in context tools
1 parent 1b8458a commit 042673c

File tree

18 files changed

+1142
-44
lines changed

18 files changed

+1142
-44
lines changed

pkg/builtin/builtin.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -633,7 +633,9 @@ func SysChatFinish(ctx context.Context, env []string, input string) (string, err
633633
Message string `json:"message,omitempty"`
634634
}
635635
if err := json.Unmarshal([]byte(input), &params); err != nil {
636-
return "", err
636+
return "", &ErrChatFinish{
637+
Message: input,
638+
}
637639
}
638640
return "", &ErrChatFinish{
639641
Message: params.Message,

pkg/runner/runner.go

Lines changed: 119 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -134,21 +134,25 @@ func (r *Runner) Chat(ctx context.Context, prevState ChatState, prg types.Progra
134134
}()
135135

136136
callCtx := engine.NewContext(ctx, &prg)
137-
if state == nil {
138-
startResult, err := r.start(callCtx, monitor, env, input)
137+
if state == nil || state.StartContinuation {
138+
if state != nil {
139+
state = state.WithResumeInput(&input)
140+
input = state.InputContextContinuationInput
141+
}
142+
state, err = r.start(callCtx, state, monitor, env, input)
139143
if err != nil {
140144
return resp, err
141145
}
142-
state = &State{
143-
Continuation: startResult,
144-
}
145146
} else {
147+
state = state.WithResumeInput(&input)
146148
state.ResumeInput = &input
147149
}
148150

149-
state, err = r.resume(callCtx, monitor, env, state)
150-
if err != nil {
151-
return resp, err
151+
if !state.StartContinuation {
152+
state, err = r.resume(callCtx, monitor, env, state)
153+
if err != nil {
154+
return resp, err
155+
}
152156
}
153157

154158
if state.Result != nil {
@@ -286,44 +290,79 @@ func getContextInput(prg *types.Program, ref types.ToolReference, input string)
286290
return string(output), err
287291
}
288292

289-
func (r *Runner) getContext(callCtx engine.Context, monitor Monitor, env []string, input string) (result []engine.InputContext, _ error) {
293+
func (r *Runner) getContext(callCtx engine.Context, state *State, monitor Monitor, env []string, input string) (result []engine.InputContext, _ *State, _ error) {
290294
toolRefs, err := callCtx.Program.GetContextToolRefs(callCtx.Tool.ID)
291295
if err != nil {
292-
return nil, err
296+
return nil, nil, err
293297
}
294298

295-
for _, toolRef := range toolRefs {
299+
var newState *State
300+
if state != nil {
301+
cp := *state
302+
newState = &cp
303+
if newState.InputContextContinuation != nil {
304+
newState.InputContexts = nil
305+
newState.InputContextContinuation = nil
306+
newState.InputContextContinuationInput = ""
307+
newState.ResumeInput = state.InputContextContinuationResumeInput
308+
309+
input = state.InputContextContinuationInput
310+
}
311+
}
312+
313+
for i, toolRef := range toolRefs {
314+
if state != nil && i < len(state.InputContexts) {
315+
result = append(result, state.InputContexts[i])
316+
continue
317+
}
318+
296319
contextInput, err := getContextInput(callCtx.Program, toolRef, input)
297320
if err != nil {
298-
return nil, err
321+
return nil, nil, err
299322
}
300323

301-
content, err := r.subCall(callCtx.Ctx, callCtx, monitor, env, toolRef.ToolID, contextInput, "", engine.ContextToolCategory)
324+
var content *State
325+
if state != nil && state.InputContextContinuation != nil {
326+
content, err = r.subCallResume(callCtx.Ctx, callCtx, monitor, env, toolRef.ToolID, "", state.InputContextContinuation.WithResumeInput(state.ResumeInput), engine.ContextToolCategory)
327+
} else {
328+
content, err = r.subCall(callCtx.Ctx, callCtx, monitor, env, toolRef.ToolID, contextInput, "", engine.ContextToolCategory)
329+
}
302330
if err != nil {
303-
return nil, err
331+
return nil, nil, err
304332
}
305-
if content.Result == nil {
306-
return nil, fmt.Errorf("context tool can not result in a chat continuation")
333+
if content.Continuation != nil {
334+
if newState == nil {
335+
newState = &State{}
336+
}
337+
newState.InputContexts = result
338+
newState.InputContextContinuation = content
339+
newState.InputContextContinuationInput = input
340+
if state != nil {
341+
newState.InputContextContinuationResumeInput = state.ResumeInput
342+
}
343+
return nil, newState, nil
307344
}
308345
result = append(result, engine.InputContext{
309346
ToolID: toolRef.ToolID,
310347
Content: *content.Result,
311348
})
312349
}
313-
return result, nil
350+
351+
return result, newState, nil
314352
}
315353

316354
func (r *Runner) call(callCtx engine.Context, monitor Monitor, env []string, input string) (*State, error) {
317-
result, err := r.start(callCtx, monitor, env, input)
355+
result, err := r.start(callCtx, nil, monitor, env, input)
318356
if err != nil {
319357
return nil, err
320358
}
321-
return r.resume(callCtx, monitor, env, &State{
322-
Continuation: result,
323-
})
359+
if result.StartContinuation {
360+
return result, nil
361+
}
362+
return r.resume(callCtx, monitor, env, result)
324363
}
325364

326-
func (r *Runner) start(callCtx engine.Context, monitor Monitor, env []string, input string) (*engine.Return, error) {
365+
func (r *Runner) start(callCtx engine.Context, state *State, monitor Monitor, env []string, input string) (*State, error) {
327366
progress, progressClose := streamProgress(&callCtx, monitor)
328367
defer progressClose()
329368

@@ -335,11 +374,18 @@ func (r *Runner) start(callCtx engine.Context, monitor Monitor, env []string, in
335374
}
336375
}
337376

338-
var err error
339-
callCtx.InputContext, err = r.getContext(callCtx, monitor, env, input)
377+
var (
378+
err error
379+
newState *State
380+
)
381+
callCtx.InputContext, newState, err = r.getContext(callCtx, state, monitor, env, input)
340382
if err != nil {
341383
return nil, err
342384
}
385+
if newState != nil && newState.InputContextContinuation != nil {
386+
newState.StartContinuation = true
387+
return newState, nil
388+
}
343389

344390
e := engine.Engine{
345391
Model: r.c,
@@ -358,7 +404,14 @@ func (r *Runner) start(callCtx engine.Context, monitor Monitor, env []string, in
358404

359405
callCtx.Ctx = context2.AddPauseFuncToCtx(callCtx.Ctx, monitor.Pause)
360406

361-
return e.Start(callCtx, input)
407+
ret, err := e.Start(callCtx, input)
408+
if err != nil {
409+
return nil, err
410+
}
411+
412+
return &State{
413+
Continuation: ret,
414+
}, nil
362415
}
363416

364417
type State struct {
@@ -369,18 +422,28 @@ type State struct {
369422
ResumeInput *string `json:"resumeInput,omitempty"`
370423
SubCalls []SubCallResult `json:"subCalls,omitempty"`
371424
SubCallID string `json:"subCallID,omitempty"`
425+
426+
InputContexts []engine.InputContext `json:"inputContexts,omitempty"`
427+
InputContextContinuation *State `json:"inputContextContinuation,omitempty"`
428+
InputContextContinuationInput string `json:"inputContextContinuationInput,omitempty"`
429+
InputContextContinuationResumeInput *string `json:"inputContextContinuationResumeInput,omitempty"`
430+
StartContinuation bool `json:"startContinuation,omitempty"`
372431
}
373432

374-
func (s State) WithInput(input *string) *State {
433+
func (s State) WithResumeInput(input *string) *State {
375434
s.ResumeInput = input
376435
return &s
377436
}
378437

379438
func (s State) ContinuationContentToolID() (string, error) {
380-
if s.Continuation.Result != nil {
439+
if s.Continuation != nil && s.Continuation.Result != nil {
381440
return s.ContinuationToolID, nil
382441
}
383442

443+
if s.InputContextContinuation != nil {
444+
return s.InputContextContinuation.ContinuationContentToolID()
445+
}
446+
384447
for _, subCall := range s.SubCalls {
385448
if s.SubCallID == subCall.CallID {
386449
return subCall.State.ContinuationContentToolID()
@@ -390,10 +453,14 @@ func (s State) ContinuationContentToolID() (string, error) {
390453
}
391454

392455
func (s State) ContinuationContent() (string, error) {
393-
if s.Continuation.Result != nil {
456+
if s.Continuation != nil && s.Continuation.Result != nil {
394457
return *s.Continuation.Result, nil
395458
}
396459

460+
if s.InputContextContinuation != nil {
461+
return s.InputContextContinuation.ContinuationContent()
462+
}
463+
397464
for _, subCall := range s.SubCalls {
398465
if s.SubCallID == subCall.CallID {
399466
return subCall.State.ContinuationContent()
@@ -408,6 +475,10 @@ type Needed struct {
408475
}
409476

410477
func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, state *State) (*State, error) {
478+
if state.StartContinuation {
479+
return nil, fmt.Errorf("invalid state, resume should not have StartContinuation set to true")
480+
}
481+
411482
progress, progressClose := streamProgress(&callCtx, monitor)
412483
defer progressClose()
413484

@@ -433,6 +504,7 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s
433504
Continuation: state.Continuation,
434505
ContinuationToolID: callCtx.Tool.ID,
435506
}, nil
507+
436508
}
437509
return &State{
438510
Result: state.Continuation.Result,
@@ -451,7 +523,7 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s
451523
err error
452524
)
453525

454-
state, callResults, err = r.subCalls(callCtx, monitor, env, state)
526+
state, callResults, err = r.subCalls(callCtx, monitor, env, state, engine.NoCategory)
455527
if errMessage := (*builtin.ErrChatFinish)(nil); errors.As(err, &errMessage) && callCtx.Tool.Chat {
456528
return &State{
457529
Result: &errMessage.Message,
@@ -477,12 +549,6 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s
477549
}
478550
}
479551

480-
if state.ResumeInput != nil {
481-
engineResults = append(engineResults, engine.CallResult{
482-
User: *state.ResumeInput,
483-
})
484-
}
485-
486552
monitor.Event(Event{
487553
Time: time.Now(),
488554
CallContext: callCtx.GetCallContext(),
@@ -506,9 +572,15 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s
506572
contentInput = state.Continuation.State.Input
507573
}
508574

509-
callCtx.InputContext, err = r.getContext(callCtx, monitor, env, contentInput)
510-
if err != nil {
511-
return nil, err
575+
callCtx.InputContext, state, err = r.getContext(callCtx, state, monitor, env, contentInput)
576+
if err != nil || state.InputContextContinuation != nil {
577+
return state, err
578+
}
579+
580+
if state.ResumeInput != nil {
581+
engineResults = append(engineResults, engine.CallResult{
582+
User: *state.ResumeInput,
583+
})
512584
}
513585

514586
nextContinuation, err := e.Continue(callCtx, state.Continuation.State, engineResults...)
@@ -571,8 +643,8 @@ func (r *Runner) subCall(ctx context.Context, parentContext engine.Context, moni
571643
return r.call(callCtx, monitor, env, input)
572644
}
573645

574-
func (r *Runner) subCallResume(ctx context.Context, parentContext engine.Context, monitor Monitor, env []string, toolID, callID string, state *State) (*State, error) {
575-
callCtx, err := parentContext.SubCall(ctx, toolID, callID, engine.NoCategory)
646+
func (r *Runner) subCallResume(ctx context.Context, parentContext engine.Context, monitor Monitor, env []string, toolID, callID string, state *State, toolCategory engine.ToolCategory) (*State, error) {
647+
callCtx, err := parentContext.SubCall(ctx, toolID, callID, toolCategory)
576648
if err != nil {
577649
return nil, err
578650
}
@@ -593,11 +665,15 @@ func (r *Runner) newDispatcher(ctx context.Context) dispatcher {
593665
return newParallelDispatcher(ctx)
594666
}
595667

596-
func (r *Runner) subCalls(callCtx engine.Context, monitor Monitor, env []string, state *State) (_ *State, callResults []SubCallResult, _ error) {
668+
func (r *Runner) subCalls(callCtx engine.Context, monitor Monitor, env []string, state *State, toolCategory engine.ToolCategory) (_ *State, callResults []SubCallResult, _ error) {
597669
var (
598670
resultLock sync.Mutex
599671
)
600672

673+
if state.InputContextContinuation != nil {
674+
return state, nil, nil
675+
}
676+
601677
if state.SubCallID != "" {
602678
if state.ResumeInput == nil {
603679
return nil, nil, fmt.Errorf("invalid state, input must be set for sub call continuation on callID [%s]", state.SubCallID)
@@ -608,7 +684,7 @@ func (r *Runner) subCalls(callCtx engine.Context, monitor Monitor, env []string,
608684
found = true
609685
subState := *subCall.State
610686
subState.ResumeInput = state.ResumeInput
611-
result, err := r.subCallResume(callCtx.Ctx, callCtx, monitor, env, subCall.ToolID, subCall.CallID, subCall.State.WithInput(state.ResumeInput))
687+
result, err := r.subCallResume(callCtx.Ctx, callCtx, monitor, env, subCall.ToolID, subCall.CallID, subCall.State.WithResumeInput(state.ResumeInput), toolCategory)
612688
if err != nil {
613689
return nil, nil, err
614690
}
@@ -618,7 +694,7 @@ func (r *Runner) subCalls(callCtx engine.Context, monitor Monitor, env []string,
618694
State: result,
619695
})
620696
// Clear the input, we have already processed it
621-
state = state.WithInput(nil)
697+
state = state.WithResumeInput(nil)
622698
} else {
623699
callResults = append(callResults, subCall)
624700
}

0 commit comments

Comments
 (0)