@@ -134,21 +134,25 @@ func (r *Runner) Chat(ctx context.Context, prevState ChatState, prg types.Progra
134
134
}()
135
135
136
136
callCtx := engine .NewContext (ctx , & prg )
137
- if state == nil {
138
- startResult , err := r .start (callCtx , monitor , env , input )
137
+ if state == nil || state .StartContinuation {
138
+ if state != nil {
139
+ state = state .WithResumeInput (& input )
140
+ input = state .InputContextContinuationInput
141
+ }
142
+ state , err = r .start (callCtx , state , monitor , env , input )
139
143
if err != nil {
140
144
return resp , err
141
145
}
142
- state = & State {
143
- Continuation : startResult ,
144
- }
145
146
} else {
147
+ state = state .WithResumeInput (& input )
146
148
state .ResumeInput = & input
147
149
}
148
150
149
- state , err = r .resume (callCtx , monitor , env , state )
150
- if err != nil {
151
- return resp , err
151
+ if ! state .StartContinuation {
152
+ state , err = r .resume (callCtx , monitor , env , state )
153
+ if err != nil {
154
+ return resp , err
155
+ }
152
156
}
153
157
154
158
if state .Result != nil {
@@ -286,44 +290,79 @@ func getContextInput(prg *types.Program, ref types.ToolReference, input string)
286
290
return string (output ), err
287
291
}
288
292
289
- func (r * Runner ) getContext (callCtx engine.Context , monitor Monitor , env []string , input string ) (result []engine.InputContext , _ error ) {
293
+ func (r * Runner ) getContext (callCtx engine.Context , state * State , monitor Monitor , env []string , input string ) (result []engine.InputContext , _ * State , _ error ) {
290
294
toolRefs , err := callCtx .Program .GetContextToolRefs (callCtx .Tool .ID )
291
295
if err != nil {
292
- return nil , err
296
+ return nil , nil , err
293
297
}
294
298
295
- for _ , toolRef := range toolRefs {
299
+ var newState * State
300
+ if state != nil {
301
+ cp := * state
302
+ newState = & cp
303
+ if newState .InputContextContinuation != nil {
304
+ newState .InputContexts = nil
305
+ newState .InputContextContinuation = nil
306
+ newState .InputContextContinuationInput = ""
307
+ newState .ResumeInput = state .InputContextContinuationResumeInput
308
+
309
+ input = state .InputContextContinuationInput
310
+ }
311
+ }
312
+
313
+ for i , toolRef := range toolRefs {
314
+ if state != nil && i < len (state .InputContexts ) {
315
+ result = append (result , state .InputContexts [i ])
316
+ continue
317
+ }
318
+
296
319
contextInput , err := getContextInput (callCtx .Program , toolRef , input )
297
320
if err != nil {
298
- return nil , err
321
+ return nil , nil , err
299
322
}
300
323
301
- content , err := r .subCall (callCtx .Ctx , callCtx , monitor , env , toolRef .ToolID , contextInput , "" , engine .ContextToolCategory )
324
+ var content * State
325
+ if state != nil && state .InputContextContinuation != nil {
326
+ content , err = r .subCallResume (callCtx .Ctx , callCtx , monitor , env , toolRef .ToolID , "" , state .InputContextContinuation .WithResumeInput (state .ResumeInput ), engine .ContextToolCategory )
327
+ } else {
328
+ content , err = r .subCall (callCtx .Ctx , callCtx , monitor , env , toolRef .ToolID , contextInput , "" , engine .ContextToolCategory )
329
+ }
302
330
if err != nil {
303
- return nil , err
331
+ return nil , nil , err
304
332
}
305
- if content .Result == nil {
306
- return nil , fmt .Errorf ("context tool can not result in a chat continuation" )
333
+ if content .Continuation != nil {
334
+ if newState == nil {
335
+ newState = & State {}
336
+ }
337
+ newState .InputContexts = result
338
+ newState .InputContextContinuation = content
339
+ newState .InputContextContinuationInput = input
340
+ if state != nil {
341
+ newState .InputContextContinuationResumeInput = state .ResumeInput
342
+ }
343
+ return nil , newState , nil
307
344
}
308
345
result = append (result , engine.InputContext {
309
346
ToolID : toolRef .ToolID ,
310
347
Content : * content .Result ,
311
348
})
312
349
}
313
- return result , nil
350
+
351
+ return result , newState , nil
314
352
}
315
353
316
354
func (r * Runner ) call (callCtx engine.Context , monitor Monitor , env []string , input string ) (* State , error ) {
317
- result , err := r .start (callCtx , monitor , env , input )
355
+ result , err := r .start (callCtx , nil , monitor , env , input )
318
356
if err != nil {
319
357
return nil , err
320
358
}
321
- return r .resume (callCtx , monitor , env , & State {
322
- Continuation : result ,
323
- })
359
+ if result .StartContinuation {
360
+ return result , nil
361
+ }
362
+ return r .resume (callCtx , monitor , env , result )
324
363
}
325
364
326
- func (r * Runner ) start (callCtx engine.Context , monitor Monitor , env []string , input string ) (* engine. Return , error ) {
365
+ func (r * Runner ) start (callCtx engine.Context , state * State , monitor Monitor , env []string , input string ) (* State , error ) {
327
366
progress , progressClose := streamProgress (& callCtx , monitor )
328
367
defer progressClose ()
329
368
@@ -335,11 +374,18 @@ func (r *Runner) start(callCtx engine.Context, monitor Monitor, env []string, in
335
374
}
336
375
}
337
376
338
- var err error
339
- callCtx .InputContext , err = r .getContext (callCtx , monitor , env , input )
377
+ var (
378
+ err error
379
+ newState * State
380
+ )
381
+ callCtx .InputContext , newState , err = r .getContext (callCtx , state , monitor , env , input )
340
382
if err != nil {
341
383
return nil , err
342
384
}
385
+ if newState != nil && newState .InputContextContinuation != nil {
386
+ newState .StartContinuation = true
387
+ return newState , nil
388
+ }
343
389
344
390
e := engine.Engine {
345
391
Model : r .c ,
@@ -358,7 +404,14 @@ func (r *Runner) start(callCtx engine.Context, monitor Monitor, env []string, in
358
404
359
405
callCtx .Ctx = context2 .AddPauseFuncToCtx (callCtx .Ctx , monitor .Pause )
360
406
361
- return e .Start (callCtx , input )
407
+ ret , err := e .Start (callCtx , input )
408
+ if err != nil {
409
+ return nil , err
410
+ }
411
+
412
+ return & State {
413
+ Continuation : ret ,
414
+ }, nil
362
415
}
363
416
364
417
type State struct {
@@ -369,18 +422,28 @@ type State struct {
369
422
ResumeInput * string `json:"resumeInput,omitempty"`
370
423
SubCalls []SubCallResult `json:"subCalls,omitempty"`
371
424
SubCallID string `json:"subCallID,omitempty"`
425
+
426
+ InputContexts []engine.InputContext `json:"inputContexts,omitempty"`
427
+ InputContextContinuation * State `json:"inputContextContinuation,omitempty"`
428
+ InputContextContinuationInput string `json:"inputContextContinuationInput,omitempty"`
429
+ InputContextContinuationResumeInput * string `json:"inputContextContinuationResumeInput,omitempty"`
430
+ StartContinuation bool `json:"startContinuation,omitempty"`
372
431
}
373
432
374
- func (s State ) WithInput (input * string ) * State {
433
+ func (s State ) WithResumeInput (input * string ) * State {
375
434
s .ResumeInput = input
376
435
return & s
377
436
}
378
437
379
438
func (s State ) ContinuationContentToolID () (string , error ) {
380
- if s .Continuation .Result != nil {
439
+ if s .Continuation != nil && s . Continuation .Result != nil {
381
440
return s .ContinuationToolID , nil
382
441
}
383
442
443
+ if s .InputContextContinuation != nil {
444
+ return s .InputContextContinuation .ContinuationContentToolID ()
445
+ }
446
+
384
447
for _ , subCall := range s .SubCalls {
385
448
if s .SubCallID == subCall .CallID {
386
449
return subCall .State .ContinuationContentToolID ()
@@ -390,10 +453,14 @@ func (s State) ContinuationContentToolID() (string, error) {
390
453
}
391
454
392
455
func (s State ) ContinuationContent () (string , error ) {
393
- if s .Continuation .Result != nil {
456
+ if s .Continuation != nil && s . Continuation .Result != nil {
394
457
return * s .Continuation .Result , nil
395
458
}
396
459
460
+ if s .InputContextContinuation != nil {
461
+ return s .InputContextContinuation .ContinuationContent ()
462
+ }
463
+
397
464
for _ , subCall := range s .SubCalls {
398
465
if s .SubCallID == subCall .CallID {
399
466
return subCall .State .ContinuationContent ()
@@ -408,6 +475,10 @@ type Needed struct {
408
475
}
409
476
410
477
func (r * Runner ) resume (callCtx engine.Context , monitor Monitor , env []string , state * State ) (* State , error ) {
478
+ if state .StartContinuation {
479
+ return nil , fmt .Errorf ("invalid state, resume should not have StartContinuation set to true" )
480
+ }
481
+
411
482
progress , progressClose := streamProgress (& callCtx , monitor )
412
483
defer progressClose ()
413
484
@@ -451,7 +522,7 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s
451
522
err error
452
523
)
453
524
454
- state , callResults , err = r .subCalls (callCtx , monitor , env , state )
525
+ state , callResults , err = r .subCalls (callCtx , monitor , env , state , engine . NoCategory )
455
526
if errMessage := (* builtin .ErrChatFinish )(nil ); errors .As (err , & errMessage ) && callCtx .Tool .Chat {
456
527
return & State {
457
528
Result : & errMessage .Message ,
@@ -477,12 +548,6 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s
477
548
}
478
549
}
479
550
480
- if state .ResumeInput != nil {
481
- engineResults = append (engineResults , engine.CallResult {
482
- User : * state .ResumeInput ,
483
- })
484
- }
485
-
486
551
monitor .Event (Event {
487
552
Time : time .Now (),
488
553
CallContext : callCtx .GetCallContext (),
@@ -506,9 +571,15 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s
506
571
contentInput = state .Continuation .State .Input
507
572
}
508
573
509
- callCtx .InputContext , err = r .getContext (callCtx , monitor , env , contentInput )
510
- if err != nil {
511
- return nil , err
574
+ callCtx .InputContext , state , err = r .getContext (callCtx , state , monitor , env , contentInput )
575
+ if err != nil || state .InputContextContinuation != nil {
576
+ return state , err
577
+ }
578
+
579
+ if state .ResumeInput != nil {
580
+ engineResults = append (engineResults , engine.CallResult {
581
+ User : * state .ResumeInput ,
582
+ })
512
583
}
513
584
514
585
nextContinuation , err := e .Continue (callCtx , state .Continuation .State , engineResults ... )
@@ -571,8 +642,8 @@ func (r *Runner) subCall(ctx context.Context, parentContext engine.Context, moni
571
642
return r .call (callCtx , monitor , env , input )
572
643
}
573
644
574
- func (r * Runner ) subCallResume (ctx context.Context , parentContext engine.Context , monitor Monitor , env []string , toolID , callID string , state * State ) (* State , error ) {
575
- callCtx , err := parentContext .SubCall (ctx , toolID , callID , engine . NoCategory )
645
+ func (r * Runner ) subCallResume (ctx context.Context , parentContext engine.Context , monitor Monitor , env []string , toolID , callID string , state * State , toolCategory engine. ToolCategory ) (* State , error ) {
646
+ callCtx , err := parentContext .SubCall (ctx , toolID , callID , toolCategory )
576
647
if err != nil {
577
648
return nil , err
578
649
}
@@ -593,11 +664,15 @@ func (r *Runner) newDispatcher(ctx context.Context) dispatcher {
593
664
return newParallelDispatcher (ctx )
594
665
}
595
666
596
- func (r * Runner ) subCalls (callCtx engine.Context , monitor Monitor , env []string , state * State ) (_ * State , callResults []SubCallResult , _ error ) {
667
+ func (r * Runner ) subCalls (callCtx engine.Context , monitor Monitor , env []string , state * State , toolCategory engine. ToolCategory ) (_ * State , callResults []SubCallResult , _ error ) {
597
668
var (
598
669
resultLock sync.Mutex
599
670
)
600
671
672
+ if state .InputContextContinuation != nil {
673
+ return state , nil , nil
674
+ }
675
+
601
676
if state .SubCallID != "" {
602
677
if state .ResumeInput == nil {
603
678
return nil , nil , fmt .Errorf ("invalid state, input must be set for sub call continuation on callID [%s]" , state .SubCallID )
@@ -608,7 +683,7 @@ func (r *Runner) subCalls(callCtx engine.Context, monitor Monitor, env []string,
608
683
found = true
609
684
subState := * subCall .State
610
685
subState .ResumeInput = state .ResumeInput
611
- result , err := r .subCallResume (callCtx .Ctx , callCtx , monitor , env , subCall .ToolID , subCall .CallID , subCall .State .WithInput (state .ResumeInput ))
686
+ result , err := r .subCallResume (callCtx .Ctx , callCtx , monitor , env , subCall .ToolID , subCall .CallID , subCall .State .WithResumeInput (state .ResumeInput ), toolCategory )
612
687
if err != nil {
613
688
return nil , nil , err
614
689
}
@@ -618,7 +693,7 @@ func (r *Runner) subCalls(callCtx engine.Context, monitor Monitor, env []string,
618
693
State : result ,
619
694
})
620
695
// Clear the input, we have already processed it
621
- state = state .WithInput (nil )
696
+ state = state .WithResumeInput (nil )
622
697
} else {
623
698
callResults = append (callResults , subCall )
624
699
}
0 commit comments