diff --git a/pkg/engine/cmd.go b/pkg/engine/cmd.go index 311c743a..14b41183 100644 --- a/pkg/engine/cmd.go +++ b/pkg/engine/cmd.go @@ -11,7 +11,6 @@ import ( "os" "os/exec" "path" - "path/filepath" "runtime" "sort" "strings" @@ -121,7 +120,7 @@ func (e *Engine) runCommand(ctx Context, tool types.Tool, input string, toolCate var extraEnv = []string{ strings.TrimSpace("GPTSCRIPT_CONTEXT=" + strings.Join(instructions, "\n")), } - cmd, stop, err := e.newCommand(ctx.Ctx, extraEnv, tool, input) + cmd, stop, err := e.newCommand(ctx.Ctx, extraEnv, tool, input, true) if err != nil { if toolCategory == NoCategory { return fmt.Sprintf("ERROR: got (%v) while parsing command", err), nil @@ -244,7 +243,11 @@ func appendInputAsEnv(env []string, input string) []string { return env } -func (e *Engine) newCommand(ctx context.Context, extraEnv []string, tool types.Tool, input string) (*exec.Cmd, func(), error) { +func (e *Engine) newCommand(ctx context.Context, extraEnv []string, tool types.Tool, input string, useShell bool) (*exec.Cmd, func(), error) { + if runtime.GOOS == "windows" { + useShell = false + } + envvars := append(e.Env[:], extraEnv...) envvars = appendInputAsEnv(envvars, input) if log.IsDebug() { @@ -254,9 +257,17 @@ func (e *Engine) newCommand(ctx context.Context, extraEnv []string, tool types.T interpreter, rest, _ := strings.Cut(tool.Instructions, "\n") interpreter = strings.TrimSpace(interpreter)[2:] - args, err := shlex.Split(interpreter) - if err != nil { - return nil, nil, err + var ( + args []string + err error + ) + if useShell { + args = strings.Fields(interpreter) + } else { + args, err = shlex.Split(interpreter) + if err != nil { + return nil, nil, err + } } envvars, err = e.getRuntimeEnv(ctx, tool, args, envvars) @@ -265,17 +276,6 @@ func (e *Engine) newCommand(ctx context.Context, extraEnv []string, tool types.T } envvars, envMap := envAsMapAndDeDup(envvars) - for i, arg := range args { - args[i] = os.Expand(arg, func(s string) string { - return envMap[s] - }) - } - - // After we determined the interpreter we again interpret the args by env vars - args, err = replaceVariablesForInterpreter(interpreter, envMap) - if err != nil { - return nil, nil, err - } if runtime.GOOS == "windows" && (args[0] == "/bin/bash" || args[0] == "/bin/sh") { args[0] = path.Base(args[0]) @@ -286,8 +286,7 @@ func (e *Engine) newCommand(ctx context.Context, extraEnv []string, tool types.T } var ( - cmdArgs = args[1:] - stop = func() {} + stop = func() {} ) if strings.TrimSpace(rest) != "" { @@ -305,105 +304,33 @@ func (e *Engine) newCommand(ctx context.Context, extraEnv []string, tool types.T stop() return nil, nil, err } - cmdArgs = append(cmdArgs, f.Name()) - } - - // This is a workaround for Windows, where the command interpreter is constructed with unix style paths - // It converts unix style paths to windows style paths - if runtime.GOOS == "windows" { - parts := strings.Split(args[0], "/") - if parts[len(parts)-1] == "gptscript-go-tool" { - parts[len(parts)-1] = "gptscript-go-tool.exe" - } - - args[0] = filepath.Join(parts...) + args = append(args, f.Name()) } - cmd := exec.CommandContext(ctx, env.Lookup(envvars, args[0]), cmdArgs...) - cmd.Env = compressEnv(envvars) - return cmd, stop, nil -} - -func replaceVariablesForInterpreter(interpreter string, envMap map[string]string) ([]string, error) { - var parts []string - for i, part := range splitByQuotes(interpreter) { - if i%2 == 0 { - part = os.Expand(part, func(s string) string { + // Expand and/or normalize env references + for i, arg := range args { + args[i] = os.Expand(arg, func(s string) string { + if strings.HasPrefix(s, "!") { + return envMap[s[1:]] + } + if !useShell { return envMap[s] - }) - // We protect newly resolved env vars from getting replaced when we do the second Expand - // after shlex. Yeah, crazy. I'm guessing this isn't secure, but just trying to avoid a foot gun. - part = os.Expand(part, func(s string) string { - return "${__" + s + "}" - }) - } - parts = append(parts, part) - } - - parts, err := shlex.Split(strings.Join(parts, "")) - if err != nil { - return nil, err - } - - for i, part := range parts { - parts[i] = os.Expand(part, func(s string) string { - if strings.HasPrefix(s, "__") { - return "${" + s[2:] + "}" } - return envMap[s] + return "${" + s + "}" }) } - return parts, nil -} - -// splitByQuotes will split a string by parsing matching double quotes (with \ as the escape character). -// The return value conforms to the following properties -// 1. s == strings.Join(result, "") -// 2. Even indexes are strings that were not in quotes. -// 3. Odd indexes are strings that were quoted. -// -// Example: s = `In a "quoted string" quotes can be escaped with \"` -// -// result = [`In a `, `"quoted string"`, ` quotes can be escaped with \"`] -func splitByQuotes(s string) (result []string) { - var ( - buf strings.Builder - inEscape, inQuote bool - ) - - for _, c := range s { - if inEscape { - buf.WriteRune(c) - inEscape = false - continue - } - - switch c { - case '"': - if inQuote { - buf.WriteRune(c) - } - result = append(result, buf.String()) - buf.Reset() - if !inQuote { - buf.WriteRune(c) - } - inQuote = !inQuote - case '\\': - inEscape = true - buf.WriteRune(c) - default: - buf.WriteRune(c) - } + if runtime.GOOS == "windows" { + args[0] = strings.ReplaceAll(args[0], "/", "\\") } - if buf.Len() > 0 { - if inQuote { - result = append(result, "") - } - result = append(result, buf.String()) + if useShell { + args = append([]string{"/bin/sh", "-c"}, strings.Join(args, " ")) + } else { + args[0] = env.Lookup(envvars, args[0]) } - return + cmd := exec.CommandContext(ctx, args[0], args[1:]...) + cmd.Env = compressEnv(envvars) + return cmd, stop, nil } diff --git a/pkg/engine/cmd_test.go b/pkg/engine/cmd_test.go deleted file mode 100644 index 15f72036..00000000 --- a/pkg/engine/cmd_test.go +++ /dev/null @@ -1,135 +0,0 @@ -// File: cmd_test.go -package engine - -import "testing" - -func TestSplitByQuotes(t *testing.T) { - tests := []struct { - name string - input string - expected []string - }{ - { - name: "NoQuotes", - input: "Hello World", - expected: []string{"Hello World"}, - }, - { - name: "ValidQuote", - input: `"Hello" "World"`, - expected: []string{``, `"Hello"`, ` `, `"World"`}, - }, - { - name: "ValidQuoteWithEscape", - input: `"Hello\" World"`, - expected: []string{``, `"Hello\" World"`}, - }, - { - name: "Nothing", - input: "", - expected: []string{}, - }, - { - name: "SpaceInsideQuote", - input: `"Hello World"`, - expected: []string{``, `"Hello World"`}, - }, - { - name: "SingleChar", - input: "H", - expected: []string{"H"}, - }, - { - name: "SingleQuote", - input: `"Hello`, - expected: []string{``, ``, `"Hello`}, - }, - { - name: "ThreeQuotes", - input: `Test "Hello "World" End\"`, - expected: []string{`Test `, `"Hello "`, `World`, ``, `" End\"`}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := splitByQuotes(tt.input) - if !equal(got, tt.expected) { - t.Errorf("splitByQuotes() = %v, want %v", got, tt.expected) - } - }) - } -} - -// Helper function to assert equality of two string slices. -func equal(a, b []string) bool { - if len(a) != len(b) { - return false - } - for i, v := range a { - if v != b[i] { - return false - } - } - return true -} - -// Testing for replaceVariablesForInterpreter -func TestReplaceVariablesForInterpreter(t *testing.T) { - tests := []struct { - name string - interpreter string - envMap map[string]string - expected []string - shouldFail bool - }{ - { - name: "No quotes", - interpreter: "/bin/bash -c ${COMMAND} tail", - envMap: map[string]string{"COMMAND": "echo Hello!"}, - expected: []string{"/bin/bash", "-c", "echo", "Hello!", "tail"}, - }, - { - name: "Quotes Variables", - interpreter: `/bin/bash -c "${COMMAND}" tail`, - envMap: map[string]string{"COMMAND": "Hello, World!"}, - expected: []string{"/bin/bash", "-c", "Hello, World!", "tail"}, - }, - { - name: "Double escape", - interpreter: `/bin/bash -c "${COMMAND}" ${TWO} tail`, - envMap: map[string]string{ - "COMMAND": "Hello, World!", - "TWO": "${COMMAND}", - }, - expected: []string{"/bin/bash", "-c", "Hello, World!", "${COMMAND}", "tail"}, - }, - { - name: "aws cli issue", - interpreter: "aws ${ARGS}", - envMap: map[string]string{ - "ARGS": `ec2 describe-instances --region us-east-1 --query 'Reservations[*].Instances[*].{Instance:InstanceId,State:State.Name}'`, - }, - expected: []string{ - `aws`, - `ec2`, - `describe-instances`, - `--region`, `us-east-1`, - `--query`, `Reservations[*].Instances[*].{Instance:InstanceId,State:State.Name}`, - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := replaceVariablesForInterpreter(tt.interpreter, tt.envMap) - if (err != nil) != tt.shouldFail { - t.Errorf("replaceVariablesForInterpreter() error = %v, want %v", err, tt.shouldFail) - return - } - if !equal(got, tt.expected) { - t.Errorf("replaceVariablesForInterpreter() = %v, want %v", got, tt.expected) - } - }) - } -} diff --git a/pkg/engine/daemon.go b/pkg/engine/daemon.go index 4cdab995..113aa1ba 100644 --- a/pkg/engine/daemon.go +++ b/pkg/engine/daemon.go @@ -133,6 +133,7 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) { }, tool, "{}", + false, ) if err != nil { return url, err