diff --git a/pkg/confirm/confirm.go b/pkg/confirm/confirm.go index be45a18f..a494aded 100644 --- a/pkg/confirm/confirm.go +++ b/pkg/confirm/confirm.go @@ -9,7 +9,7 @@ import ( ) type Confirm interface { - Confirm(prompt string) error + Confirm(ctx context.Context, prompt string) error } type confirmer struct{} @@ -23,13 +23,13 @@ func Promptf(ctx context.Context, fmtString string, args ...any) error { if !ok { return nil } - return c.Confirm(fmt.Sprintf(fmtString, args...)) + return c.Confirm(ctx, fmt.Sprintf(fmtString, args...)) } type TextPrompt struct { } -func (t TextPrompt) Confirm(prompt string) error { +func (t TextPrompt) Confirm(_ context.Context, prompt string) error { var result bool err := survey.AskOne(&survey.Confirm{ Message: prompt, diff --git a/pkg/engine/engine.go b/pkg/engine/engine.go index 24ee9302..7bad5222 100644 --- a/pkg/engine/engine.go +++ b/pkg/engine/engine.go @@ -109,6 +109,17 @@ func (c *Context) SubCall(ctx context.Context, toolID, callID string) (Context, }, nil } +type engineContext struct{} + +func FromContext(ctx context.Context) (*Context, bool) { + c, ok := ctx.Value(engineContext{}).(*Context) + return c, ok +} + +func (c *Context) WrappedContext() context.Context { + return context.WithValue(c.Ctx, engineContext{}, c) +} + func (e *Engine) Start(ctx Context, input string) (*Return, error) { tool := ctx.Tool @@ -120,7 +131,7 @@ func (e *Engine) Start(ctx Context, input string) (*Return, error) { } else if tool.IsOpenAPI() { return e.runOpenAPI(tool, input) } - s, err := e.runCommand(ctx.Ctx, tool, input) + s, err := e.runCommand(ctx.WrappedContext(), tool, input) if err != nil { return nil, err } diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index c717c78b..10990954 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -182,7 +182,7 @@ func (r *Runner) call(callCtx engine.Context, monitor Monitor, env []string, inp } } -func streamProgress(callCtx *engine.Context, monitor Monitor) (chan types.CompletionStatus, func()) { +func streamProgress(callCtx *engine.Context, monitor Monitor) (chan<- types.CompletionStatus, func()) { progress := make(chan types.CompletionStatus) wg := sync.WaitGroup{}