Skip to content

change: allow chat continuation in context tools #295

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pkg/builtin/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,9 @@ func SysChatFinish(ctx context.Context, env []string, input string) (string, err
Message string `json:"message,omitempty"`
}
if err := json.Unmarshal([]byte(input), &params); err != nil {
return "", err
return "", &ErrChatFinish{
Message: input,
}
}
return "", &ErrChatFinish{
Message: params.Message,
Expand Down
161 changes: 118 additions & 43 deletions pkg/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,21 +134,25 @@ func (r *Runner) Chat(ctx context.Context, prevState ChatState, prg types.Progra
}()

callCtx := engine.NewContext(ctx, &prg)
if state == nil {
startResult, err := r.start(callCtx, monitor, env, input)
if state == nil || state.StartContinuation {
if state != nil {
state = state.WithResumeInput(&input)
input = state.InputContextContinuationInput
}
state, err = r.start(callCtx, state, monitor, env, input)
if err != nil {
return resp, err
}
state = &State{
Continuation: startResult,
}
} else {
state = state.WithResumeInput(&input)
state.ResumeInput = &input
}

state, err = r.resume(callCtx, monitor, env, state)
if err != nil {
return resp, err
if !state.StartContinuation {
state, err = r.resume(callCtx, monitor, env, state)
if err != nil {
return resp, err
}
}

if state.Result != nil {
Expand Down Expand Up @@ -286,44 +290,79 @@ func getContextInput(prg *types.Program, ref types.ToolReference, input string)
return string(output), err
}

func (r *Runner) getContext(callCtx engine.Context, monitor Monitor, env []string, input string) (result []engine.InputContext, _ error) {
func (r *Runner) getContext(callCtx engine.Context, state *State, monitor Monitor, env []string, input string) (result []engine.InputContext, _ *State, _ error) {
toolRefs, err := callCtx.Program.GetContextToolRefs(callCtx.Tool.ID)
if err != nil {
return nil, err
return nil, nil, err
}

for _, toolRef := range toolRefs {
var newState *State
if state != nil {
cp := *state
newState = &cp
if newState.InputContextContinuation != nil {
newState.InputContexts = nil
newState.InputContextContinuation = nil
newState.InputContextContinuationInput = ""
newState.ResumeInput = state.InputContextContinuationResumeInput

input = state.InputContextContinuationInput
}
}

for i, toolRef := range toolRefs {
if state != nil && i < len(state.InputContexts) {
result = append(result, state.InputContexts[i])
continue
}

contextInput, err := getContextInput(callCtx.Program, toolRef, input)
if err != nil {
return nil, err
return nil, nil, err
}

content, err := r.subCall(callCtx.Ctx, callCtx, monitor, env, toolRef.ToolID, contextInput, "", engine.ContextToolCategory)
var content *State
if state != nil && state.InputContextContinuation != nil {
content, err = r.subCallResume(callCtx.Ctx, callCtx, monitor, env, toolRef.ToolID, "", state.InputContextContinuation.WithResumeInput(state.ResumeInput), engine.ContextToolCategory)
} else {
content, err = r.subCall(callCtx.Ctx, callCtx, monitor, env, toolRef.ToolID, contextInput, "", engine.ContextToolCategory)
}
if err != nil {
return nil, err
return nil, nil, err
}
if content.Result == nil {
return nil, fmt.Errorf("context tool can not result in a chat continuation")
if content.Continuation != nil {
if newState == nil {
newState = &State{}
}
newState.InputContexts = result
newState.InputContextContinuation = content
newState.InputContextContinuationInput = input
if state != nil {
newState.InputContextContinuationResumeInput = state.ResumeInput
}
return nil, newState, nil
}
result = append(result, engine.InputContext{
ToolID: toolRef.ToolID,
Content: *content.Result,
})
}
return result, nil

return result, newState, nil
}

func (r *Runner) call(callCtx engine.Context, monitor Monitor, env []string, input string) (*State, error) {
result, err := r.start(callCtx, monitor, env, input)
result, err := r.start(callCtx, nil, monitor, env, input)
if err != nil {
return nil, err
}
return r.resume(callCtx, monitor, env, &State{
Continuation: result,
})
if result.StartContinuation {
return result, nil
}
return r.resume(callCtx, monitor, env, result)
}

func (r *Runner) start(callCtx engine.Context, monitor Monitor, env []string, input string) (*engine.Return, error) {
func (r *Runner) start(callCtx engine.Context, state *State, monitor Monitor, env []string, input string) (*State, error) {
progress, progressClose := streamProgress(&callCtx, monitor)
defer progressClose()

Expand All @@ -335,11 +374,18 @@ func (r *Runner) start(callCtx engine.Context, monitor Monitor, env []string, in
}
}

var err error
callCtx.InputContext, err = r.getContext(callCtx, monitor, env, input)
var (
err error
newState *State
)
callCtx.InputContext, newState, err = r.getContext(callCtx, state, monitor, env, input)
if err != nil {
return nil, err
}
if newState != nil && newState.InputContextContinuation != nil {
newState.StartContinuation = true
return newState, nil
}

e := engine.Engine{
Model: r.c,
Expand All @@ -358,7 +404,14 @@ func (r *Runner) start(callCtx engine.Context, monitor Monitor, env []string, in

callCtx.Ctx = context2.AddPauseFuncToCtx(callCtx.Ctx, monitor.Pause)

return e.Start(callCtx, input)
ret, err := e.Start(callCtx, input)
if err != nil {
return nil, err
}

return &State{
Continuation: ret,
}, nil
}

type State struct {
Expand All @@ -369,18 +422,28 @@ type State struct {
ResumeInput *string `json:"resumeInput,omitempty"`
SubCalls []SubCallResult `json:"subCalls,omitempty"`
SubCallID string `json:"subCallID,omitempty"`

InputContexts []engine.InputContext `json:"inputContexts,omitempty"`
InputContextContinuation *State `json:"inputContextContinuation,omitempty"`
InputContextContinuationInput string `json:"inputContextContinuationInput,omitempty"`
InputContextContinuationResumeInput *string `json:"inputContextContinuationResumeInput,omitempty"`
StartContinuation bool `json:"startContinuation,omitempty"`
}

func (s State) WithInput(input *string) *State {
func (s State) WithResumeInput(input *string) *State {
s.ResumeInput = input
return &s
}

func (s State) ContinuationContentToolID() (string, error) {
if s.Continuation.Result != nil {
if s.Continuation != nil && s.Continuation.Result != nil {
return s.ContinuationToolID, nil
}

if s.InputContextContinuation != nil {
return s.InputContextContinuation.ContinuationContentToolID()
}

for _, subCall := range s.SubCalls {
if s.SubCallID == subCall.CallID {
return subCall.State.ContinuationContentToolID()
Expand All @@ -390,10 +453,14 @@ func (s State) ContinuationContentToolID() (string, error) {
}

func (s State) ContinuationContent() (string, error) {
if s.Continuation.Result != nil {
if s.Continuation != nil && s.Continuation.Result != nil {
return *s.Continuation.Result, nil
}

if s.InputContextContinuation != nil {
return s.InputContextContinuation.ContinuationContent()
}

for _, subCall := range s.SubCalls {
if s.SubCallID == subCall.CallID {
return subCall.State.ContinuationContent()
Expand All @@ -408,6 +475,10 @@ type Needed struct {
}

func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, state *State) (*State, error) {
if state.StartContinuation {
return nil, fmt.Errorf("invalid state, resume should not have StartContinuation set to true")
}

progress, progressClose := streamProgress(&callCtx, monitor)
defer progressClose()

Expand Down Expand Up @@ -451,7 +522,7 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s
err error
)

state, callResults, err = r.subCalls(callCtx, monitor, env, state)
state, callResults, err = r.subCalls(callCtx, monitor, env, state, engine.NoCategory)
if errMessage := (*builtin.ErrChatFinish)(nil); errors.As(err, &errMessage) && callCtx.Tool.Chat {
return &State{
Result: &errMessage.Message,
Expand All @@ -477,12 +548,6 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s
}
}

if state.ResumeInput != nil {
engineResults = append(engineResults, engine.CallResult{
User: *state.ResumeInput,
})
}

monitor.Event(Event{
Time: time.Now(),
CallContext: callCtx.GetCallContext(),
Expand All @@ -506,9 +571,15 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s
contentInput = state.Continuation.State.Input
}

callCtx.InputContext, err = r.getContext(callCtx, monitor, env, contentInput)
if err != nil {
return nil, err
callCtx.InputContext, state, err = r.getContext(callCtx, state, monitor, env, contentInput)
if err != nil || state.InputContextContinuation != nil {
return state, err
}

if state.ResumeInput != nil {
engineResults = append(engineResults, engine.CallResult{
User: *state.ResumeInput,
})
}

nextContinuation, err := e.Continue(callCtx, state.Continuation.State, engineResults...)
Expand Down Expand Up @@ -571,8 +642,8 @@ func (r *Runner) subCall(ctx context.Context, parentContext engine.Context, moni
return r.call(callCtx, monitor, env, input)
}

func (r *Runner) subCallResume(ctx context.Context, parentContext engine.Context, monitor Monitor, env []string, toolID, callID string, state *State) (*State, error) {
callCtx, err := parentContext.SubCall(ctx, toolID, callID, engine.NoCategory)
func (r *Runner) subCallResume(ctx context.Context, parentContext engine.Context, monitor Monitor, env []string, toolID, callID string, state *State, toolCategory engine.ToolCategory) (*State, error) {
callCtx, err := parentContext.SubCall(ctx, toolID, callID, toolCategory)
if err != nil {
return nil, err
}
Expand All @@ -593,11 +664,15 @@ func (r *Runner) newDispatcher(ctx context.Context) dispatcher {
return newParallelDispatcher(ctx)
}

func (r *Runner) subCalls(callCtx engine.Context, monitor Monitor, env []string, state *State) (_ *State, callResults []SubCallResult, _ error) {
func (r *Runner) subCalls(callCtx engine.Context, monitor Monitor, env []string, state *State, toolCategory engine.ToolCategory) (_ *State, callResults []SubCallResult, _ error) {
var (
resultLock sync.Mutex
)

if state.InputContextContinuation != nil {
return state, nil, nil
}

if state.SubCallID != "" {
if state.ResumeInput == nil {
return nil, nil, fmt.Errorf("invalid state, input must be set for sub call continuation on callID [%s]", state.SubCallID)
Expand All @@ -608,7 +683,7 @@ func (r *Runner) subCalls(callCtx engine.Context, monitor Monitor, env []string,
found = true
subState := *subCall.State
subState.ResumeInput = state.ResumeInput
result, err := r.subCallResume(callCtx.Ctx, callCtx, monitor, env, subCall.ToolID, subCall.CallID, subCall.State.WithInput(state.ResumeInput))
result, err := r.subCallResume(callCtx.Ctx, callCtx, monitor, env, subCall.ToolID, subCall.CallID, subCall.State.WithResumeInput(state.ResumeInput), toolCategory)
if err != nil {
return nil, nil, err
}
Expand All @@ -618,7 +693,7 @@ func (r *Runner) subCalls(callCtx engine.Context, monitor Monitor, env []string,
State: result,
})
// Clear the input, we have already processed it
state = state.WithInput(nil)
state = state.WithResumeInput(nil)
} else {
callResults = append(callResults, subCall)
}
Expand Down
Loading