From 47f097e1f9141bdc9a767cfdab6c502ac9bfb5d4 Mon Sep 17 00:00:00 2001 From: Grant Linville Date: Mon, 24 Jun 2024 15:04:35 -0400 Subject: [PATCH 1/4] fix: misc improvements related to creds and prompting Signed-off-by: Grant Linville --- pkg/prompt/prompt.go | 7 +++++++ pkg/runner/runner.go | 10 +++++----- pkg/types/tool.go | 8 ++++---- 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/pkg/prompt/prompt.go b/pkg/prompt/prompt.go index 4a9550a3..b0ddaa4d 100644 --- a/pkg/prompt/prompt.go +++ b/pkg/prompt/prompt.go @@ -75,6 +75,13 @@ func SysPrompt(ctx context.Context, envs []string, input string, _ chan<- string func sysPrompt(ctx context.Context, req types.Prompt) (_ string, err error) { defer context2.GetPauseFuncFromCtx(ctx)()() + if req.Message != "" && len(req.Fields) == 1 && strings.TrimSpace(req.Fields[0]) == "" { + _, _ = fmt.Fprintln(os.Stderr, req.Message) + _, _ = fmt.Fprintln(os.Stderr, "Press enter to continue...") + _, _ = fmt.Fscanln(os.Stdin) + return "", nil + } + if req.Message != "" && len(req.Fields) != 1 { _, _ = fmt.Fprintln(os.Stderr, req.Message) } diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index e4794535..8472a8c0 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -567,7 +567,7 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s Time: time.Now(), CallContext: callCtx.GetCallContext(), Type: EventTypeCallFinish, - Content: getFinishEventContent(*state, callCtx), + Content: getEventContent(*state.Continuation.Result, callCtx), }) if callCtx.Tool.Chat { return &State{ @@ -681,7 +681,7 @@ func streamProgress(callCtx *engine.Context, monitor Monitor) (chan<- types.Comp CallContext: callCtx.GetCallContext(), Type: EventTypeCallProgress, ChatCompletionID: status.CompletionID, - Content: message.String(), + Content: getEventContent(message.String(), *callCtx), }) } else { monitor.Event(Event{ @@ -821,13 +821,13 @@ func (r *Runner) subCalls(callCtx engine.Context, monitor Monitor, env []string, return state, callResults, nil } -func getFinishEventContent(state State, callCtx engine.Context) string { - // If it is a credential tool, the finish event contains its output, which is sensitive, so we don't return it. +func getEventContent(content string, callCtx engine.Context) string { + // If it is a credential tool, the progress and finish events may contain its output, which is sensitive, so we don't return it. if callCtx.ToolCategory == engine.CredentialToolCategory { return "" } - return *state.Continuation.Result + return content } func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env []string) ([]string, error) { diff --git a/pkg/types/tool.go b/pkg/types/tool.go index a5124796..b0af5183 100644 --- a/pkg/types/tool.go +++ b/pkg/types/tool.go @@ -255,10 +255,10 @@ func ParseCredentialArgs(toolName string, input string) (string, string, map[str inputMap := make(map[string]any) if input != "" { - err := json.Unmarshal([]byte(input), &inputMap) - if err != nil { - return "", "", nil, fmt.Errorf("failed to unmarshal input: %w", err) - } + // Sometimes this function can be called with input that is not a JSON string. + // This typically happens during chat mode. + // That's why we ignore the error if this fails to unmarshal. + _ = json.Unmarshal([]byte(input), &inputMap) } fields, err := shlex.Split(toolName) From 8276aa366404285eb2f63ad46cd3840bcb877c00 Mon Sep 17 00:00:00 2001 From: Grant Linville Date: Mon, 24 Jun 2024 15:07:37 -0400 Subject: [PATCH 2/4] fix test Signed-off-by: Grant Linville --- pkg/types/credential_test.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/pkg/types/credential_test.go b/pkg/types/credential_test.go index 74e74e55..b6f70ee3 100644 --- a/pkg/types/credential_test.go +++ b/pkg/types/credential_test.go @@ -125,10 +125,12 @@ func TestParseCredentialArgs(t *testing.T) { wantErr: true, }, { - name: "invalid input", - toolName: "myCredentialTool", - input: `{"asdf":"asdf"`, - wantErr: true, + name: "invalid input", + toolName: "myCredentialTool", + input: `{"asdf":"asdf"`, + expectedName: "myCredentialTool", + expectedAlias: "", + wantErr: false, }, } From 6b12f5416c4c084c557f22890b8c13beeb1a5940 Mon Sep 17 00:00:00 2001 From: Grant Linville Date: Tue, 25 Jun 2024 14:18:10 -0400 Subject: [PATCH 3/4] Update pkg/prompt/prompt.go Co-authored-by: Donnie Adams Signed-off-by: Grant Linville --- pkg/prompt/prompt.go | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/pkg/prompt/prompt.go b/pkg/prompt/prompt.go index b0ddaa4d..4a80d706 100644 --- a/pkg/prompt/prompt.go +++ b/pkg/prompt/prompt.go @@ -76,10 +76,14 @@ func sysPrompt(ctx context.Context, req types.Prompt) (_ string, err error) { defer context2.GetPauseFuncFromCtx(ctx)()() if req.Message != "" && len(req.Fields) == 1 && strings.TrimSpace(req.Fields[0]) == "" { - _, _ = fmt.Fprintln(os.Stderr, req.Message) - _, _ = fmt.Fprintln(os.Stderr, "Press enter to continue...") - _, _ = fmt.Fscanln(os.Stdin) - return "", nil + var errs []error + _, err := fmt.Fprintln(os.Stderr, req.Message) + errs = append(errs, err) + _, err = fmt.Fprintln(os.Stderr, "Press enter to continue...") + errs = append(errs, err) + _, err = fmt.Fscanln(os.Stdin) + errs = append(errs, err) + return "", errors.Join(errs...) } if req.Message != "" && len(req.Fields) != 1 { From 534e466957f3f0079456f44fd8d91e6999a58a45 Mon Sep 17 00:00:00 2001 From: Grant Linville Date: Tue, 25 Jun 2024 14:34:21 -0400 Subject: [PATCH 4/4] fix import Signed-off-by: Grant Linville --- pkg/prompt/prompt.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pkg/prompt/prompt.go b/pkg/prompt/prompt.go index 4a80d706..c1b693be 100644 --- a/pkg/prompt/prompt.go +++ b/pkg/prompt/prompt.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -77,7 +78,7 @@ func sysPrompt(ctx context.Context, req types.Prompt) (_ string, err error) { if req.Message != "" && len(req.Fields) == 1 && strings.TrimSpace(req.Fields[0]) == "" { var errs []error - _, err := fmt.Fprintln(os.Stderr, req.Message) + _, err := fmt.Fprintln(os.Stderr, req.Message) errs = append(errs, err) _, err = fmt.Fprintln(os.Stderr, "Press enter to continue...") errs = append(errs, err)