diff --git a/AUTHORS b/AUTHORS index f0b070246..d851a0477 100644 --- a/AUTHORS +++ b/AUTHORS @@ -34,3 +34,4 @@ Xiuming Chen Barracuda Networks, Inc. Google Inc. +Stripe Inc. diff --git a/utils.go b/utils.go index 56f1b082e..98dfc6f5e 100644 --- a/utils.go +++ b/utils.go @@ -16,6 +16,7 @@ import ( "errors" "fmt" "io" + "net" "net/url" "strings" "time" @@ -244,6 +245,13 @@ func parseDSNParams(cfg *config, params string) (err error) { if strings.ToLower(value) == "skip-verify" { cfg.tls = &tls.Config{InsecureSkipVerify: true} } else if tlsConfig, ok := tlsConfigRegister[value]; ok { + if len(tlsConfig.ServerName) == 0 && !tlsConfig.InsecureSkipVerify { + host, _, err := net.SplitHostPort(cfg.addr) + if err == nil { + tlsConfig.ServerName = host + } + } + cfg.tls = tlsConfig } else { return fmt.Errorf("Invalid value / unknown config name: %s", value) diff --git a/utils_test.go b/utils_test.go index 6e50b09b9..0855374b7 100644 --- a/utils_test.go +++ b/utils_test.go @@ -10,6 +10,7 @@ package mysql import ( "bytes" + "crypto/tls" "encoding/binary" "fmt" "testing" @@ -74,6 +75,46 @@ func TestDSNParserInvalid(t *testing.T) { } } +func TestDSNWithCustomTLS(t *testing.T) { + baseDSN := "user:password@tcp(localhost:5555)/dbname?tls=" + tlsCfg := tls.Config{} + + RegisterTLSConfig("utils_test", &tlsCfg) + + // Custom TLS is missing + tst := baseDSN + "invalid_tls" + cfg, err := parseDSN(tst) + if err == nil { + t.Errorf("Invalid custom TLS in DSN (%s) but did not error. Got config: %#v", tst, cfg) + } + + tst = baseDSN + "utils_test" + + // Custom TLS with a server name + name := "foohost" + tlsCfg.ServerName = name + cfg, err = parseDSN(tst) + + if err != nil { + t.Error(err.Error()) + } else if cfg.tls.ServerName != name { + t.Errorf("Did not get the correct TLS ServerName (%s) parsing DSN (%s).", name, tst) + } + + // Custom TLS without a server name + name = "localhost" + tlsCfg.ServerName = "" + cfg, err = parseDSN(tst) + + if err != nil { + t.Error(err.Error()) + } else if cfg.tls.ServerName != name { + t.Errorf("Did not get the correct ServerName (%s) parsing DSN (%s).", name, tst) + } + + DeregisterTLSConfig("utils_test") +} + func BenchmarkParseDSN(b *testing.B) { b.ReportAllocs()