Skip to content

Commit 3f35029

Browse files
committed
Add strict mode
Closes #40
1 parent 87c17e7 commit 3f35029

File tree

5 files changed

+138
-15
lines changed

5 files changed

+138
-15
lines changed

README.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ Possible Parameters are:
110110
* `allowAllFiles`: `allowAllFiles=true` disables the file Whitelist for `LOAD DATA LOCAL INFILE` and allows *all* files. *Might be insecure!*
111111
* `parseTime`: `parseTime=true` changes the output type of `DATE` and `DATETIME` values to `time.Time` instead of `[]byte` / `string`
112112
* `loc`: Sets the location for time.Time values (when using `parseTime=true`). The default is `UTC`. *"Local"* sets the system's location. See [time.LoadLocation](http://golang.org/pkg/time/#LoadLocation) for details.
113+
* `strict`: Enable strict mode. MySQL warnings are treated as errors.
113114

114115
All other parameters are interpreted as system variables:
115116
* `autocommit`: *"SET autocommit=`value`"*
@@ -154,7 +155,7 @@ See also the [godoc of Go-MySQL-Driver](http://godoc.org/github.com/go-sql-drive
154155
### `time.Time` support
155156
The default internal output type of MySQL `DATE` and `DATETIME` values is `[]byte` which allows you to scan the value into a `[]byte`, `string` or `sql.RawBytes` variable in your programm.
156157

157-
However, many want to scan MySQL `DATE` and `DATETIME` values into `time.Time` variables, which is the logical opposite in Go to `DATE` and `DATETIME` in MySQL. You can do that by changing the internal output type from `[]byte` to `time.Time` with the DSN parameter `parseTime=true`. You can set the default [`time.Time` location](http://golang.org/pkg/time/#Location) with the `loc` DSN parameter.
158+
However, many want to scan MySQL `DATE` and `DATETIME` values into `time.Time` variables, which is the logical opposite in Go to `DATE` and `DATETIME` in MySQL. You can do that by changing the internal output type from `[]byte` to `time.Time` with the DSN parameter `parseTime=true`. You can set the default [`time.Time` location](http://golang.org/pkg/time/#Location) with the `loc` DSN parameter.
158159
**Caution:** As of Go 1.1, this makes `time.Time` the only variable type you can scan `DATE` and `DATETIME` values into. This breaks for example [`sql.RawBytes` support](https://github.com/go-sql-driver/mysql/wiki/Examples#rawbytes).
159160

160161

connection.go

+7
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ type mysqlConn struct {
3131
maxPacketAllowed int
3232
maxWriteSize int
3333
parseTime bool
34+
strict bool
3435
}
3536

3637
type config struct {
@@ -71,6 +72,12 @@ func (mc *mysqlConn) handleParams() (err error) {
7172
mc.parseTime = true
7273
}
7374

75+
// Strict mode
76+
case "strict":
77+
if val == "true" {
78+
mc.strict = true
79+
}
80+
7481
// TLS-Encryption
7582
case "tls":
7683
err = errors.New("TLS-Encryption not implemented yet")

driver_test.go

+63-5
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ func init() {
3434
dbname := env("MYSQL_TEST_DBNAME", "gotest")
3535
charset = "charset=utf8"
3636
netAddr = fmt.Sprintf("%s(%s)", prot, addr)
37-
dsn = fmt.Sprintf("%s:%s@%s/%s?timeout=30s&"+charset, user, pass, netAddr, dbname)
37+
dsn = fmt.Sprintf("%s:%s@%s/%s?timeout=30s&strict=true&"+charset, user, pass, netAddr, dbname)
3838
c, err := net.Dial(prot, addr)
3939
if err == nil {
4040
available = true
@@ -119,11 +119,11 @@ func runTests(t *testing.T, name string, tests ...func(dbt *DBTest)) {
119119
defer db.Close()
120120

121121
dbt := &DBTest{t, db}
122-
dbt.mustExec("DROP TABLE IF EXISTS test")
122+
dbt.db.Exec("DROP TABLE IF EXISTS test")
123123
for _, test := range tests {
124124
test(dbt)
125+
dbt.db.Exec("DROP TABLE IF EXISTS test")
125126
}
126-
dbt.mustExec("DROP TABLE IF EXISTS test")
127127
}
128128

129129
func (dbt *DBTest) fail(method, query string, err error) {
@@ -446,7 +446,6 @@ func TestDateTime(t *testing.T) {
446446
testTime := func(dbt *DBTest) {
447447
var rows *sql.Rows
448448
for sqltype, tests := range timetests {
449-
dbt.mustExec("DROP TABLE IF EXISTS test")
450449
dbt.mustExec("CREATE TABLE test (value " + sqltype + ")")
451450
for _, test := range tests {
452451
for mode, q := range modes {
@@ -466,6 +465,7 @@ func TestDateTime(t *testing.T) {
466465
}
467466
}
468467
}
468+
dbt.mustExec("DROP TABLE IF EXISTS test")
469469
}
470470
}
471471

@@ -701,7 +701,7 @@ func TestLoadData(t *testing.T) {
701701
file.WriteString("1\ta string\n2\ta string containing a \\t\n3\ta string containing a \\n\n4\ta string containing both \\t\\n\n")
702702
file.Close()
703703

704-
dbt.mustExec("DROP TABLE IF EXISTS test")
704+
dbt.db.Exec("DROP TABLE IF EXISTS test")
705705
dbt.mustExec("CREATE TABLE test (id INT NOT NULL PRIMARY KEY, value TEXT NOT NULL) CHARACTER SET utf8")
706706

707707
// Local File
@@ -739,6 +739,64 @@ func TestLoadData(t *testing.T) {
739739
})
740740
}
741741

742+
func TestStrict(t *testing.T) {
743+
runTests(t, "TestStrict", func(dbt *DBTest) {
744+
dbt.mustExec("CREATE TABLE test (a TINYINT NOT NULL, b CHAR(4))")
745+
746+
queries := [...][2]string{
747+
{"DROP TABLE IF EXISTS no_such_table", "Note 1051: Unknown table 'no_such_table'"},
748+
{"INSERT INTO test VALUES(10,'mysql'),(NULL,'test'),(300,'Open Source')",
749+
"Warning 1265: Data truncated for column 'b' at row 1\r\n" +
750+
"Warning 1048: Column 'a' cannot be null\r\n" +
751+
"Warning 1264: Out of range value for column 'a' at row 3\r\n" +
752+
"Warning 1265: Data truncated for column 'b' at row 3",
753+
},
754+
}
755+
var rows *sql.Rows
756+
var err error
757+
758+
// text protocol
759+
for i := range queries {
760+
rows, err = dbt.db.Query(queries[i][0])
761+
if rows != nil {
762+
rows.Close()
763+
}
764+
765+
if err == nil {
766+
dbt.Errorf("Expecteded strict error on query [text] %s", queries[i][0])
767+
} else if err.Error() != queries[i][1] {
768+
dbt.Errorf("Unexpected error on query [text] %s: %s != %s", queries[i][0], err.Error(), queries[i][1])
769+
}
770+
}
771+
772+
var stmt *sql.Stmt
773+
774+
// binary protocol
775+
for i := range queries {
776+
stmt, err = dbt.db.Prepare(queries[i][0])
777+
if err != nil {
778+
dbt.Error("Error on preparing query %: ", queries[i][0], err.Error())
779+
}
780+
781+
rows, err = stmt.Query()
782+
if rows != nil {
783+
rows.Close()
784+
}
785+
786+
if err == nil {
787+
dbt.Errorf("Expecteded strict error on query [binary] %s", queries[i][0])
788+
} else if err.Error() != queries[i][1] {
789+
dbt.Errorf("Unexpected error on query [binary] %s: %s != %s", queries[i][0], err.Error(), queries[i][1])
790+
}
791+
792+
err = stmt.Close()
793+
if err != nil {
794+
dbt.Error("Error on closing stmt for query %: ", queries[i][0], err.Error())
795+
}
796+
}
797+
})
798+
}
799+
742800
// Special cases
743801

744802
func TestRowsClose(t *testing.T) {

errors.go

+43-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,12 @@
99

1010
package mysql
1111

12-
import "errors"
12+
import (
13+
"database/sql/driver"
14+
"errors"
15+
"fmt"
16+
"io"
17+
)
1318

1419
var (
1520
errMalformPkt = errors.New("Malformed Packet")
@@ -18,3 +23,40 @@ var (
1823
errOldPassword = errors.New("It seems like you are using old_passwords, which is unsupported. See https://github.com/go-sql-driver/mysql/wiki/old_passwords")
1924
errPktTooLarge = errors.New("Packet for query is too large. You can change this value on the server by adjusting the 'max_allowed_packet' variable.")
2025
)
26+
27+
// error type which represents one or more MySQL warnings
28+
type MySQLWarnings []string
29+
30+
func (mw MySQLWarnings) Error() string {
31+
var msg string
32+
for i := range mw {
33+
if i > 0 {
34+
msg += "\r\n"
35+
}
36+
msg += mw[i]
37+
}
38+
return msg
39+
}
40+
41+
func (mc *mysqlConn) getWarnings() (err error) {
42+
rows, err := mc.Query("SHOW WARNINGS", []driver.Value{})
43+
if err != nil {
44+
return
45+
}
46+
47+
var warnings = MySQLWarnings{}
48+
var values = make([]driver.Value, 3)
49+
50+
for {
51+
if err = rows.Next(values); err == nil {
52+
warnings = append(warnings,
53+
fmt.Sprintf("%s %s: %s", values[0], values[1], values[2]),
54+
)
55+
} else if err == io.EOF {
56+
return warnings
57+
} else {
58+
return
59+
}
60+
}
61+
return
62+
}

packets.go

+23-8
Original file line numberDiff line numberDiff line change
@@ -352,8 +352,7 @@ func (mc *mysqlConn) readResultOK() error {
352352
switch data[0] {
353353

354354
case iOK:
355-
mc.handleOkPacket(data)
356-
return nil
355+
return mc.handleOkPacket(data)
357356

358357
case iEOF: // someone is using old_passwords
359358
return errOldPassword
@@ -373,8 +372,7 @@ func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) {
373372
switch data[0] {
374373

375374
case iOK:
376-
mc.handleOkPacket(data)
377-
return 0, nil
375+
return 0, mc.handleOkPacket(data)
378376

379377
case iERR:
380378
return 0, mc.handleErrorPacket(data)
@@ -420,20 +418,31 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error {
420418

421419
// Ok Packet
422420
// http://dev.mysql.com/doc/internals/en/overview.html#packet-OK_Packet
423-
func (mc *mysqlConn) handleOkPacket(data []byte) {
424-
var n int
421+
func (mc *mysqlConn) handleOkPacket(data []byte) (err error) {
422+
var n, m int
425423

426424
// 0x00 [1 byte]
427425

428426
// Affected rows [Length Coded Binary]
429427
mc.affectedRows, _, n = readLengthEncodedInteger(data[1:])
430428

431429
// Insert id [Length Coded Binary]
432-
mc.insertId, _, _ = readLengthEncodedInteger(data[1+n:])
430+
mc.insertId, _, m = readLengthEncodedInteger(data[1+n:])
433431

434432
// server_status [2 bytes]
433+
435434
// warning count [2 bytes]
435+
if !mc.strict {
436+
return
437+
} else {
438+
pos := 1 + n + m + 2
439+
if binary.LittleEndian.Uint16(data[pos:pos+2]) > 0 {
440+
err = mc.getWarnings()
441+
}
442+
}
443+
436444
// message [until end of packet]
445+
return
437446
}
438447

439448
// Read Packets as Field Packets until EOF-Packet or an Error appears
@@ -625,7 +634,13 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (columnCount uint16, err error)
625634
pos += 2
626635

627636
// Warning count [16 bit uint]
628-
// bytesToUint16(data[pos : pos+2])
637+
if !stmt.mc.strict {
638+
return
639+
} else {
640+
if binary.LittleEndian.Uint16(data[pos:pos+2]) > 0 {
641+
err = stmt.mc.getWarnings()
642+
}
643+
}
629644
}
630645
return
631646
}

0 commit comments

Comments
 (0)