From 97c37f1a3f84d135812450973866164e66210951 Mon Sep 17 00:00:00 2001 From: Donnie Adams Date: Thu, 24 Apr 2025 14:59:29 -0400 Subject: [PATCH] chore: add credential checkParam field If the check param changes, then the credential will be re-prompted and not used nor refreshed. Signed-off-by: Donnie Adams --- pkg/credentials/credential.go | 18 +++++++----- pkg/runner/runner.go | 7 +++-- pkg/types/credential_test.go | 52 +++++++++++++++++++++++++++++------ pkg/types/tool.go | 38 +++++++++++++++++-------- 4 files changed, 85 insertions(+), 30 deletions(-) diff --git a/pkg/credentials/credential.go b/pkg/credentials/credential.go index e458cb9f..9d314a70 100644 --- a/pkg/credentials/credential.go +++ b/pkg/credentials/credential.go @@ -20,13 +20,16 @@ const ( ) type Credential struct { - Context string `json:"context"` - ToolName string `json:"toolName"` - Type CredentialType `json:"type"` - Env map[string]string `json:"env"` - Ephemeral bool `json:"ephemeral,omitempty"` - ExpiresAt *time.Time `json:"expiresAt"` - RefreshToken string `json:"refreshToken"` + Context string `json:"context"` + ToolName string `json:"toolName"` + Type CredentialType `json:"type"` + Env map[string]string `json:"env"` + // If the CheckParam that is stored is different from the param on the tool, + // then the credential will be re-authed as if it does not exist. + CheckParam string `json:"checkParam"` + Ephemeral bool `json:"ephemeral,omitempty"` + ExpiresAt *time.Time `json:"expiresAt"` + RefreshToken string `json:"refreshToken"` } func (c Credential) IsExpired() bool { @@ -82,6 +85,7 @@ func credentialFromDockerAuthConfig(authCfg types.AuthConfig) (Credential, error Context: ctx, ToolName: tool, Type: CredentialType(credType), + CheckParam: cred.CheckParam, Env: cred.Env, ExpiresAt: cred.ExpiresAt, RefreshToken: cred.RefreshToken, diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index aea91b34..6d4e7598 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -780,7 +780,7 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env var nearestExpiration *time.Time for _, ref := range credToolRefs { - toolName, credentialAlias, args, err := types.ParseCredentialArgs(ref.Reference, callCtx.Input) + toolName, credentialAlias, checkParam, args, err := types.ParseCredentialArgs(ref.Reference, callCtx.Input) if err != nil { return nil, fmt.Errorf("failed to parse credential tool %q: %w", ref.Reference, err) } @@ -830,9 +830,10 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env // If the credential doesn't already exist in the store, run the credential tool in order to get the value, // and save it in the store. - if !exists || c.IsExpired() { + if !exists || c.IsExpired() || checkParam != c.CheckParam { // If the existing credential is expired, we need to provide it to the cred tool through the environment. - if exists && c.IsExpired() { + // If the check parameter is different, then we don't refresh. We should re-auth below. + if exists && c.IsExpired() && checkParam == c.CheckParam { refresh = true credJSON, err := json.Marshal(c) if err != nil { diff --git a/pkg/types/credential_test.go b/pkg/types/credential_test.go index b6f70ee3..530b23f8 100644 --- a/pkg/types/credential_test.go +++ b/pkg/types/credential_test.go @@ -9,13 +9,14 @@ import ( func TestParseCredentialArgs(t *testing.T) { tests := []struct { - name string - toolName string - input string - expectedName string - expectedAlias string - expectedArgs map[string]string - wantErr bool + name string + toolName string + input string + expectedName string + expectedAlias string + expectedCheckParam string + expectedArgs map[string]string + wantErr bool }{ { name: "empty", @@ -94,6 +95,40 @@ func TestParseCredentialArgs(t *testing.T) { "arg2": "value2", }, }, + { + name: "tool name with check parameter", + toolName: `myCredentialTool checked with myCheckParam`, + expectedName: "myCredentialTool", + expectedCheckParam: "myCheckParam", + }, + { + name: "tool name with alias and check parameter", + toolName: `myCredentialTool as myAlias checked with myCheckParam`, + expectedName: "myCredentialTool", + expectedAlias: "myAlias", + expectedCheckParam: "myCheckParam", + }, + { + name: "tool name with alias, check parameter, and args", + toolName: `myCredentialTool as myAlias checked with myCheckParam with value1 as arg1 and value2 as arg2`, + expectedName: "myCredentialTool", + expectedAlias: "myAlias", + expectedCheckParam: "myCheckParam", + expectedArgs: map[string]string{ + "arg1": "value1", + "arg2": "value2", + }, + }, + { + name: "check parameter without with", + toolName: `myCredentialTool checked myCheckParam`, + wantErr: true, + }, + { + name: "invalid check parameter", + toolName: `myCredentialTool checked with`, + wantErr: true, + }, { name: "tool name with alias but no 'as' (invalid)", toolName: "myCredentialTool myAlias", @@ -136,7 +171,7 @@ func TestParseCredentialArgs(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - originalName, alias, args, err := ParseCredentialArgs(tt.toolName, tt.input) + originalName, alias, checkParam, args, err := ParseCredentialArgs(tt.toolName, tt.input) if tt.wantErr { require.Error(t, err, "expected an error but got none") return @@ -145,6 +180,7 @@ func TestParseCredentialArgs(t *testing.T) { require.NoError(t, err, "did not expect an error but got one") require.Equal(t, tt.expectedName, originalName, "unexpected original name") require.Equal(t, tt.expectedAlias, alias, "unexpected alias") + require.Equal(t, tt.expectedCheckParam, checkParam, "unexpected checkParam") require.Equal(t, len(tt.expectedArgs), len(args), "unexpected number of args") for k, v := range tt.expectedArgs { diff --git a/pkg/types/tool.go b/pkg/types/tool.go index 54780278..3d48c6e1 100644 --- a/pkg/types/tool.go +++ b/pkg/types/tool.go @@ -272,9 +272,9 @@ func SplitArg(hasArg string) (prefix, arg string) { // - toolName: "toolName with ${var1} as arg1 and ${var2} as arg2" // - input: `{"var1": "value1", "var2": "value2"}` // result: toolName, "", map[string]any{"arg1": "value1", "arg2": "value2"}, nil -func ParseCredentialArgs(toolName string, input string) (string, string, map[string]any, error) { +func ParseCredentialArgs(toolName string, input string) (string, string, string, map[string]any, error) { if toolName == "" { - return "", "", nil, nil + return "", "", "", nil, nil } inputMap := make(map[string]any) @@ -287,12 +287,12 @@ func ParseCredentialArgs(toolName string, input string) (string, string, map[str fields, err := shlex.Split(toolName) if err != nil { - return "", "", nil, err + return "", "", "", nil, err } // If it's just the tool name, return it if len(fields) == 1 { - return toolName, "", nil, nil + return toolName, "", "", nil, nil } // Next field is "as" if there is an alias, otherwise it should be "with" @@ -301,25 +301,39 @@ func ParseCredentialArgs(toolName string, input string) (string, string, map[str fields = fields[1:] if fields[0] == "as" { if len(fields) < 2 { - return "", "", nil, fmt.Errorf("expected alias after 'as'") + return "", "", "", nil, fmt.Errorf("expected alias after 'as'") } alias = fields[1] fields = fields[2:] } if len(fields) == 0 { // Nothing left, so just return - return originalName, alias, nil, nil + return originalName, alias, "", nil, nil + } + + var checkParam string + if fields[0] == "checked" { + if len(fields) < 3 || fields[1] != "with" { + return "", "", "", nil, fmt.Errorf("expected 'checked with some_value' but got %v", fields) + } + + checkParam = fields[2] + fields = fields[3:] + } + + if len(fields) == 0 { // Nothing left, so just return + return originalName, alias, checkParam, nil, nil } // Next we should have "with" followed by the args if fields[0] != "with" { - return "", "", nil, fmt.Errorf("expected 'with' but got %s", fields[0]) + return "", "", "", nil, fmt.Errorf("expected 'with' but got %s", fields[0]) } fields = fields[1:] // If there are no args, return an error if len(fields) == 0 { - return "", "", nil, fmt.Errorf("expected args after 'with'") + return "", "", "", nil, fmt.Errorf("expected args after 'with'") } args := make(map[string]any) @@ -332,7 +346,7 @@ func ParseCredentialArgs(toolName string, input string) (string, string, map[str prev = "value" case "value": if field != "as" { - return "", "", nil, fmt.Errorf("expected 'as' but got %s", field) + return "", "", "", nil, fmt.Errorf("expected 'as' but got %s", field) } prev = "as" case "as": @@ -340,14 +354,14 @@ func ParseCredentialArgs(toolName string, input string) (string, string, map[str prev = "name" case "name": if field != "and" { - return "", "", nil, fmt.Errorf("expected 'and' but got %s", field) + return "", "", "", nil, fmt.Errorf("expected 'and' but got %s", field) } prev = "and" } } if prev == "and" { - return "", "", nil, fmt.Errorf("expected arg name after 'and'") + return "", "", "", nil, fmt.Errorf("expected arg name after 'and'") } // Check and see if any of the arg values are references to an input @@ -360,7 +374,7 @@ func ParseCredentialArgs(toolName string, input string) (string, string, map[str } } - return originalName, alias, args, nil + return originalName, alias, checkParam, args, nil } func (t Tool) GetToolRefsFromNames(names []string) (result []ToolReference, _ error) {