Skip to content

Commit 8116d58

Browse files
authored
Implement !! command (#125)
* update PR pipelines * implement exec command * remove unused import * fix bash args * fix linux command args
1 parent a99126e commit 8116d58

File tree

7 files changed

+124
-12
lines changed

7 files changed

+124
-12
lines changed

.github/workflows/golangci-lint.yml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@ jobs:
99
name: lint-pr-changes
1010
runs-on: ubuntu-latest
1111
steps:
12-
- uses: actions/checkout@v2
12+
- uses: actions/setup-go@v3
13+
with:
14+
go-version: 1.18
15+
- uses: actions/checkout@v3
1316
- name: golangci-lint
14-
uses: golangci/golangci-lint-action@v2
17+
uses: golangci/golangci-lint-action@v3
1518
with:
16-
version: v1.42.0
19+
version: latest
1720
only-new-issues: true

.pipelines/TestSql2017.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,5 @@ steps:
4949
env:
5050
disable.coverage.autogenerate: 'true'
5151

52+
- task: ms.vss-governance-buildtask.governance-build-task-component-detection.ComponentGovernanceComponentDetection@0
53+
displayName: ‘Component Detection’

pkg/sqlcmd/commands.go

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,11 @@ func newCommands() Commands {
8888
action: connectCommand,
8989
name: "CONNECT",
9090
},
91+
"EXEC": {
92+
regex: regexp.MustCompile(`(?im)^[ \t]*?:?!!(?:[ \t]+(.*$)|$)`),
93+
action: execCommand,
94+
name: "EXEC",
95+
},
9196
}
9297

