diff --git a/compress.go b/compress.go index e247a65a..38bfa000 100644 --- a/compress.go +++ b/compress.go @@ -113,12 +113,11 @@ func (c *compIO) readCompressedPacket() error { // Server may return error packet (e.g. 1153 Got a packet bigger than 'max_allowed_packet' bytes) // before receiving all packets from client. In this case, seqnr is younger than expected. // NOTE: Both of mariadbclient and mysqlclient do not check seqnr. Only server checks it. - if debug && compressionSequence != c.mc.sequence { + if debug && compressionSequence != c.mc.compressSequence { fmt.Printf("WARN: unexpected cmpress seq nr: expected %v, got %v", - c.mc.sequence, compressionSequence) + c.mc.compressSequence, compressionSequence) } - c.mc.sequence = compressionSequence + 1 - c.mc.compressSequence = c.mc.sequence + c.mc.compressSequence = compressionSequence + 1 comprData, err := c.mc.readNext(comprLength) if err != nil { @@ -200,7 +199,7 @@ func (c *compIO) writeCompressedPacket(data []byte, uncompressedLen int) (int, e comprLength := len(data) - 7 if debug { fmt.Printf( - "writeCompressedPacket: comprLength=%v, uncompressedLen=%v, seq=%v", + "writeCompressedPacket: comprLength=%v, uncompressedLen=%v, seq=%v\n", comprLength, uncompressedLen, mc.compressSequence) } diff --git a/packets.go b/packets.go index 15b000d6..831fca6c 100644 --- a/packets.go +++ b/packets.go @@ -17,6 +17,7 @@ import ( "fmt" "io" "math" + "os" "strconv" "time" ) @@ -62,17 +63,11 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { pktLen := getUint24(data[:3]) seq := data[3] - if mc.compress { + // check packet sync [8 bit] + if seq != mc.sequence { + mc.log(fmt.Sprintf("[warn] unexpected sequence nr: expected %v, got %v", mc.sequence, seq)) // MySQL and MariaDB doesn't check packet nr in compressed packet. - if debug && seq != mc.compressSequence { - fmt.Printf("[debug] mismatched compression sequence nr: expected: %v, got %v", - mc.compressSequence, seq) - } - mc.compressSequence = seq + 1 - } else { - // check packet sync [8 bit] - if seq != mc.sequence { - mc.log(fmt.Sprintf("[warn] unexpected seq nr: expected %v, got %v", mc.sequence, seq)) + if !mc.compress { // For large packets, we stop reading as soon as sync error. if len(prevData) > 0 { mc.close() @@ -80,8 +75,8 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { } invalidSequence = true } - mc.sequence++ } + mc.sequence = seq + 1 // packets with length 0 terminate a previous packet which is a // multiple of (2^24)-1 bytes long @@ -146,7 +141,7 @@ func (mc *mysqlConn) writePacket(data []byte) error { // Write packet if debug { - fmt.Printf("writePacket: size=%v seq=%v", size, mc.sequence) + fmt.Fprintf(os.Stderr, "writePacket: size=%v seq=%v\n", size, mc.sequence) } n, err := writeFunc(data[:4+size]) @@ -445,7 +440,9 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error { data[4] = command // Send CMD packet - return mc.writePacket(data) + err = mc.writePacket(data) + mc.syncSequence() + return err } func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { @@ -486,7 +483,9 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { binary.LittleEndian.PutUint32(data[5:], arg) // Send CMD packet - return mc.writePacket(data) + err = mc.writePacket(data) + mc.syncSequence() + return err } /****************************************************************************** @@ -956,7 +955,6 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { pktLen = dataOffset + argLen } - stmt.mc.resetSequence() // Add command byte [1 byte] data[4] = comStmtSendLongData @@ -968,6 +966,8 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { // Send CMD packet err := stmt.mc.writePacket(data[:4+pktLen]) + // Every COM_LONG_DATA packet reset Packet Sequence + stmt.mc.resetSequence() if err == nil { data = data[pktLen-dataOffset:] continue @@ -975,8 +975,6 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { return err } - // Reset Packet Sequence - stmt.mc.resetSequence() return nil }