Skip to content

Add strict mode #58

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 24, 2013
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ Possible Parameters are:
* `allowAllFiles`: `allowAllFiles=true` disables the file Whitelist for `LOAD DATA LOCAL INFILE` and allows *all* files. *Might be insecure!*
* `parseTime`: `parseTime=true` changes the output type of `DATE` and `DATETIME` values to `time.Time` instead of `[]byte` / `string`
* `loc`: Sets the location for time.Time values (when using `parseTime=true`). The default is `UTC`. *"Local"* sets the system's location. See [time.LoadLocation](http://golang.org/pkg/time/#LoadLocation) for details.
* `strict`: Enable strict mode. MySQL warnings are treated as errors.

All other parameters are interpreted as system variables:
* `autocommit`: *"SET autocommit=`value`"*
Expand Down Expand Up @@ -154,7 +155,8 @@ See also the [godoc of Go-MySQL-Driver](http://godoc.org/github.com/go-sql-drive
### `time.Time` support
The default internal output type of MySQL `DATE` and `DATETIME` values is `[]byte` which allows you to scan the value into a `[]byte`, `string` or `sql.RawBytes` variable in your programm.

However, many want to scan MySQL `DATE` and `DATETIME` values into `time.Time` variables, which is the logical opposite in Go to `DATE` and `DATETIME` in MySQL. You can do that by changing the internal output type from `[]byte` to `time.Time` with the DSN parameter `parseTime=true`. You can set the default [`time.Time` location](http://golang.org/pkg/time/#Location) with the `loc` DSN parameter.
However, many want to scan MySQL `DATE` and `DATETIME` values into `time.Time` variables, which is the logical opposite in Go to `DATE` and `DATETIME` in MySQL. You can do that by changing the internal output type from `[]byte` to `time.Time` with the DSN parameter `parseTime=true`. You can set the default [`time.Time` location](http://golang.org/pkg/time/#Location) with the `loc` DSN parameter.

**Caution:** As of Go 1.1, this makes `time.Time` the only variable type you can scan `DATE` and `DATETIME` values into. This breaks for example [`sql.RawBytes` support](https://github.com/go-sql-driver/mysql/wiki/Examples#rawbytes).


Expand Down
9 changes: 6 additions & 3 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ type mysqlConn struct {
maxPacketAllowed int
maxWriteSize int
parseTime bool
strict bool
}

type config struct {
Expand Down Expand Up @@ -67,9 +68,11 @@ func (mc *mysqlConn) handleParams() (err error) {

// time.Time parsing
case "parseTime":
if val == "true" {
mc.parseTime = true
}
mc.parseTime = readBool(val)

// Strict mode
case "strict":
mc.strict = readBool(val)

// TLS-Encryption
case "tls":
Expand Down
78 changes: 72 additions & 6 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func init() {
dbname := env("MYSQL_TEST_DBNAME", "gotest")
charset = "charset=utf8"
netAddr = fmt.Sprintf("%s(%s)", prot, addr)
dsn = fmt.Sprintf("%s:%s@%s/%s?timeout=30s&"+charset, user, pass, netAddr, dbname)
dsn = fmt.Sprintf("%s:%s@%s/%s?timeout=30s&strict=true&"+charset, user, pass, netAddr, dbname)
c, err := net.Dial(prot, addr)
if err == nil {
available = true
Expand Down Expand Up @@ -118,12 +118,13 @@ func runTests(t *testing.T, name string, tests ...func(dbt *DBTest)) {
}
defer db.Close()

db.Exec("DROP TABLE IF EXISTS test")

dbt := &DBTest{t, db}
dbt.mustExec("DROP TABLE IF EXISTS test")
for _, test := range tests {
test(dbt)
dbt.db.Exec("DROP TABLE IF EXISTS test")
}
dbt.mustExec("DROP TABLE IF EXISTS test")
}

func (dbt *DBTest) fail(method, query string, err error) {
Expand Down Expand Up @@ -446,7 +447,6 @@ func TestDateTime(t *testing.T) {
testTime := func(dbt *DBTest) {
var rows *sql.Rows
for sqltype, tests := range timetests {
dbt.mustExec("DROP TABLE IF EXISTS test")
dbt.mustExec("CREATE TABLE test (value " + sqltype + ")")
for _, test := range tests {
for mode, q := range modes {
Expand All @@ -466,6 +466,7 @@ func TestDateTime(t *testing.T) {
}
}
}
dbt.mustExec("DROP TABLE IF EXISTS test")
}
}

Expand Down Expand Up @@ -701,7 +702,7 @@ func TestLoadData(t *testing.T) {
file.WriteString("1\ta string\n2\ta string containing a \\t\n3\ta string containing a \\n\n4\ta string containing both \\t\\n\n")
file.Close()

dbt.mustExec("DROP TABLE IF EXISTS test")
dbt.db.Exec("DROP TABLE IF EXISTS test")
dbt.mustExec("CREATE TABLE test (id INT NOT NULL PRIMARY KEY, value TEXT NOT NULL) CHARACTER SET utf8")

// Local File
Expand Down Expand Up @@ -739,6 +740,71 @@ func TestLoadData(t *testing.T) {
})
}

func TestStrict(t *testing.T) {
runTests(t, "TestStrict", func(dbt *DBTest) {
dbt.mustExec("CREATE TABLE test (a TINYINT NOT NULL, b CHAR(4))")

var queries = [...]struct {
in string
codes []string
}{
{"DROP TABLE IF EXISTS no_such_table", []string{"1051"}},
{"INSERT INTO test VALUES(10,'mysql'),(NULL,'test'),(300,'Open Source')", []string{"1265", "1048", "1264", "1265"}},
}
var err error

var checkWarnings = func(err error, mode string, idx int) {
if err == nil {
dbt.Errorf("Expected STRICT error on query [%s] %s", mode, queries[idx].in)
}

if warnings, ok := err.(MySQLWarnings); ok {
var codes = make([]string, len(warnings))
for i := range warnings {
codes[i] = warnings[i].Code
}
if len(codes) != len(queries[idx].codes) {
dbt.Errorf("Unexpected STRICT error count on query [%s] %s: Wanted %v, Got %v", mode, queries[idx].in, queries[idx].codes, codes)
}

for i := range warnings {
if codes[i] != queries[idx].codes[i] {
dbt.Errorf("Unexpected STRICT error codes on query [%s] %s: Wanted %v, Got %v", mode, queries[idx].in, queries[idx].codes, codes)
return
}
}

} else {
dbt.Errorf("Unexpected error on query [%s] %s: %s", mode, queries[idx].in, err.Error())
}
}

// text protocol
for i := range queries {
_, err = dbt.db.Exec(queries[i].in)
checkWarnings(err, "text", i)
}

var stmt *sql.Stmt

// binary protocol
for i := range queries {
stmt, err = dbt.db.Prepare(queries[i].in)
if err != nil {
dbt.Error("Error on preparing query %: ", queries[i].in, err.Error())
}

_, err = stmt.Exec()
checkWarnings(err, "binary", i)

err = stmt.Close()
if err != nil {
dbt.Error("Error on closing stmt for query %: ", queries[i].in, err.Error())
}
}
})
}

// Special cases

func TestRowsClose(t *testing.T) {
Expand Down Expand Up @@ -919,7 +985,7 @@ func TestStmtMultiRows(t *testing.T) {
}

func TestConcurrent(t *testing.T) {
if os.Getenv("MYSQL_TEST_CONCURRENT") != "1" {
if readBool(os.Getenv("MYSQL_TEST_CONCURRENT")) != true {
t.Log("CONCURRENT env var not set. Skipping TestConcurrent")
return
}
Expand Down
86 changes: 85 additions & 1 deletion errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@

package mysql

import "errors"
import (
"database/sql/driver"
"errors"
"fmt"
"io"
)

var (
errMalformPkt = errors.New("Malformed Packet")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've already got MySQLError - why not make two error types and get rid of the errors import?

type Error string

func (e *Error) Error() string {
    return *e
}

type MysqlError struct {
    Error
    Number uint16
    Message string
}

func newError(message string, number uint16) *MysqlError {  
    return &MysqlError{fmt.Sprintf("Error %d: %s", number, message), number, message}
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see the benefit from that. We import the errors package elsewhere and it is very tiny anyways: http://golang.org/src/pkg/errors/errors.go

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but getting rid of errors is not the main benefit. I think a custom error type could be used as a basis for errors and warnings. And you could make the errors const instead of var.

Expand All @@ -18,3 +23,82 @@ var (
errOldPassword = errors.New("It seems like you are using old_passwords, which is unsupported. See https://github.com/go-sql-driver/mysql/wiki/old_passwords")
errPktTooLarge = errors.New("Packet for query is too large. You can change this value on the server by adjusting the 'max_allowed_packet' variable.")
)

// error type which represents a single MySQL error
type MySQLError struct {
Number uint16
Message string
}

func (me *MySQLError) Error() string {
return fmt.Sprintf("Error %d: %s", me.Number, me.Message)
}

// error type which represents a group (one ore more) MySQL warnings
type MySQLWarnings []mysqlWarning

func (mws MySQLWarnings) Error() string {
var msg string
for i, warning := range mws {
if i > 0 {
msg += "\r\n"
}
msg += fmt.Sprintf("%s %s: %s", warning.Level, warning.Code, warning.Message)
}
return msg
}

// error type which represents a single MySQL warning
type mysqlWarning struct {
Level string
Code string
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why isn't this a number?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the MySQL spec doesn't say it is a number. In contrast to error numbers, which are transmitted numerical, the warning code is transmitted as text. This is why I named warning Code and error Number different.

Message string
}

func (mc *mysqlConn) getWarnings() (err error) {
rows, err := mc.Query("SHOW WARNINGS", []driver.Value{})
if err != nil {
return
}

var warnings = MySQLWarnings{}
var values = make([]driver.Value, 3)

var warning mysqlWarning
var raw []byte
var ok bool

for {
err = rows.Next(values)
switch err {
case nil:
warning = mysqlWarning{}

if raw, ok = values[0].([]byte); ok {
warning.Level = string(raw)
} else {
warning.Level = fmt.Sprintf("%s", values[0])
}
if raw, ok = values[1].([]byte); ok {
warning.Code = string(raw)
} else {
warning.Code = fmt.Sprintf("%s", values[1])
}
if raw, ok = values[2].([]byte); ok {
warning.Message = string(raw)
} else {
warning.Message = fmt.Sprintf("%s", values[0])
}

warnings = append(warnings, warning)

case io.EOF:
return warnings

default:
rows.Close()
return
}
}
return
}
36 changes: 27 additions & 9 deletions packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,7 @@ func (mc *mysqlConn) readResultOK() error {
switch data[0] {

case iOK:
mc.handleOkPacket(data)
return nil
return mc.handleOkPacket(data)

case iEOF: // someone is using old_passwords
return errOldPassword
Expand All @@ -373,8 +372,7 @@ func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) {
switch data[0] {

case iOK:
mc.handleOkPacket(data)
return 0, nil
return 0, mc.handleOkPacket(data)

case iERR:
return 0, mc.handleErrorPacket(data)
Expand Down Expand Up @@ -415,25 +413,39 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error {
}

// Error Message [string]
return fmt.Errorf("Error %d: %s", errno, string(data[pos:]))
return &MySQLError{
Number: errno,
Message: string(data[pos:]),
}
}

// Ok Packet
// http://dev.mysql.com/doc/internals/en/overview.html#packet-OK_Packet
func (mc *mysqlConn) handleOkPacket(data []byte) {
var n int
func (mc *mysqlConn) handleOkPacket(data []byte) (err error) {
var n, m int

// 0x00 [1 byte]

// Affected rows [Length Coded Binary]
mc.affectedRows, _, n = readLengthEncodedInteger(data[1:])

// Insert id [Length Coded Binary]
mc.insertId, _, _ = readLengthEncodedInteger(data[1+n:])
mc.insertId, _, m = readLengthEncodedInteger(data[1+n:])

// server_status [2 bytes]

// warning count [2 bytes]
if !mc.strict {
return
} else {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not if mc.strict { and drop stuff above?

pos := 1 + n + m + 2
if binary.LittleEndian.Uint16(data[pos:pos+2]) > 0 {
err = mc.getWarnings()
}
}

// message [until end of packet]
return
}

// Read Packets as Field Packets until EOF-Packet or an Error appears
Expand Down Expand Up @@ -625,7 +637,13 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (columnCount uint16, err error)
pos += 2

// Warning count [16 bit uint]
// bytesToUint16(data[pos : pos+2])
if !stmt.mc.strict {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like L440: drop branch - but I guess you write it this way because of benchmarks and analysis of disassembled code.
Still, my guess would be if is better than if ... else .... That may be different when a branch contains return, though...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it's mirco optimization... I should really stop with that.

return
} else {
if binary.LittleEndian.Uint16(data[pos:pos+2]) > 0 {
err = stmt.mc.getWarnings()
}
}
}
return
}
Expand Down
10 changes: 10 additions & 0 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,16 @@ func formatBinaryDateTime(num uint64, data []byte) (driver.Value, error) {
return nil, fmt.Errorf("Invalid DATETIME-packet length %d", num)
}

func readBool(value string) bool {
switch strings.ToLower(value) {
case "true":
return true
case "1":
return true
}
return false
}

/******************************************************************************
* Convert from and to bytes *
******************************************************************************/
Expand Down