diff --git a/README.md b/README.md index f44cc6f9b..fb694706b 100644 --- a/README.md +++ b/README.md @@ -110,6 +110,7 @@ Possible Parameters are: * `allowAllFiles`: `allowAllFiles=true` disables the file Whitelist for `LOAD DATA LOCAL INFILE` and allows *all* files. *Might be insecure!* * `parseTime`: `parseTime=true` changes the output type of `DATE` and `DATETIME` values to `time.Time` instead of `[]byte` / `string` * `loc`: Sets the location for time.Time values (when using `parseTime=true`). The default is `UTC`. *"Local"* sets the system's location. See [time.LoadLocation](http://golang.org/pkg/time/#LoadLocation) for details. + * `strict`: Enable strict mode. MySQL warnings are treated as errors. All other parameters are interpreted as system variables: * `autocommit`: *"SET autocommit=`value`"* @@ -154,7 +155,8 @@ See also the [godoc of Go-MySQL-Driver](http://godoc.org/github.com/go-sql-drive ### `time.Time` support The default internal output type of MySQL `DATE` and `DATETIME` values is `[]byte` which allows you to scan the value into a `[]byte`, `string` or `sql.RawBytes` variable in your programm. -However, many want to scan MySQL `DATE` and `DATETIME` values into `time.Time` variables, which is the logical opposite in Go to `DATE` and `DATETIME` in MySQL. You can do that by changing the internal output type from `[]byte` to `time.Time` with the DSN parameter `parseTime=true`. You can set the default [`time.Time` location](http://golang.org/pkg/time/#Location) with the `loc` DSN parameter. +However, many want to scan MySQL `DATE` and `DATETIME` values into `time.Time` variables, which is the logical opposite in Go to `DATE` and `DATETIME` in MySQL. You can do that by changing the internal output type from `[]byte` to `time.Time` with the DSN parameter `parseTime=true`. You can set the default [`time.Time` location](http://golang.org/pkg/time/#Location) with the `loc` DSN parameter. + **Caution:** As of Go 1.1, this makes `time.Time` the only variable type you can scan `DATE` and `DATETIME` values into. This breaks for example [`sql.RawBytes` support](https://github.com/go-sql-driver/mysql/wiki/Examples#rawbytes). diff --git a/connection.go b/connection.go index 2520be0c5..c0cdd575b 100644 --- a/connection.go +++ b/connection.go @@ -31,6 +31,7 @@ type mysqlConn struct { maxPacketAllowed int maxWriteSize int parseTime bool + strict bool } type config struct { @@ -67,9 +68,11 @@ func (mc *mysqlConn) handleParams() (err error) { // time.Time parsing case "parseTime": - if val == "true" { - mc.parseTime = true - } + mc.parseTime = readBool(val) + + // Strict mode + case "strict": + mc.strict = readBool(val) // TLS-Encryption case "tls": diff --git a/driver_test.go b/driver_test.go index 4a5d2b9bc..c45b9cfda 100644 --- a/driver_test.go +++ b/driver_test.go @@ -34,7 +34,7 @@ func init() { dbname := env("MYSQL_TEST_DBNAME", "gotest") charset = "charset=utf8" netAddr = fmt.Sprintf("%s(%s)", prot, addr) - dsn = fmt.Sprintf("%s:%s@%s/%s?timeout=30s&"+charset, user, pass, netAddr, dbname) + dsn = fmt.Sprintf("%s:%s@%s/%s?timeout=30s&strict=true&"+charset, user, pass, netAddr, dbname) c, err := net.Dial(prot, addr) if err == nil { available = true @@ -118,12 +118,13 @@ func runTests(t *testing.T, name string, tests ...func(dbt *DBTest)) { } defer db.Close() + db.Exec("DROP TABLE IF EXISTS test") + dbt := &DBTest{t, db} - dbt.mustExec("DROP TABLE IF EXISTS test") for _, test := range tests { test(dbt) + dbt.db.Exec("DROP TABLE IF EXISTS test") } - dbt.mustExec("DROP TABLE IF EXISTS test") } func (dbt *DBTest) fail(method, query string, err error) { @@ -446,7 +447,6 @@ func TestDateTime(t *testing.T) { testTime := func(dbt *DBTest) { var rows *sql.Rows for sqltype, tests := range timetests { - dbt.mustExec("DROP TABLE IF EXISTS test") dbt.mustExec("CREATE TABLE test (value " + sqltype + ")") for _, test := range tests { for mode, q := range modes { @@ -466,6 +466,7 @@ func TestDateTime(t *testing.T) { } } } + dbt.mustExec("DROP TABLE IF EXISTS test") } } @@ -701,7 +702,7 @@ func TestLoadData(t *testing.T) { file.WriteString("1\ta string\n2\ta string containing a \\t\n3\ta string containing a \\n\n4\ta string containing both \\t\\n\n") file.Close() - dbt.mustExec("DROP TABLE IF EXISTS test") + dbt.db.Exec("DROP TABLE IF EXISTS test") dbt.mustExec("CREATE TABLE test (id INT NOT NULL PRIMARY KEY, value TEXT NOT NULL) CHARACTER SET utf8") // Local File @@ -739,6 +740,71 @@ func TestLoadData(t *testing.T) { }) } +func TestStrict(t *testing.T) { + runTests(t, "TestStrict", func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (a TINYINT NOT NULL, b CHAR(4))") + + var queries = [...]struct { + in string + codes []string + }{ + {"DROP TABLE IF EXISTS no_such_table", []string{"1051"}}, + {"INSERT INTO test VALUES(10,'mysql'),(NULL,'test'),(300,'Open Source')", []string{"1265", "1048", "1264", "1265"}}, + } + var err error + + var checkWarnings = func(err error, mode string, idx int) { + if err == nil { + dbt.Errorf("Expected STRICT error on query [%s] %s", mode, queries[idx].in) + } + + if warnings, ok := err.(MySQLWarnings); ok { + var codes = make([]string, len(warnings)) + for i := range warnings { + codes[i] = warnings[i].Code + } + if len(codes) != len(queries[idx].codes) { + dbt.Errorf("Unexpected STRICT error count on query [%s] %s: Wanted %v, Got %v", mode, queries[idx].in, queries[idx].codes, codes) + } + + for i := range warnings { + if codes[i] != queries[idx].codes[i] { + dbt.Errorf("Unexpected STRICT error codes on query [%s] %s: Wanted %v, Got %v", mode, queries[idx].in, queries[idx].codes, codes) + return + } + } + + } else { + dbt.Errorf("Unexpected error on query [%s] %s: %s", mode, queries[idx].in, err.Error()) + } + } + + // text protocol + for i := range queries { + _, err = dbt.db.Exec(queries[i].in) + checkWarnings(err, "text", i) + } + + var stmt *sql.Stmt + + // binary protocol + for i := range queries { + stmt, err = dbt.db.Prepare(queries[i].in) + if err != nil { + dbt.Error("Error on preparing query %: ", queries[i].in, err.Error()) + } + + _, err = stmt.Exec() + checkWarnings(err, "binary", i) + + err = stmt.Close() + if err != nil { + dbt.Error("Error on closing stmt for query %: ", queries[i].in, err.Error()) + } + } + }) +} + // Special cases func TestRowsClose(t *testing.T) { @@ -919,7 +985,7 @@ func TestStmtMultiRows(t *testing.T) { } func TestConcurrent(t *testing.T) { - if os.Getenv("MYSQL_TEST_CONCURRENT") != "1" { + if readBool(os.Getenv("MYSQL_TEST_CONCURRENT")) != true { t.Log("CONCURRENT env var not set. Skipping TestConcurrent") return } diff --git a/errors.go b/errors.go index e71478ab2..004f78b16 100644 --- a/errors.go +++ b/errors.go @@ -9,7 +9,12 @@ package mysql -import "errors" +import ( + "database/sql/driver" + "errors" + "fmt" + "io" +) var ( errMalformPkt = errors.New("Malformed Packet") @@ -18,3 +23,82 @@ var ( errOldPassword = errors.New("It seems like you are using old_passwords, which is unsupported. See https://github.com/go-sql-driver/mysql/wiki/old_passwords") errPktTooLarge = errors.New("Packet for query is too large. You can change this value on the server by adjusting the 'max_allowed_packet' variable.") ) + +// error type which represents a single MySQL error +type MySQLError struct { + Number uint16 + Message string +} + +func (me *MySQLError) Error() string { + return fmt.Sprintf("Error %d: %s", me.Number, me.Message) +} + +// error type which represents a group (one ore more) MySQL warnings +type MySQLWarnings []mysqlWarning + +func (mws MySQLWarnings) Error() string { + var msg string + for i, warning := range mws { + if i > 0 { + msg += "\r\n" + } + msg += fmt.Sprintf("%s %s: %s", warning.Level, warning.Code, warning.Message) + } + return msg +} + +// error type which represents a single MySQL warning +type mysqlWarning struct { + Level string + Code string + Message string +} + +func (mc *mysqlConn) getWarnings() (err error) { + rows, err := mc.Query("SHOW WARNINGS", []driver.Value{}) + if err != nil { + return + } + + var warnings = MySQLWarnings{} + var values = make([]driver.Value, 3) + + var warning mysqlWarning + var raw []byte + var ok bool + + for { + err = rows.Next(values) + switch err { + case nil: + warning = mysqlWarning{} + + if raw, ok = values[0].([]byte); ok { + warning.Level = string(raw) + } else { + warning.Level = fmt.Sprintf("%s", values[0]) + } + if raw, ok = values[1].([]byte); ok { + warning.Code = string(raw) + } else { + warning.Code = fmt.Sprintf("%s", values[1]) + } + if raw, ok = values[2].([]byte); ok { + warning.Message = string(raw) + } else { + warning.Message = fmt.Sprintf("%s", values[0]) + } + + warnings = append(warnings, warning) + + case io.EOF: + return warnings + + default: + rows.Close() + return + } + } + return +} diff --git a/packets.go b/packets.go index 774ce0187..515f25227 100644 --- a/packets.go +++ b/packets.go @@ -352,8 +352,7 @@ func (mc *mysqlConn) readResultOK() error { switch data[0] { case iOK: - mc.handleOkPacket(data) - return nil + return mc.handleOkPacket(data) case iEOF: // someone is using old_passwords return errOldPassword @@ -373,8 +372,7 @@ func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) { switch data[0] { case iOK: - mc.handleOkPacket(data) - return 0, nil + return 0, mc.handleOkPacket(data) case iERR: return 0, mc.handleErrorPacket(data) @@ -415,13 +413,16 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error { } // Error Message [string] - return fmt.Errorf("Error %d: %s", errno, string(data[pos:])) + return &MySQLError{ + Number: errno, + Message: string(data[pos:]), + } } // Ok Packet // http://dev.mysql.com/doc/internals/en/overview.html#packet-OK_Packet -func (mc *mysqlConn) handleOkPacket(data []byte) { - var n int +func (mc *mysqlConn) handleOkPacket(data []byte) (err error) { + var n, m int // 0x00 [1 byte] @@ -429,11 +430,22 @@ func (mc *mysqlConn) handleOkPacket(data []byte) { mc.affectedRows, _, n = readLengthEncodedInteger(data[1:]) // Insert id [Length Coded Binary] - mc.insertId, _, _ = readLengthEncodedInteger(data[1+n:]) + mc.insertId, _, m = readLengthEncodedInteger(data[1+n:]) // server_status [2 bytes] + // warning count [2 bytes] + if !mc.strict { + return + } else { + pos := 1 + n + m + 2 + if binary.LittleEndian.Uint16(data[pos:pos+2]) > 0 { + err = mc.getWarnings() + } + } + // message [until end of packet] + return } // Read Packets as Field Packets until EOF-Packet or an Error appears @@ -625,7 +637,13 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (columnCount uint16, err error) pos += 2 // Warning count [16 bit uint] - // bytesToUint16(data[pos : pos+2]) + if !stmt.mc.strict { + return + } else { + if binary.LittleEndian.Uint16(data[pos:pos+2]) > 0 { + err = stmt.mc.getWarnings() + } + } } return } diff --git a/utils.go b/utils.go index 5eb248eac..b6fe0d2fc 100644 --- a/utils.go +++ b/utils.go @@ -234,6 +234,16 @@ func formatBinaryDateTime(num uint64, data []byte) (driver.Value, error) { return nil, fmt.Errorf("Invalid DATETIME-packet length %d", num) } +func readBool(value string) bool { + switch strings.ToLower(value) { + case "true": + return true + case "1": + return true + } + return false +} + /****************************************************************************** * Convert from and to bytes * ******************************************************************************/