From 4a1123b1aa45632e471ee7cc066581f231dbee2b Mon Sep 17 00:00:00 2001 From: oscarzhao Date: Thu, 16 Mar 2017 02:09:50 +0800 Subject: [PATCH 1/9] add context support --- .travis.yml | 1 + connection.go | 39 ++++++++++++++++++++++-------- driver.go | 9 +++---- infile.go | 7 +++--- packets.go | 65 +++++++++++++++++++++++++++++++------------------- statement.go | 21 +++++++++++++--- transaction.go | 8 +++++-- 7 files changed, 104 insertions(+), 46 deletions(-) diff --git a/.travis.yml b/.travis.yml index c1cc10aaf..b89d7e11b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -7,6 +7,7 @@ go: - 1.5 - 1.6 - 1.7 + - 1.8 - tip before_script: diff --git a/connection.go b/connection.go index d82c728f3..3aabe8ea3 100644 --- a/connection.go +++ b/connection.go @@ -9,6 +9,7 @@ package mysql import ( + "context" "database/sql/driver" "net" "strconv" @@ -41,7 +42,7 @@ func (mc *mysqlConn) handleParams() (err error) { charsets := strings.Split(val, ",") for i := range charsets { // ignore errors here - a charset may not exist - err = mc.exec("SET NAMES " + charsets[i]) + err = mc.exec(context.Background(), "SET NAMES "+charsets[i]) if err == nil { break } @@ -52,7 +53,7 @@ func (mc *mysqlConn) handleParams() (err error) { // System Vars default: - err = mc.exec("SET " + param + "=" + val + "") + err = mc.exec(context.Background(), "SET "+param+"="+val+"") if err != nil { return } @@ -62,12 +63,18 @@ func (mc *mysqlConn) handleParams() (err error) { return } +// Begin implements driver.Conn interface func (mc *mysqlConn) Begin() (driver.Tx, error) { + return mc.ConnBeginTx(context.Background(), driver.TxOptions{}) +} + +// ConnBeginTx implements driver.ConnBeginTx interface +func (mc *mysqlConn) ConnBeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { if mc.netConn == nil { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } - err := mc.exec("START TRANSACTION") + err := mc.exec(ctx, "START TRANSACTION") if err == nil { return &mysqlTx{mc}, err } @@ -78,7 +85,7 @@ func (mc *mysqlConn) Begin() (driver.Tx, error) { func (mc *mysqlConn) Close() (err error) { // Makes Close idempotent if mc.netConn != nil { - err = mc.writeCommandPacket(comQuit) + err = mc.writeCommandPacket(context.Background(), comQuit) } mc.cleanup() @@ -103,12 +110,16 @@ func (mc *mysqlConn) cleanup() { } func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { + return mc.PrepareContext(context.Background(), query) +} + +func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { if mc.netConn == nil { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } // Send command - err := mc.writeCommandPacketStr(comStmtPrepare, query) + err := mc.writeCommandPacketStr(ctx, comStmtPrepare, query) if err != nil { return nil, err } @@ -257,6 +268,10 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin } func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { + return mc.ExecContext(context.Background(), query, args) +} + +func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.Value) (driver.Result, error) { if mc.netConn == nil { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn @@ -276,7 +291,7 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err mc.affectedRows = 0 mc.insertId = 0 - err := mc.exec(query) + err := mc.exec(ctx, query) if err == nil { return &mysqlResult{ affectedRows: int64(mc.affectedRows), @@ -287,9 +302,9 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err } // Internal function to execute commands -func (mc *mysqlConn) exec(query string) error { +func (mc *mysqlConn) exec(ctx context.Context, query string) error { // Send command - err := mc.writeCommandPacketStr(comQuery, query) + err := mc.writeCommandPacketStr(ctx, comQuery, query) if err != nil { return err } @@ -308,6 +323,10 @@ func (mc *mysqlConn) exec(query string) error { } func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) { + return mc.QueryContext(context.Background(), query, args) +} + +func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.Value) (driver.Rows, error) { if mc.netConn == nil { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn @@ -325,7 +344,7 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro args = nil } // Send command - err := mc.writeCommandPacketStr(comQuery, query) + err := mc.writeCommandPacketStr(ctx, comQuery, query) if err == nil { // Read Result var resLen int @@ -350,7 +369,7 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro // The returned byte slice is only valid until the next read func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { // Send command - if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil { + if err := mc.writeCommandPacketStr(context.Background(), comQuery, "SELECT @@"+name); err != nil { return nil, err } diff --git a/driver.go b/driver.go index 0022d1f1e..fd27e6ce2 100644 --- a/driver.go +++ b/driver.go @@ -17,6 +17,7 @@ package mysql import ( + "context" "database/sql" "database/sql/driver" "net" @@ -95,7 +96,7 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { } // Send Client Authentication Packet - if err = mc.writeAuthPacket(cipher); err != nil { + if err = mc.writeAuthPacket(context.Background(), cipher); err != nil { mc.cleanup() return nil, err } @@ -157,7 +158,7 @@ func handleAuthResult(mc *mysqlConn, oldCipher []byte) error { cipher = oldCipher } - if err = mc.writeOldAuthPacket(cipher); err != nil { + if err = mc.writeOldAuthPacket(context.Background(), cipher); err != nil { return err } _, err = mc.readResultOK() @@ -165,12 +166,12 @@ func handleAuthResult(mc *mysqlConn, oldCipher []byte) error { // Retry with clear text password for // http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html // http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html - if err = mc.writeClearAuthPacket(); err != nil { + if err = mc.writeClearAuthPacket(context.Background()); err != nil { return err } _, err = mc.readResultOK() } else if mc.cfg.AllowNativePasswords && err == ErrNativePassword { - if err = mc.writeNativeAuthPacket(cipher); err != nil { + if err = mc.writeNativeAuthPacket(context.Background(), cipher); err != nil { return err } _, err = mc.readResultOK() diff --git a/infile.go b/infile.go index 547357cfa..6d90d82a3 100644 --- a/infile.go +++ b/infile.go @@ -9,6 +9,7 @@ package mysql import ( + "context" "fmt" "io" "os" @@ -93,7 +94,7 @@ func deferredClose(err *error, closer io.Closer) { } } -func (mc *mysqlConn) handleInFileRequest(name string) (err error) { +func (mc *mysqlConn) handleInFileRequest(ctx context.Context, name string) (err error) { var rdr io.Reader var data []byte packetSize := 16 * 1024 // 16KB is small enough for disk readahead and large enough for TCP @@ -153,7 +154,7 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { for err == nil { n, err = rdr.Read(data[4:]) if n > 0 { - if ioErr := mc.writePacket(data[:4+n]); ioErr != nil { + if ioErr := mc.writePacket(ctx, data[:4+n]); ioErr != nil { return ioErr } } @@ -167,7 +168,7 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { if data == nil { data = make([]byte, 4) } - if ioErr := mc.writePacket(data[:4]); ioErr != nil { + if ioErr := mc.writePacket(ctx, data[:4]); ioErr != nil { return ioErr } diff --git a/packets.go b/packets.go index aafe9793e..5b8bd68be 100644 --- a/packets.go +++ b/packets.go @@ -10,6 +10,7 @@ package mysql import ( "bytes" + "context" "crypto/tls" "database/sql/driver" "encoding/binary" @@ -83,7 +84,15 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { } // Write packet buffer 'data' -func (mc *mysqlConn) writePacket(data []byte) error { +func (mc *mysqlConn) writePacket(ctx context.Context, data []byte) error { + if ctx == nil { + panic("context cannot be nil") + } + ctxDeadline, isCtxDeadlineSet := ctx.Deadline() + if isCtxDeadlineSet && !ctxDeadline.After(time.Now()) { + return errors.New("timeout") + } + pktLen := len(data) - 4 if pktLen > mc.maxAllowedPacket { @@ -106,8 +115,16 @@ func (mc *mysqlConn) writePacket(data []byte) error { data[3] = mc.sequence // Write packet + var timeNow = time.Now() + var deadline = timeNow if mc.writeTimeout > 0 { - if err := mc.netConn.SetWriteDeadline(time.Now().Add(mc.writeTimeout)); err != nil { + deadline = timeNow.Add(mc.writeTimeout) + if isCtxDeadlineSet && deadline.After(ctxDeadline) { + deadline = ctxDeadline + } + } + if deadline.After(timeNow) { + if err := mc.netConn.SetWriteDeadline(deadline); err != nil { return err } } @@ -223,7 +240,7 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) { // Client Authentication Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse -func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { +func (mc *mysqlConn) writeAuthPacket(ctx context.Context, cipher []byte) error { // Adjust client flags based on server support clientFlags := clientProtocol41 | clientSecureConn | @@ -292,7 +309,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest if mc.cfg.tls != nil { // Send TLS / SSL request packet - if err := mc.writePacket(data[:(4+4+1+23)+4]); err != nil { + if err := mc.writePacket(ctx, data[:(4+4+1+23)+4]); err != nil { return err } @@ -334,12 +351,12 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { data[pos] = 0x00 // Send Auth packet - return mc.writePacket(data) + return mc.writePacket(ctx, data) } // Client old authentication packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse -func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error { +func (mc *mysqlConn) writeOldAuthPacket(ctx context.Context, cipher []byte) error { // User password scrambleBuff := scrambleOldPassword(cipher, []byte(mc.cfg.Passwd)) @@ -356,12 +373,12 @@ func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error { copy(data[4:], scrambleBuff) data[4+pktLen-1] = 0x00 - return mc.writePacket(data) + return mc.writePacket(ctx, data) } // Client clear text authentication packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse -func (mc *mysqlConn) writeClearAuthPacket() error { +func (mc *mysqlConn) writeClearAuthPacket(ctx context.Context) error { // Calculate the packet length and add a tailing 0 pktLen := len(mc.cfg.Passwd) + 1 data := mc.buf.takeSmallBuffer(4 + pktLen) @@ -375,12 +392,12 @@ func (mc *mysqlConn) writeClearAuthPacket() error { copy(data[4:], mc.cfg.Passwd) data[4+pktLen-1] = 0x00 - return mc.writePacket(data) + return mc.writePacket(ctx, data) } // Native password authentication method // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse -func (mc *mysqlConn) writeNativeAuthPacket(cipher []byte) error { +func (mc *mysqlConn) writeNativeAuthPacket(ctx context.Context, cipher []byte) error { scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.Passwd)) // Calculate the packet length and add a tailing 0 @@ -395,14 +412,14 @@ func (mc *mysqlConn) writeNativeAuthPacket(cipher []byte) error { // Add the scramble copy(data[4:], scrambleBuff) - return mc.writePacket(data) + return mc.writePacket(ctx, data) } /****************************************************************************** * Command Packets * ******************************************************************************/ -func (mc *mysqlConn) writeCommandPacket(command byte) error { +func (mc *mysqlConn) writeCommandPacket(ctx context.Context, command byte) error { // Reset Packet Sequence mc.sequence = 0 @@ -417,10 +434,10 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error { data[4] = command // Send CMD packet - return mc.writePacket(data) + return mc.writePacket(ctx, data) } -func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { +func (mc *mysqlConn) writeCommandPacketStr(ctx context.Context, command byte, arg string) error { // Reset Packet Sequence mc.sequence = 0 @@ -439,10 +456,10 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { copy(data[5:], arg) // Send CMD packet - return mc.writePacket(data) + return mc.writePacket(ctx, data) } -func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { +func (mc *mysqlConn) writeCommandPacketUint32(ctx context.Context, command byte, arg uint32) error { // Reset Packet Sequence mc.sequence = 0 @@ -463,7 +480,7 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { data[8] = byte(arg >> 24) // Send CMD packet - return mc.writePacket(data) + return mc.writePacket(ctx, data) } /****************************************************************************** @@ -524,7 +541,7 @@ func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) { return 0, mc.handleErrorPacket(data) case iLocalInFile: - return 0, mc.handleInFileRequest(string(data[1:])) + return 0, mc.handleInFileRequest(context.Background(), string(data[1:])) } // column count @@ -822,7 +839,7 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) { } // http://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html -func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { +func (stmt *mysqlStmt) writeCommandLongData(ctx context.Context, paramID int, arg []byte) error { maxLen := stmt.mc.maxAllowedPacket - 1 pktLen := maxLen @@ -859,7 +876,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { data[10] = byte(paramID >> 8) // Send CMD packet - err := stmt.mc.writePacket(data[:4+pktLen]) + err := stmt.mc.writePacket(ctx, data[:4+pktLen]) if err == nil { data = data[pktLen-dataOffset:] continue @@ -875,7 +892,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { // Execute Prepared Statement // http://dev.mysql.com/doc/internals/en/com-stmt-execute.html -func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { +func (stmt *mysqlStmt) writeExecutePacket(ctx context.Context, args []driver.Value) error { if len(args) != stmt.paramCount { return fmt.Errorf( "argument count mismatch (got: %d; has: %d)", @@ -1020,7 +1037,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { ) paramValues = append(paramValues, v...) } else { - if err := stmt.writeCommandLongData(i, v); err != nil { + if err := stmt.writeCommandLongData(ctx, i, v); err != nil { return err } } @@ -1042,7 +1059,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { ) paramValues = append(paramValues, v...) } else { - if err := stmt.writeCommandLongData(i, []byte(v)); err != nil { + if err := stmt.writeCommandLongData(ctx, i, []byte(v)); err != nil { return err } } @@ -1079,7 +1096,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { data = data[:pos] } - return mc.writePacket(data) + return mc.writePacket(ctx, data) } func (mc *mysqlConn) discardResults() error { diff --git a/statement.go b/statement.go index 7f9b04585..9ad1b567f 100644 --- a/statement.go +++ b/statement.go @@ -9,6 +9,7 @@ package mysql import ( + "context" "database/sql/driver" "fmt" "reflect" @@ -22,6 +23,7 @@ type mysqlStmt struct { columns []mysqlField // cached from the first query } +// Close implements driver.Conn interface func (stmt *mysqlStmt) Close() error { if stmt.mc == nil || stmt.mc.netConn == nil { // driver.Stmt.Close can be called more than once, thus this function @@ -31,7 +33,7 @@ func (stmt *mysqlStmt) Close() error { return driver.ErrBadConn } - err := stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id) + err := stmt.mc.writeCommandPacketUint32(context.Background(), comStmtClose, stmt.id) stmt.mc = nil return err } @@ -44,13 +46,20 @@ func (stmt *mysqlStmt) ColumnConverter(idx int) driver.ValueConverter { return converter{} } +// Exec implements driver.Execer and driver.Stmt interface func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { + return stmt.ExecContext(context.Background(), args) +} + +// ExecContent implements driver.ExecerContext interface +func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.Value) (driver.Result, error) { + if stmt.mc.netConn == nil { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } // Send command - err := stmt.writeExecutePacket(args) + err := stmt.writeExecutePacket(ctx, args) if err != nil { return nil, err } @@ -84,13 +93,19 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { return nil, err } +// Query implements driver.Queryer and driver.Stmt interface func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { + return stmt.QueryContext(context.Background(), args) +} + +// QueryContext implements driver.QueryerContext interface +func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.Value) (driver.Rows, error) { if stmt.mc.netConn == nil { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } // Send command - err := stmt.writeExecutePacket(args) + err := stmt.writeExecutePacket(ctx, args) if err != nil { return nil, err } diff --git a/transaction.go b/transaction.go index 33c749b35..c7338c891 100644 --- a/transaction.go +++ b/transaction.go @@ -8,24 +8,28 @@ package mysql +import "context" + type mysqlTx struct { mc *mysqlConn } +// Commit implements driver.Tx interface func (tx *mysqlTx) Commit() (err error) { if tx.mc == nil || tx.mc.netConn == nil { return ErrInvalidConn } - err = tx.mc.exec("COMMIT") + err = tx.mc.exec(context.Background(), "COMMIT") tx.mc = nil return } +// Rollback implements driver.Tx interface func (tx *mysqlTx) Rollback() (err error) { if tx.mc == nil || tx.mc.netConn == nil { return ErrInvalidConn } - err = tx.mc.exec("ROLLBACK") + err = tx.mc.exec(context.Background(), "ROLLBACK") tx.mc = nil return } From c9f6511c81f5d475c059f1e1e3e21b70b87c58c0 Mon Sep 17 00:00:00 2001 From: oscarzhao Date: Thu, 16 Mar 2017 17:17:15 +0800 Subject: [PATCH 2/9] implements driver.Pinger interface; add some comments for other function --- connection.go | 20 ++++++++++++++++++++ driver.go | 2 +- driver_test.go | 9 +++++++++ statement.go | 9 +++++---- 4 files changed, 35 insertions(+), 5 deletions(-) diff --git a/connection.go b/connection.go index 3aabe8ea3..c15547276 100644 --- a/connection.go +++ b/connection.go @@ -68,6 +68,24 @@ func (mc *mysqlConn) Begin() (driver.Tx, error) { return mc.ConnBeginTx(context.Background(), driver.TxOptions{}) } +// Ping implements drvier.Pinger interface +func (mc *mysqlConn) Ping(ctx context.Context) error { + if mc.netConn == nil { + errLog.Print(ErrInvalidConn) + return driver.ErrBadConn + } + if err := mc.writeCommandPacket(ctx, comPing); err != nil { + errLog.Print(err) + return err + } + + if _, err := mc.readResultOK(); err != nil { + errLog.Print(err) + return err + } + return nil +} + // ConnBeginTx implements driver.ConnBeginTx interface func (mc *mysqlConn) ConnBeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { if mc.netConn == nil { @@ -322,10 +340,12 @@ func (mc *mysqlConn) exec(ctx context.Context, query string) error { return err } +// Query implements driver.Queryer interface func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) { return mc.QueryContext(context.Background(), query, args) } +// QueryContext implements driver.QueryerContext interface func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.Value) (driver.Rows, error) { if mc.netConn == nil { errLog.Print(ErrInvalidConn) diff --git a/driver.go b/driver.go index fd27e6ce2..b61959f2f 100644 --- a/driver.go +++ b/driver.go @@ -66,7 +66,7 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { mc.netConn, err = dial(mc.cfg.Addr) } else { nd := net.Dialer{Timeout: mc.cfg.Timeout} - mc.netConn, err = nd.Dial(mc.cfg.Net, mc.cfg.Addr) + mc.netConn, err = nd.DialContext(context.Background(), mc.cfg.Net, mc.cfg.Addr) } if err != nil { return nil, err diff --git a/driver_test.go b/driver_test.go index 78e68f5d0..94a1224f5 100644 --- a/driver_test.go +++ b/driver_test.go @@ -171,6 +171,15 @@ func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *sql.Rows) return rows } +func TestPing(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + err := dbt.db.Ping() + if err != nil { + dbt.fail("Ping", "Ping", err) + } + }) +} + func TestEmptyQuery(t *testing.T) { runTests(t, dsn, func(dbt *DBTest) { // just a comment, no query diff --git a/statement.go b/statement.go index 9ad1b567f..5cdc35d7c 100644 --- a/statement.go +++ b/statement.go @@ -38,6 +38,7 @@ func (stmt *mysqlStmt) Close() error { return err } +// NumInput implements driver.Stmt interface func (stmt *mysqlStmt) NumInput() int { return stmt.paramCount } @@ -46,12 +47,12 @@ func (stmt *mysqlStmt) ColumnConverter(idx int) driver.ValueConverter { return converter{} } -// Exec implements driver.Execer and driver.Stmt interface +// Exec implements driver.Stmt interface func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { return stmt.ExecContext(context.Background(), args) } -// ExecContent implements driver.ExecerContext interface +// ExecContent implements driver.StmtExecContext interface func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.Value) (driver.Result, error) { if stmt.mc.netConn == nil { @@ -93,12 +94,12 @@ func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.Value) (dr return nil, err } -// Query implements driver.Queryer and driver.Stmt interface +// Query implements driver.Stmt interface func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { return stmt.QueryContext(context.Background(), args) } -// QueryContext implements driver.QueryerContext interface +// QueryContext implements driver.StmtQueryContext interface func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.Value) (driver.Rows, error) { if stmt.mc.netConn == nil { errLog.Print(ErrInvalidConn) From 543132c5d33da2142d4cb37e898438eef92ad3a8 Mon Sep 17 00:00:00 2001 From: oscarzhao Date: Thu, 16 Mar 2017 17:42:06 +0800 Subject: [PATCH 3/9] add doc on context support --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index 645203850..f0990ae84 100644 --- a/README.md +++ b/README.md @@ -412,6 +412,11 @@ Version 1.0 of the driver recommended adding `&charset=utf8` (alias for `SET NAM See http://dev.mysql.com/doc/refman/5.7/en/charset-unicode.html for more details on MySQL's Unicode support. +## Context Support +Since go1.8, context is introduced to `database/sql` for better control on timeout and cancellation. +New interfaces such as `driver.QueryerContext`, `driver.ExecerContext` are introduced. See more details on [context support to database/sql package](https://golang.org/doc/go1.8#database_sql, "sql"). + +In Go-MySQL-Driver, we implemented these interfaces for structs `mysqlConn`, `mysqlStmt` and `mysqlTx`. ## Testing / Development To run the driver tests you may need to adjust the configuration. See the [Testing Wiki-Page](https://github.com/go-sql-driver/mysql/wiki/Testing "Testing") for details. From 9b7ce2e2ee26c6b12c97213c1f8dce4a8a6b62cc Mon Sep 17 00:00:00 2001 From: oscarzhao Date: Thu, 16 Mar 2017 17:57:04 +0800 Subject: [PATCH 4/9] only apply to 1.8+ --- .travis.yml | 6 ------ 1 file changed, 6 deletions(-) diff --git a/.travis.yml b/.travis.yml index b89d7e11b..52933012f 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,12 +1,6 @@ sudo: false language: go go: - - 1.2 - - 1.3 - - 1.4 - - 1.5 - - 1.6 - - 1.7 - 1.8 - tip From 0f73f560a3854630f03800def451353cd7b4a08c Mon Sep 17 00:00:00 2001 From: oscarzhao Date: Fri, 17 Mar 2017 16:14:44 +0800 Subject: [PATCH 5/9] use build tag for different go versions --- .travis.yml | 6 + connection.go | 2 + connection_deprecated.go | 379 ++++++++ driver.go | 2 + driver_deprecated.go | 186 ++++ driver_deprecated_test.go | 1906 +++++++++++++++++++++++++++++++++++++ driver_test.go | 6 +- infile.go | 2 + infile_deprecated.go | 184 ++++ packets.go | 2 + packets_deprecated.go | 1289 +++++++++++++++++++++++++ statement.go | 2 + statement_deprecated.go | 155 +++ transaction.go | 2 + transaction_deprecated.go | 33 + 15 files changed, 4154 insertions(+), 2 deletions(-) create mode 100644 connection_deprecated.go create mode 100644 driver_deprecated.go create mode 100644 driver_deprecated_test.go create mode 100644 infile_deprecated.go create mode 100644 packets_deprecated.go create mode 100644 statement_deprecated.go create mode 100644 transaction_deprecated.go diff --git a/.travis.yml b/.travis.yml index 52933012f..b89d7e11b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,6 +1,12 @@ sudo: false language: go go: + - 1.2 + - 1.3 + - 1.4 + - 1.5 + - 1.6 + - 1.7 - 1.8 - tip diff --git a/connection.go b/connection.go index c15547276..3cb2724f2 100644 --- a/connection.go +++ b/connection.go @@ -6,6 +6,8 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. +// +build go1.8 + package mysql import ( diff --git a/connection_deprecated.go b/connection_deprecated.go new file mode 100644 index 000000000..eb8c33a55 --- /dev/null +++ b/connection_deprecated.go @@ -0,0 +1,379 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// +build !go1.8 + +package mysql + +import ( + "database/sql/driver" + "net" + "strconv" + "strings" + "time" +) + +type mysqlConn struct { + buf buffer + netConn net.Conn + affectedRows uint64 + insertId uint64 + cfg *Config + maxAllowedPacket int + maxWriteSize int + writeTimeout time.Duration + flags clientFlag + status statusFlag + sequence uint8 + parseTime bool + strict bool +} + +// Handles parameters set in DSN after the connection is established +func (mc *mysqlConn) handleParams() (err error) { + for param, val := range mc.cfg.Params { + switch param { + // Charset + case "charset": + charsets := strings.Split(val, ",") + for i := range charsets { + // ignore errors here - a charset may not exist + err = mc.exec("SET NAMES " + charsets[i]) + if err == nil { + break + } + } + if err != nil { + return + } + + // System Vars + default: + err = mc.exec("SET " + param + "=" + val + "") + if err != nil { + return + } + } + } + + return +} + +func (mc *mysqlConn) Begin() (driver.Tx, error) { + if mc.netConn == nil { + errLog.Print(ErrInvalidConn) + return nil, driver.ErrBadConn + } + err := mc.exec("START TRANSACTION") + if err == nil { + return &mysqlTx{mc}, err + } + + return nil, err +} + +func (mc *mysqlConn) Close() (err error) { + // Makes Close idempotent + if mc.netConn != nil { + err = mc.writeCommandPacket(comQuit) + } + + mc.cleanup() + + return +} + +// Closes the network connection and unsets internal variables. Do not call this +// function after successfully authentication, call Close instead. This function +// is called before auth or on auth failure because MySQL will have already +// closed the network connection. +func (mc *mysqlConn) cleanup() { + // Makes cleanup idempotent + if mc.netConn != nil { + if err := mc.netConn.Close(); err != nil { + errLog.Print(err) + } + mc.netConn = nil + } + mc.cfg = nil + mc.buf.nc = nil +} + +func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { + if mc.netConn == nil { + errLog.Print(ErrInvalidConn) + return nil, driver.ErrBadConn + } + // Send command + err := mc.writeCommandPacketStr(comStmtPrepare, query) + if err != nil { + return nil, err + } + + stmt := &mysqlStmt{ + mc: mc, + } + + // Read Result + columnCount, err := stmt.readPrepareResultPacket() + if err == nil { + if stmt.paramCount > 0 { + if err = mc.readUntilEOF(); err != nil { + return nil, err + } + } + + if columnCount > 0 { + err = mc.readUntilEOF() + } + } + + return stmt, err +} + +func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) { + // Number of ? should be same to len(args) + if strings.Count(query, "?") != len(args) { + return "", driver.ErrSkip + } + + buf := mc.buf.takeCompleteBuffer() + if buf == nil { + // can not take the buffer. Something must be wrong with the connection + errLog.Print(ErrBusyBuffer) + return "", driver.ErrBadConn + } + buf = buf[:0] + argPos := 0 + + for i := 0; i < len(query); i++ { + q := strings.IndexByte(query[i:], '?') + if q == -1 { + buf = append(buf, query[i:]...) + break + } + buf = append(buf, query[i:i+q]...) + i += q + + arg := args[argPos] + argPos++ + + if arg == nil { + buf = append(buf, "NULL"...) + continue + } + + switch v := arg.(type) { + case int64: + buf = strconv.AppendInt(buf, v, 10) + case float64: + buf = strconv.AppendFloat(buf, v, 'g', -1, 64) + case bool: + if v { + buf = append(buf, '1') + } else { + buf = append(buf, '0') + } + case time.Time: + if v.IsZero() { + buf = append(buf, "'0000-00-00'"...) + } else { + v := v.In(mc.cfg.Loc) + v = v.Add(time.Nanosecond * 500) // To round under microsecond + year := v.Year() + year100 := year / 100 + year1 := year % 100 + month := v.Month() + day := v.Day() + hour := v.Hour() + minute := v.Minute() + second := v.Second() + micro := v.Nanosecond() / 1000 + + buf = append(buf, []byte{ + '\'', + digits10[year100], digits01[year100], + digits10[year1], digits01[year1], + '-', + digits10[month], digits01[month], + '-', + digits10[day], digits01[day], + ' ', + digits10[hour], digits01[hour], + ':', + digits10[minute], digits01[minute], + ':', + digits10[second], digits01[second], + }...) + + if micro != 0 { + micro10000 := micro / 10000 + micro100 := micro / 100 % 100 + micro1 := micro % 100 + buf = append(buf, []byte{ + '.', + digits10[micro10000], digits01[micro10000], + digits10[micro100], digits01[micro100], + digits10[micro1], digits01[micro1], + }...) + } + buf = append(buf, '\'') + } + case []byte: + if v == nil { + buf = append(buf, "NULL"...) + } else { + buf = append(buf, "_binary'"...) + if mc.status&statusNoBackslashEscapes == 0 { + buf = escapeBytesBackslash(buf, v) + } else { + buf = escapeBytesQuotes(buf, v) + } + buf = append(buf, '\'') + } + case string: + buf = append(buf, '\'') + if mc.status&statusNoBackslashEscapes == 0 { + buf = escapeStringBackslash(buf, v) + } else { + buf = escapeStringQuotes(buf, v) + } + buf = append(buf, '\'') + default: + return "", driver.ErrSkip + } + + if len(buf)+4 > mc.maxAllowedPacket { + return "", driver.ErrSkip + } + } + if argPos != len(args) { + return "", driver.ErrSkip + } + return string(buf), nil +} + +func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { + if mc.netConn == nil { + errLog.Print(ErrInvalidConn) + return nil, driver.ErrBadConn + } + if len(args) != 0 { + if !mc.cfg.InterpolateParams { + return nil, driver.ErrSkip + } + // try to interpolate the parameters to save extra roundtrips for preparing and closing a statement + prepared, err := mc.interpolateParams(query, args) + if err != nil { + return nil, err + } + query = prepared + args = nil + } + mc.affectedRows = 0 + mc.insertId = 0 + + err := mc.exec(query) + if err == nil { + return &mysqlResult{ + affectedRows: int64(mc.affectedRows), + insertId: int64(mc.insertId), + }, err + } + return nil, err +} + +// Internal function to execute commands +func (mc *mysqlConn) exec(query string) error { + // Send command + err := mc.writeCommandPacketStr(comQuery, query) + if err != nil { + return err + } + + // Read Result + resLen, err := mc.readResultSetHeaderPacket() + if err == nil && resLen > 0 { + if err = mc.readUntilEOF(); err != nil { + return err + } + + err = mc.readUntilEOF() + } + + return err +} + +func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) { + if mc.netConn == nil { + errLog.Print(ErrInvalidConn) + return nil, driver.ErrBadConn + } + if len(args) != 0 { + if !mc.cfg.InterpolateParams { + return nil, driver.ErrSkip + } + // try client-side prepare to reduce roundtrip + prepared, err := mc.interpolateParams(query, args) + if err != nil { + return nil, err + } + query = prepared + args = nil + } + // Send command + err := mc.writeCommandPacketStr(comQuery, query) + if err == nil { + // Read Result + var resLen int + resLen, err = mc.readResultSetHeaderPacket() + if err == nil { + rows := new(textRows) + rows.mc = mc + + if resLen == 0 { + // no columns, no more data + return emptyRows{}, nil + } + // Columns + rows.columns, err = mc.readColumns(resLen) + return rows, err + } + } + return nil, err +} + +// Gets the value of the given MySQL System Variable +// The returned byte slice is only valid until the next read +func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { + // Send command + if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil { + return nil, err + } + + // Read Result + resLen, err := mc.readResultSetHeaderPacket() + if err == nil { + rows := new(textRows) + rows.mc = mc + rows.columns = []mysqlField{{fieldType: fieldTypeVarChar}} + + if resLen > 0 { + // Columns + if err := mc.readUntilEOF(); err != nil { + return nil, err + } + } + + dest := make([]driver.Value, resLen) + if err = rows.readRow(dest); err == nil { + return dest[0].([]byte), mc.readUntilEOF() + } + } + return nil, err +} diff --git a/driver.go b/driver.go index b61959f2f..ae8bb8708 100644 --- a/driver.go +++ b/driver.go @@ -1,3 +1,5 @@ +// +build go1.8 + // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. // // This Source Code Form is subject to the terms of the Mozilla Public diff --git a/driver_deprecated.go b/driver_deprecated.go new file mode 100644 index 000000000..63eb546a6 --- /dev/null +++ b/driver_deprecated.go @@ -0,0 +1,186 @@ +// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// Package mysql provides a MySQL driver for Go's database/sql package +// +// The driver should be used via the database/sql package: +// +// import "database/sql" +// import _ "github.com/go-sql-driver/mysql" +// +// db, err := sql.Open("mysql", "user:password@/dbname") +// +// See https://github.com/go-sql-driver/mysql#usage for details + +// +build !go1.8 + +package mysql + +import ( + "database/sql" + "database/sql/driver" + "net" +) + +// MySQLDriver is exported to make the driver directly accessible. +// In general the driver is used via the database/sql package. +type MySQLDriver struct{} + +// DialFunc is a function which can be used to establish the network connection. +// Custom dial functions must be registered with RegisterDial +type DialFunc func(addr string) (net.Conn, error) + +var dials map[string]DialFunc + +// RegisterDial registers a custom dial function. It can then be used by the +// network address mynet(addr), where mynet is the registered new network. +// addr is passed as a parameter to the dial function. +func RegisterDial(net string, dial DialFunc) { + if dials == nil { + dials = make(map[string]DialFunc) + } + dials[net] = dial +} + +// Open new Connection. +// See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how +// the DSN string is formated +func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { + var err error + + // New mysqlConn + mc := &mysqlConn{ + maxAllowedPacket: maxPacketSize, + maxWriteSize: maxPacketSize - 1, + } + mc.cfg, err = ParseDSN(dsn) + if err != nil { + return nil, err + } + mc.parseTime = mc.cfg.ParseTime + mc.strict = mc.cfg.Strict + + // Connect to Server + if dial, ok := dials[mc.cfg.Net]; ok { + mc.netConn, err = dial(mc.cfg.Addr) + } else { + nd := net.Dialer{Timeout: mc.cfg.Timeout} + mc.netConn, err = nd.Dial(mc.cfg.Net, mc.cfg.Addr) + } + if err != nil { + return nil, err + } + + // Enable TCP Keepalives on TCP connections + if tc, ok := mc.netConn.(*net.TCPConn); ok { + if err := tc.SetKeepAlive(true); err != nil { + // Don't send COM_QUIT before handshake. + mc.netConn.Close() + mc.netConn = nil + return nil, err + } + } + + mc.buf = newBuffer(mc.netConn) + + // Set I/O timeouts + mc.buf.timeout = mc.cfg.ReadTimeout + mc.writeTimeout = mc.cfg.WriteTimeout + + // Reading Handshake Initialization Packet + cipher, err := mc.readInitPacket() + if err != nil { + mc.cleanup() + return nil, err + } + + // Send Client Authentication Packet + if err = mc.writeAuthPacket(cipher); err != nil { + mc.cleanup() + return nil, err + } + + // Handle response to auth packet, switch methods if possible + if err = handleAuthResult(mc, cipher); err != nil { + // Authentication failed and MySQL has already closed the connection + // (https://dev.mysql.com/doc/internals/en/authentication-fails.html). + // Do not send COM_QUIT, just cleanup and return the error. + mc.cleanup() + return nil, err + } + + if mc.cfg.MaxAllowedPacket > 0 { + mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket + } else { + // Get max allowed packet size + maxap, err := mc.getSystemVar("max_allowed_packet") + if err != nil { + mc.Close() + return nil, err + } + mc.maxAllowedPacket = stringToInt(maxap) - 1 + } + if mc.maxAllowedPacket < maxPacketSize { + mc.maxWriteSize = mc.maxAllowedPacket + } + + // Handle DSN Params + err = mc.handleParams() + if err != nil { + mc.Close() + return nil, err + } + + return mc, nil +} + +func handleAuthResult(mc *mysqlConn, oldCipher []byte) error { + // Read Result Packet + cipher, err := mc.readResultOK() + if err == nil { + return nil // auth successful + } + + if mc.cfg == nil { + return err // auth failed and retry not possible + } + + // Retry auth if configured to do so. + if mc.cfg.AllowOldPasswords && err == ErrOldPassword { + // Retry with old authentication method. Note: there are edge cases + // where this should work but doesn't; this is currently "wontfix": + // https://github.com/go-sql-driver/mysql/issues/184 + + // If CLIENT_PLUGIN_AUTH capability is not supported, no new cipher is + // sent and we have to keep using the cipher sent in the init packet. + if cipher == nil { + cipher = oldCipher + } + + if err = mc.writeOldAuthPacket(cipher); err != nil { + return err + } + _, err = mc.readResultOK() + } else if mc.cfg.AllowCleartextPasswords && err == ErrCleartextPassword { + // Retry with clear text password for + // http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html + // http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html + if err = mc.writeClearAuthPacket(); err != nil { + return err + } + _, err = mc.readResultOK() + } else if mc.cfg.AllowNativePasswords && err == ErrNativePassword { + if err = mc.writeNativeAuthPacket(cipher); err != nil { + return err + } + _, err = mc.readResultOK() + } + return err +} + +func init() { + sql.Register("mysql", &MySQLDriver{}) +} diff --git a/driver_deprecated_test.go b/driver_deprecated_test.go new file mode 100644 index 000000000..45b9e7a18 --- /dev/null +++ b/driver_deprecated_test.go @@ -0,0 +1,1906 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// +build !go1.8 + +package mysql + +import ( + "bytes" + "crypto/tls" + "database/sql" + "database/sql/driver" + "fmt" + "io" + "io/ioutil" + "log" + "net" + "net/url" + "os" + "strings" + "sync" + "sync/atomic" + "testing" + "time" +) + +var ( + user string + pass string + prot string + addr string + dbname string + dsn string + netAddr string + available bool +) + +var ( + tDate = time.Date(2012, 6, 14, 0, 0, 0, 0, time.UTC) + sDate = "2012-06-14" + tDateTime = time.Date(2011, 11, 20, 21, 27, 37, 0, time.UTC) + sDateTime = "2011-11-20 21:27:37" + tDate0 = time.Time{} + sDate0 = "0000-00-00" + sDateTime0 = "0000-00-00 00:00:00" +) + +// See https://github.com/go-sql-driver/mysql/wiki/Testing +func init() { + // get environment variables + env := func(key, defaultValue string) string { + if value := os.Getenv(key); value != "" { + return value + } + return defaultValue + } + user = env("MYSQL_TEST_USER", "root") + pass = env("MYSQL_TEST_PASS", "") + prot = env("MYSQL_TEST_PROT", "tcp") + addr = env("MYSQL_TEST_ADDR", "localhost:3306") + dbname = env("MYSQL_TEST_DBNAME", "gotest") + netAddr = fmt.Sprintf("%s(%s)", prot, addr) + dsn = fmt.Sprintf("%s:%s@%s/%s?timeout=30s&strict=true", user, pass, netAddr, dbname) + c, err := net.Dial(prot, addr) + if err == nil { + available = true + c.Close() + } +} + +type DBTest struct { + *testing.T + db *sql.DB +} + +func runTestsWithMultiStatement(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + dsn += "&multiStatements=true" + var db *sql.DB + if _, err := ParseDSN(dsn); err != errInvalidDSNUnsafeCollation { + db, err = sql.Open("mysql", dsn) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db.Close() + } + + dbt := &DBTest{t, db} + for _, test := range tests { + test(dbt) + dbt.db.Exec("DROP TABLE IF EXISTS test") + } +} + +func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + db, err := sql.Open("mysql", dsn) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db.Close() + + db.Exec("DROP TABLE IF EXISTS test") + + dsn2 := dsn + "&interpolateParams=true" + var db2 *sql.DB + if _, err := ParseDSN(dsn2); err != errInvalidDSNUnsafeCollation { + db2, err = sql.Open("mysql", dsn2) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db2.Close() + } + + dsn3 := dsn + "&multiStatements=true" + var db3 *sql.DB + if _, err := ParseDSN(dsn3); err != errInvalidDSNUnsafeCollation { + db3, err = sql.Open("mysql", dsn3) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db3.Close() + } + + dbt := &DBTest{t, db} + dbt2 := &DBTest{t, db2} + dbt3 := &DBTest{t, db3} + for _, test := range tests { + test(dbt) + dbt.db.Exec("DROP TABLE IF EXISTS test") + if db2 != nil { + test(dbt2) + dbt2.db.Exec("DROP TABLE IF EXISTS test") + } + if db3 != nil { + test(dbt3) + dbt3.db.Exec("DROP TABLE IF EXISTS test") + } + } +} + +func (dbt *DBTest) fail(method, query string, err error) { + if len(query) > 300 { + query = "[query too large to print]" + } + dbt.Fatalf("error on %s %s: %s", method, query, err.Error()) +} + +func (dbt *DBTest) mustExec(query string, args ...interface{}) (res sql.Result) { + res, err := dbt.db.Exec(query, args...) + if err != nil { + dbt.fail("exec", query, err) + } + return res +} + +func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *sql.Rows) { + rows, err := dbt.db.Query(query, args...) + if err != nil { + dbt.fail("query", query, err) + } + return rows +} + +func TestEmptyQuery(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + // just a comment, no query + rows := dbt.mustQuery("--") + // will hang before #255 + if rows.Next() { + dbt.Errorf("next on rows must be false") + } + }) +} + +func TestCRUD(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + // Create Table + dbt.mustExec("CREATE TABLE test (value BOOL)") + + // Test for unexpected data + var out bool + rows := dbt.mustQuery("SELECT * FROM test") + if rows.Next() { + dbt.Error("unexpected data in empty table") + } + + // Create Data + res := dbt.mustExec("INSERT INTO test VALUES (1)") + count, err := res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 1 { + dbt.Fatalf("expected 1 affected row, got %d", count) + } + + id, err := res.LastInsertId() + if err != nil { + dbt.Fatalf("res.LastInsertId() returned error: %s", err.Error()) + } + if id != 0 { + dbt.Fatalf("expected InsertId 0, got %d", id) + } + + // Read + rows = dbt.mustQuery("SELECT value FROM test") + if rows.Next() { + rows.Scan(&out) + if true != out { + dbt.Errorf("true != %t", out) + } + + if rows.Next() { + dbt.Error("unexpected data") + } + } else { + dbt.Error("no data") + } + + // Update + res = dbt.mustExec("UPDATE test SET value = ? WHERE value = ?", false, true) + count, err = res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 1 { + dbt.Fatalf("expected 1 affected row, got %d", count) + } + + // Check Update + rows = dbt.mustQuery("SELECT value FROM test") + if rows.Next() { + rows.Scan(&out) + if false != out { + dbt.Errorf("false != %t", out) + } + + if rows.Next() { + dbt.Error("unexpected data") + } + } else { + dbt.Error("no data") + } + + // Delete + res = dbt.mustExec("DELETE FROM test WHERE value = ?", false) + count, err = res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 1 { + dbt.Fatalf("expected 1 affected row, got %d", count) + } + + // Check for unexpected rows + res = dbt.mustExec("DELETE FROM test") + count, err = res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 0 { + dbt.Fatalf("expected 0 affected row, got %d", count) + } + }) +} + +func TestMultiQuery(t *testing.T) { + runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) { + // Create Table + dbt.mustExec("CREATE TABLE `test` (`id` int(11) NOT NULL, `value` int(11) NOT NULL) ") + + // Create Data + res := dbt.mustExec("INSERT INTO test VALUES (1, 1)") + count, err := res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 1 { + dbt.Fatalf("expected 1 affected row, got %d", count) + } + + // Update + res = dbt.mustExec("UPDATE test SET value = 3 WHERE id = 1; UPDATE test SET value = 4 WHERE id = 1; UPDATE test SET value = 5 WHERE id = 1;") + count, err = res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 1 { + dbt.Fatalf("expected 1 affected row, got %d", count) + } + + // Read + var out int + rows := dbt.mustQuery("SELECT value FROM test WHERE id=1;") + if rows.Next() { + rows.Scan(&out) + if 5 != out { + dbt.Errorf("5 != %d", out) + } + + if rows.Next() { + dbt.Error("unexpected data") + } + } else { + dbt.Error("no data") + } + + }) +} + +func TestInt(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + types := [5]string{"TINYINT", "SMALLINT", "MEDIUMINT", "INT", "BIGINT"} + in := int64(42) + var out int64 + var rows *sql.Rows + + // SIGNED + for _, v := range types { + dbt.mustExec("CREATE TABLE test (value " + v + ")") + + dbt.mustExec("INSERT INTO test VALUES (?)", in) + + rows = dbt.mustQuery("SELECT value FROM test") + if rows.Next() { + rows.Scan(&out) + if in != out { + dbt.Errorf("%s: %d != %d", v, in, out) + } + } else { + dbt.Errorf("%s: no data", v) + } + + dbt.mustExec("DROP TABLE IF EXISTS test") + } + + // UNSIGNED ZEROFILL + for _, v := range types { + dbt.mustExec("CREATE TABLE test (value " + v + " ZEROFILL)") + + dbt.mustExec("INSERT INTO test VALUES (?)", in) + + rows = dbt.mustQuery("SELECT value FROM test") + if rows.Next() { + rows.Scan(&out) + if in != out { + dbt.Errorf("%s ZEROFILL: %d != %d", v, in, out) + } + } else { + dbt.Errorf("%s ZEROFILL: no data", v) + } + + dbt.mustExec("DROP TABLE IF EXISTS test") + } + }) +} + +func TestFloat32(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + types := [2]string{"FLOAT", "DOUBLE"} + in := float32(42.23) + var out float32 + var rows *sql.Rows + for _, v := range types { + dbt.mustExec("CREATE TABLE test (value " + v + ")") + dbt.mustExec("INSERT INTO test VALUES (?)", in) + rows = dbt.mustQuery("SELECT value FROM test") + if rows.Next() { + rows.Scan(&out) + if in != out { + dbt.Errorf("%s: %g != %g", v, in, out) + } + } else { + dbt.Errorf("%s: no data", v) + } + dbt.mustExec("DROP TABLE IF EXISTS test") + } + }) +} + +func TestFloat64(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + types := [2]string{"FLOAT", "DOUBLE"} + var expected float64 = 42.23 + var out float64 + var rows *sql.Rows + for _, v := range types { + dbt.mustExec("CREATE TABLE test (value " + v + ")") + dbt.mustExec("INSERT INTO test VALUES (42.23)") + rows = dbt.mustQuery("SELECT value FROM test") + if rows.Next() { + rows.Scan(&out) + if expected != out { + dbt.Errorf("%s: %g != %g", v, expected, out) + } + } else { + dbt.Errorf("%s: no data", v) + } + dbt.mustExec("DROP TABLE IF EXISTS test") + } + }) +} + +func TestFloat64Placeholder(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + types := [2]string{"FLOAT", "DOUBLE"} + var expected float64 = 42.23 + var out float64 + var rows *sql.Rows + for _, v := range types { + dbt.mustExec("CREATE TABLE test (id int, value " + v + ")") + dbt.mustExec("INSERT INTO test VALUES (1, 42.23)") + rows = dbt.mustQuery("SELECT value FROM test WHERE id = ?", 1) + if rows.Next() { + rows.Scan(&out) + if expected != out { + dbt.Errorf("%s: %g != %g", v, expected, out) + } + } else { + dbt.Errorf("%s: no data", v) + } + dbt.mustExec("DROP TABLE IF EXISTS test") + } + }) +} + +func TestString(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + types := [6]string{"CHAR(255)", "VARCHAR(255)", "TINYTEXT", "TEXT", "MEDIUMTEXT", "LONGTEXT"} + in := "κόσμε üöäßñóùéàâÿœ'îë Árvíztűrő いろはにほへとちりぬるを イロハニホヘト דג סקרן чащах น่าฟังเอย" + var out string + var rows *sql.Rows + + for _, v := range types { + dbt.mustExec("CREATE TABLE test (value " + v + ") CHARACTER SET utf8") + + dbt.mustExec("INSERT INTO test VALUES (?)", in) + + rows = dbt.mustQuery("SELECT value FROM test") + if rows.Next() { + rows.Scan(&out) + if in != out { + dbt.Errorf("%s: %s != %s", v, in, out) + } + } else { + dbt.Errorf("%s: no data", v) + } + + dbt.mustExec("DROP TABLE IF EXISTS test") + } + + // BLOB + dbt.mustExec("CREATE TABLE test (id int, value BLOB) CHARACTER SET utf8") + + id := 2 + in = "Lorem ipsum dolor sit amet, consetetur sadipscing elitr, " + + "sed diam nonumy eirmod tempor invidunt ut labore et dolore magna aliquyam erat, " + + "sed diam voluptua. At vero eos et accusam et justo duo dolores et ea rebum. " + + "Stet clita kasd gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet. " + + "Lorem ipsum dolor sit amet, consetetur sadipscing elitr, " + + "sed diam nonumy eirmod tempor invidunt ut labore et dolore magna aliquyam erat, " + + "sed diam voluptua. At vero eos et accusam et justo duo dolores et ea rebum. " + + "Stet clita kasd gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet." + dbt.mustExec("INSERT INTO test VALUES (?, ?)", id, in) + + err := dbt.db.QueryRow("SELECT value FROM test WHERE id = ?", id).Scan(&out) + if err != nil { + dbt.Fatalf("Error on BLOB-Query: %s", err.Error()) + } else if out != in { + dbt.Errorf("BLOB: %s != %s", in, out) + } + }) +} + +type timeTests struct { + dbtype string + tlayout string + tests []timeTest +} + +type timeTest struct { + s string // leading "!": do not use t as value in queries + t time.Time +} + +type timeMode byte + +func (t timeMode) String() string { + switch t { + case binaryString: + return "binary:string" + case binaryTime: + return "binary:time.Time" + case textString: + return "text:string" + } + panic("unsupported timeMode") +} + +func (t timeMode) Binary() bool { + switch t { + case binaryString, binaryTime: + return true + } + return false +} + +const ( + binaryString timeMode = iota + binaryTime + textString +) + +func (t timeTest) genQuery(dbtype string, mode timeMode) string { + var inner string + if mode.Binary() { + inner = "?" + } else { + inner = `"%s"` + } + return `SELECT cast(` + inner + ` as ` + dbtype + `)` +} + +func (t timeTest) run(dbt *DBTest, dbtype, tlayout string, mode timeMode) { + var rows *sql.Rows + query := t.genQuery(dbtype, mode) + switch mode { + case binaryString: + rows = dbt.mustQuery(query, t.s) + case binaryTime: + rows = dbt.mustQuery(query, t.t) + case textString: + query = fmt.Sprintf(query, t.s) + rows = dbt.mustQuery(query) + default: + panic("unsupported mode") + } + defer rows.Close() + var err error + if !rows.Next() { + err = rows.Err() + if err == nil { + err = fmt.Errorf("no data") + } + dbt.Errorf("%s [%s]: %s", dbtype, mode, err) + return + } + var dst interface{} + err = rows.Scan(&dst) + if err != nil { + dbt.Errorf("%s [%s]: %s", dbtype, mode, err) + return + } + switch val := dst.(type) { + case []uint8: + str := string(val) + if str == t.s { + return + } + if mode.Binary() && dbtype == "DATETIME" && len(str) == 26 && str[:19] == t.s { + // a fix mainly for TravisCI: + // accept full microsecond resolution in result for DATETIME columns + // where the binary protocol was used + return + } + dbt.Errorf("%s [%s] to string: expected %q, got %q", + dbtype, mode, + t.s, str, + ) + case time.Time: + if val == t.t { + return + } + dbt.Errorf("%s [%s] to string: expected %q, got %q", + dbtype, mode, + t.s, val.Format(tlayout), + ) + default: + fmt.Printf("%#v\n", []interface{}{dbtype, tlayout, mode, t.s, t.t}) + dbt.Errorf("%s [%s]: unhandled type %T (is '%v')", + dbtype, mode, + val, val, + ) + } +} + +func TestDateTime(t *testing.T) { + afterTime := func(t time.Time, d string) time.Time { + dur, err := time.ParseDuration(d) + if err != nil { + panic(err) + } + return t.Add(dur) + } + // NOTE: MySQL rounds DATETIME(x) up - but that's not included in the tests + format := "2006-01-02 15:04:05.999999" + t0 := time.Time{} + tstr0 := "0000-00-00 00:00:00.000000" + testcases := []timeTests{ + {"DATE", format[:10], []timeTest{ + {t: time.Date(2011, 11, 20, 0, 0, 0, 0, time.UTC)}, + {t: t0, s: tstr0[:10]}, + }}, + {"DATETIME", format[:19], []timeTest{ + {t: time.Date(2011, 11, 20, 21, 27, 37, 0, time.UTC)}, + {t: t0, s: tstr0[:19]}, + }}, + {"DATETIME(0)", format[:21], []timeTest{ + {t: time.Date(2011, 11, 20, 21, 27, 37, 0, time.UTC)}, + {t: t0, s: tstr0[:19]}, + }}, + {"DATETIME(1)", format[:21], []timeTest{ + {t: time.Date(2011, 11, 20, 21, 27, 37, 100000000, time.UTC)}, + {t: t0, s: tstr0[:21]}, + }}, + {"DATETIME(6)", format, []timeTest{ + {t: time.Date(2011, 11, 20, 21, 27, 37, 123456000, time.UTC)}, + {t: t0, s: tstr0}, + }}, + {"TIME", format[11:19], []timeTest{ + {t: afterTime(t0, "12345s")}, + {s: "!-12:34:56"}, + {s: "!-838:59:59"}, + {s: "!838:59:59"}, + {t: t0, s: tstr0[11:19]}, + }}, + {"TIME(0)", format[11:19], []timeTest{ + {t: afterTime(t0, "12345s")}, + {s: "!-12:34:56"}, + {s: "!-838:59:59"}, + {s: "!838:59:59"}, + {t: t0, s: tstr0[11:19]}, + }}, + {"TIME(1)", format[11:21], []timeTest{ + {t: afterTime(t0, "12345600ms")}, + {s: "!-12:34:56.7"}, + {s: "!-838:59:58.9"}, + {s: "!838:59:58.9"}, + {t: t0, s: tstr0[11:21]}, + }}, + {"TIME(6)", format[11:], []timeTest{ + {t: afterTime(t0, "1234567890123000ns")}, + {s: "!-12:34:56.789012"}, + {s: "!-838:59:58.999999"}, + {s: "!838:59:58.999999"}, + {t: t0, s: tstr0[11:]}, + }}, + } + dsns := []string{ + dsn + "&parseTime=true", + dsn + "&parseTime=false", + } + for _, testdsn := range dsns { + runTests(t, testdsn, func(dbt *DBTest) { + microsecsSupported := false + zeroDateSupported := false + var rows *sql.Rows + var err error + rows, err = dbt.db.Query(`SELECT cast("00:00:00.1" as TIME(1)) = "00:00:00.1"`) + if err == nil { + rows.Scan(µsecsSupported) + rows.Close() + } + rows, err = dbt.db.Query(`SELECT cast("0000-00-00" as DATE) = "0000-00-00"`) + if err == nil { + rows.Scan(&zeroDateSupported) + rows.Close() + } + for _, setups := range testcases { + if t := setups.dbtype; !microsecsSupported && t[len(t)-1:] == ")" { + // skip fractional second tests if unsupported by server + continue + } + for _, setup := range setups.tests { + allowBinTime := true + if setup.s == "" { + // fill time string whereever Go can reliable produce it + setup.s = setup.t.Format(setups.tlayout) + } else if setup.s[0] == '!' { + // skip tests using setup.t as source in queries + allowBinTime = false + // fix setup.s - remove the "!" + setup.s = setup.s[1:] + } + if !zeroDateSupported && setup.s == tstr0[:len(setup.s)] { + // skip disallowed 0000-00-00 date + continue + } + setup.run(dbt, setups.dbtype, setups.tlayout, textString) + setup.run(dbt, setups.dbtype, setups.tlayout, binaryString) + if allowBinTime { + setup.run(dbt, setups.dbtype, setups.tlayout, binaryTime) + } + } + } + }) + } +} + +func TestTimestampMicros(t *testing.T) { + format := "2006-01-02 15:04:05.999999" + f0 := format[:19] + f1 := format[:21] + f6 := format[:26] + runTests(t, dsn, func(dbt *DBTest) { + // check if microseconds are supported. + // Do not use timestamp(x) for that check - before 5.5.6, x would mean display width + // and not precision. + // Se last paragraph at http://dev.mysql.com/doc/refman/5.6/en/fractional-seconds.html + microsecsSupported := false + if rows, err := dbt.db.Query(`SELECT cast("00:00:00.1" as TIME(1)) = "00:00:00.1"`); err == nil { + rows.Scan(µsecsSupported) + rows.Close() + } + if !microsecsSupported { + // skip test + return + } + _, err := dbt.db.Exec(` + CREATE TABLE test ( + value0 TIMESTAMP NOT NULL DEFAULT '` + f0 + `', + value1 TIMESTAMP(1) NOT NULL DEFAULT '` + f1 + `', + value6 TIMESTAMP(6) NOT NULL DEFAULT '` + f6 + `' + )`, + ) + if err != nil { + dbt.Error(err) + } + defer dbt.mustExec("DROP TABLE IF EXISTS test") + dbt.mustExec("INSERT INTO test SET value0=?, value1=?, value6=?", f0, f1, f6) + var res0, res1, res6 string + rows := dbt.mustQuery("SELECT * FROM test") + if !rows.Next() { + dbt.Errorf("test contained no selectable values") + } + err = rows.Scan(&res0, &res1, &res6) + if err != nil { + dbt.Error(err) + } + if res0 != f0 { + dbt.Errorf("expected %q, got %q", f0, res0) + } + if res1 != f1 { + dbt.Errorf("expected %q, got %q", f1, res1) + } + if res6 != f6 { + dbt.Errorf("expected %q, got %q", f6, res6) + } + }) +} + +func TestNULL(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + nullStmt, err := dbt.db.Prepare("SELECT NULL") + if err != nil { + dbt.Fatal(err) + } + defer nullStmt.Close() + + nonNullStmt, err := dbt.db.Prepare("SELECT 1") + if err != nil { + dbt.Fatal(err) + } + defer nonNullStmt.Close() + + // NullBool + var nb sql.NullBool + // Invalid + if err = nullStmt.QueryRow().Scan(&nb); err != nil { + dbt.Fatal(err) + } + if nb.Valid { + dbt.Error("valid NullBool which should be invalid") + } + // Valid + if err = nonNullStmt.QueryRow().Scan(&nb); err != nil { + dbt.Fatal(err) + } + if !nb.Valid { + dbt.Error("invalid NullBool which should be valid") + } else if nb.Bool != true { + dbt.Errorf("Unexpected NullBool value: %t (should be true)", nb.Bool) + } + + // NullFloat64 + var nf sql.NullFloat64 + // Invalid + if err = nullStmt.QueryRow().Scan(&nf); err != nil { + dbt.Fatal(err) + } + if nf.Valid { + dbt.Error("valid NullFloat64 which should be invalid") + } + // Valid + if err = nonNullStmt.QueryRow().Scan(&nf); err != nil { + dbt.Fatal(err) + } + if !nf.Valid { + dbt.Error("invalid NullFloat64 which should be valid") + } else if nf.Float64 != float64(1) { + dbt.Errorf("unexpected NullFloat64 value: %f (should be 1.0)", nf.Float64) + } + + // NullInt64 + var ni sql.NullInt64 + // Invalid + if err = nullStmt.QueryRow().Scan(&ni); err != nil { + dbt.Fatal(err) + } + if ni.Valid { + dbt.Error("valid NullInt64 which should be invalid") + } + // Valid + if err = nonNullStmt.QueryRow().Scan(&ni); err != nil { + dbt.Fatal(err) + } + if !ni.Valid { + dbt.Error("invalid NullInt64 which should be valid") + } else if ni.Int64 != int64(1) { + dbt.Errorf("unexpected NullInt64 value: %d (should be 1)", ni.Int64) + } + + // NullString + var ns sql.NullString + // Invalid + if err = nullStmt.QueryRow().Scan(&ns); err != nil { + dbt.Fatal(err) + } + if ns.Valid { + dbt.Error("valid NullString which should be invalid") + } + // Valid + if err = nonNullStmt.QueryRow().Scan(&ns); err != nil { + dbt.Fatal(err) + } + if !ns.Valid { + dbt.Error("invalid NullString which should be valid") + } else if ns.String != `1` { + dbt.Error("unexpected NullString value:" + ns.String + " (should be `1`)") + } + + // nil-bytes + var b []byte + // Read nil + if err = nullStmt.QueryRow().Scan(&b); err != nil { + dbt.Fatal(err) + } + if b != nil { + dbt.Error("non-nil []byte wich should be nil") + } + // Read non-nil + if err = nonNullStmt.QueryRow().Scan(&b); err != nil { + dbt.Fatal(err) + } + if b == nil { + dbt.Error("nil []byte wich should be non-nil") + } + // Insert nil + b = nil + success := false + if err = dbt.db.QueryRow("SELECT ? IS NULL", b).Scan(&success); err != nil { + dbt.Fatal(err) + } + if !success { + dbt.Error("inserting []byte(nil) as NULL failed") + } + // Check input==output with input==nil + b = nil + if err = dbt.db.QueryRow("SELECT ?", b).Scan(&b); err != nil { + dbt.Fatal(err) + } + if b != nil { + dbt.Error("non-nil echo from nil input") + } + // Check input==output with input!=nil + b = []byte("") + if err = dbt.db.QueryRow("SELECT ?", b).Scan(&b); err != nil { + dbt.Fatal(err) + } + if b == nil { + dbt.Error("nil echo from non-nil input") + } + + // Insert NULL + dbt.mustExec("CREATE TABLE test (dummmy1 int, value int, dummy2 int)") + + dbt.mustExec("INSERT INTO test VALUES (?, ?, ?)", 1, nil, 2) + + var out interface{} + rows := dbt.mustQuery("SELECT * FROM test") + if rows.Next() { + rows.Scan(&out) + if out != nil { + dbt.Errorf("%v != nil", out) + } + } else { + dbt.Error("no data") + } + }) +} + +func TestUint64(t *testing.T) { + const ( + u0 = uint64(0) + uall = ^u0 + uhigh = uall >> 1 + utop = ^uhigh + s0 = int64(0) + sall = ^s0 + shigh = int64(uhigh) + stop = ^shigh + ) + runTests(t, dsn, func(dbt *DBTest) { + stmt, err := dbt.db.Prepare(`SELECT ?, ?, ? ,?, ?, ?, ?, ?`) + if err != nil { + dbt.Fatal(err) + } + defer stmt.Close() + row := stmt.QueryRow( + u0, uhigh, utop, uall, + s0, shigh, stop, sall, + ) + + var ua, ub, uc, ud uint64 + var sa, sb, sc, sd int64 + + err = row.Scan(&ua, &ub, &uc, &ud, &sa, &sb, &sc, &sd) + if err != nil { + dbt.Fatal(err) + } + switch { + case ua != u0, + ub != uhigh, + uc != utop, + ud != uall, + sa != s0, + sb != shigh, + sc != stop, + sd != sall: + dbt.Fatal("unexpected result value") + } + }) +} + +func TestLongData(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + var maxAllowedPacketSize int + err := dbt.db.QueryRow("select @@max_allowed_packet").Scan(&maxAllowedPacketSize) + if err != nil { + dbt.Fatal(err) + } + maxAllowedPacketSize-- + + // don't get too ambitious + if maxAllowedPacketSize > 1<<25 { + maxAllowedPacketSize = 1 << 25 + } + + dbt.mustExec("CREATE TABLE test (value LONGBLOB)") + + in := strings.Repeat(`a`, maxAllowedPacketSize+1) + var out string + var rows *sql.Rows + + // Long text data + const nonDataQueryLen = 28 // length query w/o value + inS := in[:maxAllowedPacketSize-nonDataQueryLen] + dbt.mustExec("INSERT INTO test VALUES('" + inS + "')") + rows = dbt.mustQuery("SELECT value FROM test") + if rows.Next() { + rows.Scan(&out) + if inS != out { + dbt.Fatalf("LONGBLOB: length in: %d, length out: %d", len(inS), len(out)) + } + if rows.Next() { + dbt.Error("LONGBLOB: unexpexted row") + } + } else { + dbt.Fatalf("LONGBLOB: no data") + } + + // Empty table + dbt.mustExec("TRUNCATE TABLE test") + + // Long binary data + dbt.mustExec("INSERT INTO test VALUES(?)", in) + rows = dbt.mustQuery("SELECT value FROM test WHERE 1=?", 1) + if rows.Next() { + rows.Scan(&out) + if in != out { + dbt.Fatalf("LONGBLOB: length in: %d, length out: %d", len(in), len(out)) + } + if rows.Next() { + dbt.Error("LONGBLOB: unexpexted row") + } + } else { + if err = rows.Err(); err != nil { + dbt.Fatalf("LONGBLOB: no data (err: %s)", err.Error()) + } else { + dbt.Fatal("LONGBLOB: no data (err: )") + } + } + }) +} + +func TestLoadData(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + verifyLoadDataResult := func() { + rows, err := dbt.db.Query("SELECT * FROM test") + if err != nil { + dbt.Fatal(err.Error()) + } + + i := 0 + values := [4]string{ + "a string", + "a string containing a \t", + "a string containing a \n", + "a string containing both \t\n", + } + + var id int + var value string + + for rows.Next() { + i++ + err = rows.Scan(&id, &value) + if err != nil { + dbt.Fatal(err.Error()) + } + if i != id { + dbt.Fatalf("%d != %d", i, id) + } + if values[i-1] != value { + dbt.Fatalf("%q != %q", values[i-1], value) + } + } + err = rows.Err() + if err != nil { + dbt.Fatal(err.Error()) + } + + if i != 4 { + dbt.Fatalf("rows count mismatch. Got %d, want 4", i) + } + } + file, err := ioutil.TempFile("", "gotest") + defer os.Remove(file.Name()) + if err != nil { + dbt.Fatal(err) + } + file.WriteString("1\ta string\n2\ta string containing a \\t\n3\ta string containing a \\n\n4\ta string containing both \\t\\n\n") + file.Close() + + dbt.db.Exec("DROP TABLE IF EXISTS test") + dbt.mustExec("CREATE TABLE test (id INT NOT NULL PRIMARY KEY, value TEXT NOT NULL) CHARACTER SET utf8") + + // Local File + RegisterLocalFile(file.Name()) + dbt.mustExec(fmt.Sprintf("LOAD DATA LOCAL INFILE %q INTO TABLE test", file.Name())) + verifyLoadDataResult() + // negative test + _, err = dbt.db.Exec("LOAD DATA LOCAL INFILE 'doesnotexist' INTO TABLE test") + if err == nil { + dbt.Fatal("load non-existent file didn't fail") + } else if err.Error() != "local file 'doesnotexist' is not registered" { + dbt.Fatal(err.Error()) + } + + // Empty table + dbt.mustExec("TRUNCATE TABLE test") + + // Reader + RegisterReaderHandler("test", func() io.Reader { + file, err = os.Open(file.Name()) + if err != nil { + dbt.Fatal(err) + } + return file + }) + dbt.mustExec("LOAD DATA LOCAL INFILE 'Reader::test' INTO TABLE test") + verifyLoadDataResult() + // negative test + _, err = dbt.db.Exec("LOAD DATA LOCAL INFILE 'Reader::doesnotexist' INTO TABLE test") + if err == nil { + dbt.Fatal("load non-existent Reader didn't fail") + } else if err.Error() != "Reader 'doesnotexist' is not registered" { + dbt.Fatal(err.Error()) + } + }) +} + +func TestFoundRows(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)") + dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)") + + res := dbt.mustExec("UPDATE test SET data = 1 WHERE id = 0") + count, err := res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 2 { + dbt.Fatalf("Expected 2 affected rows, got %d", count) + } + res = dbt.mustExec("UPDATE test SET data = 1 WHERE id = 1") + count, err = res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 2 { + dbt.Fatalf("Expected 2 affected rows, got %d", count) + } + }) + runTests(t, dsn+"&clientFoundRows=true", func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)") + dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)") + + res := dbt.mustExec("UPDATE test SET data = 1 WHERE id = 0") + count, err := res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 2 { + dbt.Fatalf("Expected 2 matched rows, got %d", count) + } + res = dbt.mustExec("UPDATE test SET data = 1 WHERE id = 1") + count, err = res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 3 { + dbt.Fatalf("Expected 3 matched rows, got %d", count) + } + }) +} + +func TestStrict(t *testing.T) { + // ALLOW_INVALID_DATES to get rid of stricter modes - we want to test for warnings, not errors + relaxedDsn := dsn + "&sql_mode='ALLOW_INVALID_DATES,NO_AUTO_CREATE_USER'" + // make sure the MySQL version is recent enough with a separate connection + // before running the test + conn, err := MySQLDriver{}.Open(relaxedDsn) + if conn != nil { + conn.Close() + } + if me, ok := err.(*MySQLError); ok && me.Number == 1231 { + // Error 1231: Variable 'sql_mode' can't be set to the value of 'ALLOW_INVALID_DATES' + // => skip test, MySQL server version is too old + return + } + runTests(t, relaxedDsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (a TINYINT NOT NULL, b CHAR(4))") + + var queries = [...]struct { + in string + codes []string + }{ + {"DROP TABLE IF EXISTS no_such_table", []string{"1051"}}, + {"INSERT INTO test VALUES(10,'mysql'),(NULL,'test'),(300,'Open Source')", []string{"1265", "1048", "1264", "1265"}}, + } + var err error + + var checkWarnings = func(err error, mode string, idx int) { + if err == nil { + dbt.Errorf("expected STRICT error on query [%s] %s", mode, queries[idx].in) + } + + if warnings, ok := err.(MySQLWarnings); ok { + var codes = make([]string, len(warnings)) + for i := range warnings { + codes[i] = warnings[i].Code + } + if len(codes) != len(queries[idx].codes) { + dbt.Errorf("unexpected STRICT error count on query [%s] %s: Wanted %v, Got %v", mode, queries[idx].in, queries[idx].codes, codes) + } + + for i := range warnings { + if codes[i] != queries[idx].codes[i] { + dbt.Errorf("unexpected STRICT error codes on query [%s] %s: Wanted %v, Got %v", mode, queries[idx].in, queries[idx].codes, codes) + return + } + } + + } else { + dbt.Errorf("unexpected error on query [%s] %s: %s", mode, queries[idx].in, err.Error()) + } + } + + // text protocol + for i := range queries { + _, err = dbt.db.Exec(queries[i].in) + checkWarnings(err, "text", i) + } + + var stmt *sql.Stmt + + // binary protocol + for i := range queries { + stmt, err = dbt.db.Prepare(queries[i].in) + if err != nil { + dbt.Errorf("error on preparing query %s: %s", queries[i].in, err.Error()) + } + + _, err = stmt.Exec() + checkWarnings(err, "binary", i) + + err = stmt.Close() + if err != nil { + dbt.Errorf("error on closing stmt for query %s: %s", queries[i].in, err.Error()) + } + } + }) +} + +func TestTLS(t *testing.T) { + tlsTest := func(dbt *DBTest) { + if err := dbt.db.Ping(); err != nil { + if err == ErrNoTLS { + dbt.Skip("server does not support TLS") + } else { + dbt.Fatalf("error on Ping: %s", err.Error()) + } + } + + rows := dbt.mustQuery("SHOW STATUS LIKE 'Ssl_cipher'") + + var variable, value *sql.RawBytes + for rows.Next() { + if err := rows.Scan(&variable, &value); err != nil { + dbt.Fatal(err.Error()) + } + + if value == nil { + dbt.Fatal("no Cipher") + } + } + } + + runTests(t, dsn+"&tls=skip-verify", tlsTest) + + // Verify that registering / using a custom cfg works + RegisterTLSConfig("custom-skip-verify", &tls.Config{ + InsecureSkipVerify: true, + }) + runTests(t, dsn+"&tls=custom-skip-verify", tlsTest) +} + +func TestReuseClosedConnection(t *testing.T) { + // this test does not use sql.database, it uses the driver directly + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + md := &MySQLDriver{} + conn, err := md.Open(dsn) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + stmt, err := conn.Prepare("DO 1") + if err != nil { + t.Fatalf("error preparing statement: %s", err.Error()) + } + _, err = stmt.Exec(nil) + if err != nil { + t.Fatalf("error executing statement: %s", err.Error()) + } + err = conn.Close() + if err != nil { + t.Fatalf("error closing connection: %s", err.Error()) + } + + defer func() { + if err := recover(); err != nil { + t.Errorf("panic after reusing a closed connection: %v", err) + } + }() + _, err = stmt.Exec(nil) + if err != nil && err != driver.ErrBadConn { + t.Errorf("unexpected error '%s', expected '%s'", + err.Error(), driver.ErrBadConn.Error()) + } +} + +func TestCharset(t *testing.T) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + mustSetCharset := func(charsetParam, expected string) { + runTests(t, dsn+"&"+charsetParam, func(dbt *DBTest) { + rows := dbt.mustQuery("SELECT @@character_set_connection") + defer rows.Close() + + if !rows.Next() { + dbt.Fatalf("error getting connection charset: %s", rows.Err()) + } + + var got string + rows.Scan(&got) + + if got != expected { + dbt.Fatalf("expected connection charset %s but got %s", expected, got) + } + }) + } + + // non utf8 test + mustSetCharset("charset=ascii", "ascii") + + // when the first charset is invalid, use the second + mustSetCharset("charset=none,utf8", "utf8") + + // when the first charset is valid, use it + mustSetCharset("charset=ascii,utf8", "ascii") + mustSetCharset("charset=utf8,ascii", "utf8") +} + +func TestFailingCharset(t *testing.T) { + runTests(t, dsn+"&charset=none", func(dbt *DBTest) { + // run query to really establish connection... + _, err := dbt.db.Exec("SELECT 1") + if err == nil { + dbt.db.Close() + t.Fatalf("connection must not succeed without a valid charset") + } + }) +} + +func TestCollation(t *testing.T) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + defaultCollation := "utf8_general_ci" + testCollations := []string{ + "", // do not set + defaultCollation, // driver default + "latin1_general_ci", + "binary", + "utf8_unicode_ci", + "cp1257_bin", + } + + for _, collation := range testCollations { + var expected, tdsn string + if collation != "" { + tdsn = dsn + "&collation=" + collation + expected = collation + } else { + tdsn = dsn + expected = defaultCollation + } + + runTests(t, tdsn, func(dbt *DBTest) { + var got string + if err := dbt.db.QueryRow("SELECT @@collation_connection").Scan(&got); err != nil { + dbt.Fatal(err) + } + + if got != expected { + dbt.Fatalf("expected connection collation %s but got %s", expected, got) + } + }) + } +} + +func TestColumnsWithAlias(t *testing.T) { + runTests(t, dsn+"&columnsWithAlias=true", func(dbt *DBTest) { + rows := dbt.mustQuery("SELECT 1 AS A") + defer rows.Close() + cols, _ := rows.Columns() + if len(cols) != 1 { + t.Fatalf("expected 1 column, got %d", len(cols)) + } + if cols[0] != "A" { + t.Fatalf("expected column name \"A\", got \"%s\"", cols[0]) + } + rows.Close() + + rows = dbt.mustQuery("SELECT * FROM (SELECT 1 AS one) AS A") + cols, _ = rows.Columns() + if len(cols) != 1 { + t.Fatalf("expected 1 column, got %d", len(cols)) + } + if cols[0] != "A.one" { + t.Fatalf("expected column name \"A.one\", got \"%s\"", cols[0]) + } + }) +} + +func TestRawBytesResultExceedsBuffer(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + // defaultBufSize from buffer.go + expected := strings.Repeat("abc", defaultBufSize) + + rows := dbt.mustQuery("SELECT '" + expected + "'") + defer rows.Close() + if !rows.Next() { + dbt.Error("expected result, got none") + } + var result sql.RawBytes + rows.Scan(&result) + if expected != string(result) { + dbt.Error("result did not match expected value") + } + }) +} + +func TestTimezoneConversion(t *testing.T) { + zones := []string{"UTC", "US/Central", "US/Pacific", "Local"} + + // Regression test for timezone handling + tzTest := func(dbt *DBTest) { + + // Create table + dbt.mustExec("CREATE TABLE test (ts TIMESTAMP)") + + // Insert local time into database (should be converted) + usCentral, _ := time.LoadLocation("US/Central") + reftime := time.Date(2014, 05, 30, 18, 03, 17, 0, time.UTC).In(usCentral) + dbt.mustExec("INSERT INTO test VALUE (?)", reftime) + + // Retrieve time from DB + rows := dbt.mustQuery("SELECT ts FROM test") + if !rows.Next() { + dbt.Fatal("did not get any rows out") + } + + var dbTime time.Time + err := rows.Scan(&dbTime) + if err != nil { + dbt.Fatal("Err", err) + } + + // Check that dates match + if reftime.Unix() != dbTime.Unix() { + dbt.Errorf("times do not match.\n") + dbt.Errorf(" Now(%v)=%v\n", usCentral, reftime) + dbt.Errorf(" Now(UTC)=%v\n", dbTime) + } + } + + for _, tz := range zones { + runTests(t, dsn+"&parseTime=true&loc="+url.QueryEscape(tz), tzTest) + } +} + +// Special cases + +func TestRowsClose(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + rows, err := dbt.db.Query("SELECT 1") + if err != nil { + dbt.Fatal(err) + } + + err = rows.Close() + if err != nil { + dbt.Fatal(err) + } + + if rows.Next() { + dbt.Fatal("unexpected row after rows.Close()") + } + + err = rows.Err() + if err != nil { + dbt.Fatal(err) + } + }) +} + +// dangling statements +// http://code.google.com/p/go/issues/detail?id=3865 +func TestCloseStmtBeforeRows(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + stmt, err := dbt.db.Prepare("SELECT 1") + if err != nil { + dbt.Fatal(err) + } + + rows, err := stmt.Query() + if err != nil { + stmt.Close() + dbt.Fatal(err) + } + defer rows.Close() + + err = stmt.Close() + if err != nil { + dbt.Fatal(err) + } + + if !rows.Next() { + dbt.Fatal("getting row failed") + } else { + err = rows.Err() + if err != nil { + dbt.Fatal(err) + } + + var out bool + err = rows.Scan(&out) + if err != nil { + dbt.Fatalf("error on rows.Scan(): %s", err.Error()) + } + if out != true { + dbt.Errorf("true != %t", out) + } + } + }) +} + +// It is valid to have multiple Rows for the same Stmt +// http://code.google.com/p/go/issues/detail?id=3734 +func TestStmtMultiRows(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + stmt, err := dbt.db.Prepare("SELECT 1 UNION SELECT 0") + if err != nil { + dbt.Fatal(err) + } + + rows1, err := stmt.Query() + if err != nil { + stmt.Close() + dbt.Fatal(err) + } + defer rows1.Close() + + rows2, err := stmt.Query() + if err != nil { + stmt.Close() + dbt.Fatal(err) + } + defer rows2.Close() + + var out bool + + // 1 + if !rows1.Next() { + dbt.Fatal("first rows1.Next failed") + } else { + err = rows1.Err() + if err != nil { + dbt.Fatal(err) + } + + err = rows1.Scan(&out) + if err != nil { + dbt.Fatalf("error on rows.Scan(): %s", err.Error()) + } + if out != true { + dbt.Errorf("true != %t", out) + } + } + + if !rows2.Next() { + dbt.Fatal("first rows2.Next failed") + } else { + err = rows2.Err() + if err != nil { + dbt.Fatal(err) + } + + err = rows2.Scan(&out) + if err != nil { + dbt.Fatalf("error on rows.Scan(): %s", err.Error()) + } + if out != true { + dbt.Errorf("true != %t", out) + } + } + + // 2 + if !rows1.Next() { + dbt.Fatal("second rows1.Next failed") + } else { + err = rows1.Err() + if err != nil { + dbt.Fatal(err) + } + + err = rows1.Scan(&out) + if err != nil { + dbt.Fatalf("error on rows.Scan(): %s", err.Error()) + } + if out != false { + dbt.Errorf("false != %t", out) + } + + if rows1.Next() { + dbt.Fatal("unexpected row on rows1") + } + err = rows1.Close() + if err != nil { + dbt.Fatal(err) + } + } + + if !rows2.Next() { + dbt.Fatal("second rows2.Next failed") + } else { + err = rows2.Err() + if err != nil { + dbt.Fatal(err) + } + + err = rows2.Scan(&out) + if err != nil { + dbt.Fatalf("error on rows.Scan(): %s", err.Error()) + } + if out != false { + dbt.Errorf("false != %t", out) + } + + if rows2.Next() { + dbt.Fatal("unexpected row on rows2") + } + err = rows2.Close() + if err != nil { + dbt.Fatal(err) + } + } + }) +} + +// Regression test for +// * more than 32 NULL parameters (issue 209) +// * more parameters than fit into the buffer (issue 201) +func TestPreparedManyCols(t *testing.T) { + const numParams = defaultBufSize + runTests(t, dsn, func(dbt *DBTest) { + query := "SELECT ?" + strings.Repeat(",?", numParams-1) + stmt, err := dbt.db.Prepare(query) + if err != nil { + dbt.Fatal(err) + } + defer stmt.Close() + // create more parameters than fit into the buffer + // which will take nil-values + params := make([]interface{}, numParams) + rows, err := stmt.Query(params...) + if err != nil { + stmt.Close() + dbt.Fatal(err) + } + defer rows.Close() + }) +} + +func TestConcurrent(t *testing.T) { + if enabled, _ := readBool(os.Getenv("MYSQL_TEST_CONCURRENT")); !enabled { + t.Skip("MYSQL_TEST_CONCURRENT env var not set") + } + + runTests(t, dsn, func(dbt *DBTest) { + var max int + err := dbt.db.QueryRow("SELECT @@max_connections").Scan(&max) + if err != nil { + dbt.Fatalf("%s", err.Error()) + } + dbt.Logf("testing up to %d concurrent connections \r\n", max) + + var remaining, succeeded int32 = int32(max), 0 + + var wg sync.WaitGroup + wg.Add(max) + + var fatalError string + var once sync.Once + fatalf := func(s string, vals ...interface{}) { + once.Do(func() { + fatalError = fmt.Sprintf(s, vals...) + }) + } + + for i := 0; i < max; i++ { + go func(id int) { + defer wg.Done() + + tx, err := dbt.db.Begin() + atomic.AddInt32(&remaining, -1) + + if err != nil { + if err.Error() != "Error 1040: Too many connections" { + fatalf("error on conn %d: %s", id, err.Error()) + } + return + } + + // keep the connection busy until all connections are open + for remaining > 0 { + if _, err = tx.Exec("DO 1"); err != nil { + fatalf("error on conn %d: %s", id, err.Error()) + return + } + } + + if err = tx.Commit(); err != nil { + fatalf("error on conn %d: %s", id, err.Error()) + return + } + + // everything went fine with this connection + atomic.AddInt32(&succeeded, 1) + }(i) + } + + // wait until all conections are open + wg.Wait() + + if fatalError != "" { + dbt.Fatal(fatalError) + } + + dbt.Logf("reached %d concurrent connections\r\n", succeeded) + }) +} + +// Tests custom dial functions +func TestCustomDial(t *testing.T) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + // our custom dial function which justs wraps net.Dial here + RegisterDial("mydial", func(addr string) (net.Conn, error) { + return net.Dial(prot, addr) + }) + + db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s&strict=true", user, pass, addr, dbname)) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db.Close() + + if _, err = db.Exec("DO 1"); err != nil { + t.Fatalf("connection failed: %s", err.Error()) + } +} + +func TestSQLInjection(t *testing.T) { + createTest := func(arg string) func(dbt *DBTest) { + return func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + dbt.mustExec("INSERT INTO test VALUES (?)", 1) + + var v int + // NULL can't be equal to anything, the idea here is to inject query so it returns row + // This test verifies that escapeQuotes and escapeBackslash are working properly + err := dbt.db.QueryRow("SELECT v FROM test WHERE NULL = ?", arg).Scan(&v) + if err == sql.ErrNoRows { + return // success, sql injection failed + } else if err == nil { + dbt.Errorf("sql injection successful with arg: %s", arg) + } else { + dbt.Errorf("error running query with arg: %s; err: %s", arg, err.Error()) + } + } + } + + dsns := []string{ + dsn, + dsn + "&sql_mode='NO_BACKSLASH_ESCAPES,NO_AUTO_CREATE_USER'", + } + for _, testdsn := range dsns { + runTests(t, testdsn, createTest("1 OR 1=1")) + runTests(t, testdsn, createTest("' OR '1'='1")) + } +} + +// Test if inserted data is correctly retrieved after being escaped +func TestInsertRetrieveEscapedData(t *testing.T) { + testData := func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v VARCHAR(255))") + + // All sequences that are escaped by escapeQuotes and escapeBackslash + v := "foo \x00\n\r\x1a\"'\\" + dbt.mustExec("INSERT INTO test VALUES (?)", v) + + var out string + err := dbt.db.QueryRow("SELECT v FROM test").Scan(&out) + if err != nil { + dbt.Fatalf("%s", err.Error()) + } + + if out != v { + dbt.Errorf("%q != %q", out, v) + } + } + + dsns := []string{ + dsn, + dsn + "&sql_mode='NO_BACKSLASH_ESCAPES,NO_AUTO_CREATE_USER'", + } + for _, testdsn := range dsns { + runTests(t, testdsn, testData) + } +} + +func TestUnixSocketAuthFail(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + // Save the current logger so we can restore it. + oldLogger := errLog + + // Set a new logger so we can capture its output. + buffer := bytes.NewBuffer(make([]byte, 0, 64)) + newLogger := log.New(buffer, "prefix: ", 0) + SetLogger(newLogger) + + // Restore the logger. + defer SetLogger(oldLogger) + + // Make a new DSN that uses the MySQL socket file and a bad password, which + // we can make by simply appending any character to the real password. + badPass := pass + "x" + socket := "" + if prot == "unix" { + socket = addr + } else { + // Get socket file from MySQL. + err := dbt.db.QueryRow("SELECT @@socket").Scan(&socket) + if err != nil { + t.Fatalf("error on SELECT @@socket: %s", err.Error()) + } + } + t.Logf("socket: %s", socket) + badDSN := fmt.Sprintf("%s:%s@unix(%s)/%s?timeout=30s&strict=true", user, badPass, socket, dbname) + db, err := sql.Open("mysql", badDSN) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db.Close() + + // Connect to MySQL for real. This will cause an auth failure. + err = db.Ping() + if err == nil { + t.Error("expected Ping() to return an error") + } + + // The driver should not log anything. + if actual := buffer.String(); actual != "" { + t.Errorf("expected no output, got %q", actual) + } + }) +} + +// See Issue #422 +func TestInterruptBySignal(t *testing.T) { + runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) { + dbt.mustExec(` + DROP PROCEDURE IF EXISTS test_signal; + CREATE PROCEDURE test_signal(ret INT) + BEGIN + SELECT ret; + SIGNAL SQLSTATE + '45001' + SET + MESSAGE_TEXT = "an error", + MYSQL_ERRNO = 45001; + END + `) + defer dbt.mustExec("DROP PROCEDURE test_signal") + + var val int + + // text protocol + rows, err := dbt.db.Query("CALL test_signal(42)") + if err != nil { + dbt.Fatalf("error on text query: %s", err.Error()) + } + for rows.Next() { + if err := rows.Scan(&val); err != nil { + dbt.Error(err) + } else if val != 42 { + dbt.Errorf("expected val to be 42") + } + } + + // binary protocol + rows, err = dbt.db.Query("CALL test_signal(?)", 42) + if err != nil { + dbt.Fatalf("error on binary query: %s", err.Error()) + } + for rows.Next() { + if err := rows.Scan(&val); err != nil { + dbt.Error(err) + } else if val != 42 { + dbt.Errorf("expected val to be 42") + } + } + }) +} diff --git a/driver_test.go b/driver_test.go index 94a1224f5..94aa37c21 100644 --- a/driver_test.go +++ b/driver_test.go @@ -1,3 +1,5 @@ +// +build go1.8 + // Go MySQL Driver - A MySQL-Driver for Go's database/sql package // // Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. @@ -60,10 +62,10 @@ func init() { user = env("MYSQL_TEST_USER", "root") pass = env("MYSQL_TEST_PASS", "") prot = env("MYSQL_TEST_PROT", "tcp") - addr = env("MYSQL_TEST_ADDR", "localhost:3306") + addr = env("MYSQL_TEST_ADDR", "127.0.0.1:3306") dbname = env("MYSQL_TEST_DBNAME", "gotest") netAddr = fmt.Sprintf("%s(%s)", prot, addr) - dsn = fmt.Sprintf("%s:%s@%s/%s?timeout=30s&strict=true", user, pass, netAddr, dbname) + dsn = fmt.Sprintf("%s@%s/%s?timeout=30s&strict=true", user, netAddr, dbname) c, err := net.Dial(prot, addr) if err == nil { available = true diff --git a/infile.go b/infile.go index 6d90d82a3..bd7bb3abb 100644 --- a/infile.go +++ b/infile.go @@ -1,3 +1,5 @@ +// +build go1.8 + // Go MySQL Driver - A MySQL-Driver for Go's database/sql package // // Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. diff --git a/infile_deprecated.go b/infile_deprecated.go new file mode 100644 index 000000000..38166dd98 --- /dev/null +++ b/infile_deprecated.go @@ -0,0 +1,184 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// +build !go1.8 + +package mysql + +import ( + "fmt" + "io" + "os" + "strings" + "sync" +) + +var ( + fileRegister map[string]bool + fileRegisterLock sync.RWMutex + readerRegister map[string]func() io.Reader + readerRegisterLock sync.RWMutex +) + +// RegisterLocalFile adds the given file to the file whitelist, +// so that it can be used by "LOAD DATA LOCAL INFILE ". +// Alternatively you can allow the use of all local files with +// the DSN parameter 'allowAllFiles=true' +// +// filePath := "/home/gopher/data.csv" +// mysql.RegisterLocalFile(filePath) +// err := db.Exec("LOAD DATA LOCAL INFILE '" + filePath + "' INTO TABLE foo") +// if err != nil { +// ... +// +func RegisterLocalFile(filePath string) { + fileRegisterLock.Lock() + // lazy map init + if fileRegister == nil { + fileRegister = make(map[string]bool) + } + + fileRegister[strings.Trim(filePath, `"`)] = true + fileRegisterLock.Unlock() +} + +// DeregisterLocalFile removes the given filepath from the whitelist. +func DeregisterLocalFile(filePath string) { + fileRegisterLock.Lock() + delete(fileRegister, strings.Trim(filePath, `"`)) + fileRegisterLock.Unlock() +} + +// RegisterReaderHandler registers a handler function which is used +// to receive a io.Reader. +// The Reader can be used by "LOAD DATA LOCAL INFILE Reader::". +// If the handler returns a io.ReadCloser Close() is called when the +// request is finished. +// +// mysql.RegisterReaderHandler("data", func() io.Reader { +// var csvReader io.Reader // Some Reader that returns CSV data +// ... // Open Reader here +// return csvReader +// }) +// err := db.Exec("LOAD DATA LOCAL INFILE 'Reader::data' INTO TABLE foo") +// if err != nil { +// ... +// +func RegisterReaderHandler(name string, handler func() io.Reader) { + readerRegisterLock.Lock() + // lazy map init + if readerRegister == nil { + readerRegister = make(map[string]func() io.Reader) + } + + readerRegister[name] = handler + readerRegisterLock.Unlock() +} + +// DeregisterReaderHandler removes the ReaderHandler function with +// the given name from the registry. +func DeregisterReaderHandler(name string) { + readerRegisterLock.Lock() + delete(readerRegister, name) + readerRegisterLock.Unlock() +} + +func deferredClose(err *error, closer io.Closer) { + closeErr := closer.Close() + if *err == nil { + *err = closeErr + } +} + +func (mc *mysqlConn) handleInFileRequest(name string) (err error) { + var rdr io.Reader + var data []byte + packetSize := 16 * 1024 // 16KB is small enough for disk readahead and large enough for TCP + if mc.maxWriteSize < packetSize { + packetSize = mc.maxWriteSize + } + + if idx := strings.Index(name, "Reader::"); idx == 0 || (idx > 0 && name[idx-1] == '/') { // io.Reader + // The server might return an an absolute path. See issue #355. + name = name[idx+8:] + + readerRegisterLock.RLock() + handler, inMap := readerRegister[name] + readerRegisterLock.RUnlock() + + if inMap { + rdr = handler() + if rdr != nil { + if cl, ok := rdr.(io.Closer); ok { + defer deferredClose(&err, cl) + } + } else { + err = fmt.Errorf("Reader '%s' is ", name) + } + } else { + err = fmt.Errorf("Reader '%s' is not registered", name) + } + } else { // File + name = strings.Trim(name, `"`) + fileRegisterLock.RLock() + fr := fileRegister[name] + fileRegisterLock.RUnlock() + if mc.cfg.AllowAllFiles || fr { + var file *os.File + var fi os.FileInfo + + if file, err = os.Open(name); err == nil { + defer deferredClose(&err, file) + + // get file size + if fi, err = file.Stat(); err == nil { + rdr = file + if fileSize := int(fi.Size()); fileSize < packetSize { + packetSize = fileSize + } + } + } + } else { + err = fmt.Errorf("local file '%s' is not registered", name) + } + } + + // send content packets + if err == nil { + data := make([]byte, 4+packetSize) + var n int + for err == nil { + n, err = rdr.Read(data[4:]) + if n > 0 { + if ioErr := mc.writePacket(data[:4+n]); ioErr != nil { + return ioErr + } + } + } + if err == io.EOF { + err = nil + } + } + + // send empty packet (termination) + if data == nil { + data = make([]byte, 4) + } + if ioErr := mc.writePacket(data[:4]); ioErr != nil { + return ioErr + } + + // read OK packet + if err == nil { + _, err = mc.readResultOK() + return err + } + + mc.readPacket() + return err +} diff --git a/packets.go b/packets.go index 5b8bd68be..0bc637e72 100644 --- a/packets.go +++ b/packets.go @@ -1,3 +1,5 @@ +// +build go1.8 + // Go MySQL Driver - A MySQL-Driver for Go's database/sql package // // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. diff --git a/packets_deprecated.go b/packets_deprecated.go new file mode 100644 index 000000000..5e6a97db3 --- /dev/null +++ b/packets_deprecated.go @@ -0,0 +1,1289 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// +build !go1.8 + +package mysql + +import ( + "bytes" + "crypto/tls" + "database/sql/driver" + "encoding/binary" + "errors" + "fmt" + "io" + "math" + "time" +) + +// Packets documentation: +// http://dev.mysql.com/doc/internals/en/client-server-protocol.html + +// Read packet to buffer 'data' +func (mc *mysqlConn) readPacket() ([]byte, error) { + var prevData []byte + for { + // read packet header + data, err := mc.buf.readNext(4) + if err != nil { + errLog.Print(err) + mc.Close() + return nil, driver.ErrBadConn + } + + // packet length [24 bit] + pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16) + + // check packet sync [8 bit] + if data[3] != mc.sequence { + if data[3] > mc.sequence { + return nil, ErrPktSyncMul + } + return nil, ErrPktSync + } + mc.sequence++ + + // packets with length 0 terminate a previous packet which is a + // multiple of (2^24)−1 bytes long + if pktLen == 0 { + // there was no previous packet + if prevData == nil { + errLog.Print(ErrMalformPkt) + mc.Close() + return nil, driver.ErrBadConn + } + + return prevData, nil + } + + // read packet body [pktLen bytes] + data, err = mc.buf.readNext(pktLen) + if err != nil { + errLog.Print(err) + mc.Close() + return nil, driver.ErrBadConn + } + + // return data if this was the last packet + if pktLen < maxPacketSize { + // zero allocations for non-split packets + if prevData == nil { + return data, nil + } + + return append(prevData, data...), nil + } + + prevData = append(prevData, data...) + } +} + +// Write packet buffer 'data' +func (mc *mysqlConn) writePacket(data []byte) error { + pktLen := len(data) - 4 + + if pktLen > mc.maxAllowedPacket { + return ErrPktTooLarge + } + + for { + var size int + if pktLen >= maxPacketSize { + data[0] = 0xff + data[1] = 0xff + data[2] = 0xff + size = maxPacketSize + } else { + data[0] = byte(pktLen) + data[1] = byte(pktLen >> 8) + data[2] = byte(pktLen >> 16) + size = pktLen + } + data[3] = mc.sequence + + // Write packet + if mc.writeTimeout > 0 { + if err := mc.netConn.SetWriteDeadline(time.Now().Add(mc.writeTimeout)); err != nil { + return err + } + } + + n, err := mc.netConn.Write(data[:4+size]) + if err == nil && n == 4+size { + mc.sequence++ + if size != maxPacketSize { + return nil + } + pktLen -= size + data = data[size:] + continue + } + + // Handle error + if err == nil { // n != len(data) + errLog.Print(ErrMalformPkt) + } else { + errLog.Print(err) + } + return driver.ErrBadConn + } +} + +/****************************************************************************** +* Initialisation Process * +******************************************************************************/ + +// Handshake Initialization Packet +// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake +func (mc *mysqlConn) readInitPacket() ([]byte, error) { + data, err := mc.readPacket() + if err != nil { + return nil, err + } + + if data[0] == iERR { + return nil, mc.handleErrorPacket(data) + } + + // protocol version [1 byte] + if data[0] < minProtocolVersion { + return nil, fmt.Errorf( + "unsupported protocol version %d. Version %d or higher is required", + data[0], + minProtocolVersion, + ) + } + + // server version [null terminated string] + // connection id [4 bytes] + pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4 + + // first part of the password cipher [8 bytes] + cipher := data[pos : pos+8] + + // (filler) always 0x00 [1 byte] + pos += 8 + 1 + + // capability flags (lower 2 bytes) [2 bytes] + mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2])) + if mc.flags&clientProtocol41 == 0 { + return nil, ErrOldProtocol + } + if mc.flags&clientSSL == 0 && mc.cfg.tls != nil { + return nil, ErrNoTLS + } + pos += 2 + + if len(data) > pos { + // character set [1 byte] + // status flags [2 bytes] + // capability flags (upper 2 bytes) [2 bytes] + // length of auth-plugin-data [1 byte] + // reserved (all [00]) [10 bytes] + pos += 1 + 2 + 2 + 1 + 10 + + // second part of the password cipher [mininum 13 bytes], + // where len=MAX(13, length of auth-plugin-data - 8) + // + // The web documentation is ambiguous about the length. However, + // according to mysql-5.7/sql/auth/sql_authentication.cc line 538, + // the 13th byte is "\0 byte, terminating the second part of + // a scramble". So the second part of the password cipher is + // a NULL terminated string that's at least 13 bytes with the + // last byte being NULL. + // + // The official Python library uses the fixed length 12 + // which seems to work but technically could have a hidden bug. + cipher = append(cipher, data[pos:pos+12]...) + + // TODO: Verify string termination + // EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2) + // \NUL otherwise + // + //if data[len(data)-1] == 0 { + // return + //} + //return ErrMalformPkt + + // make a memory safe copy of the cipher slice + var b [20]byte + copy(b[:], cipher) + return b[:], nil + } + + // make a memory safe copy of the cipher slice + var b [8]byte + copy(b[:], cipher) + return b[:], nil +} + +// Client Authentication Packet +// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse +func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { + // Adjust client flags based on server support + clientFlags := clientProtocol41 | + clientSecureConn | + clientLongPassword | + clientTransactions | + clientLocalFiles | + clientPluginAuth | + clientMultiResults | + mc.flags&clientLongFlag + + if mc.cfg.ClientFoundRows { + clientFlags |= clientFoundRows + } + + // To enable TLS / SSL + if mc.cfg.tls != nil { + clientFlags |= clientSSL + } + + if mc.cfg.MultiStatements { + clientFlags |= clientMultiStatements + } + + // User Password + scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.Passwd)) + + pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + 1 + len(scrambleBuff) + 21 + 1 + + // To specify a db name + if n := len(mc.cfg.DBName); n > 0 { + clientFlags |= clientConnectWithDB + pktLen += n + 1 + } + + // Calculate packet length and get buffer with that size + data := mc.buf.takeSmallBuffer(pktLen + 4) + if data == nil { + // can not take the buffer. Something must be wrong with the connection + errLog.Print(ErrBusyBuffer) + return driver.ErrBadConn + } + + // ClientFlags [32 bit] + data[4] = byte(clientFlags) + data[5] = byte(clientFlags >> 8) + data[6] = byte(clientFlags >> 16) + data[7] = byte(clientFlags >> 24) + + // MaxPacketSize [32 bit] (none) + data[8] = 0x00 + data[9] = 0x00 + data[10] = 0x00 + data[11] = 0x00 + + // Charset [1 byte] + var found bool + data[12], found = collations[mc.cfg.Collation] + if !found { + // Note possibility for false negatives: + // could be triggered although the collation is valid if the + // collations map does not contain entries the server supports. + return errors.New("unknown collation") + } + + // SSL Connection Request Packet + // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest + if mc.cfg.tls != nil { + // Send TLS / SSL request packet + if err := mc.writePacket(data[:(4+4+1+23)+4]); err != nil { + return err + } + + // Switch to TLS + tlsConn := tls.Client(mc.netConn, mc.cfg.tls) + if err := tlsConn.Handshake(); err != nil { + return err + } + mc.netConn = tlsConn + mc.buf.nc = tlsConn + } + + // Filler [23 bytes] (all 0x00) + pos := 13 + for ; pos < 13+23; pos++ { + data[pos] = 0 + } + + // User [null terminated string] + if len(mc.cfg.User) > 0 { + pos += copy(data[pos:], mc.cfg.User) + } + data[pos] = 0x00 + pos++ + + // ScrambleBuffer [length encoded integer] + data[pos] = byte(len(scrambleBuff)) + pos += 1 + copy(data[pos+1:], scrambleBuff) + + // Databasename [null terminated string] + if len(mc.cfg.DBName) > 0 { + pos += copy(data[pos:], mc.cfg.DBName) + data[pos] = 0x00 + pos++ + } + + // Assume native client during response + pos += copy(data[pos:], "mysql_native_password") + data[pos] = 0x00 + + // Send Auth packet + return mc.writePacket(data) +} + +// Client old authentication packet +// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse +func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error { + // User password + scrambleBuff := scrambleOldPassword(cipher, []byte(mc.cfg.Passwd)) + + // Calculate the packet length and add a tailing 0 + pktLen := len(scrambleBuff) + 1 + data := mc.buf.takeSmallBuffer(4 + pktLen) + if data == nil { + // can not take the buffer. Something must be wrong with the connection + errLog.Print(ErrBusyBuffer) + return driver.ErrBadConn + } + + // Add the scrambled password [null terminated string] + copy(data[4:], scrambleBuff) + data[4+pktLen-1] = 0x00 + + return mc.writePacket(data) +} + +// Client clear text authentication packet +// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse +func (mc *mysqlConn) writeClearAuthPacket() error { + // Calculate the packet length and add a tailing 0 + pktLen := len(mc.cfg.Passwd) + 1 + data := mc.buf.takeSmallBuffer(4 + pktLen) + if data == nil { + // can not take the buffer. Something must be wrong with the connection + errLog.Print(ErrBusyBuffer) + return driver.ErrBadConn + } + + // Add the clear password [null terminated string] + copy(data[4:], mc.cfg.Passwd) + data[4+pktLen-1] = 0x00 + + return mc.writePacket(data) +} + +// Native password authentication method +// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse +func (mc *mysqlConn) writeNativeAuthPacket(cipher []byte) error { + scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.Passwd)) + + // Calculate the packet length and add a tailing 0 + pktLen := len(scrambleBuff) + data := mc.buf.takeSmallBuffer(4 + pktLen) + if data == nil { + // can not take the buffer. Something must be wrong with the connection + errLog.Print(ErrBusyBuffer) + return driver.ErrBadConn + } + + // Add the scramble + copy(data[4:], scrambleBuff) + + return mc.writePacket(data) +} + +/****************************************************************************** +* Command Packets * +******************************************************************************/ + +func (mc *mysqlConn) writeCommandPacket(command byte) error { + // Reset Packet Sequence + mc.sequence = 0 + + data := mc.buf.takeSmallBuffer(4 + 1) + if data == nil { + // can not take the buffer. Something must be wrong with the connection + errLog.Print(ErrBusyBuffer) + return driver.ErrBadConn + } + + // Add command byte + data[4] = command + + // Send CMD packet + return mc.writePacket(data) +} + +func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { + // Reset Packet Sequence + mc.sequence = 0 + + pktLen := 1 + len(arg) + data := mc.buf.takeBuffer(pktLen + 4) + if data == nil { + // can not take the buffer. Something must be wrong with the connection + errLog.Print(ErrBusyBuffer) + return driver.ErrBadConn + } + + // Add command byte + data[4] = command + + // Add arg + copy(data[5:], arg) + + // Send CMD packet + return mc.writePacket(data) +} + +func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { + // Reset Packet Sequence + mc.sequence = 0 + + data := mc.buf.takeSmallBuffer(4 + 1 + 4) + if data == nil { + // can not take the buffer. Something must be wrong with the connection + errLog.Print(ErrBusyBuffer) + return driver.ErrBadConn + } + + // Add command byte + data[4] = command + + // Add arg [32 bit] + data[5] = byte(arg) + data[6] = byte(arg >> 8) + data[7] = byte(arg >> 16) + data[8] = byte(arg >> 24) + + // Send CMD packet + return mc.writePacket(data) +} + +/****************************************************************************** +* Result Packets * +******************************************************************************/ + +// Returns error if Packet is not an 'Result OK'-Packet +func (mc *mysqlConn) readResultOK() ([]byte, error) { + data, err := mc.readPacket() + if err == nil { + // packet indicator + switch data[0] { + + case iOK: + return nil, mc.handleOkPacket(data) + + case iEOF: + if len(data) > 1 { + pluginEndIndex := bytes.IndexByte(data, 0x00) + plugin := string(data[1:pluginEndIndex]) + cipher := data[pluginEndIndex+1 : len(data)-1] + + if plugin == "mysql_old_password" { + // using old_passwords + return cipher, ErrOldPassword + } else if plugin == "mysql_clear_password" { + // using clear text password + return cipher, ErrCleartextPassword + } else if plugin == "mysql_native_password" { + // using mysql default authentication method + return cipher, ErrNativePassword + } else { + return cipher, ErrUnknownPlugin + } + } else { + // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest + return nil, ErrOldPassword + } + + default: // Error otherwise + return nil, mc.handleErrorPacket(data) + } + } + return nil, err +} + +// Result Set Header Packet +// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset +func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) { + data, err := mc.readPacket() + if err == nil { + switch data[0] { + + case iOK: + return 0, mc.handleOkPacket(data) + + case iERR: + return 0, mc.handleErrorPacket(data) + + case iLocalInFile: + return 0, mc.handleInFileRequest(string(data[1:])) + } + + // column count + num, _, n := readLengthEncodedInteger(data) + if n-len(data) == 0 { + return int(num), nil + } + + return 0, ErrMalformPkt + } + return 0, err +} + +// Error Packet +// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-ERR_Packet +func (mc *mysqlConn) handleErrorPacket(data []byte) error { + if data[0] != iERR { + return ErrMalformPkt + } + + // 0xff [1 byte] + + // Error Number [16 bit uint] + errno := binary.LittleEndian.Uint16(data[1:3]) + + pos := 3 + + // SQL State [optional: # + 5bytes string] + if data[3] == 0x23 { + //sqlstate := string(data[4 : 4+5]) + pos = 9 + } + + // Error Message [string] + return &MySQLError{ + Number: errno, + Message: string(data[pos:]), + } +} + +func readStatus(b []byte) statusFlag { + return statusFlag(b[0]) | statusFlag(b[1])<<8 +} + +// Ok Packet +// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet +func (mc *mysqlConn) handleOkPacket(data []byte) error { + var n, m int + + // 0x00 [1 byte] + + // Affected rows [Length Coded Binary] + mc.affectedRows, _, n = readLengthEncodedInteger(data[1:]) + + // Insert id [Length Coded Binary] + mc.insertId, _, m = readLengthEncodedInteger(data[1+n:]) + + // server_status [2 bytes] + mc.status = readStatus(data[1+n+m : 1+n+m+2]) + if err := mc.discardResults(); err != nil { + return err + } + + // warning count [2 bytes] + if !mc.strict { + return nil + } + + pos := 1 + n + m + 2 + if binary.LittleEndian.Uint16(data[pos:pos+2]) > 0 { + return mc.getWarnings() + } + return nil +} + +// Read Packets as Field Packets until EOF-Packet or an Error appears +// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41 +func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { + columns := make([]mysqlField, count) + + for i := 0; ; i++ { + data, err := mc.readPacket() + if err != nil { + return nil, err + } + + // EOF Packet + if data[0] == iEOF && (len(data) == 5 || len(data) == 1) { + if i == count { + return columns, nil + } + return nil, fmt.Errorf("column count mismatch n:%d len:%d", count, len(columns)) + } + + // Catalog + pos, err := skipLengthEncodedString(data) + if err != nil { + return nil, err + } + + // Database [len coded string] + n, err := skipLengthEncodedString(data[pos:]) + if err != nil { + return nil, err + } + pos += n + + // Table [len coded string] + if mc.cfg.ColumnsWithAlias { + tableName, _, n, err := readLengthEncodedString(data[pos:]) + if err != nil { + return nil, err + } + pos += n + columns[i].tableName = string(tableName) + } else { + n, err = skipLengthEncodedString(data[pos:]) + if err != nil { + return nil, err + } + pos += n + } + + // Original table [len coded string] + n, err = skipLengthEncodedString(data[pos:]) + if err != nil { + return nil, err + } + pos += n + + // Name [len coded string] + name, _, n, err := readLengthEncodedString(data[pos:]) + if err != nil { + return nil, err + } + columns[i].name = string(name) + pos += n + + // Original name [len coded string] + n, err = skipLengthEncodedString(data[pos:]) + if err != nil { + return nil, err + } + + // Filler [uint8] + // Charset [charset, collation uint8] + // Length [uint32] + pos += n + 1 + 2 + 4 + + // Field type [uint8] + columns[i].fieldType = data[pos] + pos++ + + // Flags [uint16] + columns[i].flags = fieldFlag(binary.LittleEndian.Uint16(data[pos : pos+2])) + pos += 2 + + // Decimals [uint8] + columns[i].decimals = data[pos] + //pos++ + + // Default value [len coded binary] + //if pos < len(data) { + // defaultVal, _, err = bytesToLengthCodedBinary(data[pos:]) + //} + } +} + +// Read Packets as Field Packets until EOF-Packet or an Error appears +// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::ResultsetRow +func (rows *textRows) readRow(dest []driver.Value) error { + mc := rows.mc + + data, err := mc.readPacket() + if err != nil { + return err + } + + // EOF Packet + if data[0] == iEOF && len(data) == 5 { + // server_status [2 bytes] + rows.mc.status = readStatus(data[3:]) + err = rows.mc.discardResults() + if err == nil { + err = io.EOF + } else { + // connection unusable + rows.mc.Close() + } + rows.mc = nil + return err + } + if data[0] == iERR { + rows.mc = nil + return mc.handleErrorPacket(data) + } + + // RowSet Packet + var n int + var isNull bool + pos := 0 + + for i := range dest { + // Read bytes and convert to string + dest[i], isNull, n, err = readLengthEncodedString(data[pos:]) + pos += n + if err == nil { + if !isNull { + if !mc.parseTime { + continue + } else { + switch rows.columns[i].fieldType { + case fieldTypeTimestamp, fieldTypeDateTime, + fieldTypeDate, fieldTypeNewDate: + dest[i], err = parseDateTime( + string(dest[i].([]byte)), + mc.cfg.Loc, + ) + if err == nil { + continue + } + default: + continue + } + } + + } else { + dest[i] = nil + continue + } + } + return err // err != nil + } + + return nil +} + +// Reads Packets until EOF-Packet or an Error appears. Returns count of Packets read +func (mc *mysqlConn) readUntilEOF() error { + for { + data, err := mc.readPacket() + if err != nil { + return err + } + + switch data[0] { + case iERR: + return mc.handleErrorPacket(data) + case iEOF: + if len(data) == 5 { + mc.status = readStatus(data[3:]) + } + return nil + } + } +} + +/****************************************************************************** +* Prepared Statements * +******************************************************************************/ + +// Prepare Result Packets +// http://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html +func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) { + data, err := stmt.mc.readPacket() + if err == nil { + // packet indicator [1 byte] + if data[0] != iOK { + return 0, stmt.mc.handleErrorPacket(data) + } + + // statement id [4 bytes] + stmt.id = binary.LittleEndian.Uint32(data[1:5]) + + // Column count [16 bit uint] + columnCount := binary.LittleEndian.Uint16(data[5:7]) + + // Param count [16 bit uint] + stmt.paramCount = int(binary.LittleEndian.Uint16(data[7:9])) + + // Reserved [8 bit] + + // Warning count [16 bit uint] + if !stmt.mc.strict { + return columnCount, nil + } + + // Check for warnings count > 0, only available in MySQL > 4.1 + if len(data) >= 12 && binary.LittleEndian.Uint16(data[10:12]) > 0 { + return columnCount, stmt.mc.getWarnings() + } + return columnCount, nil + } + return 0, err +} + +// http://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html +func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { + maxLen := stmt.mc.maxAllowedPacket - 1 + pktLen := maxLen + + // After the header (bytes 0-3) follows before the data: + // 1 byte command + // 4 bytes stmtID + // 2 bytes paramID + const dataOffset = 1 + 4 + 2 + + // Can not use the write buffer since + // a) the buffer is too small + // b) it is in use + data := make([]byte, 4+1+4+2+len(arg)) + + copy(data[4+dataOffset:], arg) + + for argLen := len(arg); argLen > 0; argLen -= pktLen - dataOffset { + if dataOffset+argLen < maxLen { + pktLen = dataOffset + argLen + } + + stmt.mc.sequence = 0 + // Add command byte [1 byte] + data[4] = comStmtSendLongData + + // Add stmtID [32 bit] + data[5] = byte(stmt.id) + data[6] = byte(stmt.id >> 8) + data[7] = byte(stmt.id >> 16) + data[8] = byte(stmt.id >> 24) + + // Add paramID [16 bit] + data[9] = byte(paramID) + data[10] = byte(paramID >> 8) + + // Send CMD packet + err := stmt.mc.writePacket(data[:4+pktLen]) + if err == nil { + data = data[pktLen-dataOffset:] + continue + } + return err + + } + + // Reset Packet Sequence + stmt.mc.sequence = 0 + return nil +} + +// Execute Prepared Statement +// http://dev.mysql.com/doc/internals/en/com-stmt-execute.html +func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { + if len(args) != stmt.paramCount { + return fmt.Errorf( + "argument count mismatch (got: %d; has: %d)", + len(args), + stmt.paramCount, + ) + } + + const minPktLen = 4 + 1 + 4 + 1 + 4 + mc := stmt.mc + + // Reset packet-sequence + mc.sequence = 0 + + var data []byte + + if len(args) == 0 { + data = mc.buf.takeBuffer(minPktLen) + } else { + data = mc.buf.takeCompleteBuffer() + } + if data == nil { + // can not take the buffer. Something must be wrong with the connection + errLog.Print(ErrBusyBuffer) + return driver.ErrBadConn + } + + // command [1 byte] + data[4] = comStmtExecute + + // statement_id [4 bytes] + data[5] = byte(stmt.id) + data[6] = byte(stmt.id >> 8) + data[7] = byte(stmt.id >> 16) + data[8] = byte(stmt.id >> 24) + + // flags (0: CURSOR_TYPE_NO_CURSOR) [1 byte] + data[9] = 0x00 + + // iteration_count (uint32(1)) [4 bytes] + data[10] = 0x01 + data[11] = 0x00 + data[12] = 0x00 + data[13] = 0x00 + + if len(args) > 0 { + pos := minPktLen + + var nullMask []byte + if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= len(data) { + // buffer has to be extended but we don't know by how much so + // we depend on append after all data with known sizes fit. + // We stop at that because we deal with a lot of columns here + // which makes the required allocation size hard to guess. + tmp := make([]byte, pos+maskLen+typesLen) + copy(tmp[:pos], data[:pos]) + data = tmp + nullMask = data[pos : pos+maskLen] + pos += maskLen + } else { + nullMask = data[pos : pos+maskLen] + for i := 0; i < maskLen; i++ { + nullMask[i] = 0 + } + pos += maskLen + } + + // newParameterBoundFlag 1 [1 byte] + data[pos] = 0x01 + pos++ + + // type of each parameter [len(args)*2 bytes] + paramTypes := data[pos:] + pos += len(args) * 2 + + // value of each parameter [n bytes] + paramValues := data[pos:pos] + valuesCap := cap(paramValues) + + for i, arg := range args { + // build NULL-bitmap + if arg == nil { + nullMask[i/8] |= 1 << (uint(i) & 7) + paramTypes[i+i] = fieldTypeNULL + paramTypes[i+i+1] = 0x00 + continue + } + + // cache types and values + switch v := arg.(type) { + case int64: + paramTypes[i+i] = fieldTypeLongLong + paramTypes[i+i+1] = 0x00 + + if cap(paramValues)-len(paramValues)-8 >= 0 { + paramValues = paramValues[:len(paramValues)+8] + binary.LittleEndian.PutUint64( + paramValues[len(paramValues)-8:], + uint64(v), + ) + } else { + paramValues = append(paramValues, + uint64ToBytes(uint64(v))..., + ) + } + + case float64: + paramTypes[i+i] = fieldTypeDouble + paramTypes[i+i+1] = 0x00 + + if cap(paramValues)-len(paramValues)-8 >= 0 { + paramValues = paramValues[:len(paramValues)+8] + binary.LittleEndian.PutUint64( + paramValues[len(paramValues)-8:], + math.Float64bits(v), + ) + } else { + paramValues = append(paramValues, + uint64ToBytes(math.Float64bits(v))..., + ) + } + + case bool: + paramTypes[i+i] = fieldTypeTiny + paramTypes[i+i+1] = 0x00 + + if v { + paramValues = append(paramValues, 0x01) + } else { + paramValues = append(paramValues, 0x00) + } + + case []byte: + // Common case (non-nil value) first + if v != nil { + paramTypes[i+i] = fieldTypeString + paramTypes[i+i+1] = 0x00 + + if len(v) < mc.maxAllowedPacket-pos-len(paramValues)-(len(args)-(i+1))*64 { + paramValues = appendLengthEncodedInteger(paramValues, + uint64(len(v)), + ) + paramValues = append(paramValues, v...) + } else { + if err := stmt.writeCommandLongData(i, v); err != nil { + return err + } + } + continue + } + + // Handle []byte(nil) as a NULL value + nullMask[i/8] |= 1 << (uint(i) & 7) + paramTypes[i+i] = fieldTypeNULL + paramTypes[i+i+1] = 0x00 + + case string: + paramTypes[i+i] = fieldTypeString + paramTypes[i+i+1] = 0x00 + + if len(v) < mc.maxAllowedPacket-pos-len(paramValues)-(len(args)-(i+1))*64 { + paramValues = appendLengthEncodedInteger(paramValues, + uint64(len(v)), + ) + paramValues = append(paramValues, v...) + } else { + if err := stmt.writeCommandLongData(i, []byte(v)); err != nil { + return err + } + } + + case time.Time: + paramTypes[i+i] = fieldTypeString + paramTypes[i+i+1] = 0x00 + + var val []byte + if v.IsZero() { + val = []byte("0000-00-00") + } else { + val = []byte(v.In(mc.cfg.Loc).Format(timeFormat)) + } + + paramValues = appendLengthEncodedInteger(paramValues, + uint64(len(val)), + ) + paramValues = append(paramValues, val...) + + default: + return fmt.Errorf("can not convert type: %T", arg) + } + } + + // Check if param values exceeded the available buffer + // In that case we must build the data packet with the new values buffer + if valuesCap != cap(paramValues) { + data = append(data[:pos], paramValues...) + mc.buf.buf = data + } + + pos += len(paramValues) + data = data[:pos] + } + + return mc.writePacket(data) +} + +func (mc *mysqlConn) discardResults() error { + for mc.status&statusMoreResultsExists != 0 { + resLen, err := mc.readResultSetHeaderPacket() + if err != nil { + return err + } + if resLen > 0 { + // columns + if err := mc.readUntilEOF(); err != nil { + return err + } + // rows + if err := mc.readUntilEOF(); err != nil { + return err + } + } else { + mc.status &^= statusMoreResultsExists + } + } + return nil +} + +// http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html +func (rows *binaryRows) readRow(dest []driver.Value) error { + data, err := rows.mc.readPacket() + if err != nil { + return err + } + + // packet indicator [1 byte] + if data[0] != iOK { + // EOF Packet + if data[0] == iEOF && len(data) == 5 { + rows.mc.status = readStatus(data[3:]) + err = rows.mc.discardResults() + if err == nil { + err = io.EOF + } else { + // connection unusable + rows.mc.Close() + } + rows.mc = nil + return err + } + rows.mc = nil + + // Error otherwise + return rows.mc.handleErrorPacket(data) + } + + // NULL-bitmap, [(column-count + 7 + 2) / 8 bytes] + pos := 1 + (len(dest)+7+2)>>3 + nullMask := data[1:pos] + + for i := range dest { + // Field is NULL + // (byte >> bit-pos) % 2 == 1 + if ((nullMask[(i+2)>>3] >> uint((i+2)&7)) & 1) == 1 { + dest[i] = nil + continue + } + + // Convert to byte-coded string + switch rows.columns[i].fieldType { + case fieldTypeNULL: + dest[i] = nil + continue + + // Numeric Types + case fieldTypeTiny: + if rows.columns[i].flags&flagUnsigned != 0 { + dest[i] = int64(data[pos]) + } else { + dest[i] = int64(int8(data[pos])) + } + pos++ + continue + + case fieldTypeShort, fieldTypeYear: + if rows.columns[i].flags&flagUnsigned != 0 { + dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2])) + } else { + dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2]))) + } + pos += 2 + continue + + case fieldTypeInt24, fieldTypeLong: + if rows.columns[i].flags&flagUnsigned != 0 { + dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4])) + } else { + dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4]))) + } + pos += 4 + continue + + case fieldTypeLongLong: + if rows.columns[i].flags&flagUnsigned != 0 { + val := binary.LittleEndian.Uint64(data[pos : pos+8]) + if val > math.MaxInt64 { + dest[i] = uint64ToString(val) + } else { + dest[i] = int64(val) + } + } else { + dest[i] = int64(binary.LittleEndian.Uint64(data[pos : pos+8])) + } + pos += 8 + continue + + case fieldTypeFloat: + dest[i] = float32(math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4]))) + pos += 4 + continue + + case fieldTypeDouble: + dest[i] = math.Float64frombits(binary.LittleEndian.Uint64(data[pos : pos+8])) + pos += 8 + continue + + // Length coded Binary Strings + case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar, + fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB, + fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB, + fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON: + var isNull bool + var n int + dest[i], isNull, n, err = readLengthEncodedString(data[pos:]) + pos += n + if err == nil { + if !isNull { + continue + } else { + dest[i] = nil + continue + } + } + return err + + case + fieldTypeDate, fieldTypeNewDate, // Date YYYY-MM-DD + fieldTypeTime, // Time [-][H]HH:MM:SS[.fractal] + fieldTypeTimestamp, fieldTypeDateTime: // Timestamp YYYY-MM-DD HH:MM:SS[.fractal] + + num, isNull, n := readLengthEncodedInteger(data[pos:]) + pos += n + + switch { + case isNull: + dest[i] = nil + continue + case rows.columns[i].fieldType == fieldTypeTime: + // database/sql does not support an equivalent to TIME, return a string + var dstlen uint8 + switch decimals := rows.columns[i].decimals; decimals { + case 0x00, 0x1f: + dstlen = 8 + case 1, 2, 3, 4, 5, 6: + dstlen = 8 + 1 + decimals + default: + return fmt.Errorf( + "protocol error, illegal decimals value %d", + rows.columns[i].decimals, + ) + } + dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, true) + case rows.mc.parseTime: + dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc) + default: + var dstlen uint8 + if rows.columns[i].fieldType == fieldTypeDate { + dstlen = 10 + } else { + switch decimals := rows.columns[i].decimals; decimals { + case 0x00, 0x1f: + dstlen = 19 + case 1, 2, 3, 4, 5, 6: + dstlen = 19 + 1 + decimals + default: + return fmt.Errorf( + "protocol error, illegal decimals value %d", + rows.columns[i].decimals, + ) + } + } + dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, false) + } + + if err == nil { + pos += int(num) + continue + } else { + return err + } + + // Please report if this happens! + default: + return fmt.Errorf("unknown field type %d", rows.columns[i].fieldType) + } + } + + return nil +} diff --git a/statement.go b/statement.go index 5cdc35d7c..3ed483899 100644 --- a/statement.go +++ b/statement.go @@ -1,3 +1,5 @@ +// +build go1.8 + // Go MySQL Driver - A MySQL-Driver for Go's database/sql package // // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. diff --git a/statement_deprecated.go b/statement_deprecated.go new file mode 100644 index 000000000..53d18575b --- /dev/null +++ b/statement_deprecated.go @@ -0,0 +1,155 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// +build !go1.8 + +package mysql + +import ( + "database/sql/driver" + "fmt" + "reflect" + "strconv" +) + +type mysqlStmt struct { + mc *mysqlConn + id uint32 + paramCount int + columns []mysqlField // cached from the first query +} + +func (stmt *mysqlStmt) Close() error { + if stmt.mc == nil || stmt.mc.netConn == nil { + // driver.Stmt.Close can be called more than once, thus this function + // has to be idempotent. + // See also Issue #450 and golang/go#16019. + //errLog.Print(ErrInvalidConn) + return driver.ErrBadConn + } + + err := stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id) + stmt.mc = nil + return err +} + +func (stmt *mysqlStmt) NumInput() int { + return stmt.paramCount +} + +func (stmt *mysqlStmt) ColumnConverter(idx int) driver.ValueConverter { + return converter{} +} + +func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { + if stmt.mc.netConn == nil { + errLog.Print(ErrInvalidConn) + return nil, driver.ErrBadConn + } + // Send command + err := stmt.writeExecutePacket(args) + if err != nil { + return nil, err + } + + mc := stmt.mc + + mc.affectedRows = 0 + mc.insertId = 0 + + // Read Result + resLen, err := mc.readResultSetHeaderPacket() + if err == nil { + if resLen > 0 { + // Columns + err = mc.readUntilEOF() + if err != nil { + return nil, err + } + + // Rows + err = mc.readUntilEOF() + } + if err == nil { + return &mysqlResult{ + affectedRows: int64(mc.affectedRows), + insertId: int64(mc.insertId), + }, nil + } + } + + return nil, err +} + +func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { + if stmt.mc.netConn == nil { + errLog.Print(ErrInvalidConn) + return nil, driver.ErrBadConn + } + // Send command + err := stmt.writeExecutePacket(args) + if err != nil { + return nil, err + } + + mc := stmt.mc + + // Read Result + resLen, err := mc.readResultSetHeaderPacket() + if err != nil { + return nil, err + } + + rows := new(binaryRows) + + if resLen > 0 { + rows.mc = mc + // Columns + // If not cached, read them and cache them + if stmt.columns == nil { + rows.columns, err = mc.readColumns(resLen) + stmt.columns = rows.columns + } else { + rows.columns = stmt.columns + err = mc.readUntilEOF() + } + } + + return rows, err +} + +type converter struct{} + +func (c converter) ConvertValue(v interface{}) (driver.Value, error) { + if driver.IsValue(v) { + return v, nil + } + + rv := reflect.ValueOf(v) + switch rv.Kind() { + case reflect.Ptr: + // indirect pointers + if rv.IsNil() { + return nil, nil + } + return c.ConvertValue(rv.Elem().Interface()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return rv.Int(), nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: + return int64(rv.Uint()), nil + case reflect.Uint64: + u64 := rv.Uint() + if u64 >= 1<<63 { + return strconv.FormatUint(u64, 10), nil + } + return int64(u64), nil + case reflect.Float32, reflect.Float64: + return rv.Float(), nil + } + return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind()) +} diff --git a/transaction.go b/transaction.go index c7338c891..0f4d4faed 100644 --- a/transaction.go +++ b/transaction.go @@ -1,3 +1,5 @@ +// +build go1.8 + // Go MySQL Driver - A MySQL-Driver for Go's database/sql package // // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. diff --git a/transaction_deprecated.go b/transaction_deprecated.go new file mode 100644 index 000000000..95d374d75 --- /dev/null +++ b/transaction_deprecated.go @@ -0,0 +1,33 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// +build !go1.8 + +package mysql + +type mysqlTx struct { + mc *mysqlConn +} + +func (tx *mysqlTx) Commit() (err error) { + if tx.mc == nil || tx.mc.netConn == nil { + return ErrInvalidConn + } + err = tx.mc.exec("COMMIT") + tx.mc = nil + return +} + +func (tx *mysqlTx) Rollback() (err error) { + if tx.mc == nil || tx.mc.netConn == nil { + return ErrInvalidConn + } + err = tx.mc.exec("ROLLBACK") + tx.mc = nil + return +} From 08e3647b4ba6f93e832cd22311a67e5e19c85a7a Mon Sep 17 00:00:00 2001 From: oscarzhao Date: Fri, 17 Mar 2017 16:28:38 +0800 Subject: [PATCH 6/9] rename files to reduce the diff size --- connection.go | 61 +- connection_deprecated.go => connection_ctx.go | 61 +- driver.go | 16 +- driver_deprecated.go => driver_ctx.go | 16 +- driver_deprecated_test.go | 1906 ----------------- driver_test.go | 15 +- infile.go | 11 +- infile_deprecated.go => infile_ctx.go | 11 +- packets.go | 69 +- packets_deprecated.go => packets_ctx.go | 69 +- statement.go | 26 +- statement_deprecated.go => statement_ctx.go | 24 +- transaction.go | 12 +- ...action_deprecated.go => transaction_ctx.go | 12 +- 14 files changed, 196 insertions(+), 2113 deletions(-) rename connection_deprecated.go => connection_ctx.go (80%) rename driver_deprecated.go => driver_ctx.go (92%) delete mode 100644 driver_deprecated_test.go rename infile_deprecated.go => infile_ctx.go (94%) rename packets_deprecated.go => packets_ctx.go (93%) rename statement_deprecated.go => statement_ctx.go (79%) rename transaction_deprecated.go => transaction_ctx.go (75%) diff --git a/connection.go b/connection.go index 3cb2724f2..eb8c33a55 100644 --- a/connection.go +++ b/connection.go @@ -6,12 +6,11 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. -// +build go1.8 +// +build !go1.8 package mysql import ( - "context" "database/sql/driver" "net" "strconv" @@ -44,7 +43,7 @@ func (mc *mysqlConn) handleParams() (err error) { charsets := strings.Split(val, ",") for i := range charsets { // ignore errors here - a charset may not exist - err = mc.exec(context.Background(), "SET NAMES "+charsets[i]) + err = mc.exec("SET NAMES " + charsets[i]) if err == nil { break } @@ -55,7 +54,7 @@ func (mc *mysqlConn) handleParams() (err error) { // System Vars default: - err = mc.exec(context.Background(), "SET "+param+"="+val+"") + err = mc.exec("SET " + param + "=" + val + "") if err != nil { return } @@ -65,36 +64,12 @@ func (mc *mysqlConn) handleParams() (err error) { return } -// Begin implements driver.Conn interface func (mc *mysqlConn) Begin() (driver.Tx, error) { - return mc.ConnBeginTx(context.Background(), driver.TxOptions{}) -} - -// Ping implements drvier.Pinger interface -func (mc *mysqlConn) Ping(ctx context.Context) error { - if mc.netConn == nil { - errLog.Print(ErrInvalidConn) - return driver.ErrBadConn - } - if err := mc.writeCommandPacket(ctx, comPing); err != nil { - errLog.Print(err) - return err - } - - if _, err := mc.readResultOK(); err != nil { - errLog.Print(err) - return err - } - return nil -} - -// ConnBeginTx implements driver.ConnBeginTx interface -func (mc *mysqlConn) ConnBeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { if mc.netConn == nil { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } - err := mc.exec(ctx, "START TRANSACTION") + err := mc.exec("START TRANSACTION") if err == nil { return &mysqlTx{mc}, err } @@ -105,7 +80,7 @@ func (mc *mysqlConn) ConnBeginTx(ctx context.Context, opts driver.TxOptions) (dr func (mc *mysqlConn) Close() (err error) { // Makes Close idempotent if mc.netConn != nil { - err = mc.writeCommandPacket(context.Background(), comQuit) + err = mc.writeCommandPacket(comQuit) } mc.cleanup() @@ -130,16 +105,12 @@ func (mc *mysqlConn) cleanup() { } func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { - return mc.PrepareContext(context.Background(), query) -} - -func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { if mc.netConn == nil { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } // Send command - err := mc.writeCommandPacketStr(ctx, comStmtPrepare, query) + err := mc.writeCommandPacketStr(comStmtPrepare, query) if err != nil { return nil, err } @@ -288,10 +259,6 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin } func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { - return mc.ExecContext(context.Background(), query, args) -} - -func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.Value) (driver.Result, error) { if mc.netConn == nil { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn @@ -311,7 +278,7 @@ func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []drive mc.affectedRows = 0 mc.insertId = 0 - err := mc.exec(ctx, query) + err := mc.exec(query) if err == nil { return &mysqlResult{ affectedRows: int64(mc.affectedRows), @@ -322,9 +289,9 @@ func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []drive } // Internal function to execute commands -func (mc *mysqlConn) exec(ctx context.Context, query string) error { +func (mc *mysqlConn) exec(query string) error { // Send command - err := mc.writeCommandPacketStr(ctx, comQuery, query) + err := mc.writeCommandPacketStr(comQuery, query) if err != nil { return err } @@ -342,13 +309,7 @@ func (mc *mysqlConn) exec(ctx context.Context, query string) error { return err } -// Query implements driver.Queryer interface func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) { - return mc.QueryContext(context.Background(), query, args) -} - -// QueryContext implements driver.QueryerContext interface -func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.Value) (driver.Rows, error) { if mc.netConn == nil { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn @@ -366,7 +327,7 @@ func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driv args = nil } // Send command - err := mc.writeCommandPacketStr(ctx, comQuery, query) + err := mc.writeCommandPacketStr(comQuery, query) if err == nil { // Read Result var resLen int @@ -391,7 +352,7 @@ func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driv // The returned byte slice is only valid until the next read func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { // Send command - if err := mc.writeCommandPacketStr(context.Background(), comQuery, "SELECT @@"+name); err != nil { + if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil { return nil, err } diff --git a/connection_deprecated.go b/connection_ctx.go similarity index 80% rename from connection_deprecated.go rename to connection_ctx.go index eb8c33a55..3cb2724f2 100644 --- a/connection_deprecated.go +++ b/connection_ctx.go @@ -6,11 +6,12 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. -// +build !go1.8 +// +build go1.8 package mysql import ( + "context" "database/sql/driver" "net" "strconv" @@ -43,7 +44,7 @@ func (mc *mysqlConn) handleParams() (err error) { charsets := strings.Split(val, ",") for i := range charsets { // ignore errors here - a charset may not exist - err = mc.exec("SET NAMES " + charsets[i]) + err = mc.exec(context.Background(), "SET NAMES "+charsets[i]) if err == nil { break } @@ -54,7 +55,7 @@ func (mc *mysqlConn) handleParams() (err error) { // System Vars default: - err = mc.exec("SET " + param + "=" + val + "") + err = mc.exec(context.Background(), "SET "+param+"="+val+"") if err != nil { return } @@ -64,12 +65,36 @@ func (mc *mysqlConn) handleParams() (err error) { return } +// Begin implements driver.Conn interface func (mc *mysqlConn) Begin() (driver.Tx, error) { + return mc.ConnBeginTx(context.Background(), driver.TxOptions{}) +} + +// Ping implements drvier.Pinger interface +func (mc *mysqlConn) Ping(ctx context.Context) error { + if mc.netConn == nil { + errLog.Print(ErrInvalidConn) + return driver.ErrBadConn + } + if err := mc.writeCommandPacket(ctx, comPing); err != nil { + errLog.Print(err) + return err + } + + if _, err := mc.readResultOK(); err != nil { + errLog.Print(err) + return err + } + return nil +} + +// ConnBeginTx implements driver.ConnBeginTx interface +func (mc *mysqlConn) ConnBeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { if mc.netConn == nil { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } - err := mc.exec("START TRANSACTION") + err := mc.exec(ctx, "START TRANSACTION") if err == nil { return &mysqlTx{mc}, err } @@ -80,7 +105,7 @@ func (mc *mysqlConn) Begin() (driver.Tx, error) { func (mc *mysqlConn) Close() (err error) { // Makes Close idempotent if mc.netConn != nil { - err = mc.writeCommandPacket(comQuit) + err = mc.writeCommandPacket(context.Background(), comQuit) } mc.cleanup() @@ -105,12 +130,16 @@ func (mc *mysqlConn) cleanup() { } func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { + return mc.PrepareContext(context.Background(), query) +} + +func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { if mc.netConn == nil { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } // Send command - err := mc.writeCommandPacketStr(comStmtPrepare, query) + err := mc.writeCommandPacketStr(ctx, comStmtPrepare, query) if err != nil { return nil, err } @@ -259,6 +288,10 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin } func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { + return mc.ExecContext(context.Background(), query, args) +} + +func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.Value) (driver.Result, error) { if mc.netConn == nil { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn @@ -278,7 +311,7 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err mc.affectedRows = 0 mc.insertId = 0 - err := mc.exec(query) + err := mc.exec(ctx, query) if err == nil { return &mysqlResult{ affectedRows: int64(mc.affectedRows), @@ -289,9 +322,9 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err } // Internal function to execute commands -func (mc *mysqlConn) exec(query string) error { +func (mc *mysqlConn) exec(ctx context.Context, query string) error { // Send command - err := mc.writeCommandPacketStr(comQuery, query) + err := mc.writeCommandPacketStr(ctx, comQuery, query) if err != nil { return err } @@ -309,7 +342,13 @@ func (mc *mysqlConn) exec(query string) error { return err } +// Query implements driver.Queryer interface func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) { + return mc.QueryContext(context.Background(), query, args) +} + +// QueryContext implements driver.QueryerContext interface +func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.Value) (driver.Rows, error) { if mc.netConn == nil { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn @@ -327,7 +366,7 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro args = nil } // Send command - err := mc.writeCommandPacketStr(comQuery, query) + err := mc.writeCommandPacketStr(ctx, comQuery, query) if err == nil { // Read Result var resLen int @@ -352,7 +391,7 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro // The returned byte slice is only valid until the next read func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { // Send command - if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil { + if err := mc.writeCommandPacketStr(context.Background(), comQuery, "SELECT @@"+name); err != nil { return nil, err } diff --git a/driver.go b/driver.go index ae8bb8708..63eb546a6 100644 --- a/driver.go +++ b/driver.go @@ -1,5 +1,3 @@ -// +build go1.8 - // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. // // This Source Code Form is subject to the terms of the Mozilla Public @@ -16,10 +14,12 @@ // db, err := sql.Open("mysql", "user:password@/dbname") // // See https://github.com/go-sql-driver/mysql#usage for details + +// +build !go1.8 + package mysql import ( - "context" "database/sql" "database/sql/driver" "net" @@ -68,7 +68,7 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { mc.netConn, err = dial(mc.cfg.Addr) } else { nd := net.Dialer{Timeout: mc.cfg.Timeout} - mc.netConn, err = nd.DialContext(context.Background(), mc.cfg.Net, mc.cfg.Addr) + mc.netConn, err = nd.Dial(mc.cfg.Net, mc.cfg.Addr) } if err != nil { return nil, err @@ -98,7 +98,7 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { } // Send Client Authentication Packet - if err = mc.writeAuthPacket(context.Background(), cipher); err != nil { + if err = mc.writeAuthPacket(cipher); err != nil { mc.cleanup() return nil, err } @@ -160,7 +160,7 @@ func handleAuthResult(mc *mysqlConn, oldCipher []byte) error { cipher = oldCipher } - if err = mc.writeOldAuthPacket(context.Background(), cipher); err != nil { + if err = mc.writeOldAuthPacket(cipher); err != nil { return err } _, err = mc.readResultOK() @@ -168,12 +168,12 @@ func handleAuthResult(mc *mysqlConn, oldCipher []byte) error { // Retry with clear text password for // http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html // http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html - if err = mc.writeClearAuthPacket(context.Background()); err != nil { + if err = mc.writeClearAuthPacket(); err != nil { return err } _, err = mc.readResultOK() } else if mc.cfg.AllowNativePasswords && err == ErrNativePassword { - if err = mc.writeNativeAuthPacket(context.Background(), cipher); err != nil { + if err = mc.writeNativeAuthPacket(cipher); err != nil { return err } _, err = mc.readResultOK() diff --git a/driver_deprecated.go b/driver_ctx.go similarity index 92% rename from driver_deprecated.go rename to driver_ctx.go index 63eb546a6..ae8bb8708 100644 --- a/driver_deprecated.go +++ b/driver_ctx.go @@ -1,3 +1,5 @@ +// +build go1.8 + // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. // // This Source Code Form is subject to the terms of the Mozilla Public @@ -14,12 +16,10 @@ // db, err := sql.Open("mysql", "user:password@/dbname") // // See https://github.com/go-sql-driver/mysql#usage for details - -// +build !go1.8 - package mysql import ( + "context" "database/sql" "database/sql/driver" "net" @@ -68,7 +68,7 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { mc.netConn, err = dial(mc.cfg.Addr) } else { nd := net.Dialer{Timeout: mc.cfg.Timeout} - mc.netConn, err = nd.Dial(mc.cfg.Net, mc.cfg.Addr) + mc.netConn, err = nd.DialContext(context.Background(), mc.cfg.Net, mc.cfg.Addr) } if err != nil { return nil, err @@ -98,7 +98,7 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { } // Send Client Authentication Packet - if err = mc.writeAuthPacket(cipher); err != nil { + if err = mc.writeAuthPacket(context.Background(), cipher); err != nil { mc.cleanup() return nil, err } @@ -160,7 +160,7 @@ func handleAuthResult(mc *mysqlConn, oldCipher []byte) error { cipher = oldCipher } - if err = mc.writeOldAuthPacket(cipher); err != nil { + if err = mc.writeOldAuthPacket(context.Background(), cipher); err != nil { return err } _, err = mc.readResultOK() @@ -168,12 +168,12 @@ func handleAuthResult(mc *mysqlConn, oldCipher []byte) error { // Retry with clear text password for // http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html // http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html - if err = mc.writeClearAuthPacket(); err != nil { + if err = mc.writeClearAuthPacket(context.Background()); err != nil { return err } _, err = mc.readResultOK() } else if mc.cfg.AllowNativePasswords && err == ErrNativePassword { - if err = mc.writeNativeAuthPacket(cipher); err != nil { + if err = mc.writeNativeAuthPacket(context.Background(), cipher); err != nil { return err } _, err = mc.readResultOK() diff --git a/driver_deprecated_test.go b/driver_deprecated_test.go deleted file mode 100644 index 45b9e7a18..000000000 --- a/driver_deprecated_test.go +++ /dev/null @@ -1,1906 +0,0 @@ -// Go MySQL Driver - A MySQL-Driver for Go's database/sql package -// -// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this file, -// You can obtain one at http://mozilla.org/MPL/2.0/. - -// +build !go1.8 - -package mysql - -import ( - "bytes" - "crypto/tls" - "database/sql" - "database/sql/driver" - "fmt" - "io" - "io/ioutil" - "log" - "net" - "net/url" - "os" - "strings" - "sync" - "sync/atomic" - "testing" - "time" -) - -var ( - user string - pass string - prot string - addr string - dbname string - dsn string - netAddr string - available bool -) - -var ( - tDate = time.Date(2012, 6, 14, 0, 0, 0, 0, time.UTC) - sDate = "2012-06-14" - tDateTime = time.Date(2011, 11, 20, 21, 27, 37, 0, time.UTC) - sDateTime = "2011-11-20 21:27:37" - tDate0 = time.Time{} - sDate0 = "0000-00-00" - sDateTime0 = "0000-00-00 00:00:00" -) - -// See https://github.com/go-sql-driver/mysql/wiki/Testing -func init() { - // get environment variables - env := func(key, defaultValue string) string { - if value := os.Getenv(key); value != "" { - return value - } - return defaultValue - } - user = env("MYSQL_TEST_USER", "root") - pass = env("MYSQL_TEST_PASS", "") - prot = env("MYSQL_TEST_PROT", "tcp") - addr = env("MYSQL_TEST_ADDR", "localhost:3306") - dbname = env("MYSQL_TEST_DBNAME", "gotest") - netAddr = fmt.Sprintf("%s(%s)", prot, addr) - dsn = fmt.Sprintf("%s:%s@%s/%s?timeout=30s&strict=true", user, pass, netAddr, dbname) - c, err := net.Dial(prot, addr) - if err == nil { - available = true - c.Close() - } -} - -type DBTest struct { - *testing.T - db *sql.DB -} - -func runTestsWithMultiStatement(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { - if !available { - t.Skipf("MySQL server not running on %s", netAddr) - } - - dsn += "&multiStatements=true" - var db *sql.DB - if _, err := ParseDSN(dsn); err != errInvalidDSNUnsafeCollation { - db, err = sql.Open("mysql", dsn) - if err != nil { - t.Fatalf("error connecting: %s", err.Error()) - } - defer db.Close() - } - - dbt := &DBTest{t, db} - for _, test := range tests { - test(dbt) - dbt.db.Exec("DROP TABLE IF EXISTS test") - } -} - -func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { - if !available { - t.Skipf("MySQL server not running on %s", netAddr) - } - - db, err := sql.Open("mysql", dsn) - if err != nil { - t.Fatalf("error connecting: %s", err.Error()) - } - defer db.Close() - - db.Exec("DROP TABLE IF EXISTS test") - - dsn2 := dsn + "&interpolateParams=true" - var db2 *sql.DB - if _, err := ParseDSN(dsn2); err != errInvalidDSNUnsafeCollation { - db2, err = sql.Open("mysql", dsn2) - if err != nil { - t.Fatalf("error connecting: %s", err.Error()) - } - defer db2.Close() - } - - dsn3 := dsn + "&multiStatements=true" - var db3 *sql.DB - if _, err := ParseDSN(dsn3); err != errInvalidDSNUnsafeCollation { - db3, err = sql.Open("mysql", dsn3) - if err != nil { - t.Fatalf("error connecting: %s", err.Error()) - } - defer db3.Close() - } - - dbt := &DBTest{t, db} - dbt2 := &DBTest{t, db2} - dbt3 := &DBTest{t, db3} - for _, test := range tests { - test(dbt) - dbt.db.Exec("DROP TABLE IF EXISTS test") - if db2 != nil { - test(dbt2) - dbt2.db.Exec("DROP TABLE IF EXISTS test") - } - if db3 != nil { - test(dbt3) - dbt3.db.Exec("DROP TABLE IF EXISTS test") - } - } -} - -func (dbt *DBTest) fail(method, query string, err error) { - if len(query) > 300 { - query = "[query too large to print]" - } - dbt.Fatalf("error on %s %s: %s", method, query, err.Error()) -} - -func (dbt *DBTest) mustExec(query string, args ...interface{}) (res sql.Result) { - res, err := dbt.db.Exec(query, args...) - if err != nil { - dbt.fail("exec", query, err) - } - return res -} - -func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *sql.Rows) { - rows, err := dbt.db.Query(query, args...) - if err != nil { - dbt.fail("query", query, err) - } - return rows -} - -func TestEmptyQuery(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { - // just a comment, no query - rows := dbt.mustQuery("--") - // will hang before #255 - if rows.Next() { - dbt.Errorf("next on rows must be false") - } - }) -} - -func TestCRUD(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { - // Create Table - dbt.mustExec("CREATE TABLE test (value BOOL)") - - // Test for unexpected data - var out bool - rows := dbt.mustQuery("SELECT * FROM test") - if rows.Next() { - dbt.Error("unexpected data in empty table") - } - - // Create Data - res := dbt.mustExec("INSERT INTO test VALUES (1)") - count, err := res.RowsAffected() - if err != nil { - dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) - } - if count != 1 { - dbt.Fatalf("expected 1 affected row, got %d", count) - } - - id, err := res.LastInsertId() - if err != nil { - dbt.Fatalf("res.LastInsertId() returned error: %s", err.Error()) - } - if id != 0 { - dbt.Fatalf("expected InsertId 0, got %d", id) - } - - // Read - rows = dbt.mustQuery("SELECT value FROM test") - if rows.Next() { - rows.Scan(&out) - if true != out { - dbt.Errorf("true != %t", out) - } - - if rows.Next() { - dbt.Error("unexpected data") - } - } else { - dbt.Error("no data") - } - - // Update - res = dbt.mustExec("UPDATE test SET value = ? WHERE value = ?", false, true) - count, err = res.RowsAffected() - if err != nil { - dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) - } - if count != 1 { - dbt.Fatalf("expected 1 affected row, got %d", count) - } - - // Check Update - rows = dbt.mustQuery("SELECT value FROM test") - if rows.Next() { - rows.Scan(&out) - if false != out { - dbt.Errorf("false != %t", out) - } - - if rows.Next() { - dbt.Error("unexpected data") - } - } else { - dbt.Error("no data") - } - - // Delete - res = dbt.mustExec("DELETE FROM test WHERE value = ?", false) - count, err = res.RowsAffected() - if err != nil { - dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) - } - if count != 1 { - dbt.Fatalf("expected 1 affected row, got %d", count) - } - - // Check for unexpected rows - res = dbt.mustExec("DELETE FROM test") - count, err = res.RowsAffected() - if err != nil { - dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) - } - if count != 0 { - dbt.Fatalf("expected 0 affected row, got %d", count) - } - }) -} - -func TestMultiQuery(t *testing.T) { - runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) { - // Create Table - dbt.mustExec("CREATE TABLE `test` (`id` int(11) NOT NULL, `value` int(11) NOT NULL) ") - - // Create Data - res := dbt.mustExec("INSERT INTO test VALUES (1, 1)") - count, err := res.RowsAffected() - if err != nil { - dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) - } - if count != 1 { - dbt.Fatalf("expected 1 affected row, got %d", count) - } - - // Update - res = dbt.mustExec("UPDATE test SET value = 3 WHERE id = 1; UPDATE test SET value = 4 WHERE id = 1; UPDATE test SET value = 5 WHERE id = 1;") - count, err = res.RowsAffected() - if err != nil { - dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) - } - if count != 1 { - dbt.Fatalf("expected 1 affected row, got %d", count) - } - - // Read - var out int - rows := dbt.mustQuery("SELECT value FROM test WHERE id=1;") - if rows.Next() { - rows.Scan(&out) - if 5 != out { - dbt.Errorf("5 != %d", out) - } - - if rows.Next() { - dbt.Error("unexpected data") - } - } else { - dbt.Error("no data") - } - - }) -} - -func TestInt(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { - types := [5]string{"TINYINT", "SMALLINT", "MEDIUMINT", "INT", "BIGINT"} - in := int64(42) - var out int64 - var rows *sql.Rows - - // SIGNED - for _, v := range types { - dbt.mustExec("CREATE TABLE test (value " + v + ")") - - dbt.mustExec("INSERT INTO test VALUES (?)", in) - - rows = dbt.mustQuery("SELECT value FROM test") - if rows.Next() { - rows.Scan(&out) - if in != out { - dbt.Errorf("%s: %d != %d", v, in, out) - } - } else { - dbt.Errorf("%s: no data", v) - } - - dbt.mustExec("DROP TABLE IF EXISTS test") - } - - // UNSIGNED ZEROFILL - for _, v := range types { - dbt.mustExec("CREATE TABLE test (value " + v + " ZEROFILL)") - - dbt.mustExec("INSERT INTO test VALUES (?)", in) - - rows = dbt.mustQuery("SELECT value FROM test") - if rows.Next() { - rows.Scan(&out) - if in != out { - dbt.Errorf("%s ZEROFILL: %d != %d", v, in, out) - } - } else { - dbt.Errorf("%s ZEROFILL: no data", v) - } - - dbt.mustExec("DROP TABLE IF EXISTS test") - } - }) -} - -func TestFloat32(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { - types := [2]string{"FLOAT", "DOUBLE"} - in := float32(42.23) - var out float32 - var rows *sql.Rows - for _, v := range types { - dbt.mustExec("CREATE TABLE test (value " + v + ")") - dbt.mustExec("INSERT INTO test VALUES (?)", in) - rows = dbt.mustQuery("SELECT value FROM test") - if rows.Next() { - rows.Scan(&out) - if in != out { - dbt.Errorf("%s: %g != %g", v, in, out) - } - } else { - dbt.Errorf("%s: no data", v) - } - dbt.mustExec("DROP TABLE IF EXISTS test") - } - }) -} - -func TestFloat64(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { - types := [2]string{"FLOAT", "DOUBLE"} - var expected float64 = 42.23 - var out float64 - var rows *sql.Rows - for _, v := range types { - dbt.mustExec("CREATE TABLE test (value " + v + ")") - dbt.mustExec("INSERT INTO test VALUES (42.23)") - rows = dbt.mustQuery("SELECT value FROM test") - if rows.Next() { - rows.Scan(&out) - if expected != out { - dbt.Errorf("%s: %g != %g", v, expected, out) - } - } else { - dbt.Errorf("%s: no data", v) - } - dbt.mustExec("DROP TABLE IF EXISTS test") - } - }) -} - -func TestFloat64Placeholder(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { - types := [2]string{"FLOAT", "DOUBLE"} - var expected float64 = 42.23 - var out float64 - var rows *sql.Rows - for _, v := range types { - dbt.mustExec("CREATE TABLE test (id int, value " + v + ")") - dbt.mustExec("INSERT INTO test VALUES (1, 42.23)") - rows = dbt.mustQuery("SELECT value FROM test WHERE id = ?", 1) - if rows.Next() { - rows.Scan(&out) - if expected != out { - dbt.Errorf("%s: %g != %g", v, expected, out) - } - } else { - dbt.Errorf("%s: no data", v) - } - dbt.mustExec("DROP TABLE IF EXISTS test") - } - }) -} - -func TestString(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { - types := [6]string{"CHAR(255)", "VARCHAR(255)", "TINYTEXT", "TEXT", "MEDIUMTEXT", "LONGTEXT"} - in := "κόσμε üöäßñóùéàâÿœ'îë Árvíztűrő いろはにほへとちりぬるを イロハニホヘト דג סקרן чащах น่าฟังเอย" - var out string - var rows *sql.Rows - - for _, v := range types { - dbt.mustExec("CREATE TABLE test (value " + v + ") CHARACTER SET utf8") - - dbt.mustExec("INSERT INTO test VALUES (?)", in) - - rows = dbt.mustQuery("SELECT value FROM test") - if rows.Next() { - rows.Scan(&out) - if in != out { - dbt.Errorf("%s: %s != %s", v, in, out) - } - } else { - dbt.Errorf("%s: no data", v) - } - - dbt.mustExec("DROP TABLE IF EXISTS test") - } - - // BLOB - dbt.mustExec("CREATE TABLE test (id int, value BLOB) CHARACTER SET utf8") - - id := 2 - in = "Lorem ipsum dolor sit amet, consetetur sadipscing elitr, " + - "sed diam nonumy eirmod tempor invidunt ut labore et dolore magna aliquyam erat, " + - "sed diam voluptua. At vero eos et accusam et justo duo dolores et ea rebum. " + - "Stet clita kasd gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet. " + - "Lorem ipsum dolor sit amet, consetetur sadipscing elitr, " + - "sed diam nonumy eirmod tempor invidunt ut labore et dolore magna aliquyam erat, " + - "sed diam voluptua. At vero eos et accusam et justo duo dolores et ea rebum. " + - "Stet clita kasd gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet." - dbt.mustExec("INSERT INTO test VALUES (?, ?)", id, in) - - err := dbt.db.QueryRow("SELECT value FROM test WHERE id = ?", id).Scan(&out) - if err != nil { - dbt.Fatalf("Error on BLOB-Query: %s", err.Error()) - } else if out != in { - dbt.Errorf("BLOB: %s != %s", in, out) - } - }) -} - -type timeTests struct { - dbtype string - tlayout string - tests []timeTest -} - -type timeTest struct { - s string // leading "!": do not use t as value in queries - t time.Time -} - -type timeMode byte - -func (t timeMode) String() string { - switch t { - case binaryString: - return "binary:string" - case binaryTime: - return "binary:time.Time" - case textString: - return "text:string" - } - panic("unsupported timeMode") -} - -func (t timeMode) Binary() bool { - switch t { - case binaryString, binaryTime: - return true - } - return false -} - -const ( - binaryString timeMode = iota - binaryTime - textString -) - -func (t timeTest) genQuery(dbtype string, mode timeMode) string { - var inner string - if mode.Binary() { - inner = "?" - } else { - inner = `"%s"` - } - return `SELECT cast(` + inner + ` as ` + dbtype + `)` -} - -func (t timeTest) run(dbt *DBTest, dbtype, tlayout string, mode timeMode) { - var rows *sql.Rows - query := t.genQuery(dbtype, mode) - switch mode { - case binaryString: - rows = dbt.mustQuery(query, t.s) - case binaryTime: - rows = dbt.mustQuery(query, t.t) - case textString: - query = fmt.Sprintf(query, t.s) - rows = dbt.mustQuery(query) - default: - panic("unsupported mode") - } - defer rows.Close() - var err error - if !rows.Next() { - err = rows.Err() - if err == nil { - err = fmt.Errorf("no data") - } - dbt.Errorf("%s [%s]: %s", dbtype, mode, err) - return - } - var dst interface{} - err = rows.Scan(&dst) - if err != nil { - dbt.Errorf("%s [%s]: %s", dbtype, mode, err) - return - } - switch val := dst.(type) { - case []uint8: - str := string(val) - if str == t.s { - return - } - if mode.Binary() && dbtype == "DATETIME" && len(str) == 26 && str[:19] == t.s { - // a fix mainly for TravisCI: - // accept full microsecond resolution in result for DATETIME columns - // where the binary protocol was used - return - } - dbt.Errorf("%s [%s] to string: expected %q, got %q", - dbtype, mode, - t.s, str, - ) - case time.Time: - if val == t.t { - return - } - dbt.Errorf("%s [%s] to string: expected %q, got %q", - dbtype, mode, - t.s, val.Format(tlayout), - ) - default: - fmt.Printf("%#v\n", []interface{}{dbtype, tlayout, mode, t.s, t.t}) - dbt.Errorf("%s [%s]: unhandled type %T (is '%v')", - dbtype, mode, - val, val, - ) - } -} - -func TestDateTime(t *testing.T) { - afterTime := func(t time.Time, d string) time.Time { - dur, err := time.ParseDuration(d) - if err != nil { - panic(err) - } - return t.Add(dur) - } - // NOTE: MySQL rounds DATETIME(x) up - but that's not included in the tests - format := "2006-01-02 15:04:05.999999" - t0 := time.Time{} - tstr0 := "0000-00-00 00:00:00.000000" - testcases := []timeTests{ - {"DATE", format[:10], []timeTest{ - {t: time.Date(2011, 11, 20, 0, 0, 0, 0, time.UTC)}, - {t: t0, s: tstr0[:10]}, - }}, - {"DATETIME", format[:19], []timeTest{ - {t: time.Date(2011, 11, 20, 21, 27, 37, 0, time.UTC)}, - {t: t0, s: tstr0[:19]}, - }}, - {"DATETIME(0)", format[:21], []timeTest{ - {t: time.Date(2011, 11, 20, 21, 27, 37, 0, time.UTC)}, - {t: t0, s: tstr0[:19]}, - }}, - {"DATETIME(1)", format[:21], []timeTest{ - {t: time.Date(2011, 11, 20, 21, 27, 37, 100000000, time.UTC)}, - {t: t0, s: tstr0[:21]}, - }}, - {"DATETIME(6)", format, []timeTest{ - {t: time.Date(2011, 11, 20, 21, 27, 37, 123456000, time.UTC)}, - {t: t0, s: tstr0}, - }}, - {"TIME", format[11:19], []timeTest{ - {t: afterTime(t0, "12345s")}, - {s: "!-12:34:56"}, - {s: "!-838:59:59"}, - {s: "!838:59:59"}, - {t: t0, s: tstr0[11:19]}, - }}, - {"TIME(0)", format[11:19], []timeTest{ - {t: afterTime(t0, "12345s")}, - {s: "!-12:34:56"}, - {s: "!-838:59:59"}, - {s: "!838:59:59"}, - {t: t0, s: tstr0[11:19]}, - }}, - {"TIME(1)", format[11:21], []timeTest{ - {t: afterTime(t0, "12345600ms")}, - {s: "!-12:34:56.7"}, - {s: "!-838:59:58.9"}, - {s: "!838:59:58.9"}, - {t: t0, s: tstr0[11:21]}, - }}, - {"TIME(6)", format[11:], []timeTest{ - {t: afterTime(t0, "1234567890123000ns")}, - {s: "!-12:34:56.789012"}, - {s: "!-838:59:58.999999"}, - {s: "!838:59:58.999999"}, - {t: t0, s: tstr0[11:]}, - }}, - } - dsns := []string{ - dsn + "&parseTime=true", - dsn + "&parseTime=false", - } - for _, testdsn := range dsns { - runTests(t, testdsn, func(dbt *DBTest) { - microsecsSupported := false - zeroDateSupported := false - var rows *sql.Rows - var err error - rows, err = dbt.db.Query(`SELECT cast("00:00:00.1" as TIME(1)) = "00:00:00.1"`) - if err == nil { - rows.Scan(µsecsSupported) - rows.Close() - } - rows, err = dbt.db.Query(`SELECT cast("0000-00-00" as DATE) = "0000-00-00"`) - if err == nil { - rows.Scan(&zeroDateSupported) - rows.Close() - } - for _, setups := range testcases { - if t := setups.dbtype; !microsecsSupported && t[len(t)-1:] == ")" { - // skip fractional second tests if unsupported by server - continue - } - for _, setup := range setups.tests { - allowBinTime := true - if setup.s == "" { - // fill time string whereever Go can reliable produce it - setup.s = setup.t.Format(setups.tlayout) - } else if setup.s[0] == '!' { - // skip tests using setup.t as source in queries - allowBinTime = false - // fix setup.s - remove the "!" - setup.s = setup.s[1:] - } - if !zeroDateSupported && setup.s == tstr0[:len(setup.s)] { - // skip disallowed 0000-00-00 date - continue - } - setup.run(dbt, setups.dbtype, setups.tlayout, textString) - setup.run(dbt, setups.dbtype, setups.tlayout, binaryString) - if allowBinTime { - setup.run(dbt, setups.dbtype, setups.tlayout, binaryTime) - } - } - } - }) - } -} - -func TestTimestampMicros(t *testing.T) { - format := "2006-01-02 15:04:05.999999" - f0 := format[:19] - f1 := format[:21] - f6 := format[:26] - runTests(t, dsn, func(dbt *DBTest) { - // check if microseconds are supported. - // Do not use timestamp(x) for that check - before 5.5.6, x would mean display width - // and not precision. - // Se last paragraph at http://dev.mysql.com/doc/refman/5.6/en/fractional-seconds.html - microsecsSupported := false - if rows, err := dbt.db.Query(`SELECT cast("00:00:00.1" as TIME(1)) = "00:00:00.1"`); err == nil { - rows.Scan(µsecsSupported) - rows.Close() - } - if !microsecsSupported { - // skip test - return - } - _, err := dbt.db.Exec(` - CREATE TABLE test ( - value0 TIMESTAMP NOT NULL DEFAULT '` + f0 + `', - value1 TIMESTAMP(1) NOT NULL DEFAULT '` + f1 + `', - value6 TIMESTAMP(6) NOT NULL DEFAULT '` + f6 + `' - )`, - ) - if err != nil { - dbt.Error(err) - } - defer dbt.mustExec("DROP TABLE IF EXISTS test") - dbt.mustExec("INSERT INTO test SET value0=?, value1=?, value6=?", f0, f1, f6) - var res0, res1, res6 string - rows := dbt.mustQuery("SELECT * FROM test") - if !rows.Next() { - dbt.Errorf("test contained no selectable values") - } - err = rows.Scan(&res0, &res1, &res6) - if err != nil { - dbt.Error(err) - } - if res0 != f0 { - dbt.Errorf("expected %q, got %q", f0, res0) - } - if res1 != f1 { - dbt.Errorf("expected %q, got %q", f1, res1) - } - if res6 != f6 { - dbt.Errorf("expected %q, got %q", f6, res6) - } - }) -} - -func TestNULL(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { - nullStmt, err := dbt.db.Prepare("SELECT NULL") - if err != nil { - dbt.Fatal(err) - } - defer nullStmt.Close() - - nonNullStmt, err := dbt.db.Prepare("SELECT 1") - if err != nil { - dbt.Fatal(err) - } - defer nonNullStmt.Close() - - // NullBool - var nb sql.NullBool - // Invalid - if err = nullStmt.QueryRow().Scan(&nb); err != nil { - dbt.Fatal(err) - } - if nb.Valid { - dbt.Error("valid NullBool which should be invalid") - } - // Valid - if err = nonNullStmt.QueryRow().Scan(&nb); err != nil { - dbt.Fatal(err) - } - if !nb.Valid { - dbt.Error("invalid NullBool which should be valid") - } else if nb.Bool != true { - dbt.Errorf("Unexpected NullBool value: %t (should be true)", nb.Bool) - } - - // NullFloat64 - var nf sql.NullFloat64 - // Invalid - if err = nullStmt.QueryRow().Scan(&nf); err != nil { - dbt.Fatal(err) - } - if nf.Valid { - dbt.Error("valid NullFloat64 which should be invalid") - } - // Valid - if err = nonNullStmt.QueryRow().Scan(&nf); err != nil { - dbt.Fatal(err) - } - if !nf.Valid { - dbt.Error("invalid NullFloat64 which should be valid") - } else if nf.Float64 != float64(1) { - dbt.Errorf("unexpected NullFloat64 value: %f (should be 1.0)", nf.Float64) - } - - // NullInt64 - var ni sql.NullInt64 - // Invalid - if err = nullStmt.QueryRow().Scan(&ni); err != nil { - dbt.Fatal(err) - } - if ni.Valid { - dbt.Error("valid NullInt64 which should be invalid") - } - // Valid - if err = nonNullStmt.QueryRow().Scan(&ni); err != nil { - dbt.Fatal(err) - } - if !ni.Valid { - dbt.Error("invalid NullInt64 which should be valid") - } else if ni.Int64 != int64(1) { - dbt.Errorf("unexpected NullInt64 value: %d (should be 1)", ni.Int64) - } - - // NullString - var ns sql.NullString - // Invalid - if err = nullStmt.QueryRow().Scan(&ns); err != nil { - dbt.Fatal(err) - } - if ns.Valid { - dbt.Error("valid NullString which should be invalid") - } - // Valid - if err = nonNullStmt.QueryRow().Scan(&ns); err != nil { - dbt.Fatal(err) - } - if !ns.Valid { - dbt.Error("invalid NullString which should be valid") - } else if ns.String != `1` { - dbt.Error("unexpected NullString value:" + ns.String + " (should be `1`)") - } - - // nil-bytes - var b []byte - // Read nil - if err = nullStmt.QueryRow().Scan(&b); err != nil { - dbt.Fatal(err) - } - if b != nil { - dbt.Error("non-nil []byte wich should be nil") - } - // Read non-nil - if err = nonNullStmt.QueryRow().Scan(&b); err != nil { - dbt.Fatal(err) - } - if b == nil { - dbt.Error("nil []byte wich should be non-nil") - } - // Insert nil - b = nil - success := false - if err = dbt.db.QueryRow("SELECT ? IS NULL", b).Scan(&success); err != nil { - dbt.Fatal(err) - } - if !success { - dbt.Error("inserting []byte(nil) as NULL failed") - } - // Check input==output with input==nil - b = nil - if err = dbt.db.QueryRow("SELECT ?", b).Scan(&b); err != nil { - dbt.Fatal(err) - } - if b != nil { - dbt.Error("non-nil echo from nil input") - } - // Check input==output with input!=nil - b = []byte("") - if err = dbt.db.QueryRow("SELECT ?", b).Scan(&b); err != nil { - dbt.Fatal(err) - } - if b == nil { - dbt.Error("nil echo from non-nil input") - } - - // Insert NULL - dbt.mustExec("CREATE TABLE test (dummmy1 int, value int, dummy2 int)") - - dbt.mustExec("INSERT INTO test VALUES (?, ?, ?)", 1, nil, 2) - - var out interface{} - rows := dbt.mustQuery("SELECT * FROM test") - if rows.Next() { - rows.Scan(&out) - if out != nil { - dbt.Errorf("%v != nil", out) - } - } else { - dbt.Error("no data") - } - }) -} - -func TestUint64(t *testing.T) { - const ( - u0 = uint64(0) - uall = ^u0 - uhigh = uall >> 1 - utop = ^uhigh - s0 = int64(0) - sall = ^s0 - shigh = int64(uhigh) - stop = ^shigh - ) - runTests(t, dsn, func(dbt *DBTest) { - stmt, err := dbt.db.Prepare(`SELECT ?, ?, ? ,?, ?, ?, ?, ?`) - if err != nil { - dbt.Fatal(err) - } - defer stmt.Close() - row := stmt.QueryRow( - u0, uhigh, utop, uall, - s0, shigh, stop, sall, - ) - - var ua, ub, uc, ud uint64 - var sa, sb, sc, sd int64 - - err = row.Scan(&ua, &ub, &uc, &ud, &sa, &sb, &sc, &sd) - if err != nil { - dbt.Fatal(err) - } - switch { - case ua != u0, - ub != uhigh, - uc != utop, - ud != uall, - sa != s0, - sb != shigh, - sc != stop, - sd != sall: - dbt.Fatal("unexpected result value") - } - }) -} - -func TestLongData(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { - var maxAllowedPacketSize int - err := dbt.db.QueryRow("select @@max_allowed_packet").Scan(&maxAllowedPacketSize) - if err != nil { - dbt.Fatal(err) - } - maxAllowedPacketSize-- - - // don't get too ambitious - if maxAllowedPacketSize > 1<<25 { - maxAllowedPacketSize = 1 << 25 - } - - dbt.mustExec("CREATE TABLE test (value LONGBLOB)") - - in := strings.Repeat(`a`, maxAllowedPacketSize+1) - var out string - var rows *sql.Rows - - // Long text data - const nonDataQueryLen = 28 // length query w/o value - inS := in[:maxAllowedPacketSize-nonDataQueryLen] - dbt.mustExec("INSERT INTO test VALUES('" + inS + "')") - rows = dbt.mustQuery("SELECT value FROM test") - if rows.Next() { - rows.Scan(&out) - if inS != out { - dbt.Fatalf("LONGBLOB: length in: %d, length out: %d", len(inS), len(out)) - } - if rows.Next() { - dbt.Error("LONGBLOB: unexpexted row") - } - } else { - dbt.Fatalf("LONGBLOB: no data") - } - - // Empty table - dbt.mustExec("TRUNCATE TABLE test") - - // Long binary data - dbt.mustExec("INSERT INTO test VALUES(?)", in) - rows = dbt.mustQuery("SELECT value FROM test WHERE 1=?", 1) - if rows.Next() { - rows.Scan(&out) - if in != out { - dbt.Fatalf("LONGBLOB: length in: %d, length out: %d", len(in), len(out)) - } - if rows.Next() { - dbt.Error("LONGBLOB: unexpexted row") - } - } else { - if err = rows.Err(); err != nil { - dbt.Fatalf("LONGBLOB: no data (err: %s)", err.Error()) - } else { - dbt.Fatal("LONGBLOB: no data (err: )") - } - } - }) -} - -func TestLoadData(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { - verifyLoadDataResult := func() { - rows, err := dbt.db.Query("SELECT * FROM test") - if err != nil { - dbt.Fatal(err.Error()) - } - - i := 0 - values := [4]string{ - "a string", - "a string containing a \t", - "a string containing a \n", - "a string containing both \t\n", - } - - var id int - var value string - - for rows.Next() { - i++ - err = rows.Scan(&id, &value) - if err != nil { - dbt.Fatal(err.Error()) - } - if i != id { - dbt.Fatalf("%d != %d", i, id) - } - if values[i-1] != value { - dbt.Fatalf("%q != %q", values[i-1], value) - } - } - err = rows.Err() - if err != nil { - dbt.Fatal(err.Error()) - } - - if i != 4 { - dbt.Fatalf("rows count mismatch. Got %d, want 4", i) - } - } - file, err := ioutil.TempFile("", "gotest") - defer os.Remove(file.Name()) - if err != nil { - dbt.Fatal(err) - } - file.WriteString("1\ta string\n2\ta string containing a \\t\n3\ta string containing a \\n\n4\ta string containing both \\t\\n\n") - file.Close() - - dbt.db.Exec("DROP TABLE IF EXISTS test") - dbt.mustExec("CREATE TABLE test (id INT NOT NULL PRIMARY KEY, value TEXT NOT NULL) CHARACTER SET utf8") - - // Local File - RegisterLocalFile(file.Name()) - dbt.mustExec(fmt.Sprintf("LOAD DATA LOCAL INFILE %q INTO TABLE test", file.Name())) - verifyLoadDataResult() - // negative test - _, err = dbt.db.Exec("LOAD DATA LOCAL INFILE 'doesnotexist' INTO TABLE test") - if err == nil { - dbt.Fatal("load non-existent file didn't fail") - } else if err.Error() != "local file 'doesnotexist' is not registered" { - dbt.Fatal(err.Error()) - } - - // Empty table - dbt.mustExec("TRUNCATE TABLE test") - - // Reader - RegisterReaderHandler("test", func() io.Reader { - file, err = os.Open(file.Name()) - if err != nil { - dbt.Fatal(err) - } - return file - }) - dbt.mustExec("LOAD DATA LOCAL INFILE 'Reader::test' INTO TABLE test") - verifyLoadDataResult() - // negative test - _, err = dbt.db.Exec("LOAD DATA LOCAL INFILE 'Reader::doesnotexist' INTO TABLE test") - if err == nil { - dbt.Fatal("load non-existent Reader didn't fail") - } else if err.Error() != "Reader 'doesnotexist' is not registered" { - dbt.Fatal(err.Error()) - } - }) -} - -func TestFoundRows(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { - dbt.mustExec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)") - dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)") - - res := dbt.mustExec("UPDATE test SET data = 1 WHERE id = 0") - count, err := res.RowsAffected() - if err != nil { - dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) - } - if count != 2 { - dbt.Fatalf("Expected 2 affected rows, got %d", count) - } - res = dbt.mustExec("UPDATE test SET data = 1 WHERE id = 1") - count, err = res.RowsAffected() - if err != nil { - dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) - } - if count != 2 { - dbt.Fatalf("Expected 2 affected rows, got %d", count) - } - }) - runTests(t, dsn+"&clientFoundRows=true", func(dbt *DBTest) { - dbt.mustExec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)") - dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)") - - res := dbt.mustExec("UPDATE test SET data = 1 WHERE id = 0") - count, err := res.RowsAffected() - if err != nil { - dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) - } - if count != 2 { - dbt.Fatalf("Expected 2 matched rows, got %d", count) - } - res = dbt.mustExec("UPDATE test SET data = 1 WHERE id = 1") - count, err = res.RowsAffected() - if err != nil { - dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) - } - if count != 3 { - dbt.Fatalf("Expected 3 matched rows, got %d", count) - } - }) -} - -func TestStrict(t *testing.T) { - // ALLOW_INVALID_DATES to get rid of stricter modes - we want to test for warnings, not errors - relaxedDsn := dsn + "&sql_mode='ALLOW_INVALID_DATES,NO_AUTO_CREATE_USER'" - // make sure the MySQL version is recent enough with a separate connection - // before running the test - conn, err := MySQLDriver{}.Open(relaxedDsn) - if conn != nil { - conn.Close() - } - if me, ok := err.(*MySQLError); ok && me.Number == 1231 { - // Error 1231: Variable 'sql_mode' can't be set to the value of 'ALLOW_INVALID_DATES' - // => skip test, MySQL server version is too old - return - } - runTests(t, relaxedDsn, func(dbt *DBTest) { - dbt.mustExec("CREATE TABLE test (a TINYINT NOT NULL, b CHAR(4))") - - var queries = [...]struct { - in string - codes []string - }{ - {"DROP TABLE IF EXISTS no_such_table", []string{"1051"}}, - {"INSERT INTO test VALUES(10,'mysql'),(NULL,'test'),(300,'Open Source')", []string{"1265", "1048", "1264", "1265"}}, - } - var err error - - var checkWarnings = func(err error, mode string, idx int) { - if err == nil { - dbt.Errorf("expected STRICT error on query [%s] %s", mode, queries[idx].in) - } - - if warnings, ok := err.(MySQLWarnings); ok { - var codes = make([]string, len(warnings)) - for i := range warnings { - codes[i] = warnings[i].Code - } - if len(codes) != len(queries[idx].codes) { - dbt.Errorf("unexpected STRICT error count on query [%s] %s: Wanted %v, Got %v", mode, queries[idx].in, queries[idx].codes, codes) - } - - for i := range warnings { - if codes[i] != queries[idx].codes[i] { - dbt.Errorf("unexpected STRICT error codes on query [%s] %s: Wanted %v, Got %v", mode, queries[idx].in, queries[idx].codes, codes) - return - } - } - - } else { - dbt.Errorf("unexpected error on query [%s] %s: %s", mode, queries[idx].in, err.Error()) - } - } - - // text protocol - for i := range queries { - _, err = dbt.db.Exec(queries[i].in) - checkWarnings(err, "text", i) - } - - var stmt *sql.Stmt - - // binary protocol - for i := range queries { - stmt, err = dbt.db.Prepare(queries[i].in) - if err != nil { - dbt.Errorf("error on preparing query %s: %s", queries[i].in, err.Error()) - } - - _, err = stmt.Exec() - checkWarnings(err, "binary", i) - - err = stmt.Close() - if err != nil { - dbt.Errorf("error on closing stmt for query %s: %s", queries[i].in, err.Error()) - } - } - }) -} - -func TestTLS(t *testing.T) { - tlsTest := func(dbt *DBTest) { - if err := dbt.db.Ping(); err != nil { - if err == ErrNoTLS { - dbt.Skip("server does not support TLS") - } else { - dbt.Fatalf("error on Ping: %s", err.Error()) - } - } - - rows := dbt.mustQuery("SHOW STATUS LIKE 'Ssl_cipher'") - - var variable, value *sql.RawBytes - for rows.Next() { - if err := rows.Scan(&variable, &value); err != nil { - dbt.Fatal(err.Error()) - } - - if value == nil { - dbt.Fatal("no Cipher") - } - } - } - - runTests(t, dsn+"&tls=skip-verify", tlsTest) - - // Verify that registering / using a custom cfg works - RegisterTLSConfig("custom-skip-verify", &tls.Config{ - InsecureSkipVerify: true, - }) - runTests(t, dsn+"&tls=custom-skip-verify", tlsTest) -} - -func TestReuseClosedConnection(t *testing.T) { - // this test does not use sql.database, it uses the driver directly - if !available { - t.Skipf("MySQL server not running on %s", netAddr) - } - - md := &MySQLDriver{} - conn, err := md.Open(dsn) - if err != nil { - t.Fatalf("error connecting: %s", err.Error()) - } - stmt, err := conn.Prepare("DO 1") - if err != nil { - t.Fatalf("error preparing statement: %s", err.Error()) - } - _, err = stmt.Exec(nil) - if err != nil { - t.Fatalf("error executing statement: %s", err.Error()) - } - err = conn.Close() - if err != nil { - t.Fatalf("error closing connection: %s", err.Error()) - } - - defer func() { - if err := recover(); err != nil { - t.Errorf("panic after reusing a closed connection: %v", err) - } - }() - _, err = stmt.Exec(nil) - if err != nil && err != driver.ErrBadConn { - t.Errorf("unexpected error '%s', expected '%s'", - err.Error(), driver.ErrBadConn.Error()) - } -} - -func TestCharset(t *testing.T) { - if !available { - t.Skipf("MySQL server not running on %s", netAddr) - } - - mustSetCharset := func(charsetParam, expected string) { - runTests(t, dsn+"&"+charsetParam, func(dbt *DBTest) { - rows := dbt.mustQuery("SELECT @@character_set_connection") - defer rows.Close() - - if !rows.Next() { - dbt.Fatalf("error getting connection charset: %s", rows.Err()) - } - - var got string - rows.Scan(&got) - - if got != expected { - dbt.Fatalf("expected connection charset %s but got %s", expected, got) - } - }) - } - - // non utf8 test - mustSetCharset("charset=ascii", "ascii") - - // when the first charset is invalid, use the second - mustSetCharset("charset=none,utf8", "utf8") - - // when the first charset is valid, use it - mustSetCharset("charset=ascii,utf8", "ascii") - mustSetCharset("charset=utf8,ascii", "utf8") -} - -func TestFailingCharset(t *testing.T) { - runTests(t, dsn+"&charset=none", func(dbt *DBTest) { - // run query to really establish connection... - _, err := dbt.db.Exec("SELECT 1") - if err == nil { - dbt.db.Close() - t.Fatalf("connection must not succeed without a valid charset") - } - }) -} - -func TestCollation(t *testing.T) { - if !available { - t.Skipf("MySQL server not running on %s", netAddr) - } - - defaultCollation := "utf8_general_ci" - testCollations := []string{ - "", // do not set - defaultCollation, // driver default - "latin1_general_ci", - "binary", - "utf8_unicode_ci", - "cp1257_bin", - } - - for _, collation := range testCollations { - var expected, tdsn string - if collation != "" { - tdsn = dsn + "&collation=" + collation - expected = collation - } else { - tdsn = dsn - expected = defaultCollation - } - - runTests(t, tdsn, func(dbt *DBTest) { - var got string - if err := dbt.db.QueryRow("SELECT @@collation_connection").Scan(&got); err != nil { - dbt.Fatal(err) - } - - if got != expected { - dbt.Fatalf("expected connection collation %s but got %s", expected, got) - } - }) - } -} - -func TestColumnsWithAlias(t *testing.T) { - runTests(t, dsn+"&columnsWithAlias=true", func(dbt *DBTest) { - rows := dbt.mustQuery("SELECT 1 AS A") - defer rows.Close() - cols, _ := rows.Columns() - if len(cols) != 1 { - t.Fatalf("expected 1 column, got %d", len(cols)) - } - if cols[0] != "A" { - t.Fatalf("expected column name \"A\", got \"%s\"", cols[0]) - } - rows.Close() - - rows = dbt.mustQuery("SELECT * FROM (SELECT 1 AS one) AS A") - cols, _ = rows.Columns() - if len(cols) != 1 { - t.Fatalf("expected 1 column, got %d", len(cols)) - } - if cols[0] != "A.one" { - t.Fatalf("expected column name \"A.one\", got \"%s\"", cols[0]) - } - }) -} - -func TestRawBytesResultExceedsBuffer(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { - // defaultBufSize from buffer.go - expected := strings.Repeat("abc", defaultBufSize) - - rows := dbt.mustQuery("SELECT '" + expected + "'") - defer rows.Close() - if !rows.Next() { - dbt.Error("expected result, got none") - } - var result sql.RawBytes - rows.Scan(&result) - if expected != string(result) { - dbt.Error("result did not match expected value") - } - }) -} - -func TestTimezoneConversion(t *testing.T) { - zones := []string{"UTC", "US/Central", "US/Pacific", "Local"} - - // Regression test for timezone handling - tzTest := func(dbt *DBTest) { - - // Create table - dbt.mustExec("CREATE TABLE test (ts TIMESTAMP)") - - // Insert local time into database (should be converted) - usCentral, _ := time.LoadLocation("US/Central") - reftime := time.Date(2014, 05, 30, 18, 03, 17, 0, time.UTC).In(usCentral) - dbt.mustExec("INSERT INTO test VALUE (?)", reftime) - - // Retrieve time from DB - rows := dbt.mustQuery("SELECT ts FROM test") - if !rows.Next() { - dbt.Fatal("did not get any rows out") - } - - var dbTime time.Time - err := rows.Scan(&dbTime) - if err != nil { - dbt.Fatal("Err", err) - } - - // Check that dates match - if reftime.Unix() != dbTime.Unix() { - dbt.Errorf("times do not match.\n") - dbt.Errorf(" Now(%v)=%v\n", usCentral, reftime) - dbt.Errorf(" Now(UTC)=%v\n", dbTime) - } - } - - for _, tz := range zones { - runTests(t, dsn+"&parseTime=true&loc="+url.QueryEscape(tz), tzTest) - } -} - -// Special cases - -func TestRowsClose(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { - rows, err := dbt.db.Query("SELECT 1") - if err != nil { - dbt.Fatal(err) - } - - err = rows.Close() - if err != nil { - dbt.Fatal(err) - } - - if rows.Next() { - dbt.Fatal("unexpected row after rows.Close()") - } - - err = rows.Err() - if err != nil { - dbt.Fatal(err) - } - }) -} - -// dangling statements -// http://code.google.com/p/go/issues/detail?id=3865 -func TestCloseStmtBeforeRows(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { - stmt, err := dbt.db.Prepare("SELECT 1") - if err != nil { - dbt.Fatal(err) - } - - rows, err := stmt.Query() - if err != nil { - stmt.Close() - dbt.Fatal(err) - } - defer rows.Close() - - err = stmt.Close() - if err != nil { - dbt.Fatal(err) - } - - if !rows.Next() { - dbt.Fatal("getting row failed") - } else { - err = rows.Err() - if err != nil { - dbt.Fatal(err) - } - - var out bool - err = rows.Scan(&out) - if err != nil { - dbt.Fatalf("error on rows.Scan(): %s", err.Error()) - } - if out != true { - dbt.Errorf("true != %t", out) - } - } - }) -} - -// It is valid to have multiple Rows for the same Stmt -// http://code.google.com/p/go/issues/detail?id=3734 -func TestStmtMultiRows(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { - stmt, err := dbt.db.Prepare("SELECT 1 UNION SELECT 0") - if err != nil { - dbt.Fatal(err) - } - - rows1, err := stmt.Query() - if err != nil { - stmt.Close() - dbt.Fatal(err) - } - defer rows1.Close() - - rows2, err := stmt.Query() - if err != nil { - stmt.Close() - dbt.Fatal(err) - } - defer rows2.Close() - - var out bool - - // 1 - if !rows1.Next() { - dbt.Fatal("first rows1.Next failed") - } else { - err = rows1.Err() - if err != nil { - dbt.Fatal(err) - } - - err = rows1.Scan(&out) - if err != nil { - dbt.Fatalf("error on rows.Scan(): %s", err.Error()) - } - if out != true { - dbt.Errorf("true != %t", out) - } - } - - if !rows2.Next() { - dbt.Fatal("first rows2.Next failed") - } else { - err = rows2.Err() - if err != nil { - dbt.Fatal(err) - } - - err = rows2.Scan(&out) - if err != nil { - dbt.Fatalf("error on rows.Scan(): %s", err.Error()) - } - if out != true { - dbt.Errorf("true != %t", out) - } - } - - // 2 - if !rows1.Next() { - dbt.Fatal("second rows1.Next failed") - } else { - err = rows1.Err() - if err != nil { - dbt.Fatal(err) - } - - err = rows1.Scan(&out) - if err != nil { - dbt.Fatalf("error on rows.Scan(): %s", err.Error()) - } - if out != false { - dbt.Errorf("false != %t", out) - } - - if rows1.Next() { - dbt.Fatal("unexpected row on rows1") - } - err = rows1.Close() - if err != nil { - dbt.Fatal(err) - } - } - - if !rows2.Next() { - dbt.Fatal("second rows2.Next failed") - } else { - err = rows2.Err() - if err != nil { - dbt.Fatal(err) - } - - err = rows2.Scan(&out) - if err != nil { - dbt.Fatalf("error on rows.Scan(): %s", err.Error()) - } - if out != false { - dbt.Errorf("false != %t", out) - } - - if rows2.Next() { - dbt.Fatal("unexpected row on rows2") - } - err = rows2.Close() - if err != nil { - dbt.Fatal(err) - } - } - }) -} - -// Regression test for -// * more than 32 NULL parameters (issue 209) -// * more parameters than fit into the buffer (issue 201) -func TestPreparedManyCols(t *testing.T) { - const numParams = defaultBufSize - runTests(t, dsn, func(dbt *DBTest) { - query := "SELECT ?" + strings.Repeat(",?", numParams-1) - stmt, err := dbt.db.Prepare(query) - if err != nil { - dbt.Fatal(err) - } - defer stmt.Close() - // create more parameters than fit into the buffer - // which will take nil-values - params := make([]interface{}, numParams) - rows, err := stmt.Query(params...) - if err != nil { - stmt.Close() - dbt.Fatal(err) - } - defer rows.Close() - }) -} - -func TestConcurrent(t *testing.T) { - if enabled, _ := readBool(os.Getenv("MYSQL_TEST_CONCURRENT")); !enabled { - t.Skip("MYSQL_TEST_CONCURRENT env var not set") - } - - runTests(t, dsn, func(dbt *DBTest) { - var max int - err := dbt.db.QueryRow("SELECT @@max_connections").Scan(&max) - if err != nil { - dbt.Fatalf("%s", err.Error()) - } - dbt.Logf("testing up to %d concurrent connections \r\n", max) - - var remaining, succeeded int32 = int32(max), 0 - - var wg sync.WaitGroup - wg.Add(max) - - var fatalError string - var once sync.Once - fatalf := func(s string, vals ...interface{}) { - once.Do(func() { - fatalError = fmt.Sprintf(s, vals...) - }) - } - - for i := 0; i < max; i++ { - go func(id int) { - defer wg.Done() - - tx, err := dbt.db.Begin() - atomic.AddInt32(&remaining, -1) - - if err != nil { - if err.Error() != "Error 1040: Too many connections" { - fatalf("error on conn %d: %s", id, err.Error()) - } - return - } - - // keep the connection busy until all connections are open - for remaining > 0 { - if _, err = tx.Exec("DO 1"); err != nil { - fatalf("error on conn %d: %s", id, err.Error()) - return - } - } - - if err = tx.Commit(); err != nil { - fatalf("error on conn %d: %s", id, err.Error()) - return - } - - // everything went fine with this connection - atomic.AddInt32(&succeeded, 1) - }(i) - } - - // wait until all conections are open - wg.Wait() - - if fatalError != "" { - dbt.Fatal(fatalError) - } - - dbt.Logf("reached %d concurrent connections\r\n", succeeded) - }) -} - -// Tests custom dial functions -func TestCustomDial(t *testing.T) { - if !available { - t.Skipf("MySQL server not running on %s", netAddr) - } - - // our custom dial function which justs wraps net.Dial here - RegisterDial("mydial", func(addr string) (net.Conn, error) { - return net.Dial(prot, addr) - }) - - db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s&strict=true", user, pass, addr, dbname)) - if err != nil { - t.Fatalf("error connecting: %s", err.Error()) - } - defer db.Close() - - if _, err = db.Exec("DO 1"); err != nil { - t.Fatalf("connection failed: %s", err.Error()) - } -} - -func TestSQLInjection(t *testing.T) { - createTest := func(arg string) func(dbt *DBTest) { - return func(dbt *DBTest) { - dbt.mustExec("CREATE TABLE test (v INTEGER)") - dbt.mustExec("INSERT INTO test VALUES (?)", 1) - - var v int - // NULL can't be equal to anything, the idea here is to inject query so it returns row - // This test verifies that escapeQuotes and escapeBackslash are working properly - err := dbt.db.QueryRow("SELECT v FROM test WHERE NULL = ?", arg).Scan(&v) - if err == sql.ErrNoRows { - return // success, sql injection failed - } else if err == nil { - dbt.Errorf("sql injection successful with arg: %s", arg) - } else { - dbt.Errorf("error running query with arg: %s; err: %s", arg, err.Error()) - } - } - } - - dsns := []string{ - dsn, - dsn + "&sql_mode='NO_BACKSLASH_ESCAPES,NO_AUTO_CREATE_USER'", - } - for _, testdsn := range dsns { - runTests(t, testdsn, createTest("1 OR 1=1")) - runTests(t, testdsn, createTest("' OR '1'='1")) - } -} - -// Test if inserted data is correctly retrieved after being escaped -func TestInsertRetrieveEscapedData(t *testing.T) { - testData := func(dbt *DBTest) { - dbt.mustExec("CREATE TABLE test (v VARCHAR(255))") - - // All sequences that are escaped by escapeQuotes and escapeBackslash - v := "foo \x00\n\r\x1a\"'\\" - dbt.mustExec("INSERT INTO test VALUES (?)", v) - - var out string - err := dbt.db.QueryRow("SELECT v FROM test").Scan(&out) - if err != nil { - dbt.Fatalf("%s", err.Error()) - } - - if out != v { - dbt.Errorf("%q != %q", out, v) - } - } - - dsns := []string{ - dsn, - dsn + "&sql_mode='NO_BACKSLASH_ESCAPES,NO_AUTO_CREATE_USER'", - } - for _, testdsn := range dsns { - runTests(t, testdsn, testData) - } -} - -func TestUnixSocketAuthFail(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { - // Save the current logger so we can restore it. - oldLogger := errLog - - // Set a new logger so we can capture its output. - buffer := bytes.NewBuffer(make([]byte, 0, 64)) - newLogger := log.New(buffer, "prefix: ", 0) - SetLogger(newLogger) - - // Restore the logger. - defer SetLogger(oldLogger) - - // Make a new DSN that uses the MySQL socket file and a bad password, which - // we can make by simply appending any character to the real password. - badPass := pass + "x" - socket := "" - if prot == "unix" { - socket = addr - } else { - // Get socket file from MySQL. - err := dbt.db.QueryRow("SELECT @@socket").Scan(&socket) - if err != nil { - t.Fatalf("error on SELECT @@socket: %s", err.Error()) - } - } - t.Logf("socket: %s", socket) - badDSN := fmt.Sprintf("%s:%s@unix(%s)/%s?timeout=30s&strict=true", user, badPass, socket, dbname) - db, err := sql.Open("mysql", badDSN) - if err != nil { - t.Fatalf("error connecting: %s", err.Error()) - } - defer db.Close() - - // Connect to MySQL for real. This will cause an auth failure. - err = db.Ping() - if err == nil { - t.Error("expected Ping() to return an error") - } - - // The driver should not log anything. - if actual := buffer.String(); actual != "" { - t.Errorf("expected no output, got %q", actual) - } - }) -} - -// See Issue #422 -func TestInterruptBySignal(t *testing.T) { - runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) { - dbt.mustExec(` - DROP PROCEDURE IF EXISTS test_signal; - CREATE PROCEDURE test_signal(ret INT) - BEGIN - SELECT ret; - SIGNAL SQLSTATE - '45001' - SET - MESSAGE_TEXT = "an error", - MYSQL_ERRNO = 45001; - END - `) - defer dbt.mustExec("DROP PROCEDURE test_signal") - - var val int - - // text protocol - rows, err := dbt.db.Query("CALL test_signal(42)") - if err != nil { - dbt.Fatalf("error on text query: %s", err.Error()) - } - for rows.Next() { - if err := rows.Scan(&val); err != nil { - dbt.Error(err) - } else if val != 42 { - dbt.Errorf("expected val to be 42") - } - } - - // binary protocol - rows, err = dbt.db.Query("CALL test_signal(?)", 42) - if err != nil { - dbt.Fatalf("error on binary query: %s", err.Error()) - } - for rows.Next() { - if err := rows.Scan(&val); err != nil { - dbt.Error(err) - } else if val != 42 { - dbt.Errorf("expected val to be 42") - } - } - }) -} diff --git a/driver_test.go b/driver_test.go index 94aa37c21..78e68f5d0 100644 --- a/driver_test.go +++ b/driver_test.go @@ -1,5 +1,3 @@ -// +build go1.8 - // Go MySQL Driver - A MySQL-Driver for Go's database/sql package // // Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. @@ -62,10 +60,10 @@ func init() { user = env("MYSQL_TEST_USER", "root") pass = env("MYSQL_TEST_PASS", "") prot = env("MYSQL_TEST_PROT", "tcp") - addr = env("MYSQL_TEST_ADDR", "127.0.0.1:3306") + addr = env("MYSQL_TEST_ADDR", "localhost:3306") dbname = env("MYSQL_TEST_DBNAME", "gotest") netAddr = fmt.Sprintf("%s(%s)", prot, addr) - dsn = fmt.Sprintf("%s@%s/%s?timeout=30s&strict=true", user, netAddr, dbname) + dsn = fmt.Sprintf("%s:%s@%s/%s?timeout=30s&strict=true", user, pass, netAddr, dbname) c, err := net.Dial(prot, addr) if err == nil { available = true @@ -173,15 +171,6 @@ func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *sql.Rows) return rows } -func TestPing(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { - err := dbt.db.Ping() - if err != nil { - dbt.fail("Ping", "Ping", err) - } - }) -} - func TestEmptyQuery(t *testing.T) { runTests(t, dsn, func(dbt *DBTest) { // just a comment, no query diff --git a/infile.go b/infile.go index bd7bb3abb..38166dd98 100644 --- a/infile.go +++ b/infile.go @@ -1,5 +1,3 @@ -// +build go1.8 - // Go MySQL Driver - A MySQL-Driver for Go's database/sql package // // Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. @@ -8,10 +6,11 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. +// +build !go1.8 + package mysql import ( - "context" "fmt" "io" "os" @@ -96,7 +95,7 @@ func deferredClose(err *error, closer io.Closer) { } } -func (mc *mysqlConn) handleInFileRequest(ctx context.Context, name string) (err error) { +func (mc *mysqlConn) handleInFileRequest(name string) (err error) { var rdr io.Reader var data []byte packetSize := 16 * 1024 // 16KB is small enough for disk readahead and large enough for TCP @@ -156,7 +155,7 @@ func (mc *mysqlConn) handleInFileRequest(ctx context.Context, name string) (err for err == nil { n, err = rdr.Read(data[4:]) if n > 0 { - if ioErr := mc.writePacket(ctx, data[:4+n]); ioErr != nil { + if ioErr := mc.writePacket(data[:4+n]); ioErr != nil { return ioErr } } @@ -170,7 +169,7 @@ func (mc *mysqlConn) handleInFileRequest(ctx context.Context, name string) (err if data == nil { data = make([]byte, 4) } - if ioErr := mc.writePacket(ctx, data[:4]); ioErr != nil { + if ioErr := mc.writePacket(data[:4]); ioErr != nil { return ioErr } diff --git a/infile_deprecated.go b/infile_ctx.go similarity index 94% rename from infile_deprecated.go rename to infile_ctx.go index 38166dd98..bd7bb3abb 100644 --- a/infile_deprecated.go +++ b/infile_ctx.go @@ -1,3 +1,5 @@ +// +build go1.8 + // Go MySQL Driver - A MySQL-Driver for Go's database/sql package // // Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. @@ -6,11 +8,10 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. -// +build !go1.8 - package mysql import ( + "context" "fmt" "io" "os" @@ -95,7 +96,7 @@ func deferredClose(err *error, closer io.Closer) { } } -func (mc *mysqlConn) handleInFileRequest(name string) (err error) { +func (mc *mysqlConn) handleInFileRequest(ctx context.Context, name string) (err error) { var rdr io.Reader var data []byte packetSize := 16 * 1024 // 16KB is small enough for disk readahead and large enough for TCP @@ -155,7 +156,7 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { for err == nil { n, err = rdr.Read(data[4:]) if n > 0 { - if ioErr := mc.writePacket(data[:4+n]); ioErr != nil { + if ioErr := mc.writePacket(ctx, data[:4+n]); ioErr != nil { return ioErr } } @@ -169,7 +170,7 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { if data == nil { data = make([]byte, 4) } - if ioErr := mc.writePacket(data[:4]); ioErr != nil { + if ioErr := mc.writePacket(ctx, data[:4]); ioErr != nil { return ioErr } diff --git a/packets.go b/packets.go index 0bc637e72..5e6a97db3 100644 --- a/packets.go +++ b/packets.go @@ -1,5 +1,3 @@ -// +build go1.8 - // Go MySQL Driver - A MySQL-Driver for Go's database/sql package // // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. @@ -8,11 +6,12 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. +// +build !go1.8 + package mysql import ( "bytes" - "context" "crypto/tls" "database/sql/driver" "encoding/binary" @@ -86,15 +85,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { } // Write packet buffer 'data' -func (mc *mysqlConn) writePacket(ctx context.Context, data []byte) error { - if ctx == nil { - panic("context cannot be nil") - } - ctxDeadline, isCtxDeadlineSet := ctx.Deadline() - if isCtxDeadlineSet && !ctxDeadline.After(time.Now()) { - return errors.New("timeout") - } - +func (mc *mysqlConn) writePacket(data []byte) error { pktLen := len(data) - 4 if pktLen > mc.maxAllowedPacket { @@ -117,16 +108,8 @@ func (mc *mysqlConn) writePacket(ctx context.Context, data []byte) error { data[3] = mc.sequence // Write packet - var timeNow = time.Now() - var deadline = timeNow if mc.writeTimeout > 0 { - deadline = timeNow.Add(mc.writeTimeout) - if isCtxDeadlineSet && deadline.After(ctxDeadline) { - deadline = ctxDeadline - } - } - if deadline.After(timeNow) { - if err := mc.netConn.SetWriteDeadline(deadline); err != nil { + if err := mc.netConn.SetWriteDeadline(time.Now().Add(mc.writeTimeout)); err != nil { return err } } @@ -242,7 +225,7 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) { // Client Authentication Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse -func (mc *mysqlConn) writeAuthPacket(ctx context.Context, cipher []byte) error { +func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { // Adjust client flags based on server support clientFlags := clientProtocol41 | clientSecureConn | @@ -311,7 +294,7 @@ func (mc *mysqlConn) writeAuthPacket(ctx context.Context, cipher []byte) error { // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest if mc.cfg.tls != nil { // Send TLS / SSL request packet - if err := mc.writePacket(ctx, data[:(4+4+1+23)+4]); err != nil { + if err := mc.writePacket(data[:(4+4+1+23)+4]); err != nil { return err } @@ -353,12 +336,12 @@ func (mc *mysqlConn) writeAuthPacket(ctx context.Context, cipher []byte) error { data[pos] = 0x00 // Send Auth packet - return mc.writePacket(ctx, data) + return mc.writePacket(data) } // Client old authentication packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse -func (mc *mysqlConn) writeOldAuthPacket(ctx context.Context, cipher []byte) error { +func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error { // User password scrambleBuff := scrambleOldPassword(cipher, []byte(mc.cfg.Passwd)) @@ -375,12 +358,12 @@ func (mc *mysqlConn) writeOldAuthPacket(ctx context.Context, cipher []byte) erro copy(data[4:], scrambleBuff) data[4+pktLen-1] = 0x00 - return mc.writePacket(ctx, data) + return mc.writePacket(data) } // Client clear text authentication packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse -func (mc *mysqlConn) writeClearAuthPacket(ctx context.Context) error { +func (mc *mysqlConn) writeClearAuthPacket() error { // Calculate the packet length and add a tailing 0 pktLen := len(mc.cfg.Passwd) + 1 data := mc.buf.takeSmallBuffer(4 + pktLen) @@ -394,12 +377,12 @@ func (mc *mysqlConn) writeClearAuthPacket(ctx context.Context) error { copy(data[4:], mc.cfg.Passwd) data[4+pktLen-1] = 0x00 - return mc.writePacket(ctx, data) + return mc.writePacket(data) } // Native password authentication method // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse -func (mc *mysqlConn) writeNativeAuthPacket(ctx context.Context, cipher []byte) error { +func (mc *mysqlConn) writeNativeAuthPacket(cipher []byte) error { scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.Passwd)) // Calculate the packet length and add a tailing 0 @@ -414,14 +397,14 @@ func (mc *mysqlConn) writeNativeAuthPacket(ctx context.Context, cipher []byte) e // Add the scramble copy(data[4:], scrambleBuff) - return mc.writePacket(ctx, data) + return mc.writePacket(data) } /****************************************************************************** * Command Packets * ******************************************************************************/ -func (mc *mysqlConn) writeCommandPacket(ctx context.Context, command byte) error { +func (mc *mysqlConn) writeCommandPacket(command byte) error { // Reset Packet Sequence mc.sequence = 0 @@ -436,10 +419,10 @@ func (mc *mysqlConn) writeCommandPacket(ctx context.Context, command byte) error data[4] = command // Send CMD packet - return mc.writePacket(ctx, data) + return mc.writePacket(data) } -func (mc *mysqlConn) writeCommandPacketStr(ctx context.Context, command byte, arg string) error { +func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { // Reset Packet Sequence mc.sequence = 0 @@ -458,10 +441,10 @@ func (mc *mysqlConn) writeCommandPacketStr(ctx context.Context, command byte, ar copy(data[5:], arg) // Send CMD packet - return mc.writePacket(ctx, data) + return mc.writePacket(data) } -func (mc *mysqlConn) writeCommandPacketUint32(ctx context.Context, command byte, arg uint32) error { +func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { // Reset Packet Sequence mc.sequence = 0 @@ -482,7 +465,7 @@ func (mc *mysqlConn) writeCommandPacketUint32(ctx context.Context, command byte, data[8] = byte(arg >> 24) // Send CMD packet - return mc.writePacket(ctx, data) + return mc.writePacket(data) } /****************************************************************************** @@ -543,7 +526,7 @@ func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) { return 0, mc.handleErrorPacket(data) case iLocalInFile: - return 0, mc.handleInFileRequest(context.Background(), string(data[1:])) + return 0, mc.handleInFileRequest(string(data[1:])) } // column count @@ -841,7 +824,7 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) { } // http://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html -func (stmt *mysqlStmt) writeCommandLongData(ctx context.Context, paramID int, arg []byte) error { +func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { maxLen := stmt.mc.maxAllowedPacket - 1 pktLen := maxLen @@ -878,7 +861,7 @@ func (stmt *mysqlStmt) writeCommandLongData(ctx context.Context, paramID int, ar data[10] = byte(paramID >> 8) // Send CMD packet - err := stmt.mc.writePacket(ctx, data[:4+pktLen]) + err := stmt.mc.writePacket(data[:4+pktLen]) if err == nil { data = data[pktLen-dataOffset:] continue @@ -894,7 +877,7 @@ func (stmt *mysqlStmt) writeCommandLongData(ctx context.Context, paramID int, ar // Execute Prepared Statement // http://dev.mysql.com/doc/internals/en/com-stmt-execute.html -func (stmt *mysqlStmt) writeExecutePacket(ctx context.Context, args []driver.Value) error { +func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { if len(args) != stmt.paramCount { return fmt.Errorf( "argument count mismatch (got: %d; has: %d)", @@ -1039,7 +1022,7 @@ func (stmt *mysqlStmt) writeExecutePacket(ctx context.Context, args []driver.Val ) paramValues = append(paramValues, v...) } else { - if err := stmt.writeCommandLongData(ctx, i, v); err != nil { + if err := stmt.writeCommandLongData(i, v); err != nil { return err } } @@ -1061,7 +1044,7 @@ func (stmt *mysqlStmt) writeExecutePacket(ctx context.Context, args []driver.Val ) paramValues = append(paramValues, v...) } else { - if err := stmt.writeCommandLongData(ctx, i, []byte(v)); err != nil { + if err := stmt.writeCommandLongData(i, []byte(v)); err != nil { return err } } @@ -1098,7 +1081,7 @@ func (stmt *mysqlStmt) writeExecutePacket(ctx context.Context, args []driver.Val data = data[:pos] } - return mc.writePacket(ctx, data) + return mc.writePacket(data) } func (mc *mysqlConn) discardResults() error { diff --git a/packets_deprecated.go b/packets_ctx.go similarity index 93% rename from packets_deprecated.go rename to packets_ctx.go index 5e6a97db3..0bc637e72 100644 --- a/packets_deprecated.go +++ b/packets_ctx.go @@ -1,3 +1,5 @@ +// +build go1.8 + // Go MySQL Driver - A MySQL-Driver for Go's database/sql package // // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. @@ -6,12 +8,11 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. -// +build !go1.8 - package mysql import ( "bytes" + "context" "crypto/tls" "database/sql/driver" "encoding/binary" @@ -85,7 +86,15 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { } // Write packet buffer 'data' -func (mc *mysqlConn) writePacket(data []byte) error { +func (mc *mysqlConn) writePacket(ctx context.Context, data []byte) error { + if ctx == nil { + panic("context cannot be nil") + } + ctxDeadline, isCtxDeadlineSet := ctx.Deadline() + if isCtxDeadlineSet && !ctxDeadline.After(time.Now()) { + return errors.New("timeout") + } + pktLen := len(data) - 4 if pktLen > mc.maxAllowedPacket { @@ -108,8 +117,16 @@ func (mc *mysqlConn) writePacket(data []byte) error { data[3] = mc.sequence // Write packet + var timeNow = time.Now() + var deadline = timeNow if mc.writeTimeout > 0 { - if err := mc.netConn.SetWriteDeadline(time.Now().Add(mc.writeTimeout)); err != nil { + deadline = timeNow.Add(mc.writeTimeout) + if isCtxDeadlineSet && deadline.After(ctxDeadline) { + deadline = ctxDeadline + } + } + if deadline.After(timeNow) { + if err := mc.netConn.SetWriteDeadline(deadline); err != nil { return err } } @@ -225,7 +242,7 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) { // Client Authentication Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse -func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { +func (mc *mysqlConn) writeAuthPacket(ctx context.Context, cipher []byte) error { // Adjust client flags based on server support clientFlags := clientProtocol41 | clientSecureConn | @@ -294,7 +311,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest if mc.cfg.tls != nil { // Send TLS / SSL request packet - if err := mc.writePacket(data[:(4+4+1+23)+4]); err != nil { + if err := mc.writePacket(ctx, data[:(4+4+1+23)+4]); err != nil { return err } @@ -336,12 +353,12 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { data[pos] = 0x00 // Send Auth packet - return mc.writePacket(data) + return mc.writePacket(ctx, data) } // Client old authentication packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse -func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error { +func (mc *mysqlConn) writeOldAuthPacket(ctx context.Context, cipher []byte) error { // User password scrambleBuff := scrambleOldPassword(cipher, []byte(mc.cfg.Passwd)) @@ -358,12 +375,12 @@ func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error { copy(data[4:], scrambleBuff) data[4+pktLen-1] = 0x00 - return mc.writePacket(data) + return mc.writePacket(ctx, data) } // Client clear text authentication packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse -func (mc *mysqlConn) writeClearAuthPacket() error { +func (mc *mysqlConn) writeClearAuthPacket(ctx context.Context) error { // Calculate the packet length and add a tailing 0 pktLen := len(mc.cfg.Passwd) + 1 data := mc.buf.takeSmallBuffer(4 + pktLen) @@ -377,12 +394,12 @@ func (mc *mysqlConn) writeClearAuthPacket() error { copy(data[4:], mc.cfg.Passwd) data[4+pktLen-1] = 0x00 - return mc.writePacket(data) + return mc.writePacket(ctx, data) } // Native password authentication method // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse -func (mc *mysqlConn) writeNativeAuthPacket(cipher []byte) error { +func (mc *mysqlConn) writeNativeAuthPacket(ctx context.Context, cipher []byte) error { scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.Passwd)) // Calculate the packet length and add a tailing 0 @@ -397,14 +414,14 @@ func (mc *mysqlConn) writeNativeAuthPacket(cipher []byte) error { // Add the scramble copy(data[4:], scrambleBuff) - return mc.writePacket(data) + return mc.writePacket(ctx, data) } /****************************************************************************** * Command Packets * ******************************************************************************/ -func (mc *mysqlConn) writeCommandPacket(command byte) error { +func (mc *mysqlConn) writeCommandPacket(ctx context.Context, command byte) error { // Reset Packet Sequence mc.sequence = 0 @@ -419,10 +436,10 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error { data[4] = command // Send CMD packet - return mc.writePacket(data) + return mc.writePacket(ctx, data) } -func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { +func (mc *mysqlConn) writeCommandPacketStr(ctx context.Context, command byte, arg string) error { // Reset Packet Sequence mc.sequence = 0 @@ -441,10 +458,10 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { copy(data[5:], arg) // Send CMD packet - return mc.writePacket(data) + return mc.writePacket(ctx, data) } -func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { +func (mc *mysqlConn) writeCommandPacketUint32(ctx context.Context, command byte, arg uint32) error { // Reset Packet Sequence mc.sequence = 0 @@ -465,7 +482,7 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { data[8] = byte(arg >> 24) // Send CMD packet - return mc.writePacket(data) + return mc.writePacket(ctx, data) } /****************************************************************************** @@ -526,7 +543,7 @@ func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) { return 0, mc.handleErrorPacket(data) case iLocalInFile: - return 0, mc.handleInFileRequest(string(data[1:])) + return 0, mc.handleInFileRequest(context.Background(), string(data[1:])) } // column count @@ -824,7 +841,7 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) { } // http://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html -func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { +func (stmt *mysqlStmt) writeCommandLongData(ctx context.Context, paramID int, arg []byte) error { maxLen := stmt.mc.maxAllowedPacket - 1 pktLen := maxLen @@ -861,7 +878,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { data[10] = byte(paramID >> 8) // Send CMD packet - err := stmt.mc.writePacket(data[:4+pktLen]) + err := stmt.mc.writePacket(ctx, data[:4+pktLen]) if err == nil { data = data[pktLen-dataOffset:] continue @@ -877,7 +894,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { // Execute Prepared Statement // http://dev.mysql.com/doc/internals/en/com-stmt-execute.html -func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { +func (stmt *mysqlStmt) writeExecutePacket(ctx context.Context, args []driver.Value) error { if len(args) != stmt.paramCount { return fmt.Errorf( "argument count mismatch (got: %d; has: %d)", @@ -1022,7 +1039,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { ) paramValues = append(paramValues, v...) } else { - if err := stmt.writeCommandLongData(i, v); err != nil { + if err := stmt.writeCommandLongData(ctx, i, v); err != nil { return err } } @@ -1044,7 +1061,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { ) paramValues = append(paramValues, v...) } else { - if err := stmt.writeCommandLongData(i, []byte(v)); err != nil { + if err := stmt.writeCommandLongData(ctx, i, []byte(v)); err != nil { return err } } @@ -1081,7 +1098,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { data = data[:pos] } - return mc.writePacket(data) + return mc.writePacket(ctx, data) } func (mc *mysqlConn) discardResults() error { diff --git a/statement.go b/statement.go index 3ed483899..53d18575b 100644 --- a/statement.go +++ b/statement.go @@ -1,5 +1,3 @@ -// +build go1.8 - // Go MySQL Driver - A MySQL-Driver for Go's database/sql package // // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. @@ -8,10 +6,11 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. +// +build !go1.8 + package mysql import ( - "context" "database/sql/driver" "fmt" "reflect" @@ -25,7 +24,6 @@ type mysqlStmt struct { columns []mysqlField // cached from the first query } -// Close implements driver.Conn interface func (stmt *mysqlStmt) Close() error { if stmt.mc == nil || stmt.mc.netConn == nil { // driver.Stmt.Close can be called more than once, thus this function @@ -35,12 +33,11 @@ func (stmt *mysqlStmt) Close() error { return driver.ErrBadConn } - err := stmt.mc.writeCommandPacketUint32(context.Background(), comStmtClose, stmt.id) + err := stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id) stmt.mc = nil return err } -// NumInput implements driver.Stmt interface func (stmt *mysqlStmt) NumInput() int { return stmt.paramCount } @@ -49,20 +46,13 @@ func (stmt *mysqlStmt) ColumnConverter(idx int) driver.ValueConverter { return converter{} } -// Exec implements driver.Stmt interface func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { - return stmt.ExecContext(context.Background(), args) -} - -// ExecContent implements driver.StmtExecContext interface -func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.Value) (driver.Result, error) { - if stmt.mc.netConn == nil { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } // Send command - err := stmt.writeExecutePacket(ctx, args) + err := stmt.writeExecutePacket(args) if err != nil { return nil, err } @@ -96,19 +86,13 @@ func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.Value) (dr return nil, err } -// Query implements driver.Stmt interface func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { - return stmt.QueryContext(context.Background(), args) -} - -// QueryContext implements driver.StmtQueryContext interface -func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.Value) (driver.Rows, error) { if stmt.mc.netConn == nil { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } // Send command - err := stmt.writeExecutePacket(ctx, args) + err := stmt.writeExecutePacket(args) if err != nil { return nil, err } diff --git a/statement_deprecated.go b/statement_ctx.go similarity index 79% rename from statement_deprecated.go rename to statement_ctx.go index 53d18575b..030dd5eb5 100644 --- a/statement_deprecated.go +++ b/statement_ctx.go @@ -6,11 +6,12 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. -// +build !go1.8 +// +build go1.8 package mysql import ( + "context" "database/sql/driver" "fmt" "reflect" @@ -24,6 +25,7 @@ type mysqlStmt struct { columns []mysqlField // cached from the first query } +// Close implements driver.Conn interface func (stmt *mysqlStmt) Close() error { if stmt.mc == nil || stmt.mc.netConn == nil { // driver.Stmt.Close can be called more than once, thus this function @@ -33,11 +35,12 @@ func (stmt *mysqlStmt) Close() error { return driver.ErrBadConn } - err := stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id) + err := stmt.mc.writeCommandPacketUint32(context.Background(), comStmtClose, stmt.id) stmt.mc = nil return err } +// NumInput implements driver.Stmt interface func (stmt *mysqlStmt) NumInput() int { return stmt.paramCount } @@ -46,13 +49,20 @@ func (stmt *mysqlStmt) ColumnConverter(idx int) driver.ValueConverter { return converter{} } +// Exec implements driver.Stmt interface func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { + return stmt.ExecContext(context.Background(), args) +} + +// ExecContent implements driver.StmtExecContext interface +func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.Value) (driver.Result, error) { + if stmt.mc.netConn == nil { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } // Send command - err := stmt.writeExecutePacket(args) + err := stmt.writeExecutePacket(ctx, args) if err != nil { return nil, err } @@ -86,13 +96,19 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { return nil, err } +// Query implements driver.Stmt interface func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { + return stmt.QueryContext(context.Background(), args) +} + +// QueryContext implements driver.StmtQueryContext interface +func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.Value) (driver.Rows, error) { if stmt.mc.netConn == nil { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } // Send command - err := stmt.writeExecutePacket(args) + err := stmt.writeExecutePacket(ctx, args) if err != nil { return nil, err } diff --git a/transaction.go b/transaction.go index 0f4d4faed..95d374d75 100644 --- a/transaction.go +++ b/transaction.go @@ -1,5 +1,3 @@ -// +build go1.8 - // Go MySQL Driver - A MySQL-Driver for Go's database/sql package // // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. @@ -8,30 +6,28 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. -package mysql +// +build !go1.8 -import "context" +package mysql type mysqlTx struct { mc *mysqlConn } -// Commit implements driver.Tx interface func (tx *mysqlTx) Commit() (err error) { if tx.mc == nil || tx.mc.netConn == nil { return ErrInvalidConn } - err = tx.mc.exec(context.Background(), "COMMIT") + err = tx.mc.exec("COMMIT") tx.mc = nil return } -// Rollback implements driver.Tx interface func (tx *mysqlTx) Rollback() (err error) { if tx.mc == nil || tx.mc.netConn == nil { return ErrInvalidConn } - err = tx.mc.exec(context.Background(), "ROLLBACK") + err = tx.mc.exec("ROLLBACK") tx.mc = nil return } diff --git a/transaction_deprecated.go b/transaction_ctx.go similarity index 75% rename from transaction_deprecated.go rename to transaction_ctx.go index 95d374d75..0f4d4faed 100644 --- a/transaction_deprecated.go +++ b/transaction_ctx.go @@ -1,3 +1,5 @@ +// +build go1.8 + // Go MySQL Driver - A MySQL-Driver for Go's database/sql package // // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. @@ -6,28 +8,30 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. -// +build !go1.8 - package mysql +import "context" + type mysqlTx struct { mc *mysqlConn } +// Commit implements driver.Tx interface func (tx *mysqlTx) Commit() (err error) { if tx.mc == nil || tx.mc.netConn == nil { return ErrInvalidConn } - err = tx.mc.exec("COMMIT") + err = tx.mc.exec(context.Background(), "COMMIT") tx.mc = nil return } +// Rollback implements driver.Tx interface func (tx *mysqlTx) Rollback() (err error) { if tx.mc == nil || tx.mc.netConn == nil { return ErrInvalidConn } - err = tx.mc.exec("ROLLBACK") + err = tx.mc.exec(context.Background(), "ROLLBACK") tx.mc = nil return } From 4921cf53957fb97c2d8678d2e0c2314ba26e28ce Mon Sep 17 00:00:00 2001 From: oscarzhao Date: Fri, 17 Mar 2017 16:33:51 +0800 Subject: [PATCH 7/9] move build tag to top of files --- connection.go | 4 ++-- connection_ctx.go | 4 ++-- driver.go | 4 ++-- driver_ctx.go | 1 + infile.go | 4 ++-- packets.go | 4 ++-- statement.go | 4 ++-- statement_ctx.go | 4 ++-- transaction.go | 4 ++-- 9 files changed, 17 insertions(+), 16 deletions(-) diff --git a/connection.go b/connection.go index eb8c33a55..816586048 100644 --- a/connection.go +++ b/connection.go @@ -1,3 +1,5 @@ +// +build !go1.8 + // Go MySQL Driver - A MySQL-Driver for Go's database/sql package // // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. @@ -6,8 +8,6 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. -// +build !go1.8 - package mysql import ( diff --git a/connection_ctx.go b/connection_ctx.go index 3cb2724f2..b1fddab58 100644 --- a/connection_ctx.go +++ b/connection_ctx.go @@ -1,3 +1,5 @@ +// +build go1.8 + // Go MySQL Driver - A MySQL-Driver for Go's database/sql package // // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. @@ -6,8 +8,6 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. -// +build go1.8 - package mysql import ( diff --git a/driver.go b/driver.go index 63eb546a6..d94b4f0f3 100644 --- a/driver.go +++ b/driver.go @@ -1,3 +1,5 @@ +// +build !go1.8 + // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. // // This Source Code Form is subject to the terms of the Mozilla Public @@ -15,8 +17,6 @@ // // See https://github.com/go-sql-driver/mysql#usage for details -// +build !go1.8 - package mysql import ( diff --git a/driver_ctx.go b/driver_ctx.go index ae8bb8708..9a9aecd7a 100644 --- a/driver_ctx.go +++ b/driver_ctx.go @@ -16,6 +16,7 @@ // db, err := sql.Open("mysql", "user:password@/dbname") // // See https://github.com/go-sql-driver/mysql#usage for details + package mysql import ( diff --git a/infile.go b/infile.go index 38166dd98..60f06c13f 100644 --- a/infile.go +++ b/infile.go @@ -1,3 +1,5 @@ +// +build !go1.8 + // Go MySQL Driver - A MySQL-Driver for Go's database/sql package // // Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. @@ -6,8 +8,6 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. -// +build !go1.8 - package mysql import ( diff --git a/packets.go b/packets.go index 5e6a97db3..444fb9d68 100644 --- a/packets.go +++ b/packets.go @@ -1,3 +1,5 @@ +// +build !go1.8 + // Go MySQL Driver - A MySQL-Driver for Go's database/sql package // // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. @@ -6,8 +8,6 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. -// +build !go1.8 - package mysql import ( diff --git a/statement.go b/statement.go index 53d18575b..60c708861 100644 --- a/statement.go +++ b/statement.go @@ -1,3 +1,5 @@ +// +build !go1.8 + // Go MySQL Driver - A MySQL-Driver for Go's database/sql package // // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. @@ -6,8 +8,6 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. -// +build !go1.8 - package mysql import ( diff --git a/statement_ctx.go b/statement_ctx.go index 030dd5eb5..3ed483899 100644 --- a/statement_ctx.go +++ b/statement_ctx.go @@ -1,3 +1,5 @@ +// +build go1.8 + // Go MySQL Driver - A MySQL-Driver for Go's database/sql package // // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. @@ -6,8 +8,6 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. -// +build go1.8 - package mysql import ( diff --git a/transaction.go b/transaction.go index 95d374d75..c47b52b45 100644 --- a/transaction.go +++ b/transaction.go @@ -1,3 +1,5 @@ +// +build !go1.8 + // Go MySQL Driver - A MySQL-Driver for Go's database/sql package // // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. @@ -6,8 +8,6 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. -// +build !go1.8 - package mysql type mysqlTx struct { From 33b837adb5c7a8c8d5e50077510cc76f0709318e Mon Sep 17 00:00:00 2001 From: oscarzhao Date: Fri, 17 Mar 2017 17:13:18 +0800 Subject: [PATCH 8/9] Implements driver.Pinger interface --- driver_test.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/driver_test.go b/driver_test.go index 78e68f5d0..d0772a443 100644 --- a/driver_test.go +++ b/driver_test.go @@ -182,6 +182,14 @@ func TestEmptyQuery(t *testing.T) { }) } +func (dbt *DBTest) TestPing(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + if err := dbt.db.Ping(); err != nil { + dbt.fail("Ping", "Ping", err) + } + }) +} + func TestCRUD(t *testing.T) { runTests(t, dsn, func(dbt *DBTest) { // Create Table From b27039834eb0dab217ea4c6b5cc6f2168ce6be18 Mon Sep 17 00:00:00 2001 From: oscarzhao Date: Tue, 21 Mar 2017 00:14:54 +0800 Subject: [PATCH 9/9] bugfix for implementing ConnBeginTx interface --- connection_ctx.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/connection_ctx.go b/connection_ctx.go index b1fddab58..f92cf7bcb 100644 --- a/connection_ctx.go +++ b/connection_ctx.go @@ -67,7 +67,7 @@ func (mc *mysqlConn) handleParams() (err error) { // Begin implements driver.Conn interface func (mc *mysqlConn) Begin() (driver.Tx, error) { - return mc.ConnBeginTx(context.Background(), driver.TxOptions{}) + return mc.BeginTx(context.Background(), driver.TxOptions{}) } // Ping implements drvier.Pinger interface @@ -88,8 +88,8 @@ func (mc *mysqlConn) Ping(ctx context.Context) error { return nil } -// ConnBeginTx implements driver.ConnBeginTx interface -func (mc *mysqlConn) ConnBeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { +// BeginTx implements driver.ConnBeginTx interface +func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { if mc.netConn == nil { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn