diff --git a/buffer.go b/buffer.go deleted file mode 100644 index 191c14855..000000000 --- a/buffer.go +++ /dev/null @@ -1,85 +0,0 @@ -// Go MySQL Driver - A MySQL-Driver for Go's database/sql package -// -// Copyright 2013 Julien Schmidt. All rights reserved. -// http://www.julienschmidt.com -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this file, -// You can obtain one at http://mozilla.org/MPL/2.0/. - -package mysql - -import ( - "io" -) - -const ( - defaultBufSize = 4096 -) - -type buffer struct { - buf []byte - rd io.Reader - idx int - length int -} - -func newBuffer(rd io.Reader) *buffer { - return &buffer{ - buf: make([]byte, defaultBufSize), - rd: rd, - } -} - -// fill reads at least _need_ bytes in the buffer -// existing data in the buffer gets lost -func (b *buffer) fill(need int) (err error) { - b.idx = 0 - b.length = 0 - - var n int - for b.length < need { - n, err = b.rd.Read(b.buf[b.length:]) - b.length += n - - if err == nil { - continue - } - return // err - } - - return -} - -// read len(p) bytes -func (b *buffer) read(p []byte) (err error) { - need := len(p) - - if b.length < need { - if b.length > 0 { - copy(p[0:b.length], b.buf[b.idx:]) - need -= b.length - p = p[b.length:] - - b.idx = 0 - b.length = 0 - } - - if need >= len(b.buf) { - var n int - has := 0 - for err == nil && need > has { - n, err = b.rd.Read(p[has:]) - has += n - } - return - } - - err = b.fill(need) // err deferred - } - - copy(p, b.buf[b.idx:]) - b.idx += need - b.length -= need - return -} diff --git a/connection.go b/connection.go index 3531595fc..85e5cacab 100644 --- a/connection.go +++ b/connection.go @@ -10,6 +10,7 @@ package mysql import ( + "bufio" "database/sql/driver" "errors" "net" @@ -22,7 +23,7 @@ type mysqlConn struct { charset byte cipher []byte netConn net.Conn - buf *buffer + buf *bufio.Reader protocol uint8 sequence uint8 affectedRows uint64 @@ -182,7 +183,7 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro var resLen int resLen, err = mc.readResultSetHeaderPacket() if err == nil { - rows := &mysqlRows{mc, false, nil, false} + rows := &mysqlRows{mc, false, nil, false, nil} if resLen > 0 { // Columns @@ -208,7 +209,7 @@ func (mc *mysqlConn) getSystemVar(name string) (val []byte, err error) { var resLen int resLen, err = mc.readResultSetHeaderPacket() if err == nil { - rows := &mysqlRows{mc, false, nil, false} + rows := &mysqlRows{mc, false, nil, false, nil} if resLen > 0 { // Columns diff --git a/driver.go b/driver.go index 4b699b1f0..f6f42660e 100644 --- a/driver.go +++ b/driver.go @@ -9,6 +9,7 @@ package mysql import ( + "bufio" "database/sql" "database/sql/driver" "net" @@ -43,7 +44,7 @@ func (d *mysqlDriver) Open(dsn string) (driver.Conn, error) { if err != nil { return nil, err } - mc.buf = newBuffer(mc.netConn) + mc.buf = bufio.NewReader(mc.netConn) // Reading Handshake Initialization Packet err = mc.readInitPacket() diff --git a/infile.go b/infile.go index 3485032e1..ef0638ab6 100644 --- a/infile.go +++ b/infile.go @@ -119,7 +119,7 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { if err == nil { return mc.readResultOK() } else { - mc.readPacket() + mc.readPacket(nil) } return err } diff --git a/packets.go b/packets.go index 7c56245de..94a7fa981 100644 --- a/packets.go +++ b/packets.go @@ -24,17 +24,17 @@ import ( // http://dev.mysql.com/doc/internals/en/client-server-protocol.html // Read packet to buffer 'data' -func (mc *mysqlConn) readPacket() (data []byte, err error) { +func (mc *mysqlConn) readPacket(reuseBuf *bytes.Buffer) (*bytes.Buffer, error) { // Read packet header - data = make([]byte, 4) - err = mc.buf.read(data) + var header [4]byte + _, err := io.ReadAtLeast(mc.buf, header[:], len(header)) if err != nil { errLog.Print(err.Error()) return nil, driver.ErrBadConn } // Packet Length [24 bit] - pktLen := uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16 + pktLen := uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16 if pktLen < 1 { errLog.Print(errMalformPkt.Error()) @@ -42,8 +42,8 @@ func (mc *mysqlConn) readPacket() (data []byte, err error) { } // Check Packet Sync [8 bit] - if data[3] != mc.sequence { - if data[3] > mc.sequence { + if header[3] != mc.sequence { + if header[3] > mc.sequence { return nil, errPktSyncMul } else { return nil, errPktSync @@ -51,23 +51,48 @@ func (mc *mysqlConn) readPacket() (data []byte, err error) { } mc.sequence++ + // Setup buffer with space for atleast pktLen bytes + var dataBuf *bytes.Buffer + if reuseBuf != nil { + dataBuf = reuseBuf + // TODO: in Go 1.1, bytes.Buffer has a Grow method that could reduce + // number of allocations even further. We could support go1.0.3 and + // go1.1 with some build-tags, but I'm leaving it for now. + // + // go1.1: + // dataBuf.Grow(int(pktLen)) + } else { + dataBuf = bytes.NewBuffer(make([]byte, 0, int(pktLen))) + } + // Read packet body [pktLen bytes] - data = make([]byte, pktLen) - err = mc.buf.read(data) - if err == nil { - if pktLen < maxPacketSize { - return data, nil - } + _, err = io.CopyN(dataBuf, mc.buf, int64(pktLen)) + if err != nil { + errLog.Print(err.Error()) + return nil, driver.ErrBadConn + } - // More data - var data2 []byte - data2, err = mc.readPacket() - if err == nil { - return append(data, data2...), nil - } + if pktLen < maxPacketSize { + return dataBuf, nil + } + + // TODO: convert this recursion into iteration + + // pktLen == maxPacketSize is MySQL signalling that more data is in + // the next packet. We write the next packet directly into our data buffer. + dataBuf, err = mc.readPacket(dataBuf) + if err != nil { + return nil, err + } + return dataBuf, err +} + +func (mc *mysqlConn) readPacketBytes() ([]byte, error) { + dataBuf, err := mc.readPacket(nil) + if err != nil { + return nil, err } - errLog.Print(err.Error()) - return nil, driver.ErrBadConn + return dataBuf.Bytes(), nil } // Write packet buffer 'data' @@ -139,7 +164,7 @@ func (mc *mysqlConn) splitPacket(data []byte) (err error) { // Handshake Initialization Packet // http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::Handshake func (mc *mysqlConn) readInitPacket() (err error) { - data, err := mc.readPacket() + data, err := mc.readPacketBytes() if err != nil { return } @@ -346,7 +371,7 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { // Returns error if Packet is not an 'Result OK'-Packet func (mc *mysqlConn) readResultOK() error { - data, err := mc.readPacket() + data, err := mc.readPacketBytes() if err == nil { // packet indicator switch data[0] { @@ -368,7 +393,7 @@ func (mc *mysqlConn) readResultOK() error { // Result Set Header Packet // http://dev.mysql.com/doc/internals/en/text-protocol.html#packet-ProtocolText::Resultset func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) { - data, err := mc.readPacket() + data, err := mc.readPacketBytes() if err == nil { switch data[0] { @@ -439,18 +464,25 @@ func (mc *mysqlConn) handleOkPacket(data []byte) { // Read Packets as Field Packets until EOF-Packet or an Error appears // http://dev.mysql.com/doc/internals/en/text-protocol.html#packet-Protocol::ColumnDefinition41 func (mc *mysqlConn) readColumns(count int) (columns []mysqlField, err error) { - var data []byte + var dataBuf *bytes.Buffer var i, pos, n int var name []byte columns = make([]mysqlField, count) for { - data, err = mc.readPacket() + // If we're reusing a buffer, reset it + if dataBuf != nil { + dataBuf.Reset() + } + + dataBuf, err = mc.readPacket(dataBuf) if err != nil { return } + data := dataBuf.Bytes() + // EOF Packet if data[0] == iEOF && len(data) == 5 { if i != count { @@ -487,6 +519,9 @@ func (mc *mysqlConn) readColumns(count int) (columns []mysqlField, err error) { pos += n // Name [len coded string] + // + // NOTE: We must take a copy of name, because we reuse the underlying + // storage name, _, n, err = readLengthEnodedString(data[pos:]) if err != nil { return @@ -530,10 +565,14 @@ func (mc *mysqlConn) readColumns(count int) (columns []mysqlField, err error) { // Read Packets as Field Packets until EOF-Packet or an Error appears // http://dev.mysql.com/doc/internals/en/text-protocol.html#packet-ProtocolText::ResultsetRow func (rows *mysqlRows) readRow(dest []driver.Value) (err error) { - data, err := rows.mc.readPacket() + if rows.buf != nil { + rows.buf.Reset() + } + rows.buf, err = rows.mc.readPacket(rows.buf) if err != nil { return } + data := rows.buf.Bytes() // EOF Packet if data[0] == iEOF && len(data) == 5 { @@ -563,20 +602,30 @@ func (rows *mysqlRows) readRow(dest []driver.Value) (err error) { return } -// Reads Packets until EOF-Packet or an Error appears. Returns count of Packets read -func (mc *mysqlConn) readUntilEOF() (err error) { - var data []byte +// Reads Packets until EOF-Packet or an Error appears. +func (mc *mysqlConn) readUntilEOF() error { + var ( + dataBuf *bytes.Buffer + err error + ) for { - data, err = mc.readPacket() + if dataBuf != nil { + dataBuf.Reset() + } + dataBuf, err = mc.readPacket(dataBuf) + if err != nil { + return err + } + data := dataBuf.Bytes() - // No Err and no EOF Packet - if err == nil && (data[0] != iEOF || len(data) != 5) { - continue + // If we found an EOF packet, then we're done. + if data[0] == iEOF && len(data) == 5 { + break } - return // Err or EOF } - return + + return nil } /****************************************************************************** @@ -586,7 +635,7 @@ func (mc *mysqlConn) readUntilEOF() (err error) { // Prepare Result Packets // http://dev.mysql.com/doc/internals/en/prepared-statements.html#com-stmt-prepare-response func (stmt *mysqlStmt) readPrepareResultPacket() (columnCount uint16, err error) { - data, err := stmt.mc.readPacket() + data, err := stmt.mc.readPacketBytes() if err == nil { // Position pos := 0 @@ -816,10 +865,14 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { // http://dev.mysql.com/doc/internals/en/prepared-statements.html#packet-ProtocolBinary::ResultsetRow func (rc *mysqlRows) readBinaryRow(dest []driver.Value) (err error) { - data, err := rc.mc.readPacket() + if rc.buf != nil { + rc.buf.Reset() + } + rc.buf, err = rc.mc.readPacket(rc.buf) if err != nil { return } + data := rc.buf.Bytes() // packet indicator [1 byte] if data[0] != iOK { diff --git a/rows.go b/rows.go index fda75998b..3bb3aa5f6 100644 --- a/rows.go +++ b/rows.go @@ -10,6 +10,7 @@ package mysql import ( + "bytes" "database/sql/driver" "errors" "io" @@ -26,6 +27,9 @@ type mysqlRows struct { binary bool columns []mysqlField eof bool + + // We reuse this buffer when parsing rows + buf *bytes.Buffer } func (rows *mysqlRows) Columns() (columns []string) { diff --git a/statement.go b/statement.go index faa1ad032..c140514c0 100644 --- a/statement.go +++ b/statement.go @@ -79,7 +79,7 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { return nil, err } - rows := &mysqlRows{stmt.mc, true, nil, false} + rows := &mysqlRows{stmt.mc, true, nil, false, nil} if resLen > 0 { // Columns