9398
}
@@ -172,6 +177,9 @@ func goCommand(s *Sqlcmd, args []string, line uint) error {
172177
if len(args) > 0 {
173178
cnt := strings.TrimSpace(args[0])
174179
if cnt != "" {
180+
if cnt, err = resolveArgumentVariables(s, []rune(cnt), true); err != nil {
181+
return err
182+
}
175183
_, err = fmt.Sscanf(cnt, "%d", &n)
176184
}
177185
}
@@ -245,7 +253,8 @@ func readFileCommand(s *Sqlcmd, args []string, line uint) error {
245253
if args == nil || len(args) != 1 {
246254
return InvalidCommandError(":R", line)
247255
}
248-
return s.IncludeFile(resolveArgumentVariables(s, []rune(args[0])), false)
256+
fileName, _ := resolveArgumentVariables(s, []rune(args[0]), false)
257+
return s.IncludeFile(fileName, false)
249258
}
250259

251260
// setVarCommand parses a variable setting and applies it to the current Sqlcmd variables
@@ -345,10 +354,10 @@ func connectCommand(s *Sqlcmd, args []string, line uint) error {
345354
}
346355

347356
connect := s.Connect
348-
connect.UserName = resolveArgumentVariables(s, []rune(arguments.Username))
349-
connect.Password = resolveArgumentVariables(s, []rune(arguments.Password))
350-
connect.ServerName = resolveArgumentVariables(s, []rune(arguments.Server))
351-
timeout := resolveArgumentVariables(s, []rune(arguments.LoginTimeout))
357+
connect.UserName, _ = resolveArgumentVariables(s, []rune(arguments.Username), false)
358+
connect.Password, _ = resolveArgumentVariables(s, []rune(arguments.Password), false)
359+
connect.ServerName, _ = resolveArgumentVariables(s, []rune(arguments.Server), false)
360+
timeout, _ := resolveArgumentVariables(s, []rune(arguments.LoginTimeout), false)
352361
if timeout != "" {
353362
if timeoutSeconds, err := strconv.ParseInt(timeout, 10, 32); err == nil {
354363
if timeoutSeconds < 0 {
@@ -364,7 +373,26 @@ func connectCommand(s *Sqlcmd, args []string, line uint) error {
364373
return nil
365374
}
366375

367-
func resolveArgumentVariables(s *Sqlcmd, arg []rune) string {
376+
func execCommand(s *Sqlcmd, args []string, line uint) error {
377+
if len(args) == 0 {
378+
return InvalidCommandError("EXEC", line)
379+
}
380+
cmdLine := strings.TrimSpace(args[0])
381+
if cmdLine == "" {
382+
return InvalidCommandError("EXEC", line)
383+
}
384+
if cmdLine, err := resolveArgumentVariables(s, []rune(cmdLine), true); err != nil {
385+
return err
386+
} else {
387+
cmd := sysCommand(cmdLine)
388+
cmd.Stderr = s.GetError()
389+
cmd.Stdout = s.GetOutput()
390+
_ = cmd.Run()
391+
}
392+
return nil
393+
}
394+
395+
func resolveArgumentVariables(s *Sqlcmd, arg []rune, failOnUnresolved bool) (string, error) {
368396
var b *strings.Builder
369397
end := len(arg)
370398
for i := 0; i < end; {
@@ -383,6 +411,9 @@ func resolveArgumentVariables(s *Sqlcmd, arg []rune) string {
383411
}
384412
b.WriteString(val)
385413
} else {
414+
if failOnUnresolved {
415+
return "", UndefinedVariable(varName)
416+
}
386417
_, _ = s.GetError().Write([]byte(UndefinedVariable(varName).Error() + SqlcmdEol))
387418
if b != nil {
388419
b.WriteString(string(arg[i : vl+1]))
@@ -403,7 +434,7 @@ func resolveArgumentVariables(s *Sqlcmd, arg []rune) string {
403434
}
404435
}
405436
if b == nil {
406-
return string(arg)
437+
return string(arg), nil
407438
}
408-
return b.String()
439+
return b.String(), nil
409440
}

pkg/sqlcmd/commands_test.go

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ func TestCommandParsing(t *testing.T) {
4646
{`EXIT `, "EXIT", []string{""}},
4747
{`:Connect someserver -U someuser`, "CONNECT", []string{"someserver -U someuser"}},
4848
{`:r c:\$(var)\file.sql`, "READFILE", []string{`c:\$(var)\file.sql`}},
49+
{`:!! notepad`, "EXEC", []string{"notepad"}},
50+
{` !! dir c:\`, "EXEC", []string{`dir c:\`}},
4951
}
5052

5153
for _, test := range commands {
@@ -242,9 +244,26 @@ func TestResolveArgumentVariables(t *testing.T) {
242244
defer buf.Close()
243245
s.SetError(buf)
244246
for _, test := range args {
245-
actual := resolveArgumentVariables(s, []rune(test.arg))
247+
actual, _ := resolveArgumentVariables(s, []rune(test.arg), false)
246248
assert.Equal(t, test.val, actual, "Incorrect argument parsing of "+test.arg)
247249
assert.Contains(t, buf.buf.String(), test.err, "Error output mismatch for "+test.arg)
248250
buf.buf.Reset()
249251
}
252+
actual, err := resolveArgumentVariables(s, []rune("$(var1)$(var2)"), true)
253+
if assert.ErrorContains(t, err, UndefinedVariable("var2").Error(), "fail on unresolved variable") {
254+
assert.Empty(t, actual, "fail on unresolved variable")
255+
}
256+
}
257+
258+
func TestExecCommand(t *testing.T) {
259+
vars := InitializeVariables(false)
260+
s := New(nil, "", vars)
261+
s.vars.Set("var1", "hello")
262+
buf := &memoryBuffer{buf: new(bytes.Buffer)}
263+
defer buf.Close()
264+
s.SetOutput(buf)
265+
err := execCommand(s, []string{`echo $(var1)`}, 1)
266+
if assert.NoError(t, err, "execCommand with valid arguments") {
267+
assert.Equal(t, buf.buf.String(), "hello"+SqlcmdEol, "echo output should be in sqlcmd output")
268+
}
250269
}

pkg/sqlcmd/exec_darwin.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package sqlcmd
2+
3+
import (
4+
"os/exec"
5+
)
6+
7+
func sysCommand(arg string) *exec.Cmd {
8+
cmd := exec.Command(comSpec(), "-c", arg)
9+
return cmd
10+
}
11+
12+
// comSpec returns the path of the command shell executable
13+
func comSpec() string {
14+
// /bin/sh will be a link to the shell
15+
return `/bin/sh`
16+
}

pkg/sqlcmd/exec_linux.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package sqlcmd
2+
3+
import (
4+
"os/exec"
5+
)
6+
7+
func sysCommand(arg string) *exec.Cmd {
8+
cmd := exec.Command(comSpec(), "-c", arg)
9+
return cmd
10+
}
11+
12+
// comSpec returns the path of the command shell executable
13+
func comSpec() string {
14+
// /bin/sh will be a link to the shell
15+
return `/bin/sh`
16+
}

pkg/sqlcmd/exec_windows.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package sqlcmd
2+
3+
import (
4+
"os"
5+
"os/exec"
6+
"syscall"
7+
)
8+
9+
func sysCommand(arg string) *exec.Cmd {
10+
cmd := exec.Command(comSpec())
11+
cmd.SysProcAttr = &syscall.SysProcAttr{CmdLine: cmd.Path + " " + comArgs(arg)}
12+
return cmd
13+
}
14+
15+
// comSpec returns the path of the command shell executable
16+
func comSpec() string {
17+
if cmd, ok := os.LookupEnv("COMSPEC"); ok {
18+
return cmd
19+
}
20+
return `C:\Windows\System32\cmd.exe`
21+
}
22+
23+
func comArgs(args string) string {
24+
return `/c ` + args
25+
}

0 commit comments

Comments
 (0)