diff --git a/pkg/sqlcmd/sqlcmd.go b/pkg/sqlcmd/sqlcmd.go index d0ebb491..7c2b8d3b 100644 --- a/pkg/sqlcmd/sqlcmd.go +++ b/pkg/sqlcmd/sqlcmd.go @@ -68,8 +68,9 @@ type Sqlcmd struct { Query string Cmd Commands // PrintError allows the host to redirect errors away from the default output. Returns false if the error is not redirected by the host. - PrintError func(msg string, severity uint8) bool - UnicodeOutputFile bool + PrintError func(msg string, severity uint8) bool + UnicodeOutputFile bool + IsInteractiveSession bool } // New creates a new Sqlcmd instance @@ -86,10 +87,13 @@ func New(l Console, workingDirectory string, vars *Variables) *Sqlcmd { s.PrintError = func(msg string, severity uint8) bool { return false } + s.SetOutput(os.Stdout) + s.SetError(os.Stderr) return s } func (s *Sqlcmd) scanNext() (string, error) { + s.IsInteractiveSession = true return s.lineIo.Readline() } @@ -134,7 +138,11 @@ func (s *Sqlcmd) Run(once bool, processAll bool) error { args = make([]string, 0) once = true } else { - _, _ = s.GetOutput().Write([]byte(err.Error() + SqlcmdEol)) + if iactive && s.IsInteractiveSession { + _, _ = s.GetOutput().Write([]byte(err.Error() + SqlcmdEol)) + } else { + _, _ = s.GetError().Write([]byte(err.Error() + SqlcmdEol)) + } } } if cmd != nil { @@ -144,7 +152,11 @@ func (s *Sqlcmd) Run(once bool, processAll bool) error { break } if err != nil { - _, _ = s.GetOutput().Write([]byte(err.Error() + SqlcmdEol)) + if iactive && s.IsInteractiveSession { + _, _ = s.GetOutput().Write([]byte(err.Error() + SqlcmdEol)) + } else { + _, _ = s.GetError().Write([]byte(err.Error() + SqlcmdEol)) + } lastError = err } } diff --git a/pkg/sqlcmd/sqlcmd_test.go b/pkg/sqlcmd/sqlcmd_test.go index f7909334..accaa965 100644 --- a/pkg/sqlcmd/sqlcmd_test.go +++ b/pkg/sqlcmd/sqlcmd_test.go @@ -109,6 +109,44 @@ func TestSqlCmdQueryAndExit(t *testing.T) { } } +func TestSqlCmdOutputAndError(t *testing.T) { + s, outfile, errfile := setupSqlcmdWithFileErrorOutput(t) + defer os.Remove(outfile.Name()) + defer os.Remove(errfile.Name()) + s.Query = "select $(X" + err := s.Run(true, false) + if assert.NoError(t, err, "s.Run(once = true)") { + bytes, err := os.ReadFile(errfile.Name()) + if assert.NoError(t, err, "os.ReadFile") { + assert.Equal(t, "Sqlcmd: Error: Syntax error at line 1."+SqlcmdEol, string(bytes), "Expected syntax error not received for query execution") + } + } + s.Query = "select '1'" + err = s.Run(true, false) + if assert.NoError(t, err, "s.Run(once = true)") { + bytes, err := os.ReadFile(outfile.Name()) + if assert.NoError(t, err, "os.ReadFile") { + assert.Equal(t, "1"+SqlcmdEol+SqlcmdEol+"(1 row affected)"+SqlcmdEol, string(bytes), "Unexpected output for query execution") + } + } + + s, outfile, errfile = setupSqlcmdWithFileErrorOutput(t) + defer os.Remove(outfile.Name()) + defer os.Remove(errfile.Name()) + dataPath := "testdata" + string(os.PathSeparator) + err = s.IncludeFile(dataPath+"teststdouterr.sql", false) + if assert.NoError(t, err, "IncludeFile teststdouterr.sql false") { + bytes, err := os.ReadFile(outfile.Name()) + if assert.NoError(t, err, "os.ReadFile outfile") { + assert.Equal(t, "1"+SqlcmdEol+SqlcmdEol+"(1 row affected)"+SqlcmdEol, string(bytes), "Unexpected output for sql file execution in outfile") + } + bytes, err = os.ReadFile(errfile.Name()) + if assert.NoError(t, err, "os.ReadFile errfile") { + assert.Equal(t, "Sqlcmd: Error: Syntax error at line 1."+SqlcmdEol, string(bytes), "Expected syntax error not found in errfile") + } + } +} + // Simulate :r command func TestIncludeFileNoExecutions(t *testing.T) { s, file := setupSqlcmdWithFileOutput(t) @@ -476,6 +514,7 @@ func setupSqlCmdWithMemoryOutput(t testing.TB) (*Sqlcmd, *memoryBuffer) { s.Format = NewSQLCmdDefaultFormatter(true) buf := &memoryBuffer{buf: new(bytes.Buffer)} s.SetOutput(buf) + s.SetError(buf) err := s.ConnectDb(nil, true) assert.NoError(t, err, "s.ConnectDB") return s, buf @@ -491,6 +530,7 @@ func setupSqlcmdWithFileOutput(t testing.TB) (*Sqlcmd, *os.File) { file, err := os.CreateTemp("", "sqlcmdout") assert.NoError(t, err, "os.CreateTemp") s.SetOutput(file) + s.SetError(file) err = s.ConnectDb(nil, true) if err != nil { os.Remove(file.Name()) @@ -499,6 +539,28 @@ func setupSqlcmdWithFileOutput(t testing.TB) (*Sqlcmd, *os.File) { return s, file } +func setupSqlcmdWithFileErrorOutput(t testing.TB) (*Sqlcmd, *os.File, *os.File) { + t.Helper() + v := InitializeVariables(true) + v.Set(SQLCMDMAXVARTYPEWIDTH, "0") + s := New(nil, "", v) + s.Connect = newConnect(t) + s.Format = NewSQLCmdDefaultFormatter(true) + outfile, err := os.CreateTemp("", "sqlcmdout") + assert.NoError(t, err, "os.CreateTemp") + errfile, err := os.CreateTemp("", "sqlcmderr") + assert.NoError(t, err, "os.CreateTemp") + s.SetOutput(outfile) + s.SetError(errfile) + err = s.ConnectDb(nil, true) + if err != nil { + os.Remove(outfile.Name()) + os.Remove(errfile.Name()) + } + assert.NoError(t, err, "s.ConnectDB") + return s, outfile, errfile +} + // Assuming public Azure, use AAD when SQLCMDUSER environment variable is not set func canTestAzureAuth() bool { server := os.Getenv(SQLCMDSERVER) diff --git a/pkg/sqlcmd/testdata/teststdouterr.sql b/pkg/sqlcmd/testdata/teststdouterr.sql new file mode 100644 index 00000000..d09b4113 --- /dev/null +++ b/pkg/sqlcmd/testdata/teststdouterr.sql @@ -0,0 +1,4 @@ +select $(X +go +select '1' +go \ No newline at end of file