diff --git a/Makefile b/Makefile index e9f079f6..284c94c9 100644 --- a/Makefile +++ b/Makefile @@ -53,7 +53,7 @@ init-docs: # Ensure docs build without errors. Makes sure generated docs are in-sync with CLI. validate-docs: gen-docs - docker run --rm --workdir=/docs -v $${PWD}/docs:/docs node:18-buster yarn build + docker run --rm --workdir=/docs -v $${PWD}/docs:/docs node:18-buster npm run build if [ -n "$$(git status --porcelain --untracked-files=no)" ]; then \ git status --porcelain --untracked-files=no; \ echo "Encountered dirty repo!"; \ diff --git a/pkg/credentials/store.go b/pkg/credentials/store.go index def6ff89..6e5d24ca 100644 --- a/pkg/credentials/store.go +++ b/pkg/credentials/store.go @@ -269,8 +269,9 @@ func (s *Store) recreateCredential(store credentials.Store, serverAddress string func (s *Store) getStore() (credentials.Store, error) { if s.program != nil { return &toolCredentialStore{ - file: credentials.NewFileStore(s.cfg), - program: s.program, + file: credentials.NewFileStore(s.cfg), + program: s.program, + contexts: s.credCtxs, }, nil } return credentials.NewFileStore(s.cfg), nil diff --git a/pkg/credentials/toolstore.go b/pkg/credentials/toolstore.go index 536e2aa1..66d4d38d 100644 --- a/pkg/credentials/toolstore.go +++ b/pkg/credentials/toolstore.go @@ -1,7 +1,10 @@ package credentials import ( + "bytes" + "encoding/json" "errors" + "fmt" "net/url" "regexp" "strings" @@ -13,8 +16,9 @@ import ( ) type toolCredentialStore struct { - file credentials.Store - program client.ProgramFunc + file credentials.Store + program client.ProgramFunc + contexts []string } func (h *toolCredentialStore) Erase(serverAddress string) error { @@ -42,8 +46,21 @@ func (h *toolCredentialStore) Get(serverAddress string) (types.AuthConfig, error }, nil } +// GetAll will list all credentials in the credential store. +// It MAY (but is not required to) filter the credentials based on the contexts provided. +// This is only supported by some credential stores, while others will ignore it and return all credentials. +// The caller of this function is still required to filter the output to only include the contexts requested. func (h *toolCredentialStore) GetAll() (map[string]types.AuthConfig, error) { - serverAddresses, err := client.List(h.program) + var ( + serverAddresses map[string]string + err error + ) + if len(h.contexts) == 0 { + serverAddresses, err = client.List(h.program) + } else { + serverAddresses, err = listWithContexts(h.program, h.contexts) + } + if err != nil { return nil, err } @@ -94,3 +111,44 @@ func (h *toolCredentialStore) Store(authConfig types.AuthConfig) error { Secret: authConfig.Password, }) } + +// listWithContexts is almost an exact copy of the List function in Docker's libraries, +// the only difference being that we pass the context through as input to the program. +// This will allow some credential stores, like Postgres, to do an optimized list. +func listWithContexts(program client.ProgramFunc, contexts []string) (map[string]string, error) { + cmd := program(credentials2.ActionList) + + contextsJSON, err := json.Marshal(contexts) + if err != nil { + return nil, err + } + + cmd.Input(bytes.NewReader(contextsJSON)) + out, err := cmd.Output() + if err != nil { + t := strings.TrimSpace(string(out)) + + if isValidErr := isValidCredsMessage(t); isValidErr != nil { + err = isValidErr + } + + return nil, fmt.Errorf("error listing credentials - err: %v, out: `%s`", err, t) + } + + var resp map[string]string + if err = json.NewDecoder(bytes.NewReader(out)).Decode(&resp); err != nil { + return nil, err + } + + return resp, nil +} + +func isValidCredsMessage(msg string) error { + if credentials2.IsCredentialsMissingServerURLMessage(msg) { + return credentials2.NewErrCredentialsMissingServerURL() + } + if credentials2.IsCredentialsMissingUsernameMessage(msg) { + return credentials2.NewErrCredentialsMissingUsername() + } + return nil +} diff --git a/pkg/credentials/toolstore_test.go b/pkg/credentials/toolstore_test.go new file mode 100644 index 00000000..fdbae25c --- /dev/null +++ b/pkg/credentials/toolstore_test.go @@ -0,0 +1,130 @@ +package credentials + +import ( + "encoding/json" + "fmt" + "io" + "testing" + + "github.com/docker/cli/cli/config/types" + "github.com/docker/docker-credential-helpers/client" + "github.com/docker/docker-credential-helpers/credentials" +) + +type mockProgram struct { + // mode is either "db" or "normal" + // db mode will honor contexts, normal mode will not + mode string + action string + contexts []string +} + +func (m *mockProgram) Input(in io.Reader) { + switch m.action { + case credentials.ActionList: + var contexts []string + if err := json.NewDecoder(in).Decode(&contexts); err == nil && len(contexts) > 0 { + m.contexts = contexts + } + } + // TODO: add other cases here as needed +} + +func (m *mockProgram) Output() ([]byte, error) { + switch m.action { + case credentials.ActionList: + switch m.mode { + case "db": + // Return only credentials that are in the list of contexts. + creds := make(map[string]string) + for _, context := range m.contexts { + creds[fmt.Sprintf("https://example///%s", context)] = "username" + } + return json.Marshal(creds) + case "normal": + // Return credentials in the list of contexts, plus some made up extras. + creds := make(map[string]string) + for _, context := range m.contexts { + creds[fmt.Sprintf("https://example///%s", context)] = "username" + } + creds[fmt.Sprintf("https://example///%s", "otherContext1")] = "username" + creds[fmt.Sprintf("https://example///%s", "otherContext2")] = "username" + return json.Marshal(creds) + } + } + return nil, nil +} + +func newMockProgram(t *testing.T, mode string) client.ProgramFunc { + t.Helper() + return func(args ...string) client.Program { + p := &mockProgram{ + mode: mode, + } + if len(args) > 0 { + p.action = args[0] + } + return p + } +} + +func TestGetAll(t *testing.T) { + dbProgram := newMockProgram(t, "db") + normalProgram := newMockProgram(t, "normal") + + tests := []struct { + name string + program client.ProgramFunc + wantErr bool + contexts []string + expected map[string]types.AuthConfig + }{ + {name: "db", program: dbProgram, wantErr: false, contexts: []string{"credctx"}, expected: map[string]types.AuthConfig{ + "https://example///credctx": { + Username: "username", + ServerAddress: "https://example///credctx", + }, + }}, + {name: "normal", program: normalProgram, wantErr: false, contexts: []string{"credctx"}, expected: map[string]types.AuthConfig{ + "https://example///credctx": { + Username: "username", + ServerAddress: "https://example///credctx", + }, + "https://example///otherContext1": { + Username: "username", + ServerAddress: "https://example///otherContext1", + }, + "https://example///otherContext2": { + Username: "username", + ServerAddress: "https://example///otherContext2", + }, + }}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + store := &toolCredentialStore{ + program: test.program, + contexts: test.contexts, + } + got, err := store.GetAll() + if (err != nil) != test.wantErr { + t.Errorf("GetAll() error = %v, wantErr %v", err, test.wantErr) + } + if len(got) != len(test.expected) { + t.Errorf("GetAll() got %d credentials, want %d", len(got), len(test.expected)) + } + for name, cred := range got { + if _, ok := test.expected[name]; !ok { + t.Errorf("GetAll() got unexpected credential: %s", name) + } + if got[name].Username != test.expected[name].Username { + t.Errorf("GetAll() got unexpected username for %s", cred.ServerAddress) + } + if got[name].Username != test.expected[name].Username { + t.Errorf("GetAll() got unexpected username for %s", name) + } + } + }) + } +}