diff --git a/pkg/credentials/factory.go b/pkg/credentials/factory.go index ca6f1d18..42295fc8 100644 --- a/pkg/credentials/factory.go +++ b/pkg/credentials/factory.go @@ -2,6 +2,7 @@ package credentials import ( "context" + "strings" "github.com/docker/docker-credential-helpers/client" "github.com/gptscript-ai/gptscript/pkg/config" @@ -13,12 +14,32 @@ type ProgramLoaderRunner interface { Run(ctx context.Context, prg types.Program, input string) (output string, err error) } -func NewFactory(ctx context.Context, cfg *config.CLIConfig, plr ProgramLoaderRunner) (StoreFactory, error) { +func NewFactory(ctx context.Context, cfg *config.CLIConfig, overrides []string, plr ProgramLoaderRunner) (StoreFactory, error) { + creds, err := ParseCredentialOverrides(overrides) + if err != nil { + return StoreFactory{}, err + } + + overrideMap := make(map[string]map[string]map[string]string) + for k, v := range creds { + contextName, toolName, ok := strings.Cut(k, ctxSeparator) + if !ok { + continue + } + toolMap, ok := overrideMap[contextName] + if !ok { + toolMap = make(map[string]map[string]string) + } + toolMap[toolName] = v + overrideMap[contextName] = toolMap + } + toolName := translateToolName(cfg.CredentialsStore) if toolName == config.FileCredHelper { return StoreFactory{ - file: true, - cfg: cfg, + file: true, + cfg: cfg, + overrides: overrideMap, }, nil } @@ -28,10 +49,11 @@ func NewFactory(ctx context.Context, cfg *config.CLIConfig, plr ProgramLoaderRun } return StoreFactory{ - ctx: ctx, - prg: prg, - runner: plr, - cfg: cfg, + ctx: ctx, + prg: prg, + runner: plr, + cfg: cfg, + overrides: overrideMap, }, nil } @@ -41,6 +63,8 @@ type StoreFactory struct { file bool runner ProgramLoaderRunner cfg *config.CLIConfig + // That's a lot of maps: context -> toolName -> key -> value + overrides map[string]map[string]map[string]string } func (s *StoreFactory) NewStore(credCtxs []string) (CredentialStore, error) { @@ -48,15 +72,23 @@ func (s *StoreFactory) NewStore(credCtxs []string) (CredentialStore, error) { return nil, err } if s.file { - return Store{ - credCtxs: credCtxs, - cfg: s.cfg, + return withOverride{ + target: Store{ + credCtxs: credCtxs, + cfg: s.cfg, + }, + overrides: s.overrides, + credContext: credCtxs, }, nil } - return Store{ - credCtxs: credCtxs, - cfg: s.cfg, - program: s.program, + return withOverride{ + target: Store{ + credCtxs: credCtxs, + cfg: s.cfg, + program: s.program, + }, + overrides: s.overrides, + credContext: credCtxs, }, nil } diff --git a/pkg/credentials/overrides.go b/pkg/credentials/overrides.go new file mode 100644 index 00000000..0911cac5 --- /dev/null +++ b/pkg/credentials/overrides.go @@ -0,0 +1,149 @@ +package credentials + +import ( + "context" + "fmt" + "maps" + "os" + "strings" +) + +// ParseCredentialOverrides parses a string of credential overrides that the user provided as a command line arg. +// The format of credential overrides can be one of two things: +// cred1:ENV1,ENV2 (direct mapping of environment variables) +// cred1:ENV1=VALUE1,ENV2=VALUE2 (key-value pairs) +// +// This function turns it into a map[string]map[string]string like this: +// +// { +// "cred1": { +// "ENV1": "VALUE1", +// "ENV2": "VALUE2", +// } +// } +func ParseCredentialOverrides(overrides []string) (map[string]map[string]string, error) { + credentialOverrides := make(map[string]map[string]string) + + for _, o := range overrides { + credName, envs, found := strings.Cut(o, ":") + if !found { + return nil, fmt.Errorf("invalid credential override: %s", o) + } + envMap, ok := credentialOverrides[credName] + if !ok { + envMap = make(map[string]string) + } + for _, env := range strings.Split(envs, ",") { + for _, env := range strings.Split(env, "|") { + key, value, found := strings.Cut(env, "=") + if !found { + // User just passed an env var name as the key, so look up the value. + value = os.Getenv(key) + } + envMap[key] = value + } + } + credentialOverrides[credName] = envMap + } + + return credentialOverrides, nil +} + +type withOverride struct { + target CredentialStore + credContext []string + overrides map[string]map[string]map[string]string +} + +func (w withOverride) Get(ctx context.Context, toolName string) (*Credential, bool, error) { + for _, credCtx := range w.credContext { + overrides, ok := w.overrides[credCtx] + if !ok { + continue + } + override, ok := overrides[toolName] + if !ok { + continue + } + + return &Credential{ + Context: credCtx, + ToolName: toolName, + Type: CredentialTypeTool, + Env: maps.Clone(override), + }, true, nil + } + + return w.target.Get(ctx, toolName) +} + +func (w withOverride) Add(ctx context.Context, cred Credential) error { + for _, credCtx := range w.credContext { + if override, ok := w.overrides[credCtx]; ok { + if _, ok := override[cred.ToolName]; ok { + return fmt.Errorf("cannot add credential with context %q and tool %q because it is statically configure", cred.Context, cred.ToolName) + } + } + } + return w.target.Add(ctx, cred) +} + +func (w withOverride) Refresh(ctx context.Context, cred Credential) error { + if override, ok := w.overrides[cred.Context]; ok { + if _, ok := override[cred.ToolName]; ok { + return nil + } + } + return w.target.Refresh(ctx, cred) +} + +func (w withOverride) Remove(ctx context.Context, toolName string) error { + for _, credCtx := range w.credContext { + if override, ok := w.overrides[credCtx]; ok { + if _, ok := override[toolName]; ok { + return fmt.Errorf("cannot remove credential with context %q and tool %q because it is statically configure", credCtx, toolName) + } + } + } + return w.target.Remove(ctx, toolName) +} + +func (w withOverride) List(ctx context.Context) ([]Credential, error) { + creds, err := w.target.List(ctx) + if err != nil { + return nil, err + } + + added := make(map[string]map[string]bool) + for i, cred := range creds { + if override, ok := w.overrides[cred.Context]; ok { + if _, ok := override[cred.ToolName]; ok { + creds[i].Type = CredentialTypeTool + creds[i].Env = maps.Clone(override[cred.ToolName]) + } + } + tools, ok := added[cred.Context] + if !ok { + tools = make(map[string]bool) + } + tools[cred.ToolName] = true + added[cred.Context] = tools + } + + for _, credCtx := range w.credContext { + tools := w.overrides[credCtx] + for toolName := range tools { + if _, ok := added[credCtx][toolName]; ok { + continue + } + creds = append(creds, Credential{ + Context: credCtx, + ToolName: toolName, + Type: CredentialTypeTool, + Env: maps.Clone(tools[toolName]), + }) + } + } + + return creds, nil +} diff --git a/pkg/gptscript/gptscript.go b/pkg/gptscript/gptscript.go index 6b7ceb04..4669e5ab 100644 --- a/pkg/gptscript/gptscript.go +++ b/pkg/gptscript/gptscript.go @@ -124,7 +124,7 @@ func New(ctx context.Context, o ...Options) (*GPTScript, error) { return nil, err } - storeFactory, err := credentials.NewFactory(ctx, cliCfg, simplerRunner) + storeFactory, err := credentials.NewFactory(ctx, cliCfg, opts.Runner.CredentialOverrides, simplerRunner) if err != nil { return nil, err } diff --git a/pkg/runner/credentials.go b/pkg/runner/credentials.go deleted file mode 100644 index d2fbb00e..00000000 --- a/pkg/runner/credentials.go +++ /dev/null @@ -1,43 +0,0 @@ -package runner - -import ( - "fmt" - "os" - "strings" -) - -// parseCredentialOverrides parses a string of credential overrides that the user provided as a command line arg. -// The format of credential overrides can be one of two things: -// cred1:ENV1,ENV2 (direct mapping of environment variables) -// cred1:ENV1=VALUE1,ENV2=VALUE2 (key-value pairs) -// -// This function turns it into a map[string]map[string]string like this: -// -// { -// "cred1": { -// "ENV1": "VALUE1", -// "ENV2": "VALUE2", -// } -// } -func parseCredentialOverrides(overrides []string) (map[string]map[string]string, error) { - credentialOverrides := make(map[string]map[string]string) - - for _, o := range overrides { - credName, envs, found := strings.Cut(o, ":") - if !found { - return nil, fmt.Errorf("invalid credential override: %s", o) - } - envMap := make(map[string]string) - for _, env := range strings.Split(envs, ",") { - key, value, found := strings.Cut(env, "=") - if !found { - // User just passed an env var name as the key, so look up the value. - value = os.Getenv(key) - } - envMap[key] = value - } - credentialOverrides[credName] = envMap - } - - return credentialOverrides, nil -} diff --git a/pkg/runner/credentials_test.go b/pkg/runner/credentials_test.go index c568d6be..74fa9353 100644 --- a/pkg/runner/credentials_test.go +++ b/pkg/runner/credentials_test.go @@ -4,6 +4,7 @@ import ( "os" "testing" + "github.com/gptscript-ai/gptscript/pkg/credentials" "github.com/stretchr/testify/require" ) @@ -119,7 +120,7 @@ func TestParseCredentialOverrides(t *testing.T) { _ = os.Setenv(k, v) } - out, err := parseCredentialOverrides(tc.in) + out, err := credentials.ParseCredentialOverrides(tc.in) if tc.expectErr { require.Error(t, err) return diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index 09b242f7..e2699cf6 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -754,7 +754,7 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env err error ) if r.credOverrides != nil { - credOverrides, err = parseCredentialOverrides(r.credOverrides) + credOverrides, err = credentials.ParseCredentialOverrides(r.credOverrides) if err != nil { return nil, fmt.Errorf("failed to parse credential overrides: %w", err) }