From ab3240455a00c07842eff95286f5219123e0a606 Mon Sep 17 00:00:00 2001 From: Donnie Adams Date: Tue, 4 Feb 2025 09:58:59 -0500 Subject: [PATCH] enhance: add field-level sensitivity for prompts Additionally, each field can now also have a description. This change is made such that all existing tools will work. However, existing code will need to be updated to support the new types. Signed-off-by: Donnie Adams --- go.mod | 4 +- go.sum | 8 +-- pkg/cli/gptscript.go | 1 - pkg/engine/call.go | 10 +-- pkg/prompt/prompt.go | 25 ++++--- pkg/types/prompt.go | 63 ++++++++++++++++- pkg/types/prompt_test.go | 142 +++++++++++++++++++++++++++++++++++++++ 7 files changed, 227 insertions(+), 26 deletions(-) create mode 100644 pkg/types/prompt_test.go diff --git a/go.mod b/go.mod index 72e43a50..bc80c47e 100644 --- a/go.mod +++ b/go.mod @@ -17,8 +17,8 @@ require ( github.com/gptscript-ai/broadcaster v0.0.0-20240625175512-c43682019b86 github.com/gptscript-ai/chat-completion-client v0.0.0-20250128181713-57857b74f9f1 github.com/gptscript-ai/cmd v0.0.0-20240802230653-326b7baf6fcb - github.com/gptscript-ai/go-gptscript v0.9.5-rc5.0.20240927213153-2af51434b93e - github.com/gptscript-ai/tui v0.0.0-20240923192013-172e51ccf1d6 + github.com/gptscript-ai/go-gptscript v0.9.6-0.20250204133419-744b25b84a61 + github.com/gptscript-ai/tui v0.0.0-20250204145344-33cd15de4cee github.com/hexops/autogold/v2 v2.2.1 github.com/hexops/valast v1.4.4 github.com/jaytaylor/html2text v0.0.0-20230321000545-74c2419ad056 diff --git a/go.sum b/go.sum index ff7e21e9..7ed757bd 100644 --- a/go.sum +++ b/go.sum @@ -201,10 +201,10 @@ github.com/gptscript-ai/chat-completion-client v0.0.0-20250128181713-57857b74f9f github.com/gptscript-ai/chat-completion-client v0.0.0-20250128181713-57857b74f9f1/go.mod h1:7P/o6/IWa1KqsntVf68hSnLKuu3+xuqm6lYhch1w4jo= github.com/gptscript-ai/cmd v0.0.0-20240802230653-326b7baf6fcb h1:ky2J2CzBOskC7Jgm2VJAQi2x3p7FVGa+2/PcywkFJuc= github.com/gptscript-ai/cmd v0.0.0-20240802230653-326b7baf6fcb/go.mod h1:DJAo1xTht1LDkNYFNydVjTHd576TC7MlpsVRl3oloVw= -github.com/gptscript-ai/go-gptscript v0.9.5-rc5.0.20240927213153-2af51434b93e h1:WpNae0NBx+Ri8RB3SxF8DhadDKU7h+jfWPQterDpbJA= -github.com/gptscript-ai/go-gptscript v0.9.5-rc5.0.20240927213153-2af51434b93e/go.mod h1:/FVuLwhz+sIfsWUgUHWKi32qT0i6+IXlUlzs70KKt/Q= -github.com/gptscript-ai/tui v0.0.0-20240923192013-172e51ccf1d6 h1:vkgNZVWQgbE33VD3z9WKDwuu7B/eJVVMMPM62ixfCR8= -github.com/gptscript-ai/tui v0.0.0-20240923192013-172e51ccf1d6/go.mod h1:frrl/B+ZH3VSs3Tqk2qxEIIWTONExX3tuUa4JsVnqx4= +github.com/gptscript-ai/go-gptscript v0.9.6-0.20250204133419-744b25b84a61 h1:QxLjsLOYlsVLPwuRkP0Q8EcAoZT1s8vU2ZBSX0+R6CI= +github.com/gptscript-ai/go-gptscript v0.9.6-0.20250204133419-744b25b84a61/go.mod h1:/FVuLwhz+sIfsWUgUHWKi32qT0i6+IXlUlzs70KKt/Q= +github.com/gptscript-ai/tui v0.0.0-20250204145344-33cd15de4cee h1:70PHW6Xw70yNNZ5aX936XqcMLwNmfMZpCV3FCOGKpxE= +github.com/gptscript-ai/tui v0.0.0-20250204145344-33cd15de4cee/go.mod h1:iwHxuueg2paOak7zIg0ESBWx7A0wIHGopAratbgaPNY= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= diff --git a/pkg/cli/gptscript.go b/pkg/cli/gptscript.go index a3454dd5..4bd04509 100644 --- a/pkg/cli/gptscript.go +++ b/pkg/cli/gptscript.go @@ -494,7 +494,6 @@ func (r *GPTScript) Run(cmd *cobra.Command, args []string) (retErr error) { DisableCache: r.DisableCache, CredentialOverrides: r.CredentialOverride, Input: toolInput, - CacheDir: r.CacheDir, SubTool: r.SubTool, Workspace: r.Workspace, SaveChatStateFile: r.SaveChatStateFile, diff --git a/pkg/engine/call.go b/pkg/engine/call.go index d116d0fa..4a3b70b5 100644 --- a/pkg/engine/call.go +++ b/pkg/engine/call.go @@ -76,13 +76,13 @@ func mergeInputs(base, overlay string) (string, error) { return base, nil } - err := json.Unmarshal([]byte(base), &baseMap) - if err != nil { - return "", fmt.Errorf("failed to unmarshal base input: %w", err) + if base != "" { + if err := json.Unmarshal([]byte(base), &baseMap); err != nil { + return "", fmt.Errorf("failed to unmarshal base input: %w", err) + } } - err = json.Unmarshal([]byte(overlay), &overlayMap) - if err != nil { + if err := json.Unmarshal([]byte(overlay), &overlayMap); err != nil { return "", fmt.Errorf("failed to unmarshal overlay input: %w", err) } diff --git a/pkg/prompt/prompt.go b/pkg/prompt/prompt.go index f91a04b6..fa4beeb6 100644 --- a/pkg/prompt/prompt.go +++ b/pkg/prompt/prompt.go @@ -52,7 +52,7 @@ func sysPromptHTTP(ctx context.Context, envs []string, url string, prompt types. func SysPrompt(ctx context.Context, envs []string, input string, _ chan<- string) (_ string, err error) { var params struct { Message string `json:"message,omitempty"` - Fields string `json:"fields,omitempty"` + Fields types.Fields `json:"fields,omitempty"` Sensitive string `json:"sensitive,omitempty"` Metadata map[string]string `json:"metadata,omitempty"` } @@ -60,16 +60,11 @@ func SysPrompt(ctx context.Context, envs []string, input string, _ chan<- string return "", err } - var fields []string for _, env := range envs { if url, ok := strings.CutPrefix(env, types.PromptURLEnvVar+"="); ok { - if params.Fields != "" { - fields = strings.Split(params.Fields, ",") - } - httpPrompt := types.Prompt{ Message: params.Message, - Fields: fields, + Fields: params.Fields, Sensitive: params.Sensitive == "true", Metadata: params.Metadata, } @@ -102,21 +97,25 @@ func sysPrompt(ctx context.Context, req types.Prompt) (_ string, err error) { results := map[string]string{} for _, f := range req.Fields { var ( - value string - msg = f + value string + msg = f.Name + sensitive = req.Sensitive ) + if f.Sensitive != nil { + sensitive = *f.Sensitive + } if len(req.Fields) == 1 && req.Message != "" { msg = req.Message } - if req.Sensitive { - err = survey.AskOne(&survey.Password{Message: msg}, &value, survey.WithStdio(os.Stdin, os.Stderr, os.Stderr)) + if sensitive { + err = survey.AskOne(&survey.Password{Message: msg, Help: f.Description}, &value, survey.WithStdio(os.Stdin, os.Stderr, os.Stderr)) } else { - err = survey.AskOne(&survey.Input{Message: msg}, &value, survey.WithStdio(os.Stdin, os.Stderr, os.Stderr)) + err = survey.AskOne(&survey.Input{Message: msg, Help: f.Description}, &value, survey.WithStdio(os.Stdin, os.Stderr, os.Stderr)) } if err != nil { return "", err } - results[f] = value + results[f.Name] = value } resultsStr, err := json.Marshal(results) diff --git a/pkg/types/prompt.go b/pkg/types/prompt.go index 653ad066..3da40a2e 100644 --- a/pkg/types/prompt.go +++ b/pkg/types/prompt.go @@ -1,5 +1,10 @@ package types +import ( + "encoding/json" + "strings" +) + const ( PromptURLEnvVar = "GPTSCRIPT_PROMPT_URL" PromptTokenEnvVar = "GPTSCRIPT_PROMPT_TOKEN" @@ -7,7 +12,63 @@ const ( type Prompt struct { Message string `json:"message,omitempty"` - Fields []string `json:"fields,omitempty"` + Fields Fields `json:"fields,omitempty"` Sensitive bool `json:"sensitive,omitempty"` Metadata map[string]string `json:"metadata,omitempty"` } + +type Field struct { + Name string `json:"name,omitempty"` + Sensitive *bool `json:"sensitive,omitempty"` + Description string `json:"description,omitempty"` +} + +type Fields []Field + +// UnmarshalJSON will unmarshal the corresponding JSON object for Fields, +// or a comma-separated strings (for backwards compatibility). +func (f *Fields) UnmarshalJSON(b []byte) error { + if len(b) == 0 || f == nil { + return nil + } + + if b[0] == '[' { + var arr []Field + if err := json.Unmarshal(b, &arr); err != nil { + return err + } + *f = arr + return nil + } + + var fields string + if err := json.Unmarshal(b, &fields); err != nil { + return err + } + + if fields != "" { + fieldsArr := strings.Split(fields, ",") + *f = make([]Field, 0, len(fieldsArr)) + for _, field := range fieldsArr { + *f = append(*f, Field{Name: strings.TrimSpace(field)}) + } + } + + return nil +} + +type field *Field + +// UnmarshalJSON will unmarshal the corresponding JSON object for a Field, +// or a string (for backwards compatibility). +func (f *Field) UnmarshalJSON(b []byte) error { + if len(b) == 0 || f == nil { + return nil + } + + if b[0] == '{' { + return json.Unmarshal(b, field(f)) + } + + return json.Unmarshal(b, &f.Name) +} diff --git a/pkg/types/prompt_test.go b/pkg/types/prompt_test.go new file mode 100644 index 00000000..f2d911ef --- /dev/null +++ b/pkg/types/prompt_test.go @@ -0,0 +1,142 @@ +package types + +import ( + "reflect" + "testing" +) + +func TestFieldUnmarshalJSON(t *testing.T) { + tests := []struct { + name string + input []byte + expected Field + expectErr bool + }{ + { + name: "valid single Field object JSON", + input: []byte(`{"name":"field1","sensitive":true,"description":"A test field"}`), + expected: Field{Name: "field1", Sensitive: boolPtr(true), Description: "A test field"}, + expectErr: false, + }, + { + name: "valid Field name as string", + input: []byte(`"field1"`), + expected: Field{Name: "field1"}, + expectErr: false, + }, + { + name: "empty input", + input: []byte(``), + expected: Field{}, + expectErr: false, + }, + { + name: "invalid JSON object", + input: []byte(`{"name":"field1","sensitive":"not_boolean"}`), + expected: Field{Name: "field1", Sensitive: new(bool)}, + expectErr: true, + }, + { + name: "extra unknown fields in JSON object", + input: []byte(`{"name":"field1","unknown":"field","sensitive":false}`), + expected: Field{Name: "field1", Sensitive: boolPtr(false)}, + expectErr: false, + }, + { + name: "malformed JSON", + input: []byte(`{"name":"field1","sensitive":true`), + expected: Field{}, + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var field Field + err := field.UnmarshalJSON(tt.input) + if (err != nil) != tt.expectErr { + t.Errorf("UnmarshalJSON() error = %v, expectErr %v", err, tt.expectErr) + } + if !reflect.DeepEqual(field, tt.expected) { + t.Errorf("UnmarshalJSON() = %v, expected %v", field, tt.expected) + } + }) + } +} + +func TestFieldsUnmarshalJSON(t *testing.T) { + tests := []struct { + name string + input []byte + expected Fields + expectErr bool + }{ + { + name: "empty input", + input: nil, + expected: nil, + expectErr: false, + }, + { + name: "nil pointer", + input: nil, + expected: nil, + expectErr: false, + }, + { + name: "valid JSON array", + input: []byte(`[{"Name":"field1"},{"Name":"field2"}]`), + expected: Fields{{Name: "field1"}, {Name: "field2"}}, + expectErr: false, + }, + { + name: "single string input", + input: []byte(`"field1,field2,field3"`), + expected: Fields{{Name: "field1"}, {Name: "field2"}, {Name: "field3"}}, + expectErr: false, + }, + { + name: "trim spaces in single string input", + input: []byte(`"field1, field2 , field3 "`), + expected: Fields{{Name: "field1"}, {Name: "field2"}, {Name: "field3"}}, + expectErr: false, + }, + { + name: "invalid JSON array", + input: []byte(`[{"Name":"field1"},{"Name":field2}]`), + expected: nil, + expectErr: true, + }, + { + name: "invalid single string", + input: []byte(`1234`), + expected: nil, + expectErr: true, + }, + { + name: "empty array", + input: []byte(`[]`), + expected: Fields{}, + expectErr: false, + }, + { + name: "empty string", + input: []byte(`""`), + expected: nil, + expectErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var fields Fields + err := fields.UnmarshalJSON(tt.input) + if (err != nil) != tt.expectErr { + t.Errorf("UnmarshalJSON() error = %v, expectErr %v", err, tt.expectErr) + } + if !reflect.DeepEqual(fields, tt.expected) { + t.Errorf("UnmarshalJSON() = %v, expected %v", fields, tt.expected) + } + }) + } +}