diff --git a/utils.go b/utils.go index c1afb54ac..9324806b2 100644 --- a/utils.go +++ b/utils.go @@ -26,6 +26,7 @@ var ( errInvalidDSNUnescaped = errors.New("Invalid DSN: Did you forget to escape a param value?") errInvalidDSNAddr = errors.New("Invalid DSN: Network Address not terminated (missing closing brace)") + errInvalidDSNNoSlash = errors.New("Invalid DSN: Missing the slash separating the database name") ) func init() { @@ -77,8 +78,10 @@ func parseDSN(dsn string) (cfg *config, err error) { // [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN] // Find the last '/' (since the password or the net addr might contain a '/') + foundSlash := false for i := len(dsn) - 1; i >= 0; i-- { if dsn[i] == '/' { + foundSlash = true var j, k int // left part is empty if i <= 0 @@ -135,6 +138,10 @@ func parseDSN(dsn string) (cfg *config, err error) { } } + if !foundSlash && len(dsn) > 0 { + return nil, errInvalidDSNNoSlash + } + // Set default network if empty if cfg.net == "" { cfg.net = "tcp" diff --git a/utils_test.go b/utils_test.go index 088077a7c..352811bba 100644 --- a/utils_test.go +++ b/utils_test.go @@ -9,9 +9,9 @@ package mysql import ( + "bytes" "fmt" "testing" - "bytes" "time" ) @@ -57,11 +57,12 @@ func TestDSNParser(t *testing.T) { func TestDSNParserInvalid(t *testing.T) { var invalidDSNs = []string{ - "@net(addr/", // no closing brace - "@tcp(/", // no closing brace - "tcp(/", // no closing brace - "(/", // no closing brace - "net(addr)//", // unescaped + "@net(addr/", // no closing brace + "@tcp(/", // no closing brace + "tcp(/", // no closing brace + "(/", // no closing brace + "net(addr)//", // unescaped + "user:pass@tcp(1.2.3.4:3306)", // no trailing slash //"/dbname?arg=/some/unescaped/path", } @@ -126,8 +127,8 @@ func TestScanNullTime(t *testing.T) { func TestLengthEncodedInteger(t *testing.T) { var integerTests = []struct { - num uint64 - encoded []byte + num uint64 + encoded []byte }{ {0x0000000000000000, []byte{0x00}}, {0x0000000000000012, []byte{0x12}}, @@ -155,10 +156,9 @@ func TestLengthEncodedInteger(t *testing.T) { t.Errorf("%x: expected size %d, got %d", tst.encoded, len(tst.encoded), numLen) } encoded := appendLengthEncodedInteger(nil, num) - if (!bytes.Equal(encoded, tst.encoded)) { + if !bytes.Equal(encoded, tst.encoded) { t.Errorf("%v: expected %x, got %x", num, tst.encoded, encoded) } } - }