Skip to content

chore: make credential overrides cred context aware #936

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 46 additions & 14 deletions pkg/credentials/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package credentials

import (
"context"
"strings"

"github.com/docker/docker-credential-helpers/client"
"github.com/gptscript-ai/gptscript/pkg/config"
Expand All @@ -13,12 +14,32 @@ type ProgramLoaderRunner interface {
Run(ctx context.Context, prg types.Program, input string) (output string, err error)
}

func NewFactory(ctx context.Context, cfg *config.CLIConfig, plr ProgramLoaderRunner) (StoreFactory, error) {
func NewFactory(ctx context.Context, cfg *config.CLIConfig, overrides []string, plr ProgramLoaderRunner) (StoreFactory, error) {
creds, err := ParseCredentialOverrides(overrides)
if err != nil {
return StoreFactory{}, err
}

overrideMap := make(map[string]map[string]map[string]string)
for k, v := range creds {
contextName, toolName, ok := strings.Cut(k, ctxSeparator)
if !ok {
continue
}
toolMap, ok := overrideMap[contextName]
if !ok {
toolMap = make(map[string]map[string]string)
}
toolMap[toolName] = v
overrideMap[contextName] = toolMap
}

toolName := translateToolName(cfg.CredentialsStore)
if toolName == config.FileCredHelper {
return StoreFactory{
file: true,
cfg: cfg,
file: true,
cfg: cfg,
overrides: overrideMap,
}, nil
}

Expand All @@ -28,10 +49,11 @@ func NewFactory(ctx context.Context, cfg *config.CLIConfig, plr ProgramLoaderRun
}

return StoreFactory{
ctx: ctx,
prg: prg,
runner: plr,
cfg: cfg,
ctx: ctx,
prg: prg,
runner: plr,
cfg: cfg,
overrides: overrideMap,
}, nil
}

Expand All @@ -41,22 +63,32 @@ type StoreFactory struct {
file bool
runner ProgramLoaderRunner
cfg *config.CLIConfig
// That's a lot of maps: context -> toolName -> key -> value
overrides map[string]map[string]map[string]string
}

func (s *StoreFactory) NewStore(credCtxs []string) (CredentialStore, error) {
if err := validateCredentialCtx(credCtxs); err != nil {
return nil, err
}
if s.file {
return Store{
credCtxs: credCtxs,
cfg: s.cfg,
return withOverride{
target: Store{
credCtxs: credCtxs,
cfg: s.cfg,
},
overrides: s.overrides,
credContext: credCtxs,
}, nil
}
return Store{
credCtxs: credCtxs,
cfg: s.cfg,
program: s.program,
return withOverride{
target: Store{
credCtxs: credCtxs,
cfg: s.cfg,
program: s.program,
},
overrides: s.overrides,
credContext: credCtxs,
}, nil
}

Expand Down
149 changes: 149 additions & 0 deletions pkg/credentials/overrides.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
package credentials

import (
"context"
"fmt"
"maps"
"os"
"strings"
)

// ParseCredentialOverrides parses a string of credential overrides that the user provided as a command line arg.
// The format of credential overrides can be one of two things:
// cred1:ENV1,ENV2 (direct mapping of environment variables)
// cred1:ENV1=VALUE1,ENV2=VALUE2 (key-value pairs)
//
// This function turns it into a map[string]map[string]string like this:
//
// {
// "cred1": {
// "ENV1": "VALUE1",
// "ENV2": "VALUE2",
// }
// }
func ParseCredentialOverrides(overrides []string) (map[string]map[string]string, error) {
credentialOverrides := make(map[string]map[string]string)

for _, o := range overrides {
credName, envs, found := strings.Cut(o, ":")
if !found {
return nil, fmt.Errorf("invalid credential override: %s", o)
}
envMap, ok := credentialOverrides[credName]
if !ok {
envMap = make(map[string]string)
}
for _, env := range strings.Split(envs, ",") {
for _, env := range strings.Split(env, "|") {
key, value, found := strings.Cut(env, "=")
if !found {
// User just passed an env var name as the key, so look up the value.
value = os.Getenv(key)
}
envMap[key] = value
}
}
credentialOverrides[credName] = envMap
}

return credentialOverrides, nil
}

type withOverride struct {
target CredentialStore
credContext []string
overrides map[string]map[string]map[string]string
}

func (w withOverride) Get(ctx context.Context, toolName string) (*Credential, bool, error) {
for _, credCtx := range w.credContext {
overrides, ok := w.overrides[credCtx]
if !ok {
continue
}
override, ok := overrides[toolName]
if !ok {
continue
}

return &Credential{
Context: credCtx,
ToolName: toolName,
Type: CredentialTypeTool,
Env: maps.Clone(override),
}, true, nil
}

return w.target.Get(ctx, toolName)
}

func (w withOverride) Add(ctx context.Context, cred Credential) error {
for _, credCtx := range w.credContext {
if override, ok := w.overrides[credCtx]; ok {
if _, ok := override[cred.ToolName]; ok {
return fmt.Errorf("cannot add credential with context %q and tool %q because it is statically configure", cred.Context, cred.ToolName)
}
}
}
return w.target.Add(ctx, cred)
}

func (w withOverride) Refresh(ctx context.Context, cred Credential) error {
if override, ok := w.overrides[cred.Context]; ok {
if _, ok := override[cred.ToolName]; ok {
return nil
}
}
return w.target.Refresh(ctx, cred)
}

func (w withOverride) Remove(ctx context.Context, toolName string) error {
for _, credCtx := range w.credContext {
if override, ok := w.overrides[credCtx]; ok {
if _, ok := override[toolName]; ok {
return fmt.Errorf("cannot remove credential with context %q and tool %q because it is statically configure", credCtx, toolName)
}
}
}
return w.target.Remove(ctx, toolName)
}

func (w withOverride) List(ctx context.Context) ([]Credential, error) {
creds, err := w.target.List(ctx)
if err != nil {
return nil, err
}

added := make(map[string]map[string]bool)
for i, cred := range creds {
if override, ok := w.overrides[cred.Context]; ok {
if _, ok := override[cred.ToolName]; ok {
creds[i].Type = CredentialTypeTool
creds[i].Env = maps.Clone(override[cred.ToolName])
}
}
tools, ok := added[cred.Context]
if !ok {
tools = make(map[string]bool)
}
tools[cred.ToolName] = true
added[cred.Context] = tools
}

for _, credCtx := range w.credContext {
tools := w.overrides[credCtx]
for toolName := range tools {
if _, ok := added[credCtx][toolName]; ok {
continue
}
creds = append(creds, Credential{
Context: credCtx,
ToolName: toolName,
Type: CredentialTypeTool,
Env: maps.Clone(tools[toolName]),
})
}
}

return creds, nil
}
2 changes: 1 addition & 1 deletion pkg/gptscript/gptscript.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ func New(ctx context.Context, o ...Options) (*GPTScript, error) {
return nil, err
}

storeFactory, err := credentials.NewFactory(ctx, cliCfg, simplerRunner)
storeFactory, err := credentials.NewFactory(ctx, cliCfg, opts.Runner.CredentialOverrides, simplerRunner)
if err != nil {
return nil, err
}
Expand Down
43 changes: 0 additions & 43 deletions pkg/runner/credentials.go

This file was deleted.

3 changes: 2 additions & 1 deletion pkg/runner/credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"os"
"testing"

"github.com/gptscript-ai/gptscript/pkg/credentials"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -119,7 +120,7 @@ func TestParseCredentialOverrides(t *testing.T) {
_ = os.Setenv(k, v)
}

out, err := parseCredentialOverrides(tc.in)
out, err := credentials.ParseCredentialOverrides(tc.in)
if tc.expectErr {
require.Error(t, err)
return
Expand Down
2 changes: 1 addition & 1 deletion pkg/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,7 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
err error
)
if r.credOverrides != nil {
credOverrides, err = parseCredentialOverrides(r.credOverrides)
credOverrides, err = credentials.ParseCredentialOverrides(r.credOverrides)
if err != nil {
return nil, fmt.Errorf("failed to parse credential overrides: %w", err)
}
Expand Down
Loading