diff --git a/AUTHORS b/AUTHORS index 52a0d12e2..7191cfac6 100644 --- a/AUTHORS +++ b/AUTHORS @@ -23,6 +23,7 @@ Luke Scott Michael Woolnough Nicola Peduzzi Xiaobing Jiang +Xiuming Chen # Organizations diff --git a/infile.go b/infile.go index 518946d0e..952bd3d6d 100644 --- a/infile.go +++ b/infile.go @@ -114,10 +114,6 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { for err == nil && ioErr == nil { n, err = rdr.Read(data[4:]) if n > 0 { - data[0] = byte(n) - data[1] = byte(n >> 8) - data[2] = byte(n >> 16) - data[3] = mc.sequence ioErr = mc.writePacket(data[:4+n]) } } diff --git a/packets.go b/packets.go index 8ee32cae6..c6a38d36d 100644 --- a/packets.go +++ b/packets.go @@ -75,48 +75,37 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { } // Write packet buffer 'data' -// The packet header must be already included func (mc *mysqlConn) writePacket(data []byte) error { - if len(data)-4 <= mc.maxWriteSize { // Can send data at once - // Write packet - n, err := mc.netConn.Write(data) - if err == nil && n == len(data) { - mc.sequence++ - return nil - } - - // Handle error - if err == nil { // n != len(data) - errLog.Print(errMalformPkt.Error()) - } else { - errLog.Print(err.Error()) - } - return driver.ErrBadConn - } - - // Must split packet - return mc.splitPacket(data) -} - -func (mc *mysqlConn) splitPacket(data []byte) error { pktLen := len(data) - 4 if pktLen > mc.maxPacketAllowed { return errPktTooLarge } - for pktLen >= maxPacketSize { - data[0] = 0xff - data[1] = 0xff - data[2] = 0xff + for { + var size int + if pktLen >= maxPacketSize { + data[0] = 0xff + data[1] = 0xff + data[2] = 0xff + size = maxPacketSize + } else { + data[0] = byte(pktLen) + data[1] = byte(pktLen >> 8) + data[2] = byte(pktLen >> 16) + size = pktLen + } data[3] = mc.sequence // Write packet - n, err := mc.netConn.Write(data[:4+maxPacketSize]) - if err == nil && n == 4+maxPacketSize { + n, err := mc.netConn.Write(data[:4+size]) + if err == nil && n == 4+size { mc.sequence++ - data = data[maxPacketSize:] - pktLen -= maxPacketSize + if size != maxPacketSize { + return nil + } + pktLen -= size + data = data[size:] continue } @@ -128,12 +117,6 @@ func (mc *mysqlConn) splitPacket(data []byte) error { } return driver.ErrBadConn } - - data[0] = byte(pktLen) - data[1] = byte(pktLen >> 8) - data[2] = byte(pktLen >> 16) - data[3] = mc.sequence - return mc.writePacket(data) } /****************************************************************************** @@ -265,12 +248,6 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { // SSL Connection Request Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest if mc.cfg.tls != nil { - // Packet header [24bit length + 1 byte sequence] - data[0] = byte((4 + 4 + 1 + 23)) - data[1] = byte((4 + 4 + 1 + 23) >> 8) - data[2] = byte((4 + 4 + 1 + 23) >> 16) - data[3] = mc.sequence - // Send TLS / SSL request packet if err := mc.writePacket(data[:(4+4+1+23)+4]); err != nil { return err @@ -285,12 +262,6 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { mc.buf.rd = tlsConn } - // Add the packet header [24bit length + 1 byte sequence] - data[0] = byte(pktLen) - data[1] = byte(pktLen >> 8) - data[2] = byte(pktLen >> 16) - data[3] = mc.sequence - // Filler [23 bytes] (all 0x00) pos := 13 + 23 @@ -330,12 +301,6 @@ func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error { return driver.ErrBadConn } - // Add the packet header [24bit length + 1 byte sequence] - data[0] = byte(pktLen) - data[1] = byte(pktLen >> 8) - data[2] = byte(pktLen >> 16) - data[3] = mc.sequence - // Add the scrambled password [null terminated string] copy(data[4:], scrambleBuff) @@ -357,12 +322,6 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error { return driver.ErrBadConn } - // Add the packet header [24bit length + 1 byte sequence] - data[0] = 0x01 // 1 byte long - data[1] = 0x00 - data[2] = 0x00 - data[3] = 0x00 // new command, sequence id is always 0 - // Add command byte data[4] = command @@ -382,12 +341,6 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { return driver.ErrBadConn } - // Add the packet header [24bit length + 1 byte sequence] - data[0] = byte(pktLen) - data[1] = byte(pktLen >> 8) - data[2] = byte(pktLen >> 16) - data[3] = 0x00 // new command, sequence id is always 0 - // Add command byte data[4] = command @@ -409,12 +362,6 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { return driver.ErrBadConn } - // Add the packet header [24bit length + 1 byte sequence] - data[0] = 0x05 // 5 bytes long - data[1] = 0x00 - data[2] = 0x00 - data[3] = 0x00 // new command, sequence id is always 0 - // Add command byte data[4] = command @@ -748,12 +695,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { pktLen = dataOffset + argLen } - // Add the packet header [24bit length + 1 byte sequence] - data[0] = byte(pktLen) - data[1] = byte(pktLen >> 8) - data[2] = byte(pktLen >> 16) - data[3] = 0x00 // mc.sequence - + stmt.mc.sequence = 0 // Add command byte [1 byte] data[4] = comStmtSendLongData @@ -801,28 +743,14 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { var data []byte if len(args) == 0 { - const pktLen = 1 + 4 + 1 + 4 - data = mc.buf.takeBuffer(4 + pktLen) - if data == nil { - // can not take the buffer. Something must be wrong with the connection - errLog.Print("Busy buffer") - return driver.ErrBadConn - } - - // packet header [4 bytes] - data[0] = byte(pktLen) - data[1] = byte(pktLen >> 8) - data[2] = byte(pktLen >> 16) - data[3] = 0x00 // new command, sequence id is always 0 + data = mc.buf.takeBuffer(4 + 1 + 4 + 1 + 4) } else { data = mc.buf.takeCompleteBuffer() - if data == nil { - // can not take the buffer. Something must be wrong with the connection - errLog.Print("Busy buffer") - return driver.ErrBadConn - } - - // header (bytes 0-3) is added after we know the packet size + } + if data == nil { + // can not take the buffer. Something must be wrong with the connection + errLog.Print("Busy buffer") + return driver.ErrBadConn } // command [1 byte] @@ -984,14 +912,6 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { pos += len(paramValues) data = data[:pos] - pktLen := pos - 4 - - // packet header [4 bytes] - data[0] = byte(pktLen) - data[1] = byte(pktLen >> 8) - data[2] = byte(pktLen >> 16) - data[3] = mc.sequence - // Convert nullMask to bytes for i, max := 0, (stmt.paramCount+7)>>3; i < max; i++ { data[i+14] = byte(nullMask >> uint(i<<3))