Skip to content

Reuse buffers when parsing #52

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
85 changes: 0 additions & 85 deletions buffer.go

This file was deleted.

7 changes: 4 additions & 3 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
package mysql

import (
"bufio"
"database/sql/driver"
"errors"
"net"
Expand All @@ -22,7 +23,7 @@ type mysqlConn struct {
charset byte
cipher []byte
netConn net.Conn
buf *buffer
buf *bufio.Reader
protocol uint8
sequence uint8
affectedRows uint64
Expand Down Expand Up @@ -182,7 +183,7 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
var resLen int
resLen, err = mc.readResultSetHeaderPacket()
if err == nil {
rows := &mysqlRows{mc, false, nil, false}
rows := &mysqlRows{mc, false, nil, false, nil}

if resLen > 0 {
// Columns
Expand All @@ -208,7 +209,7 @@ func (mc *mysqlConn) getSystemVar(name string) (val []byte, err error) {
var resLen int
resLen, err = mc.readResultSetHeaderPacket()
if err == nil {
rows := &mysqlRows{mc, false, nil, false}
rows := &mysqlRows{mc, false, nil, false, nil}

if resLen > 0 {
// Columns
Expand Down
3 changes: 2 additions & 1 deletion driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
package mysql

import (
"bufio"
"database/sql"
"database/sql/driver"
"net"
Expand Down Expand Up @@ -43,7 +44,7 @@ func (d *mysqlDriver) Open(dsn string) (driver.Conn, error) {
if err != nil {
return nil, err
}
mc.buf = newBuffer(mc.netConn)
mc.buf = bufio.NewReader(mc.netConn)

// Reading Handshake Initialization Packet
err = mc.readInitPacket()
Expand Down
2 changes: 1 addition & 1 deletion infile.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
if err == nil {
return mc.readResultOK()
} else {
mc.readPacket()
mc.readPacket(nil)
}
return err
}
127 changes: 90 additions & 37 deletions packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,50 +24,75 @@ import (
// http://dev.mysql.com/doc/internals/en/client-server-protocol.html

// Read packet to buffer 'data'
func (mc *mysqlConn) readPacket() (data []byte, err error) {
func (mc *mysqlConn) readPacket(reuseBuf *bytes.Buffer) (*bytes.Buffer, error) {
// Read packet header
data = make([]byte, 4)
err = mc.buf.read(data)
var header [4]byte
_, err := io.ReadAtLeast(mc.buf, header[:], len(header))
if err != nil {
errLog.Print(err.Error())
return nil, driver.ErrBadConn
}

// Packet Length [24 bit]
pktLen := uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16
pktLen := uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16

if pktLen < 1 {
errLog.Print(errMalformPkt.Error())
return nil, driver.ErrBadConn
}

// Check Packet Sync [8 bit]
if data[3] != mc.sequence {
if data[3] > mc.sequence {
if header[3] != mc.sequence {
if header[3] > mc.sequence {
return nil, errPktSyncMul
} else {
return nil, errPktSync
}
}
mc.sequence++

// Setup buffer with space for atleast pktLen bytes
var dataBuf *bytes.Buffer
if reuseBuf != nil {
dataBuf = reuseBuf
// TODO: in Go 1.1, bytes.Buffer has a Grow method that could reduce
// number of allocations even further. We could support go1.0.3 and
// go1.1 with some build-tags, but I'm leaving it for now.
//
// go1.1:
// dataBuf.Grow(int(pktLen))
} else {
dataBuf = bytes.NewBuffer(make([]byte, 0, int(pktLen)))
}

// Read packet body [pktLen bytes]
data = make([]byte, pktLen)
err = mc.buf.read(data)
if err == nil {
if pktLen < maxPacketSize {
return data, nil
}
_, err = io.CopyN(dataBuf, mc.buf, int64(pktLen))
if err != nil {
errLog.Print(err.Error())
return nil, driver.ErrBadConn
}

// More data
var data2 []byte
data2, err = mc.readPacket()
if err == nil {
return append(data, data2...), nil
}
if pktLen < maxPacketSize {
return dataBuf, nil
}

// TODO: convert this recursion into iteration

// pktLen == maxPacketSize is MySQL signalling that more data is in
// the next packet. We write the next packet directly into our data buffer.
dataBuf, err = mc.readPacket(dataBuf)
if err != nil {
return nil, err
}
return dataBuf, err
}

func (mc *mysqlConn) readPacketBytes() ([]byte, error) {
dataBuf, err := mc.readPacket(nil)
if err != nil {
return nil, err
}
errLog.Print(err.Error())
return nil, driver.ErrBadConn
return dataBuf.Bytes(), nil
}

// Write packet buffer 'data'
Expand Down Expand Up @@ -139,7 +164,7 @@ func (mc *mysqlConn) splitPacket(data []byte) (err error) {
// Handshake Initialization Packet
// http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::Handshake
func (mc *mysqlConn) readInitPacket() (err error) {
data, err := mc.readPacket()
data, err := mc.readPacketBytes()
if err != nil {
return
}
Expand Down Expand Up @@ -346,7 +371,7 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {

// Returns error if Packet is not an 'Result OK'-Packet
func (mc *mysqlConn) readResultOK() error {
data, err := mc.readPacket()
data, err := mc.readPacketBytes()
if err == nil {
// packet indicator
switch data[0] {
Expand All @@ -368,7 +393,7 @@ func (mc *mysqlConn) readResultOK() error {
// Result Set Header Packet
// http://dev.mysql.com/doc/internals/en/text-protocol.html#packet-ProtocolText::Resultset
func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) {
data, err := mc.readPacket()
data, err := mc.readPacketBytes()
if err == nil {
switch data[0] {

Expand Down Expand Up @@ -439,18 +464,25 @@ func (mc *mysqlConn) handleOkPacket(data []byte) {
// Read Packets as Field Packets until EOF-Packet or an Error appears
// http://dev.mysql.com/doc/internals/en/text-protocol.html#packet-Protocol::ColumnDefinition41
func (mc *mysqlConn) readColumns(count int) (columns []mysqlField, err error) {
var data []byte
var dataBuf *bytes.Buffer
var i, pos, n int
var name []byte

columns = make([]mysqlField, count)

for {
data, err = mc.readPacket()
// If we're reusing a buffer, reset it
if dataBuf != nil {
dataBuf.Reset()
}

dataBuf, err = mc.readPacket(dataBuf)
if err != nil {
return
}

data := dataBuf.Bytes()

// EOF Packet
if data[0] == iEOF && len(data) == 5 {
if i != count {
Expand Down Expand Up @@ -487,6 +519,9 @@ func (mc *mysqlConn) readColumns(count int) (columns []mysqlField, err error) {
pos += n

// Name [len coded string]
//
// NOTE: We must take a copy of name, because we reuse the underlying
// storage
name, _, n, err = readLengthEnodedString(data[pos:])
if err != nil {
return
Expand Down Expand Up @@ -530,10 +565,14 @@ func (mc *mysqlConn) readColumns(count int) (columns []mysqlField, err error) {
// Read Packets as Field Packets until EOF-Packet or an Error appears
// http://dev.mysql.com/doc/internals/en/text-protocol.html#packet-ProtocolText::ResultsetRow
func (rows *mysqlRows) readRow(dest []driver.Value) (err error) {
data, err := rows.mc.readPacket()
if rows.buf != nil {
rows.buf.Reset()
}
rows.buf, err = rows.mc.readPacket(rows.buf)
if err != nil {
return
}
data := rows.buf.Bytes()

// EOF Packet
if data[0] == iEOF && len(data) == 5 {
Expand Down Expand Up @@ -563,20 +602,30 @@ func (rows *mysqlRows) readRow(dest []driver.Value) (err error) {
return
}

// Reads Packets until EOF-Packet or an Error appears. Returns count of Packets read
func (mc *mysqlConn) readUntilEOF() (err error) {
var data []byte
// Reads Packets until EOF-Packet or an Error appears.
func (mc *mysqlConn) readUntilEOF() error {
var (
dataBuf *bytes.Buffer
err error
)

for {
data, err = mc.readPacket()
if dataBuf != nil {
dataBuf.Reset()
}
dataBuf, err = mc.readPacket(dataBuf)
if err != nil {
return err
}
data := dataBuf.Bytes()

// No Err and no EOF Packet
if err == nil && (data[0] != iEOF || len(data) != 5) {
continue
// If we found an EOF packet, then we're done.
if data[0] == iEOF && len(data) == 5 {
break
}
return // Err or EOF
}
return

return nil
}

/******************************************************************************
Expand All @@ -586,7 +635,7 @@ func (mc *mysqlConn) readUntilEOF() (err error) {
// Prepare Result Packets
// http://dev.mysql.com/doc/internals/en/prepared-statements.html#com-stmt-prepare-response
func (stmt *mysqlStmt) readPrepareResultPacket() (columnCount uint16, err error) {
data, err := stmt.mc.readPacket()
data, err := stmt.mc.readPacketBytes()
if err == nil {
// Position
pos := 0
Expand Down Expand Up @@ -816,10 +865,14 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {

// http://dev.mysql.com/doc/internals/en/prepared-statements.html#packet-ProtocolBinary::ResultsetRow
func (rc *mysqlRows) readBinaryRow(dest []driver.Value) (err error) {
data, err := rc.mc.readPacket()
if rc.buf != nil {
rc.buf.Reset()
}
rc.buf, err = rc.mc.readPacket(rc.buf)
if err != nil {
return
}
data := rc.buf.Bytes()

// packet indicator [1 byte]
if data[0] != iOK {
Expand Down
Loading