diff --git a/pkg/prompt/prompt.go b/pkg/prompt/prompt.go index 4a9550a3..c1b693be 100644 --- a/pkg/prompt/prompt.go +++ b/pkg/prompt/prompt.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -75,6 +76,17 @@ func SysPrompt(ctx context.Context, envs []string, input string, _ chan<- string func sysPrompt(ctx context.Context, req types.Prompt) (_ string, err error) { defer context2.GetPauseFuncFromCtx(ctx)()() + if req.Message != "" && len(req.Fields) == 1 && strings.TrimSpace(req.Fields[0]) == "" { + var errs []error + _, err := fmt.Fprintln(os.Stderr, req.Message) + errs = append(errs, err) + _, err = fmt.Fprintln(os.Stderr, "Press enter to continue...") + errs = append(errs, err) + _, err = fmt.Fscanln(os.Stdin) + errs = append(errs, err) + return "", errors.Join(errs...) + } + if req.Message != "" && len(req.Fields) != 1 { _, _ = fmt.Fprintln(os.Stderr, req.Message) } diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index e4794535..8472a8c0 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -567,7 +567,7 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s Time: time.Now(), CallContext: callCtx.GetCallContext(), Type: EventTypeCallFinish, - Content: getFinishEventContent(*state, callCtx), + Content: getEventContent(*state.Continuation.Result, callCtx), }) if callCtx.Tool.Chat { return &State{ @@ -681,7 +681,7 @@ func streamProgress(callCtx *engine.Context, monitor Monitor) (chan<- types.Comp CallContext: callCtx.GetCallContext(), Type: EventTypeCallProgress, ChatCompletionID: status.CompletionID, - Content: message.String(), + Content: getEventContent(message.String(), *callCtx), }) } else { monitor.Event(Event{ @@ -821,13 +821,13 @@ func (r *Runner) subCalls(callCtx engine.Context, monitor Monitor, env []string, return state, callResults, nil } -func getFinishEventContent(state State, callCtx engine.Context) string { - // If it is a credential tool, the finish event contains its output, which is sensitive, so we don't return it. +func getEventContent(content string, callCtx engine.Context) string { + // If it is a credential tool, the progress and finish events may contain its output, which is sensitive, so we don't return it. if callCtx.ToolCategory == engine.CredentialToolCategory { return "" } - return *state.Continuation.Result + return content } func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env []string) ([]string, error) { diff --git a/pkg/types/credential_test.go b/pkg/types/credential_test.go index 74e74e55..b6f70ee3 100644 --- a/pkg/types/credential_test.go +++ b/pkg/types/credential_test.go @@ -125,10 +125,12 @@ func TestParseCredentialArgs(t *testing.T) { wantErr: true, }, { - name: "invalid input", - toolName: "myCredentialTool", - input: `{"asdf":"asdf"`, - wantErr: true, + name: "invalid input", + toolName: "myCredentialTool", + input: `{"asdf":"asdf"`, + expectedName: "myCredentialTool", + expectedAlias: "", + wantErr: false, }, } diff --git a/pkg/types/tool.go b/pkg/types/tool.go index a5124796..b0af5183 100644 --- a/pkg/types/tool.go +++ b/pkg/types/tool.go @@ -255,10 +255,10 @@ func ParseCredentialArgs(toolName string, input string) (string, string, map[str inputMap := make(map[string]any) if input != "" { - err := json.Unmarshal([]byte(input), &inputMap) - if err != nil { - return "", "", nil, fmt.Errorf("failed to unmarshal input: %w", err) - } + // Sometimes this function can be called with input that is not a JSON string. + // This typically happens during chat mode. + // That's why we ignore the error if this fails to unmarshal. + _ = json.Unmarshal([]byte(input), &inputMap) } fields, err := shlex.Split(toolName)