Skip to content

Commit d9190b5

Browse files
committed
enhance: ask user for OpenAI key and store it in the cred store
Signed-off-by: Grant Linville <[email protected]>
1 parent e65513e commit d9190b5

File tree

9 files changed

+185
-62
lines changed

9 files changed

+185
-62
lines changed

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ require (
1515
github.com/getkin/kin-openapi v0.123.0
1616
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
1717
github.com/google/uuid v1.6.0
18-
github.com/gptscript-ai/chat-completion-client v0.0.0-20240515050533-bdef9f2226a9
18+
github.com/gptscript-ai/chat-completion-client v0.0.0-20240531200700-af8e7ecf0379
1919
github.com/hexops/autogold/v2 v2.2.1
2020
github.com/hexops/valast v1.4.4
2121
github.com/jaytaylor/html2text v0.0.0-20230321000545-74c2419ad056

go.sum

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,8 @@ github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+
125125
github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
126126
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
127127
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
128-
github.com/gptscript-ai/chat-completion-client v0.0.0-20240515050533-bdef9f2226a9 h1:s6nL/aokB1sJTqVXEjN0zFI5CJa66ubw9g68VTMzEw0=
129-
github.com/gptscript-ai/chat-completion-client v0.0.0-20240515050533-bdef9f2226a9/go.mod h1:7P/o6/IWa1KqsntVf68hSnLKuu3+xuqm6lYhch1w4jo=
128+
github.com/gptscript-ai/chat-completion-client v0.0.0-20240531200700-af8e7ecf0379 h1:vYnXoIyCXzaCEw0sYifQ4bDpsv3/fO/dZ2suEsTwCIo=
129+
github.com/gptscript-ai/chat-completion-client v0.0.0-20240531200700-af8e7ecf0379/go.mod h1:7P/o6/IWa1KqsntVf68hSnLKuu3+xuqm6lYhch1w4jo=
130130
github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA=
131131
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
132132
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=

pkg/builtin/builtin.go

Lines changed: 2 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@ import (
1717
"strings"
1818
"time"
1919

20-
"github.com/AlecAivazis/survey/v2"
2120
"github.com/BurntSushi/locker"
2221
"github.com/google/shlex"
2322
"github.com/gptscript-ai/gptscript/pkg/engine"
23+
"github.com/gptscript-ai/gptscript/pkg/prompt"
2424
"github.com/gptscript-ai/gptscript/pkg/types"
2525
"github.com/jaytaylor/html2text"
2626
)
@@ -215,7 +215,7 @@ var tools = map[string]types.Tool{
215215
"sensitive", "(true or false) Whether the input should be hidden",
216216
),
217217
},
218-
BuiltinFunc: SysPrompt,
218+
BuiltinFunc: prompt.SysPrompt,
219219
},
220220
},
221221
"sys.chat.history": {
@@ -774,42 +774,6 @@ func SysDownload(ctx context.Context, env []string, input string) (_ string, err
774774
return params.Location, nil
775775
}
776776

777-
func SysPrompt(_ context.Context, _ []string, input string) (_ string, err error) {
778-
var params struct {
779-
Message string `json:"message,omitempty"`
780-
Fields string `json:"fields,omitempty"`
781-
Sensitive string `json:"sensitive,omitempty"`
782-
}
783-
if err := json.Unmarshal([]byte(input), &params); err != nil {
784-
return "", err
785-
}
786-
787-
if params.Message != "" {
788-
_, _ = fmt.Fprintln(os.Stderr, params.Message)
789-
}
790-
791-
results := map[string]string{}
792-
for _, f := range strings.Split(params.Fields, ",") {
793-
var value string
794-
if params.Sensitive == "true" {
795-
err = survey.AskOne(&survey.Password{Message: f}, &value, survey.WithStdio(os.Stdin, os.Stderr, os.Stderr))
796-
} else {
797-
err = survey.AskOne(&survey.Input{Message: f}, &value, survey.WithStdio(os.Stdin, os.Stderr, os.Stderr))
798-
}
799-
if err != nil {
800-
return "", err
801-
}
802-
results[f] = value
803-
}
804-
805-
resultsStr, err := json.Marshal(results)
806-
if err != nil {
807-
return "", err
808-
}
809-
810-
return string(resultsStr), nil
811-
}
812-
813777
func SysTimeNow(ctx context.Context, env []string, input string) (string, error) {
814778
return time.Now().Format(time.RFC3339), nil
815779
}

pkg/gptscript/gptscript.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88

99
"github.com/gptscript-ai/gptscript/pkg/builtin"
1010
"github.com/gptscript-ai/gptscript/pkg/cache"
11+
"github.com/gptscript-ai/gptscript/pkg/config"
1112
"github.com/gptscript-ai/gptscript/pkg/engine"
1213
"github.com/gptscript-ai/gptscript/pkg/hash"
1314
"github.com/gptscript-ai/gptscript/pkg/llm"
@@ -65,15 +66,20 @@ func New(opts *Options) (*GPTScript, error) {
6566
return nil, err
6667
}
6768

68-
oAIClient, err := openai.NewClient(opts.OpenAI, openai.Options{
69+
cliCfg, err := config.ReadCLIConfig(opts.OpenAI.ConfigFile)
70+
if err != nil {
71+
return nil, err
72+
}
73+
74+
oaiClient, err := openai.NewClient(cliCfg, opts.CredentialContext, opts.OpenAI, openai.Options{
6975
Cache: cacheClient,
7076
SetSeed: true,
7177
})
7278
if err != nil {
7379
return nil, err
7480
}
7581

76-
if err := registry.AddClient(oAIClient); err != nil {
82+
if err := registry.AddClient(oaiClient); err != nil {
7783
return nil, err
7884
}
7985

@@ -90,7 +96,7 @@ func New(opts *Options) (*GPTScript, error) {
9096
return nil, err
9197
}
9298

93-
remoteClient := remote.New(runner, opts.Env, cacheClient)
99+
remoteClient := remote.New(runner, opts.Env, cacheClient, cliCfg, opts.CredentialContext)
94100

95101
if err := registry.AddClient(remoteClient); err != nil {
96102
return nil, err

pkg/llm/registry.go

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"fmt"
77
"sort"
88

9+
"github.com/gptscript-ai/gptscript/pkg/openai"
910
"github.com/gptscript-ai/gptscript/pkg/types"
1011
)
1112

@@ -32,7 +33,21 @@ func (r *Registry) ListModels(ctx context.Context, providers ...string) (result
3233
for _, v := range r.clients {
3334
models, err := v.ListModels(ctx, providers...)
3435
if err != nil {
35-
return nil, err
36+
// If we got back an InvalidAuthError, then we know it came from the OpenAI client, and we can
37+
// try to get the credential from the cred store.
38+
if errors.Is(err, openai.InvalidAuthError{}) {
39+
if err := v.(*openai.Client).RetrieveAPIKey(); err != nil {
40+
return nil, err
41+
}
42+
43+
// Now that the API key has been retrieved, try to list models again.
44+
models, err = v.ListModels(ctx, providers...)
45+
if err != nil {
46+
return nil, err
47+
}
48+
} else {
49+
return nil, err
50+
}
3651
}
3752
result = append(result, models...)
3853
}

pkg/openai/client.go

Lines changed: 75 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,19 @@ import (
1212

1313
openai "github.com/gptscript-ai/chat-completion-client"
1414
"github.com/gptscript-ai/gptscript/pkg/cache"
15+
"github.com/gptscript-ai/gptscript/pkg/config"
1516
"github.com/gptscript-ai/gptscript/pkg/counter"
17+
"github.com/gptscript-ai/gptscript/pkg/credentials"
1618
"github.com/gptscript-ai/gptscript/pkg/hash"
19+
"github.com/gptscript-ai/gptscript/pkg/prompt"
1720
"github.com/gptscript-ai/gptscript/pkg/system"
1821
"github.com/gptscript-ai/gptscript/pkg/types"
22+
"github.com/tidwall/gjson"
1923
)
2024

2125
const (
22-
DefaultModel = openai.GPT4o
26+
DefaultModel = openai.GPT4o
27+
BuiltinCredName = "sys.openai"
2328
)
2429

2530
var (
@@ -28,13 +33,21 @@ var (
2833
azureModel = os.Getenv("OPENAI_AZURE_DEPLOYMENT")
2934
)
3035

36+
type InvalidAuthError struct{}
37+
38+
func (InvalidAuthError) Error() string {
39+
return "OPENAI_API_KEY is not set. Please set the OPENAI_API_KEY environment variable"
40+
}
41+
3142
type Client struct {
3243
defaultModel string
3344
c *openai.Client
3445
cache *cache.Client
3546
invalidAuth bool
3647
cacheKeyBase string
3748
setSeed bool
49+
cliCfg *config.CLIConfig
50+
credCtx string
3851
}
3952

4053
type Options struct {
@@ -93,12 +106,28 @@ func GetAzureMapperFunction(defaultModel, azureModel string) func(string) string
93106
}
94107
}
95108

96-
func NewClient(opts ...Options) (*Client, error) {
109+
func NewClient(cliCfg *config.CLIConfig, credCtx string, opts ...Options) (*Client, error) {
97110
opt, err := complete(opts...)
98111
if err != nil {
99112
return nil, err
100113
}
101114

115+
// If the API key is not set, try to get it from the cred store
116+
if opt.APIKey == "" && opt.BaseURL == "" {
117+
store, err := credentials.NewStore(cliCfg, credCtx)
118+
if err != nil {
119+
return nil, err
120+
}
121+
122+
cred, exists, err := store.Get(BuiltinCredName)
123+
if err != nil {
124+
return nil, err
125+
}
126+
if exists {
127+
opt.APIKey = cred.Env["OPENAI_API_KEY"]
128+
}
129+
}
130+
102131
cfg := openai.DefaultConfig(opt.APIKey)
103132
if strings.Contains(string(opt.APIType), "AZURE") {
104133
cfg = openai.DefaultAzureConfig(key, url)
@@ -122,12 +151,14 @@ func NewClient(opts ...Options) (*Client, error) {
122151
cacheKeyBase: cacheKeyBase,
123152
invalidAuth: opt.APIKey == "" && opt.BaseURL == "",
124153
setSeed: opt.SetSeed,
154+
cliCfg: cliCfg,
155+
credCtx: credCtx,
125156
}, nil
126157
}
127158

128159
func (c *Client) ValidAuth() error {
129160
if c.invalidAuth {
130-
return fmt.Errorf("OPENAI_API_KEY is not set. Please set the OPENAI_API_KEY environment variable")
161+
return InvalidAuthError{}
131162
}
132163
return nil
133164
}
@@ -276,7 +307,9 @@ func toMessages(request types.CompletionRequest, compat bool) (result []openai.C
276307

277308
func (c *Client) Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) {
278309
if err := c.ValidAuth(); err != nil {
279-
return nil, err
310+
if err := c.RetrieveAPIKey(); err != nil {
311+
return nil, err
312+
}
280313
}
281314

282315
if messageRequest.Model == "" {
@@ -525,6 +558,44 @@ func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest,
525558
}
526559
}
527560

561+
func (c *Client) RetrieveAPIKey() error {
562+
store, err := credentials.NewStore(c.cliCfg, c.credCtx)
563+
if err != nil {
564+
return err
565+
}
566+
567+
cred, exists, err := store.Get(BuiltinCredName)
568+
if err != nil {
569+
return err
570+
}
571+
572+
var k string
573+
if exists {
574+
k = cred.Env[key]
575+
} else {
576+
// SysPrompt doesn't use its first two arguments, so we can safely pass nil to them
577+
result, err := prompt.SysPrompt(nil, nil, `{"message":"Please provide your OpenAI API key:","fields":"key","sensitive":"true"}`)
578+
if err != nil {
579+
return err
580+
}
581+
582+
k = gjson.Get(result, "key").String()
583+
if err := store.Add(credentials.Credential{
584+
ToolName: BuiltinCredName,
585+
Env: map[string]string{
586+
"OPENAI_API_KEY": k,
587+
},
588+
}); err != nil {
589+
return err
590+
}
591+
log.Infof("Saved API key as credential %s", BuiltinCredName)
592+
}
593+
594+
c.c.SetAPIKey(k)
595+
c.invalidAuth = false
596+
return nil
597+
}
598+
528599
func ptr[T any](v T) *T {
529600
return &v
530601
}

pkg/openai/log.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
package openai
2+
3+
import "github.com/gptscript-ai/gptscript/pkg/mvl"
4+
5+
var log = mvl.Package()

pkg/prompt/prompt.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package prompt
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
"os"
8+
"strings"
9+
10+
"github.com/AlecAivazis/survey/v2"
11+
)
12+
13+
func SysPrompt(_ context.Context, _ []string, input string) (_ string, err error) {
14+
var params struct {
15+
Message string `json:"message,omitempty"`
16+
Fields string `json:"fields,omitempty"`
17+
Sensitive string `json:"sensitive,omitempty"`
18+
}
19+
if err := json.Unmarshal([]byte(input), &params); err != nil {
20+
return "", err
21+
}
22+
23+
if params.Message != "" {
24+
_, _ = fmt.Fprintln(os.Stderr, params.Message)
25+
}
26+
27+
results := map[string]string{}
28+
for _, f := range strings.Split(params.Fields, ",") {
29+
var value string
30+
if params.Sensitive == "true" {
31+
err = survey.AskOne(&survey.Password{Message: f}, &value, survey.WithStdio(os.Stdin, os.Stderr, os.Stderr))
32+
} else {
33+
err = survey.AskOne(&survey.Input{Message: f}, &value, survey.WithStdio(os.Stdin, os.Stderr, os.Stderr))
34+
}
35+
if err != nil {
36+
return "", err
37+
}
38+
results[f] = value
39+
}
40+
41+
resultsStr, err := json.Marshal(results)
42+
if err != nil {
43+
return "", err
44+
}
45+
46+
return string(resultsStr), nil
47+
}

0 commit comments

Comments
 (0)