diff --git a/driver_test.go b/driver_test.go index de60d087e..4a5d2b9bc 100644 --- a/driver_test.go +++ b/driver_test.go @@ -8,111 +8,91 @@ import ( "net" "os" "strings" - "sync" "testing" "time" ) var ( - charset string - dsn string - netAddr string - run bool - once sync.Once + charset string + dsn string + netAddr string + available bool ) // See https://github.com/go-sql-driver/mysql/wiki/Testing -func getEnv() bool { - once.Do(func() { - user := os.Getenv("MYSQL_TEST_USER") - if user == "" { - user = "root" - } - - pass := os.Getenv("MYSQL_TEST_PASS") - - prot := os.Getenv("MYSQL_TEST_PROT") - if prot == "" { - prot = "tcp" - } - - addr := os.Getenv("MYSQL_TEST_ADDR") - if addr == "" { - addr = "localhost:3306" - } - - dbname := os.Getenv("MYSQL_TEST_DBNAME") - if dbname == "" { - dbname = "gotest" - } - - charset = "charset=utf8" - netAddr = fmt.Sprintf("%s(%s)", prot, addr) - dsn = fmt.Sprintf("%s:%s@%s/%s?timeout=30s&"+charset, user, pass, netAddr, dbname) - - c, err := net.Dial(prot, addr) - if err == nil { - run = true - c.Close() - } - }) - - return run -} - -func mustExec(t *testing.T, db *sql.DB, query string, args ...interface{}) (res sql.Result) { - res, err := db.Exec(query, args...) - if err != nil { - if len(query) > 300 { - query = "[query too large to print]" - } - t.Fatalf("Error on Exec %s: %v", query, err) +func init() { + env := func(key, defaultValue string) string { + if value := os.Getenv(key); value != "" { + return value + } + return defaultValue + } + user := env("MYSQL_TEST_USER", "root") + pass := env("MYSQL_TEST_PASS", "") + prot := env("MYSQL_TEST_PROT", "tcp") + addr := env("MYSQL_TEST_ADDR", "localhost:3306") + dbname := env("MYSQL_TEST_DBNAME", "gotest") + charset = "charset=utf8" + netAddr = fmt.Sprintf("%s(%s)", prot, addr) + dsn = fmt.Sprintf("%s:%s@%s/%s?timeout=30s&"+charset, user, pass, netAddr, dbname) + c, err := net.Dial(prot, addr) + if err == nil { + available = true + c.Close() } - return } -func mustQuery(t *testing.T, db *sql.DB, query string, args ...interface{}) (rows *sql.Rows) { - rows, err := db.Query(query, args...) - if err != nil { - if len(query) > 300 { - query = "[query too large to print]" +func TestCharset(t *testing.T) { + mustSetCharset := func(charsetParam, expected string) { + db, err := sql.Open("mysql", strings.Replace(dsn, charset, charsetParam, 1)) + if err != nil { + t.Fatalf("Error on Open: %v", err) } - t.Fatalf("Error on Query %s: %v", query, err) - } - return -} + defer db.Close() -func mustSetCharset(t *testing.T, charsetParam, expected string) { - db, err := sql.Open("mysql", strings.Replace(dsn, charset, charsetParam, 1)) - if err != nil { - t.Fatalf("Error on Open: %v", err) - } + dbt := &DBTest{t, db} + rows := dbt.mustQuery("SELECT @@character_set_connection") + defer rows.Close() - rows := mustQuery(t, db, ("SELECT @@character_set_connection")) - if !rows.Next() { - t.Fatalf("Error getting connection charset: %v", err) - } + if !rows.Next() { + dbt.Fatalf("Error getting connection charset: %v", err) + } - var got string - rows.Scan(&got) + var got string + rows.Scan(&got) - if got != expected { - t.Fatalf("Expected connection charset %s but got %s", expected, got) + if got != expected { + dbt.Fatalf("Expected connection charset %s but got %s", expected, got) + } } -} -func TestCharset(t *testing.T) { - if !getEnv() { + if !available { t.Logf("MySQL-Server not running on %s. Skipping TestCharset", netAddr) return } // non utf8 test - mustSetCharset(t, "charset=ascii", "ascii") + mustSetCharset("charset=ascii", "ascii") + + // when the first charset is invalid, use the second + mustSetCharset("charset=none,utf8", "utf8") + + // when the first charset is valid, use it + mustSetCharset("charset=ascii,utf8", "ascii") + mustSetCharset("charset=utf8,ascii", "utf8") } func TestFailingCharset(t *testing.T) { + if !available { + t.Logf("MySQL-Server not running on %s. Skipping TestFailingCharset", netAddr) + return + } db, err := sql.Open("mysql", strings.Replace(dsn, charset, "charset=none", 1)) + if err != nil { + t.Fatalf("Error on Open: %v", err) + } + defer db.Close() + // run query to really establish connection... _, err = db.Exec("SELECT 1") if err == nil { @@ -121,27 +101,14 @@ func TestFailingCharset(t *testing.T) { } } -func TestFallbackCharset(t *testing.T) { - if !getEnv() { - t.Logf("MySQL-Server not running on %s. Skipping TestFallbackCharset", netAddr) - return - } - - // when the first charset is invalid, use the second - mustSetCharset(t, "charset=none,utf8", "utf8") - - // when the first charset is valid, use it - charsets := []string{"ascii", "utf8"} - for i := range charsets { - expected := charsets[i] - other := charsets[1-i] - mustSetCharset(t, "charset="+expected+","+other, expected) - } +type DBTest struct { + *testing.T + db *sql.DB } -func TestCRUD(t *testing.T) { - if !getEnv() { - t.Logf("MySQL-Server not running on %s. Skipping TestCRUD", netAddr) +func runTests(t *testing.T, name string, tests ...func(dbt *DBTest)) { + if !available { + t.Logf("MySQL-Server not running on %s. Skipping %s", netAddr, name) return } @@ -149,869 +116,806 @@ func TestCRUD(t *testing.T) { if err != nil { t.Fatalf("Error connecting: %v", err) } - defer db.Close() - mustExec(t, db, "DROP TABLE IF EXISTS test") - - // Create Table - mustExec(t, db, "CREATE TABLE test (value BOOL)") + dbt := &DBTest{t, db} + dbt.mustExec("DROP TABLE IF EXISTS test") + for _, test := range tests { + test(dbt) + } + dbt.mustExec("DROP TABLE IF EXISTS test") +} - // Test for unexpected data - var out bool - rows := mustQuery(t, db, ("SELECT * FROM test")) - if rows.Next() { - t.Error("unexpected data in empty table") +func (dbt *DBTest) fail(method, query string, err error) { + if len(query) > 300 { + query = "[query too large to print]" } + dbt.Fatalf("Error on %s %s: %v", method, query, err) +} - // Create Data - res := mustExec(t, db, ("INSERT INTO test VALUES (1)")) - count, err := res.RowsAffected() +func (dbt *DBTest) mustExec(query string, args ...interface{}) (res sql.Result) { + res, err := dbt.db.Exec(query, args...) if err != nil { - t.Fatalf("res.RowsAffected() returned error: %v", err) - } - if count != 1 { - t.Fatalf("Expected 1 affected row, got %d", count) + dbt.fail("Exec", query, err) } + return res +} - id, err := res.LastInsertId() +func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *sql.Rows) { + rows, err := dbt.db.Query(query, args...) if err != nil { - t.Fatalf("res.LastInsertId() returned error: %v", err) - } - if id != 0 { - t.Fatalf("Expected InsertID 0, got %d", id) + dbt.fail("Query", query, err) } + return rows +} - // Read - rows = mustQuery(t, db, ("SELECT value FROM test")) - if rows.Next() { - rows.Scan(&out) - if true != out { - t.Errorf("true != %t", out) +func TestRawBytesResultExceedsBuffer(t *testing.T) { + runTests(t, "TestRawBytesResultExceedsBuffer", func(dbt *DBTest) { + // defaultBufSize from buffer.go + expected := strings.Repeat("abc", defaultBufSize) + rows := dbt.mustQuery("SELECT '" + expected + "'") + defer rows.Close() + if !rows.Next() { + dbt.Error("expected result, got none") } + var result sql.RawBytes + rows.Scan(&result) + if expected != string(result) { + dbt.Error("result did not match expected value") + } + }) +} + +func TestCRUD(t *testing.T) { + runTests(t, "TestCRUD", func(dbt *DBTest) { + // Create Table + dbt.mustExec("CREATE TABLE test (value BOOL)") + // Test for unexpected data + var out bool + rows := dbt.mustQuery("SELECT * FROM test") if rows.Next() { - t.Error("unexpected data") + dbt.Error("unexpected data in empty table") } - } else { - t.Error("no data") - } - // Update - res = mustExec(t, db, "UPDATE test SET value = ? WHERE value = ?", false, true) - count, err = res.RowsAffected() - if err != nil { - t.Fatalf("res.RowsAffected() returned error: %v", err) - } - if count != 1 { - t.Fatalf("Expected 1 affected row, got %d", count) - } + // Create Data + res := dbt.mustExec("INSERT INTO test VALUES (1)") + count, err := res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %v", err) + } + if count != 1 { + dbt.Fatalf("Expected 1 affected row, got %d", count) + } - // Check Update - rows = mustQuery(t, db, ("SELECT value FROM test")) - if rows.Next() { - rows.Scan(&out) - if false != out { - t.Errorf("false != %t", out) + id, err := res.LastInsertId() + if err != nil { + dbt.Fatalf("res.LastInsertId() returned error: %v", err) + } + if id != 0 { + dbt.Fatalf("Expected InsertID 0, got %d", id) } + // Read + rows = dbt.mustQuery("SELECT value FROM test") if rows.Next() { - t.Error("unexpected data") + rows.Scan(&out) + if true != out { + dbt.Errorf("true != %t", out) + } + + if rows.Next() { + dbt.Error("unexpected data") + } + } else { + dbt.Error("no data") } - } else { - t.Error("no data") - } - // Delete - res = mustExec(t, db, "DELETE FROM test WHERE value = ?", false) - count, err = res.RowsAffected() - if err != nil { - t.Fatalf("res.RowsAffected() returned error: %v", err) - } - if count != 1 { - t.Fatalf("Expected 1 affected row, got %d", count) - } + // Update + res = dbt.mustExec("UPDATE test SET value = ? WHERE value = ?", false, true) + count, err = res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %v", err) + } + if count != 1 { + dbt.Fatalf("Expected 1 affected row, got %d", count) + } - // Check for unexpected rows - res = mustExec(t, db, "DELETE FROM test") - count, err = res.RowsAffected() - if err != nil { - t.Fatalf("res.RowsAffected() returned error: %v", err) - } - if count != 0 { - t.Fatalf("Expected 0 affected row, got %d", count) - } + // Check Update + rows = dbt.mustQuery("SELECT value FROM test") + if rows.Next() { + rows.Scan(&out) + if false != out { + dbt.Errorf("false != %t", out) + } + + if rows.Next() { + dbt.Error("unexpected data") + } + } else { + dbt.Error("no data") + } + + // Delete + res = dbt.mustExec("DELETE FROM test WHERE value = ?", false) + count, err = res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %v", err) + } + if count != 1 { + dbt.Fatalf("Expected 1 affected row, got %d", count) + } + + // Check for unexpected rows + res = dbt.mustExec("DELETE FROM test") + count, err = res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %v", err) + } + if count != 0 { + dbt.Fatalf("Expected 0 affected row, got %d", count) + } + }) } func TestInt(t *testing.T) { - if !getEnv() { - t.Logf("MySQL-Server not running on %s. Skipping TestInt", netAddr) - return - } + runTests(t, "TestInt", func(dbt *DBTest) { + types := [5]string{"TINYINT", "SMALLINT", "MEDIUMINT", "INT", "BIGINT"} + in := int64(42) + var out int64 + var rows *sql.Rows + + // SIGNED + for _, v := range types { + dbt.mustExec("CREATE TABLE test (value " + v + ")") + + dbt.mustExec("INSERT INTO test VALUES (?)", in) + + rows = dbt.mustQuery("SELECT value FROM test") + if rows.Next() { + rows.Scan(&out) + if in != out { + dbt.Errorf("%s: %d != %d", v, in, out) + } + } else { + dbt.Errorf("%s: no data", v) + } - db, err := sql.Open("mysql", dsn) - if err != nil { - t.Fatalf("Error connecting: %v", err) - } + dbt.mustExec("DROP TABLE IF EXISTS test") + } - defer db.Close() + // UNSIGNED ZEROFILL + for _, v := range types { + dbt.mustExec("CREATE TABLE test (value " + v + " ZEROFILL)") - mustExec(t, db, "DROP TABLE IF EXISTS test") + dbt.mustExec("INSERT INTO test VALUES (?)", in) - types := [5]string{"TINYINT", "SMALLINT", "MEDIUMINT", "INT", "BIGINT"} - in := int64(42) - var out int64 - var rows *sql.Rows + rows = dbt.mustQuery("SELECT value FROM test") + if rows.Next() { + rows.Scan(&out) + if in != out { + dbt.Errorf("%s ZEROFILL: %d != %d", v, in, out) + } + } else { + dbt.Errorf("%s ZEROFILL: no data", v) + } - // SIGNED - for _, v := range types { - mustExec(t, db, "CREATE TABLE test (value "+v+")") + dbt.mustExec("DROP TABLE IF EXISTS test") + } + }) +} - mustExec(t, db, ("INSERT INTO test VALUES (?)"), in) +func TestFloat(t *testing.T) { + runTests(t, "TestFloat", func(dbt *DBTest) { + types := [2]string{"FLOAT", "DOUBLE"} + in := float32(42.23) + var out float32 + var rows *sql.Rows + for _, v := range types { + dbt.mustExec("CREATE TABLE test (value " + v + ")") + dbt.mustExec("INSERT INTO test VALUES (?)", in) + rows = dbt.mustQuery("SELECT value FROM test") + if rows.Next() { + rows.Scan(&out) + if in != out { + dbt.Errorf("%s: %g != %g", v, in, out) + } + } else { + dbt.Errorf("%s: no data", v) + } + dbt.mustExec("DROP TABLE IF EXISTS test") + } + }) +} - rows = mustQuery(t, db, ("SELECT value FROM test")) - if rows.Next() { - rows.Scan(&out) - if in != out { - t.Errorf("%s: %d != %d", v, in, out) +func TestString(t *testing.T) { + runTests(t, "TestString", func(dbt *DBTest) { + types := [6]string{"CHAR(255)", "VARCHAR(255)", "TINYTEXT", "TEXT", "MEDIUMTEXT", "LONGTEXT"} + in := "κόσμε üöäßñóùéàâÿœ'îë Árvíztűrő いろはにほへとちりぬるを イロハニホヘト דג סקרן чащах น่าฟังเอย" + var out string + var rows *sql.Rows + + for _, v := range types { + dbt.mustExec("CREATE TABLE test (value " + v + ") CHARACTER SET utf8") + + dbt.mustExec("INSERT INTO test VALUES (?)", in) + + rows = dbt.mustQuery("SELECT value FROM test") + if rows.Next() { + rows.Scan(&out) + if in != out { + dbt.Errorf("%s: %s != %s", v, in, out) + } + } else { + dbt.Errorf("%s: no data", v) } - } else { - t.Errorf("%s: no data", v) + + dbt.mustExec("DROP TABLE IF EXISTS test") } - mustExec(t, db, "DROP TABLE IF EXISTS test") - } + // BLOB + dbt.mustExec("CREATE TABLE test (id int, value BLOB) CHARACTER SET utf8") - // UNSIGNED ZEROFILL - for _, v := range types { - mustExec(t, db, "CREATE TABLE test (value "+v+" ZEROFILL)") + id := 2 + in = "Lorem ipsum dolor sit amet, consetetur sadipscing elitr, " + + "sed diam nonumy eirmod tempor invidunt ut labore et dolore magna aliquyam erat, " + + "sed diam voluptua. At vero eos et accusam et justo duo dolores et ea rebum. " + + "Stet clita kasd gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet. " + + "Lorem ipsum dolor sit amet, consetetur sadipscing elitr, " + + "sed diam nonumy eirmod tempor invidunt ut labore et dolore magna aliquyam erat, " + + "sed diam voluptua. At vero eos et accusam et justo duo dolores et ea rebum. " + + "Stet clita kasd gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet." + dbt.mustExec("INSERT INTO test VALUES (?, ?)", id, in) - mustExec(t, db, ("INSERT INTO test VALUES (?)"), in) + err := dbt.db.QueryRow("SELECT value FROM test WHERE id = ?", id).Scan(&out) + if err != nil { + dbt.Fatalf("Error on BLOB-Query: %v", err) + } else if out != in { + dbt.Errorf("BLOB: %s != %s", in, out) + } + }) +} - rows = mustQuery(t, db, ("SELECT value FROM test")) - if rows.Next() { - rows.Scan(&out) - if in != out { - t.Errorf("%s ZEROFILL: %d != %d", v, in, out) +func TestDateTime(t *testing.T) { + type testmode struct { + selectSuffix string + args []interface{} + } + type timetest struct { + in interface{} + sOut string + tOut time.Time + tIsZero bool + } + type tester func(dbt *DBTest, rows *sql.Rows, + test *timetest, sqltype, resulttype, mode string) + type setup struct { + vartype string + dsnSuffix string + test tester + } + var ( + tdate = time.Date(2012, 6, 14, 0, 0, 0, 0, time.UTC) + sdate = "2012-06-14" + tdatetime = time.Date(2011, 11, 20, 21, 27, 37, 0, time.UTC) + sdatetime = "2011-11-20 21:27:37" + tdate0 = time.Time{} + sdate0 = "0000-00-00" + sdatetime0 = "0000-00-00 00:00:00" + modes = map[string]*testmode{ + "text": &testmode{}, + "binary": &testmode{" WHERE 1 = ?", []interface{}{1}}, + } + timetests = map[string][]*timetest{ + "DATE": { + {sdate, sdate, tdate, false}, + {sdate0, sdate0, tdate0, true}, + {tdate, sdate, tdate, false}, + {tdate0, sdate0, tdate0, true}, + }, + "DATETIME": { + {sdatetime, sdatetime, tdatetime, false}, + {sdatetime0, sdatetime0, tdate0, true}, + {tdatetime, sdatetime, tdatetime, false}, + {tdate0, sdatetime0, tdate0, true}, + }, + } + setups = []*setup{ + {"string", "&parseTime=false", func( + dbt *DBTest, rows *sql.Rows, test *timetest, sqltype, resulttype, mode string) { + var sOut string + if err := rows.Scan(&sOut); err != nil { + dbt.Errorf("%s (%s %s): %v", sqltype, resulttype, mode, err) + } else if test.sOut != sOut { + dbt.Errorf("%s (%s %s): %s != %s", sqltype, resulttype, mode, test.sOut, sOut) + } + }}, + {"time.Time", "&parseTime=true", func( + dbt *DBTest, rows *sql.Rows, test *timetest, sqltype, resulttype, mode string) { + var tOut time.Time + if err := rows.Scan(&tOut); err != nil { + dbt.Errorf("%s (%s %s): %v", sqltype, resulttype, mode, err) + } else if test.tOut != tOut || test.tIsZero != tOut.IsZero() { + dbt.Errorf("%s (%s %s): %s [%t] != %s [%t]", sqltype, resulttype, mode, test.tOut, test.tIsZero, tOut, tOut.IsZero()) + } + }}, + } + ) + + var s *setup + testTime := func(dbt *DBTest) { + var rows *sql.Rows + for sqltype, tests := range timetests { + dbt.mustExec("DROP TABLE IF EXISTS test") + dbt.mustExec("CREATE TABLE test (value " + sqltype + ")") + for _, test := range tests { + for mode, q := range modes { + dbt.mustExec("TRUNCATE test") + dbt.mustExec("INSERT INTO test VALUES (?)", test.in) + rows = dbt.mustQuery("SELECT value FROM test"+q.selectSuffix, q.args...) + if rows.Next() { + s.test(dbt, rows, test, sqltype, s.vartype, mode) + } else { + if err := rows.Err(); err != nil { + dbt.Errorf("%s (%s %s): %v", + sqltype, s.vartype, mode, err) + } else { + dbt.Errorf("%s (%s %s): no data", + sqltype, s.vartype, mode) + } + } + } } - } else { - t.Errorf("%s ZEROFILL: no data", v) } + } - mustExec(t, db, "DROP TABLE IF EXISTS test") + oldDsn := dsn + usedDsn := oldDsn + "&sql_mode=ALLOW_INVALID_DATES" + for _, v := range setups { + s = v + dsn = usedDsn + s.dsnSuffix + runTests(t, "TestDateTime", testTime) } + dsn = oldDsn } -func TestFloat(t *testing.T) { - if !getEnv() { - t.Logf("MySQL-Server not running on %s. Skipping TestFloat", netAddr) - return - } +func TestNULL(t *testing.T) { + runTests(t, "TestNULL", func(dbt *DBTest) { + nullStmt, err := dbt.db.Prepare("SELECT NULL") + if err != nil { + dbt.Fatal(err) + } + defer nullStmt.Close() - db, err := sql.Open("mysql", dsn) - if err != nil { - t.Fatalf("Error connecting: %v", err) - } + nonNullStmt, err := dbt.db.Prepare("SELECT 1") + if err != nil { + dbt.Fatal(err) + } + defer nonNullStmt.Close() - defer db.Close() + // NullBool + var nb sql.NullBool + // Invalid + err = nullStmt.QueryRow().Scan(&nb) + if err != nil { + dbt.Fatal(err) + } + if nb.Valid { + dbt.Error("Valid NullBool which should be invalid") + } + // Valid + err = nonNullStmt.QueryRow().Scan(&nb) + if err != nil { + dbt.Fatal(err) + } + if !nb.Valid { + dbt.Error("Invalid NullBool which should be valid") + } else if nb.Bool != true { + dbt.Errorf("Unexpected NullBool value: %t (should be true)", nb.Bool) + } + + // NullFloat64 + var nf sql.NullFloat64 + // Invalid + err = nullStmt.QueryRow().Scan(&nf) + if err != nil { + dbt.Fatal(err) + } + if nf.Valid { + dbt.Error("Valid NullFloat64 which should be invalid") + } + // Valid + err = nonNullStmt.QueryRow().Scan(&nf) + if err != nil { + dbt.Fatal(err) + } + if !nf.Valid { + dbt.Error("Invalid NullFloat64 which should be valid") + } else if nf.Float64 != float64(1) { + dbt.Errorf("Unexpected NullFloat64 value: %f (should be 1.0)", nf.Float64) + } - mustExec(t, db, "DROP TABLE IF EXISTS test") + // NullInt64 + var ni sql.NullInt64 + // Invalid + err = nullStmt.QueryRow().Scan(&ni) + if err != nil { + dbt.Fatal(err) + } + if ni.Valid { + dbt.Error("Valid NullInt64 which should be invalid") + } + // Valid + err = nonNullStmt.QueryRow().Scan(&ni) + if err != nil { + dbt.Fatal(err) + } + if !ni.Valid { + dbt.Error("Invalid NullInt64 which should be valid") + } else if ni.Int64 != int64(1) { + dbt.Errorf("Unexpected NullInt64 value: %d (should be 1)", ni.Int64) + } - types := [2]string{"FLOAT", "DOUBLE"} - in := float32(42.23) - var out float32 - var rows *sql.Rows + // NullString + var ns sql.NullString + // Invalid + err = nullStmt.QueryRow().Scan(&ns) + if err != nil { + dbt.Fatal(err) + } + if ns.Valid { + dbt.Error("Valid NullString which should be invalid") + } + // Valid + err = nonNullStmt.QueryRow().Scan(&ns) + if err != nil { + dbt.Fatal(err) + } + if !ns.Valid { + dbt.Error("Invalid NullString which should be valid") + } else if ns.String != `1` { + dbt.Error("Unexpected NullString value:" + ns.String + " (should be `1`)") + } - for _, v := range types { - mustExec(t, db, "CREATE TABLE test (value "+v+")") + // Insert NULL + dbt.mustExec("CREATE TABLE test (dummmy1 int, value int, dummy2 int)") - mustExec(t, db, ("INSERT INTO test VALUES (?)"), in) + dbt.mustExec("INSERT INTO test VALUES (?, ?, ?)", 1, nil, 2) - rows = mustQuery(t, db, ("SELECT value FROM test")) + var out interface{} + rows := dbt.mustQuery("SELECT * FROM test") if rows.Next() { rows.Scan(&out) - if in != out { - t.Errorf("%s: %g != %g", v, in, out) + if out != nil { + dbt.Errorf("%v != nil", out) } } else { - t.Errorf("%s: no data", v) + dbt.Error("no data") } - - mustExec(t, db, "DROP TABLE IF EXISTS test") - } + }) } -func TestString(t *testing.T) { - if !getEnv() { - t.Logf("MySQL-Server not running on %s. Skipping TestString", netAddr) - return - } - - db, err := sql.Open("mysql", dsn) - if err != nil { - t.Fatalf("Error connecting: %v", err) - } +func TestLongData(t *testing.T) { + runTests(t, "TestLongData", func(dbt *DBTest) { + var maxAllowedPacketSize int + err := dbt.db.QueryRow("select @@max_allowed_packet").Scan(&maxAllowedPacketSize) + if err != nil { + dbt.Fatal(err) + } + maxAllowedPacketSize-- - defer db.Close() + // don't get too ambitious + if maxAllowedPacketSize > 1<<25 { + maxAllowedPacketSize = 1 << 25 + } - mustExec(t, db, "DROP TABLE IF EXISTS test") + dbt.mustExec("CREATE TABLE test (value LONGBLOB)") - types := [6]string{"CHAR(255)", "VARCHAR(255)", "TINYTEXT", "TEXT", "MEDIUMTEXT", "LONGTEXT"} - in := "κόσμε üöäßñóùéàâÿœ'îë Árvíztűrő いろはにほへとちりぬるを イロハニホヘト דג סקרן чащах น่าฟังเอย" - var out string - var rows *sql.Rows + in := strings.Repeat(`0`, maxAllowedPacketSize+1) + var out string + var rows *sql.Rows - for _, v := range types { - mustExec(t, db, "CREATE TABLE test (value "+v+") CHARACTER SET utf8 COLLATE utf8_unicode_ci") + // Long text data + const nonDataQueryLen = 28 // length query w/o value + inS := in[:maxAllowedPacketSize-nonDataQueryLen] + dbt.mustExec("INSERT INTO test VALUES('" + inS + "')") + rows = dbt.mustQuery("SELECT value FROM test") + if rows.Next() { + rows.Scan(&out) + if inS != out { + dbt.Fatalf("LONGBLOB: length in: %d, length out: %d", len(inS), len(out)) + } + if rows.Next() { + dbt.Error("LONGBLOB: unexpexted row") + } + } else { + dbt.Fatalf("LONGBLOB: no data") + } - mustExec(t, db, ("INSERT INTO test VALUES (?)"), in) + // Empty table + dbt.mustExec("TRUNCATE TABLE test") - rows = mustQuery(t, db, ("SELECT value FROM test")) + // Long binary data + dbt.mustExec("INSERT INTO test VALUES(?)", in) + rows = dbt.mustQuery("SELECT value FROM test WHERE 1=?", 1) if rows.Next() { rows.Scan(&out) if in != out { - t.Errorf("%s: %s != %s", v, in, out) + dbt.Fatalf("LONGBLOB: length in: %d, length out: %d", len(in), len(out)) + } + if rows.Next() { + dbt.Error("LONGBLOB: unexpexted row") } } else { - t.Errorf("%s: no data", v) + dbt.Fatalf("LONGBLOB: no data") } - - mustExec(t, db, "DROP TABLE IF EXISTS test") - } - - // BLOB - mustExec(t, db, "CREATE TABLE test (id int, value BLOB) CHARACTER SET utf8 COLLATE utf8_unicode_ci") - - id := 2 - in = "Lorem ipsum dolor sit amet, consetetur sadipscing elitr, " + - "sed diam nonumy eirmod tempor invidunt ut labore et dolore magna aliquyam erat, " + - "sed diam voluptua. At vero eos et accusam et justo duo dolores et ea rebum. " + - "Stet clita kasd gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet. " + - "Lorem ipsum dolor sit amet, consetetur sadipscing elitr, " + - "sed diam nonumy eirmod tempor invidunt ut labore et dolore magna aliquyam erat, " + - "sed diam voluptua. At vero eos et accusam et justo duo dolores et ea rebum. " + - "Stet clita kasd gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet." - mustExec(t, db, ("INSERT INTO test VALUES (?, ?)"), id, in) - - err = db.QueryRow("SELECT value FROM test WHERE id = ?", id).Scan(&out) - if err != nil { - t.Fatalf("Error on BLOB-Query: %v", err) - } else if out != in { - t.Errorf("BLOB: %s != %s", in, out) - } - - return + }) } -func TestDateTime(t *testing.T) { - if !getEnv() { - t.Logf("MySQL-Server not running on %s. Skipping TestString", netAddr) - return - } +func TestLoadData(t *testing.T) { + runTests(t, "TestLoadData", func(dbt *DBTest) { + verifyLoadDataResult := func() { + rows, err := dbt.db.Query("SELECT * FROM test") + if err != nil { + dbt.Fatal(err.Error()) + } - var modes = [2]string{"text", "binary"} - var types = [2]string{"DATE", "DATETIME"} - var tests = [2][]struct { - in interface{} - sOut string - tOut time.Time - tIsZero bool - }{ - { - {"2012-06-14", "2012-06-14", time.Date(2012, 6, 14, 0, 0, 0, 0, time.UTC), false}, - {"0000-00-00", "0000-00-00", time.Time{}, true}, - {time.Date(2012, 6, 14, 0, 0, 0, 0, time.UTC), "2012-06-14", time.Date(2012, 6, 14, 0, 0, 0, 0, time.UTC), false}, - {time.Time{}, "0000-00-00", time.Time{}, true}, - }, - { - {"2011-11-20 21:27:37", "2011-11-20 21:27:37", time.Date(2011, 11, 20, 21, 27, 37, 0, time.UTC), false}, - {"0000-00-00 00:00:00", "0000-00-00 00:00:00", time.Time{}, true}, - {time.Date(2011, 11, 20, 21, 27, 37, 0, time.UTC), "2011-11-20 21:27:37", time.Date(2011, 11, 20, 21, 27, 37, 0, time.UTC), false}, - {time.Time{}, "0000-00-00 00:00:00", time.Time{}, true}, - }, - } - var sOut string - var tOut time.Time + i := 0 + values := [4]string{ + "a string", + "a string containing a \t", + "a string containing a \n", + "a string containing both \t\n", + } - var rows [2]*sql.Rows - var sDB, tDB *sql.DB - var err error + var id int + var value string - sDB, err = sql.Open("mysql", dsn) - if err != nil { - t.Fatalf("Error connecting (string): %v", err) - } - defer sDB.Close() + for rows.Next() { + i++ + err = rows.Scan(&id, &value) + if err != nil { + dbt.Fatal(err.Error()) + } + if i != id { + dbt.Fatalf("%d != %d", i, id) + } + if values[i-1] != value { + dbt.Fatalf("%s != %s", values[i-1], value) + } + } + err = rows.Err() + if err != nil { + dbt.Fatal(err.Error()) + } - tDB, err = sql.Open("mysql", dsn+"&parseTime=true") - if err != nil { - t.Fatalf("Error connecting (time.Time): %v", err) - } - defer tDB.Close() + if i != 4 { + dbt.Fatalf("Rows count mismatch. Got %d, want 4", i) + } + } + file, err := ioutil.TempFile("", "gotest") + defer os.Remove(file.Name()) + if err != nil { + dbt.Fatal(err) + } + file.WriteString("1\ta string\n2\ta string containing a \\t\n3\ta string containing a \\n\n4\ta string containing both \\t\\n\n") + file.Close() - mustExec(t, sDB, "DROP TABLE IF EXISTS test") - for i, v := range types { - mustExec(t, sDB, "CREATE TABLE test (value "+v+") CHARACTER SET utf8 COLLATE utf8_unicode_ci") + dbt.mustExec("DROP TABLE IF EXISTS test") + dbt.mustExec("CREATE TABLE test (id INT NOT NULL PRIMARY KEY, value TEXT NOT NULL) CHARACTER SET utf8") - for j := range tests[i] { - mustExec(t, sDB, "INSERT INTO test VALUES (?)", tests[i][j].in) + // Local File + RegisterLocalFile(file.Name()) + dbt.mustExec(fmt.Sprintf("LOAD DATA LOCAL INFILE '%q' INTO TABLE test", file.Name())) + verifyLoadDataResult() + // negative test + _, err = dbt.db.Exec("LOAD DATA LOCAL INFILE 'doesnotexist' INTO TABLE test") + if err == nil { + dbt.Fatal("Load non-existent file didn't fail") + } else if err.Error() != "Local File 'doesnotexist' is not registered. Use the DSN parameter 'allowAllFiles=true' to allow all files" { + dbt.Fatal(err.Error()) + } - // string - rows[0] = mustQuery(t, sDB, "SELECT value FROM test") // text - rows[1] = mustQuery(t, sDB, "SELECT value FROM test WHERE 1 = ?", 1) // binary + // Empty table + dbt.mustExec("TRUNCATE TABLE test") - for k := range rows { - if rows[k].Next() { - err = rows[k].Scan(&sOut) - if err != nil { - t.Errorf("%s (string %s): %v", v, modes[k], err) - } else if tests[i][j].sOut != sOut { - t.Errorf("%s (string %s): %s != %s", v, modes[k], tests[i][j].sOut, sOut) - } - } else { - err = rows[k].Err() - if err != nil { - t.Errorf("%s (string %s): %v", v, modes[k], err) - } else { - t.Errorf("%s (string %s): no data", v, modes[k]) - } - } - } - - // time.Time - rows[0] = mustQuery(t, tDB, "SELECT value FROM test") // text - rows[1] = mustQuery(t, tDB, "SELECT value FROM test WHERE 1 = ?", 1) // binary - - for k := range rows { - if rows[k].Next() { - err = rows[k].Scan(&tOut) - if err != nil { - t.Errorf("%s (time.Time %s): %v", v, modes[k], err) - } else if tests[i][j].tOut != tOut || tests[i][j].tIsZero != tOut.IsZero() { - t.Errorf("%s (time.Time %s): %s [%t] != %s [%t]", v, modes[k], tests[i][j].tOut, tests[i][j].tIsZero, tOut, tOut.IsZero()) - } - } else { - err = rows[k].Err() - if err != nil { - t.Errorf("%s (time.Time %s): %v", v, modes[k], err) - } else { - t.Errorf("%s (time.Time %s): no data", v, modes[k]) - } - - } + // Reader + RegisterReaderHandler("test", func() io.Reader { + file, err = os.Open(file.Name()) + if err != nil { + dbt.Fatal(err) } - - mustExec(t, sDB, "TRUNCATE TABLE test") + return file + }) + dbt.mustExec("LOAD DATA LOCAL INFILE 'Reader::test' INTO TABLE test") + verifyLoadDataResult() + // negative test + _, err = dbt.db.Exec("LOAD DATA LOCAL INFILE 'Reader::doesnotexist' INTO TABLE test") + if err == nil { + dbt.Fatal("Load non-existent Reader didn't fail") + } else if err.Error() != "Reader 'doesnotexist' is not registered" { + dbt.Fatal(err.Error()) } - - mustExec(t, sDB, "DROP TABLE IF EXISTS test") - } + }) } -func TestNULL(t *testing.T) { - if !getEnv() { - t.Logf("MySQL-Server not running on %s. Skipping TestNULL", netAddr) - return - } - - db, err := sql.Open("mysql", dsn) - if err != nil { - t.Fatalf("Error connecting: %v", err) - } - - defer db.Close() - - nullStmt, err := db.Prepare("SELECT NULL") - if err != nil { - t.Fatal(err) - } - defer nullStmt.Close() - - nonNullStmt, err := db.Prepare("SELECT 1") - if err != nil { - t.Fatal(err) - } - defer nonNullStmt.Close() - - // NullBool - var nb sql.NullBool - // Invalid - err = nullStmt.QueryRow().Scan(&nb) - if err != nil { - t.Fatal(err) - } - if nb.Valid { - t.Error("Valid NullBool which should be invalid") - } - // Valid - err = nonNullStmt.QueryRow().Scan(&nb) - if err != nil { - t.Fatal(err) - } - if !nb.Valid { - t.Error("Invalid NullBool which should be valid") - } else if nb.Bool != true { - t.Errorf("Unexpected NullBool value: %t (should be true)", nb.Bool) - } - - // NullFloat64 - var nf sql.NullFloat64 - // Invalid - err = nullStmt.QueryRow().Scan(&nf) - if err != nil { - t.Fatal(err) - } - if nf.Valid { - t.Error("Valid NullFloat64 which should be invalid") - } - // Valid - err = nonNullStmt.QueryRow().Scan(&nf) - if err != nil { - t.Fatal(err) - } - if !nf.Valid { - t.Error("Invalid NullFloat64 which should be valid") - } else if nf.Float64 != float64(1) { - t.Errorf("Unexpected NullFloat64 value: %f (should be 1.0)", nf.Float64) - } - - // NullInt64 - var ni sql.NullInt64 - // Invalid - err = nullStmt.QueryRow().Scan(&ni) - if err != nil { - t.Fatal(err) - } - if ni.Valid { - t.Error("Valid NullInt64 which should be invalid") - } - // Valid - err = nonNullStmt.QueryRow().Scan(&ni) - if err != nil { - t.Fatal(err) - } - if !ni.Valid { - t.Error("Invalid NullInt64 which should be valid") - } else if ni.Int64 != int64(1) { - t.Errorf("Unexpected NullInt64 value: %d (should be 1)", ni.Int64) - } - - // NullString - var ns sql.NullString - // Invalid - err = nullStmt.QueryRow().Scan(&ns) - if err != nil { - t.Fatal(err) - } - if ns.Valid { - t.Error("Valid NullString which should be invalid") - } - // Valid - err = nonNullStmt.QueryRow().Scan(&ns) - if err != nil { - t.Fatal(err) - } - if !ns.Valid { - t.Error("Invalid NullString which should be valid") - } else if ns.String != `1` { - t.Error("Unexpected NullString value:" + ns.String + " (should be `1`)") - } - - // Insert NULL - mustExec(t, db, "CREATE TABLE test (dummmy1 int, value int, dummy2 int)") - - mustExec(t, db, ("INSERT INTO test VALUES (?, ?, ?)"), 1, nil, 2) +// Special cases - var out interface{} - rows := mustQuery(t, db, ("SELECT * FROM test")) - if rows.Next() { - rows.Scan(&out) - if out != nil { - t.Errorf("%v != nil", out) +func TestRowsClose(t *testing.T) { + runTests(t, "TestRowsClose", func(dbt *DBTest) { + rows, err := dbt.db.Query("SELECT 1") + if err != nil { + dbt.Fatal(err) } - } else { - t.Error("no data") - } - mustExec(t, db, "DROP TABLE IF EXISTS test") -} - -func TestLongData(t *testing.T) { - if !getEnv() { - t.Logf("MySQL-Server not running on %s. Skipping TestLongData", netAddr) - return - } - - db, err := sql.Open("mysql", dsn) - if err != nil { - t.Fatalf("Error connecting: %v", err) - } - defer db.Close() - - var maxAllowedPacketSize int - err = db.QueryRow("select @@max_allowed_packet").Scan(&maxAllowedPacketSize) - if err != nil { - t.Fatal(err) - } - maxAllowedPacketSize-- - - // don't get too ambitious - if maxAllowedPacketSize > 1<<25 { - maxAllowedPacketSize = 1 << 25 - } - - mustExec(t, db, "DROP TABLE IF EXISTS test") - mustExec(t, db, "CREATE TABLE test (value LONGBLOB) CHARACTER SET utf8 COLLATE utf8_unicode_ci") - - in := strings.Repeat(`0`, maxAllowedPacketSize+1) - var out string - var rows *sql.Rows - - // Long text data - const nonDataQueryLen = 28 // length query w/o value - inS := in[:maxAllowedPacketSize-nonDataQueryLen] - mustExec(t, db, "INSERT INTO test VALUES('"+inS+"')") - rows = mustQuery(t, db, "SELECT value FROM test") - if rows.Next() { - rows.Scan(&out) - if inS != out { - t.Fatalf("LONGBLOB: length in: %d, length out: %d", len(inS), len(out)) - } - if rows.Next() { - t.Error("LONGBLOB: unexpexted row") + err = rows.Close() + if err != nil { + dbt.Fatal(err) } - } else { - t.Fatalf("LONGBLOB: no data") - } - - // Empty table - mustExec(t, db, "TRUNCATE TABLE test") - // Long binary data - mustExec(t, db, "INSERT INTO test VALUES(?)", in) - rows = mustQuery(t, db, "SELECT value FROM test WHERE 1=?", 1) - if rows.Next() { - rows.Scan(&out) - if in != out { - t.Fatalf("LONGBLOB: length in: %d, length out: %d", len(in), len(out)) - } if rows.Next() { - t.Error("LONGBLOB: unexpexted row") - } - } else { - t.Fatalf("LONGBLOB: no data") - } - - mustExec(t, db, "DROP TABLE IF EXISTS test") -} - -func verifyLoadDataResult(t *testing.T, db *sql.DB) { - rows, err := db.Query("SELECT * FROM test") - if err != nil { - t.Fatal(err.Error()) - } - - i := 0 - values := [4]string{ - "a string", - "a string containing a \t", - "a string containing a \n", - "a string containing both \t\n", - } - - var id int - var value string - - for rows.Next() { - i++ - err = rows.Scan(&id, &value) - if err != nil { - t.Fatal(err.Error()) - } - if i != id { - t.Fatalf("%d != %d", i, id) - } - if values[i-1] != value { - t.Fatalf("%s != %s", values[i-1], value) + dbt.Fatal("Unexpected row after rows.Close()") } - } - err = rows.Err() - if err != nil { - t.Fatal(err.Error()) - } - - if i != 4 { - t.Fatalf("Rows count mismatch. Got %d, want 4", i) - } -} -func TestLoadData(t *testing.T) { - if !getEnv() { - t.Logf("MySQL-Server not running on %s. Skipping TestLoadData", netAddr) - return - } - - db, err := sql.Open("mysql", dsn) - if err != nil { - t.Fatalf("Error connecting: %v", err) - } - defer db.Close() - - file, err := ioutil.TempFile("", "gotest") - defer os.Remove(file.Name()) - if err != nil { - t.Fatal(err) - } - file.WriteString("1\ta string\n2\ta string containing a \\t\n3\ta string containing a \\n\n4\ta string containing both \\t\\n\n") - file.Close() - - mustExec(t, db, "DROP TABLE IF EXISTS test") - mustExec(t, db, "CREATE TABLE test (id INT NOT NULL PRIMARY KEY, value TEXT NOT NULL) CHARACTER SET utf8 COLLATE utf8_unicode_ci") - - // Local File - RegisterLocalFile(file.Name()) - mustExec(t, db, fmt.Sprintf("LOAD DATA LOCAL INFILE '%q' INTO TABLE test", file.Name())) - verifyLoadDataResult(t, db) - // negative test - _, err = db.Exec("LOAD DATA LOCAL INFILE 'doesnotexist' INTO TABLE test") - if err == nil { - t.Fatal("Load non-existent file didn't fail") - } else if err.Error() != "Local File 'doesnotexist' is not registered. Use the DSN parameter 'allowAllFiles=true' to allow all files" { - t.Fatal(err.Error()) - } - - // Empty table - mustExec(t, db, "TRUNCATE TABLE test") - - // Reader - RegisterReaderHandler("test", func() io.Reader { - file, err = os.Open(file.Name()) + err = rows.Err() if err != nil { - t.Fatal(err) + dbt.Fatal(err) } - return file }) - mustExec(t, db, "LOAD DATA LOCAL INFILE 'Reader::test' INTO TABLE test") - verifyLoadDataResult(t, db) - // negative test - _, err = db.Exec("LOAD DATA LOCAL INFILE 'Reader::doesnotexist' INTO TABLE test") - if err == nil { - t.Fatal("Load non-existent Reader didn't fail") - } else if err.Error() != "Reader 'doesnotexist' is not registered" { - t.Fatal(err.Error()) - } - - mustExec(t, db, "DROP TABLE IF EXISTS test") -} - -// Special cases - -func TestRowsClose(t *testing.T) { - if !getEnv() { - t.Logf("MySQL-Server not running on %s. Skipping TestRowsClose", netAddr) - return - } - - db, err := sql.Open("mysql", dsn) - if err != nil { - t.Fatalf("Error connecting: %v", err) - } - - defer db.Close() - - rows, err := db.Query("SELECT 1") - if err != nil { - t.Fatal(err) - } - - err = rows.Close() - if err != nil { - t.Fatal(err) - } - - if rows.Next() { - t.Fatal("Unexpected row after rows.Close()") - } - - err = rows.Err() - if err != nil { - t.Fatal(err) - } } // dangling statements // http://code.google.com/p/go/issues/detail?id=3865 func TestCloseStmtBeforeRows(t *testing.T) { - if !getEnv() { - t.Logf("MySQL-Server not running on %s. Skipping TestCloseStmtBeforeRows", netAddr) - return - } - - db, err := sql.Open("mysql", dsn) - if err != nil { - t.Fatalf("Error connecting: %v", err) - } - - defer db.Close() - - stmt, err := db.Prepare("SELECT 1") - if err != nil { - t.Fatal(err) - } - - rows, err := stmt.Query() - if err != nil { - stmt.Close() - t.Fatal(err) - } - defer rows.Close() - - err = stmt.Close() - if err != nil { - t.Fatal(err) - } + runTests(t, "TestCloseStmtBeforeRows", func(dbt *DBTest) { + stmt, err := dbt.db.Prepare("SELECT 1") + if err != nil { + dbt.Fatal(err) + } - if !rows.Next() { - t.Fatal("Getting row failed") - } else { - err = rows.Err() + rows, err := stmt.Query() if err != nil { - t.Fatal(err) + stmt.Close() + dbt.Fatal(err) } + defer rows.Close() - var out bool - err = rows.Scan(&out) + err = stmt.Close() if err != nil { - t.Fatalf("Error on rows.Scan(): %v", err) + dbt.Fatal(err) } - if out != true { - t.Errorf("true != %t", out) + + if !rows.Next() { + dbt.Fatal("Getting row failed") + } else { + err = rows.Err() + if err != nil { + dbt.Fatal(err) + } + + var out bool + err = rows.Scan(&out) + if err != nil { + dbt.Fatalf("Error on rows.Scan(): %v", err) + } + if out != true { + dbt.Errorf("true != %t", out) + } } - } + }) } // It is valid to have multiple Rows for the same Stmt // http://code.google.com/p/go/issues/detail?id=3734 func TestStmtMultiRows(t *testing.T) { - if !getEnv() { - t.Logf("MySQL-Server not running on %s. Skipping TestStmtMultiRows", netAddr) - return - } - - db, err := sql.Open("mysql", dsn) - if err != nil { - t.Fatalf("Error connecting: %v", err) - } - - defer db.Close() - - stmt, err := db.Prepare("SELECT 1 UNION SELECT 0") - if err != nil { - t.Fatal(err) - } - - rows1, err := stmt.Query() - if err != nil { - stmt.Close() - t.Fatal(err) - } - defer rows1.Close() - - rows2, err := stmt.Query() - if err != nil { - stmt.Close() - t.Fatal(err) - } - defer rows2.Close() - - var out bool - - // 1 - if !rows1.Next() { - t.Fatal("1st rows1.Next failed") - } else { - err = rows1.Err() + runTests(t, "TestStmtMultiRows", func(dbt *DBTest) { + stmt, err := dbt.db.Prepare("SELECT 1 UNION SELECT 0") if err != nil { - t.Fatal(err) + dbt.Fatal(err) } - err = rows1.Scan(&out) + rows1, err := stmt.Query() if err != nil { - t.Fatalf("Error on rows.Scan(): %v", err) - } - if out != true { - t.Errorf("true != %t", out) + stmt.Close() + dbt.Fatal(err) } - } + defer rows1.Close() - if !rows2.Next() { - t.Fatal("1st rows2.Next failed") - } else { - err = rows2.Err() + rows2, err := stmt.Query() if err != nil { - t.Fatal(err) + stmt.Close() + dbt.Fatal(err) } + defer rows2.Close() - err = rows2.Scan(&out) - if err != nil { - t.Fatalf("Error on rows.Scan(): %v", err) - } - if out != true { - t.Errorf("true != %t", out) - } - } + var out bool - // 2 - if !rows1.Next() { - t.Fatal("2nd rows1.Next failed") - } else { - err = rows1.Err() - if err != nil { - t.Fatal(err) - } + // 1 + if !rows1.Next() { + dbt.Fatal("1st rows1.Next failed") + } else { + err = rows1.Err() + if err != nil { + dbt.Fatal(err) + } - err = rows1.Scan(&out) - if err != nil { - t.Fatalf("Error on rows.Scan(): %v", err) - } - if out != false { - t.Errorf("false != %t", out) + err = rows1.Scan(&out) + if err != nil { + dbt.Fatalf("Error on rows.Scan(): %v", err) + } + if out != true { + dbt.Errorf("true != %t", out) + } } - if rows1.Next() { - t.Fatal("Unexpected row on rows1") - } - err = rows1.Close() - if err != nil { - t.Fatal(err) - } - } + if !rows2.Next() { + dbt.Fatal("1st rows2.Next failed") + } else { + err = rows2.Err() + if err != nil { + dbt.Fatal(err) + } - if !rows2.Next() { - t.Fatal("2nd rows2.Next failed") - } else { - err = rows2.Err() - if err != nil { - t.Fatal(err) + err = rows2.Scan(&out) + if err != nil { + dbt.Fatalf("Error on rows.Scan(): %v", err) + } + if out != true { + dbt.Errorf("true != %t", out) + } } - err = rows2.Scan(&out) - if err != nil { - t.Fatalf("Error on rows.Scan(): %v", err) - } - if out != false { - t.Errorf("false != %t", out) - } + // 2 + if !rows1.Next() { + dbt.Fatal("2nd rows1.Next failed") + } else { + err = rows1.Err() + if err != nil { + dbt.Fatal(err) + } - if rows2.Next() { - t.Fatal("Unexpected row on rows2") - } - err = rows2.Close() - if err != nil { - t.Fatal(err) + err = rows1.Scan(&out) + if err != nil { + dbt.Fatalf("Error on rows.Scan(): %v", err) + } + if out != false { + dbt.Errorf("false != %t", out) + } + + if rows1.Next() { + dbt.Fatal("Unexpected row on rows1") + } + err = rows1.Close() + if err != nil { + dbt.Fatal(err) + } } - } + if !rows2.Next() { + dbt.Fatal("2nd rows2.Next failed") + } else { + err = rows2.Err() + if err != nil { + dbt.Fatal(err) + } + + err = rows2.Scan(&out) + if err != nil { + dbt.Fatalf("Error on rows.Scan(): %v", err) + } + if out != false { + dbt.Errorf("false != %t", out) + } + + if rows2.Next() { + dbt.Fatal("Unexpected row on rows2") + } + err = rows2.Close() + if err != nil { + dbt.Fatal(err) + } + } + }) } func TestConcurrent(t *testing.T) { @@ -1019,64 +923,47 @@ func TestConcurrent(t *testing.T) { t.Log("CONCURRENT env var not set. Skipping TestConcurrent") return } - if !getEnv() { - t.Logf("MySQL-Server not running on %s. Skipping TestConcurrent", netAddr) - return - } - - db, err := sql.Open("mysql", dsn) - if err != nil { - t.Fatalf("Error connecting: %v", err) - } - - defer db.Close() - - var max int - err = db.QueryRow("SELECT @@max_connections").Scan(&max) - if err != nil { - t.Fatalf("%v", err) - } - - t.Logf("Testing up to %d concurrent connections \r\n", max) - - canStop := false - - c := make(chan struct{}, max) - for i := 0; i < max; i++ { - go func(id int) { - tx, err := db.Begin() - if err != nil { - canStop = true - if err.Error() == "Error 1040: Too many connections" { - max-- - return - } else { - t.Fatalf("Error on Con %d: %s", id, err.Error()) + runTests(t, "TestConcurrent", func(dbt *DBTest) { + var max int + err := dbt.db.QueryRow("SELECT @@max_connections").Scan(&max) + if err != nil { + dbt.Fatalf("%v", err) + } + dbt.Logf("Testing up to %d concurrent connections \r\n", max) + canStop := false + c := make(chan struct{}, max) + for i := 0; i < max; i++ { + go func(id int) { + tx, err := dbt.db.Begin() + if err != nil { + canStop = true + if err.Error() == "Error 1040: Too many connections" { + max-- + return + } else { + dbt.Fatalf("Error on Con %d: %s", id, err.Error()) + } } - } - - c <- struct{}{} - - for !canStop { - _, err = tx.Exec("SELECT 1") + c <- struct{}{} + for !canStop { + _, err = tx.Exec("SELECT 1") + if err != nil { + canStop = true + dbt.Fatalf("Error on Con %d: %s", id, err.Error()) + } + } + err = tx.Commit() if err != nil { canStop = true - t.Fatalf("Error on Con %d: %s", id, err.Error()) + dbt.Fatalf("Error on Con %d: %s", id, err.Error()) } - } - - err = tx.Commit() - if err != nil { - canStop = true - t.Fatalf("Error on Con %d: %s", id, err.Error()) - } - }(i) - } - - for i := 0; i < max; i++ { - <-c - } - canStop = true + }(i) + } + for i := 0; i < max; i++ { + <-c + } + canStop = true - t.Logf("Reached %d concurrent connections \r\n", max) + dbt.Logf("Reached %d concurrent connections \r\n", max) + }) }