Skip to content

Implement !! command #125

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Sep 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions .github/workflows/golangci-lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@ jobs:
name: lint-pr-changes
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/setup-go@v3
with:
go-version: 1.18
- uses: actions/checkout@v3
- name: golangci-lint
uses: golangci/golangci-lint-action@v2
uses: golangci/golangci-lint-action@v3
with:
version: v1.42.0
version: latest
only-new-issues: true
2 changes: 2 additions & 0 deletions .pipelines/TestSql2017.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,5 @@ steps:
env:
disable.coverage.autogenerate: 'true'

- task: ms.vss-governance-buildtask.governance-build-task-component-detection.ComponentGovernanceComponentDetection@0
displayName: ‘Component Detection’
47 changes: 39 additions & 8 deletions pkg/sqlcmd/commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ func newCommands() Commands {
action: connectCommand,
name: "CONNECT",
},
"EXEC": {
regex: regexp.MustCompile(`(?im)^[ \t]*?:?!!(?:[ \t]+(.*$)|$)`),
action: execCommand,
name: "EXEC",
},
}

}
Expand Down Expand Up @@ -172,6 +177,9 @@ func goCommand(s *Sqlcmd, args []string, line uint) error {
if len(args) > 0 {
cnt := strings.TrimSpace(args[0])
if cnt != "" {
if cnt, err = resolveArgumentVariables(s, []rune(cnt), true); err != nil {
return err
}
_, err = fmt.Sscanf(cnt, "%d", &n)
}
}
Expand Down Expand Up @@ -245,7 +253,8 @@ func readFileCommand(s *Sqlcmd, args []string, line uint) error {
if args == nil || len(args) != 1 {
return InvalidCommandError(":R", line)
}
return s.IncludeFile(resolveArgumentVariables(s, []rune(args[0])), false)
fileName, _ := resolveArgumentVariables(s, []rune(args[0]), false)
return s.IncludeFile(fileName, false)
}

// setVarCommand parses a variable setting and applies it to the current Sqlcmd variables
Expand Down Expand Up @@ -345,10 +354,10 @@ func connectCommand(s *Sqlcmd, args []string, line uint) error {
}

connect := s.Connect
connect.UserName = resolveArgumentVariables(s, []rune(arguments.Username))
connect.Password = resolveArgumentVariables(s, []rune(arguments.Password))
connect.ServerName = resolveArgumentVariables(s, []rune(arguments.Server))
timeout := resolveArgumentVariables(s, []rune(arguments.LoginTimeout))
connect.UserName, _ = resolveArgumentVariables(s, []rune(arguments.Username), false)
connect.Password, _ = resolveArgumentVariables(s, []rune(arguments.Password), false)
connect.ServerName, _ = resolveArgumentVariables(s, []rune(arguments.Server), false)
timeout, _ := resolveArgumentVariables(s, []rune(arguments.LoginTimeout), false)
if timeout != "" {
if timeoutSeconds, err := strconv.ParseInt(timeout, 10, 32); err == nil {
if timeoutSeconds < 0 {
Expand All @@ -364,7 +373,26 @@ func connectCommand(s *Sqlcmd, args []string, line uint) error {
return nil
}

func resolveArgumentVariables(s *Sqlcmd, arg []rune) string {
func execCommand(s *Sqlcmd, args []string, line uint) error {
if len(args) == 0 {
return InvalidCommandError("EXEC", line)
}
cmdLine := strings.TrimSpace(args[0])
if cmdLine == "" {
return InvalidCommandError("EXEC", line)
}
if cmdLine, err := resolveArgumentVariables(s, []rune(cmdLine), true); err != nil {
return err
} else {
cmd := sysCommand(cmdLine)
cmd.Stderr = s.GetError()
cmd.Stdout = s.GetOutput()
_ = cmd.Run()
}
return nil
}

func resolveArgumentVariables(s *Sqlcmd, arg []rune, failOnUnresolved bool) (string, error) {
var b *strings.Builder
end := len(arg)
for i := 0; i < end; {
Expand All @@ -383,6 +411,9 @@ func resolveArgumentVariables(s *Sqlcmd, arg []rune) string {
}
b.WriteString(val)
} else {
if failOnUnresolved {
return "", UndefinedVariable(varName)
}
_, _ = s.GetError().Write([]byte(UndefinedVariable(varName).Error() + SqlcmdEol))
if b != nil {
b.WriteString(string(arg[i : vl+1]))
Expand All @@ -403,7 +434,7 @@ func resolveArgumentVariables(s *Sqlcmd, arg []rune) string {
}
}
if b == nil {
return string(arg)
return string(arg), nil
}
return b.String()
return b.String(), nil
}
21 changes: 20 additions & 1 deletion pkg/sqlcmd/commands_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ func TestCommandParsing(t *testing.T) {
{`EXIT `, "EXIT", []string{""}},
{`:Connect someserver -U someuser`, "CONNECT", []string{"someserver -U someuser"}},
{`:r c:\$(var)\file.sql`, "READFILE", []string{`c:\$(var)\file.sql`}},
{`:!! notepad`, "EXEC", []string{"notepad"}},
{` !! dir c:\`, "EXEC", []string{`dir c:\`}},
}

for _, test := range commands {
Expand Down Expand Up @@ -242,9 +244,26 @@ func TestResolveArgumentVariables(t *testing.T) {
defer buf.Close()
s.SetError(buf)
for _, test := range args {
actual := resolveArgumentVariables(s, []rune(test.arg))
actual, _ := resolveArgumentVariables(s, []rune(test.arg), false)
assert.Equal(t, test.val, actual, "Incorrect argument parsing of "+test.arg)
assert.Contains(t, buf.buf.String(), test.err, "Error output mismatch for "+test.arg)
buf.buf.Reset()
}
actual, err := resolveArgumentVariables(s, []rune("$(var1)$(var2)"), true)
if assert.ErrorContains(t, err, UndefinedVariable("var2").Error(), "fail on unresolved variable") {
assert.Empty(t, actual, "fail on unresolved variable")
}
}

func TestExecCommand(t *testing.T) {
vars := InitializeVariables(false)
s := New(nil, "", vars)
s.vars.Set("var1", "hello")
buf := &memoryBuffer{buf: new(bytes.Buffer)}
defer buf.Close()
s.SetOutput(buf)
err := execCommand(s, []string{`echo $(var1)`}, 1)
if assert.NoError(t, err, "execCommand with valid arguments") {
assert.Equal(t, buf.buf.String(), "hello"+SqlcmdEol, "echo output should be in sqlcmd output")
}
}
16 changes: 16 additions & 0 deletions pkg/sqlcmd/exec_darwin.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package sqlcmd

import (
"os/exec"
)

func sysCommand(arg string) *exec.Cmd {
cmd := exec.Command(comSpec(), "-c", arg)
return cmd
}

// comSpec returns the path of the command shell executable
func comSpec() string {
// /bin/sh will be a link to the shell
return `/bin/sh`
}
16 changes: 16 additions & 0 deletions pkg/sqlcmd/exec_linux.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package sqlcmd

import (
"os/exec"
)

func sysCommand(arg string) *exec.Cmd {
cmd := exec.Command(comSpec(), "-c", arg)
return cmd
}

// comSpec returns the path of the command shell executable
func comSpec() string {
// /bin/sh will be a link to the shell
return `/bin/sh`
}
25 changes: 25 additions & 0 deletions pkg/sqlcmd/exec_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package sqlcmd

import (
"os"
"os/exec"
"syscall"
)

func sysCommand(arg string) *exec.Cmd {
cmd := exec.Command(comSpec())
cmd.SysProcAttr = &syscall.SysProcAttr{CmdLine: cmd.Path + " " + comArgs(arg)}
return cmd
}

// comSpec returns the path of the command shell executable
func comSpec() string {
if cmd, ok := os.LookupEnv("COMSPEC"); ok {
return cmd
}
return `C:\Windows\System32\cmd.exe`
}

func comArgs(args string) string {
return `/c ` + args
}