From 59cb281325e2ff8431d2ff5797cea116e278b3ed Mon Sep 17 00:00:00 2001 From: "Gruetzmacher, Anthony" Date: Mon, 13 Nov 2017 12:40:12 -0800 Subject: [PATCH 1/5] Implement Connector interface #671 --- AUTHORS | 1 + driver.go | 120 +++++++++++++++++++++++++++++++++++++++++++----------- 2 files changed, 97 insertions(+), 24 deletions(-) diff --git a/AUTHORS b/AUTHORS index ac36be9a7..8376ff9db 100644 --- a/AUTHORS +++ b/AUTHORS @@ -13,6 +13,7 @@ Aaron Hopkins Achille Roussel +Anthony Gruetzmacher Arne Hormann Asta Xie Bulat Gaifullin diff --git a/driver.go b/driver.go index d42ce7a3d..8a49cdbcd 100644 --- a/driver.go +++ b/driver.go @@ -17,11 +17,20 @@ package mysql import ( + "context" "database/sql" "database/sql/driver" + "errors" "net" ) +var ( + errInvalidUser = errors.New("invalid Connection: User is not set or longer than 32 chars") + errInvalidAddr = errors.New("invalid Connection: Addr config is missing") + errInvalidNet = errors.New("invalid Connection: Only tcp is valid for Net") + errInvalidDBName = errors.New("invalid Connection: DBName config is missing") +) + // watcher interface is used for context support (From Go 1.8) type watcher interface { startWatcher() @@ -29,7 +38,9 @@ type watcher interface { // MySQLDriver is exported to make the driver directly accessible. // In general the driver is used via the database/sql package. -type MySQLDriver struct{} +type MySQLDriver struct { + Cfg *Config +} // DialFunc is a function which can be used to establish the network connection. // Custom dial functions must be registered with RegisterDial @@ -47,24 +58,9 @@ func RegisterDial(net string, dial DialFunc) { dials[net] = dial } -// Open new Connection. -// See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how -// the DSN string is formated -func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { +//Open a new Connection +func (d MySQLDriver) connectServer(mc *mysqlConn) error { var err error - - // New mysqlConn - mc := &mysqlConn{ - maxAllowedPacket: maxPacketSize, - maxWriteSize: maxPacketSize - 1, - closech: make(chan struct{}), - } - mc.cfg, err = ParseDSN(dsn) - if err != nil { - return nil, err - } - mc.parseTime = mc.cfg.ParseTime - // Connect to Server if dial, ok := dials[mc.cfg.Net]; ok { mc.netConn, err = dial(mc.cfg.Addr) @@ -73,7 +69,7 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { mc.netConn, err = nd.Dial(mc.cfg.Net, mc.cfg.Addr) } if err != nil { - return nil, err + return err } // Enable TCP Keepalives on TCP connections @@ -82,7 +78,7 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { // Don't send COM_QUIT before handshake. mc.netConn.Close() mc.netConn = nil - return nil, err + return err } } @@ -101,13 +97,13 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { cipher, err := mc.readInitPacket() if err != nil { mc.cleanup() - return nil, err + return err } // Send Client Authentication Packet if err = mc.writeAuthPacket(cipher); err != nil { mc.cleanup() - return nil, err + return err } // Handle response to auth packet, switch methods if possible @@ -116,7 +112,7 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { // (https://dev.mysql.com/doc/internals/en/authentication-fails.html). // Do not send COM_QUIT, just cleanup and return the error. mc.cleanup() - return nil, err + return err } if mc.cfg.MaxAllowedPacket > 0 { @@ -126,7 +122,7 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { maxap, err := mc.getSystemVar("max_allowed_packet") if err != nil { mc.Close() - return nil, err + return err } mc.maxAllowedPacket = stringToInt(maxap) - 1 } @@ -134,6 +130,82 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { mc.maxWriteSize = mc.maxAllowedPacket } + return err +} + +//Connect opens a new connection without using a DSN +func (d MySQLDriver) Connect(cxt context.Context) (driver.Conn, error) { + var err error + + //Validate the connection parameters + //the following are required User,Pass,Net,Addr,DBName + //Pass may be blank + //The other optional parameters are not checks + //as GO will automatically enforce proper bool types on the options + if len(d.Cfg.User) > 32 || len(d.Cfg.User) <= 0 { + return nil, errInvalidUser + } + + if len(d.Cfg.Addr) <= 0 { + return nil, errInvalidAddr + } + + if len(d.Cfg.DBName) <= 0 { + return nil, errInvalidDBName + } + + if d.Cfg.Net != "tcp" { + return nil, errInvalidNet + } + + //New mysqlConn + mc := &mysqlConn{ + maxAllowedPacket: maxPacketSize, + maxWriteSize: maxPacketSize - 1, + closech: make(chan struct{}), + cfg: d.Cfg, + parseTime: d.Cfg.ParseTime, + } + + //Connect to the server and setting the connection settings + err = d.connectServer(mc) + if err != nil { + return nil, err + } + + return mc, nil + +} + +//Driver returns a driver interface +func (d MySQLDriver) Driver() driver.Driver { + return MySQLDriver{} +} + +// Open new Connection using a DSN. +// See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how +// the DSN string is formated +func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { + var err error + + // New mysqlConn + mc := &mysqlConn{ + maxAllowedPacket: maxPacketSize, + maxWriteSize: maxPacketSize - 1, + closech: make(chan struct{}), + } + mc.cfg, err = ParseDSN(dsn) + if err != nil { + return nil, err + } + mc.parseTime = mc.cfg.ParseTime + + err = d.connectServer(mc) + // Connect to Server + if err != nil { + return nil, err + } + // Handle DSN Params err = mc.handleParams() if err != nil { From 39714cd2983c2c0b4734747a65d71e5755203e5e Mon Sep 17 00:00:00 2001 From: "Gruetzmacher, Anthony" Date: Tue, 5 Dec 2017 17:21:54 -0800 Subject: [PATCH 2/5] Requested modifications --- driver.go | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/driver.go b/driver.go index 8a49cdbcd..424948401 100644 --- a/driver.go +++ b/driver.go @@ -39,6 +39,9 @@ type watcher interface { // MySQLDriver is exported to make the driver directly accessible. // In general the driver is used via the database/sql package. type MySQLDriver struct { +} + +type MySQLConnector struct { Cfg *Config } @@ -59,7 +62,7 @@ func RegisterDial(net string, dial DialFunc) { } //Open a new Connection -func (d MySQLDriver) connectServer(mc *mysqlConn) error { +func connectServer(mc *mysqlConn) error { var err error // Connect to Server if dial, ok := dials[mc.cfg.Net]; ok { @@ -134,7 +137,7 @@ func (d MySQLDriver) connectServer(mc *mysqlConn) error { } //Connect opens a new connection without using a DSN -func (d MySQLDriver) Connect(cxt context.Context) (driver.Conn, error) { +func (c MySQLConnector) Connect(cxt context.Context) (driver.Conn, error) { var err error //Validate the connection parameters @@ -142,19 +145,19 @@ func (d MySQLDriver) Connect(cxt context.Context) (driver.Conn, error) { //Pass may be blank //The other optional parameters are not checks //as GO will automatically enforce proper bool types on the options - if len(d.Cfg.User) > 32 || len(d.Cfg.User) <= 0 { + if len(c.Cfg.User) > 32 || len(c.Cfg.User) <= 0 { return nil, errInvalidUser } - if len(d.Cfg.Addr) <= 0 { + if len(c.Cfg.Addr) <= 0 { return nil, errInvalidAddr } - if len(d.Cfg.DBName) <= 0 { + if len(c.Cfg.DBName) <= 0 { return nil, errInvalidDBName } - if d.Cfg.Net != "tcp" { + if c.Cfg.Net != "tcp" { return nil, errInvalidNet } @@ -163,12 +166,12 @@ func (d MySQLDriver) Connect(cxt context.Context) (driver.Conn, error) { maxAllowedPacket: maxPacketSize, maxWriteSize: maxPacketSize - 1, closech: make(chan struct{}), - cfg: d.Cfg, - parseTime: d.Cfg.ParseTime, + cfg: c.Cfg, + parseTime: c.Cfg.ParseTime, } //Connect to the server and setting the connection settings - err = d.connectServer(mc) + err = connectServer(mc) if err != nil { return nil, err } @@ -182,6 +185,11 @@ func (d MySQLDriver) Driver() driver.Driver { return MySQLDriver{} } +//Driver returns a driver interface +func (c MySQLConnector) Driver() driver.Driver { + return MySQLDriver{} +} + // Open new Connection using a DSN. // See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how // the DSN string is formated @@ -200,7 +208,7 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { } mc.parseTime = mc.cfg.ParseTime - err = d.connectServer(mc) + err = connectServer(mc) // Connect to Server if err != nil { return nil, err From ad4a9e9d904dcc4342792c635893ce77388460b9 Mon Sep 17 00:00:00 2001 From: "Gruetzmacher, Anthony" Date: Tue, 20 Feb 2018 20:49:35 -0800 Subject: [PATCH 3/5] Added support to check the context passed into the connect function. Added reasonable testing. --- driver.go | 26 +++++++++++--- driver_go110_test.go | 82 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 103 insertions(+), 5 deletions(-) create mode 100644 driver_go110_test.go diff --git a/driver.go b/driver.go index 424948401..a5ac25304 100644 --- a/driver.go +++ b/driver.go @@ -170,14 +170,30 @@ func (c MySQLConnector) Connect(cxt context.Context) (driver.Conn, error) { parseTime: c.Cfg.ParseTime, } - //Connect to the server and setting the connection settings - err = connectServer(mc) - if err != nil { - return nil, err + //Check if the there is a canelation before creating the connection + select { + case <-cxt.Done(): + return nil, cxt.Err() + default: + //Connect to the server and setting the connection settings + err = connectServer(mc) + if err != nil { + return nil, err + } } - return mc, nil + //Check to see if there was a canelation during creating the connection + select { + case <-cxt.Done(): + err = mc.Close() + if err != nil { + return nil, errors.New(cxt.Err().Error() + ":" + err.Error()) + } + return nil, cxt.Err() + default: + return mc, nil + } } //Driver returns a driver interface diff --git a/driver_go110_test.go b/driver_go110_test.go new file mode 100644 index 000000000..2af9b8769 --- /dev/null +++ b/driver_go110_test.go @@ -0,0 +1,82 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// +build go1.10 + +package mysql + +import ( + "context" + "database/sql" + "database/sql/driver" + "sync" + "testing" +) + +type Connector struct { + m sync.Mutex + mysql *MySQLConnector +} + +func (c *Connector) Connect(cxt context.Context) (driver.Conn, error) { + var err error + + if c.mysql == nil { + c.mysql = c.init() + } + + //Just use the global DSN because we just want to test the connector + //interface and we do not care about any custom functionality in the Connector + c.m.Lock() + c.mysql.Cfg, err = ParseDSN(dsn) + c.m.Unlock() + if err != nil { + println(err) + return nil, err + } + + return c.mysql.Connect(cxt) +} + +func (c *Connector) Driver() driver.Driver { + return c.mysql.Driver() +} + +func (c *Connector) init() *MySQLConnector { + return &MySQLConnector{} +} + +func runtestsWithConnector(t *testing.T, tests ...func(dbt *DBTest)) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + connector := &Connector{} + + db := sql.OpenDB(connector) + if err := db.Ping(); err != nil { + db.Close() + t.Fatalf("error connecting: %s", err.Error()) + } + defer db.Close() + + dbt := &DBTest{t, db} + for _, test := range tests { + test(dbt) + dbt.db.Exec("DROP TABLE IF EXISTS test") + } + +} + +func TestPingWithConnector(t *testing.T) { + runtestsWithConnector(t, func(dbt *DBTest) { + if err := dbt.db.Ping(); err != nil { + dbt.fail("Ping With Connector", "Ping With Connector", err) + } + }) +} From bc89eaee2eaaeee95015bb9fe157a19d4fe601ea Mon Sep 17 00:00:00 2001 From: "Gruetzmacher, Anthony" Date: Wed, 21 Feb 2018 22:13:16 -0800 Subject: [PATCH 4/5] Addresses feedback --- driver.go | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/driver.go b/driver.go index a5ac25304..59c5999d8 100644 --- a/driver.go +++ b/driver.go @@ -62,14 +62,18 @@ func RegisterDial(net string, dial DialFunc) { } //Open a new Connection -func connectServer(mc *mysqlConn) error { +func connectServer(cxt context.Context, mc *mysqlConn) error { var err error // Connect to Server if dial, ok := dials[mc.cfg.Net]; ok { mc.netConn, err = dial(mc.cfg.Addr) } else { nd := net.Dialer{Timeout: mc.cfg.Timeout} - mc.netConn, err = nd.Dial(mc.cfg.Net, mc.cfg.Addr) + if cxt == nil { + mc.netConn, err = nd.Dial(mc.cfg.Net, mc.cfg.Addr) + } else { + mc.netConn, err = nd.DialContext(cxt, mc.cfg.Net, mc.cfg.Addr) + } } if err != nil { return err @@ -176,22 +180,11 @@ func (c MySQLConnector) Connect(cxt context.Context) (driver.Conn, error) { return nil, cxt.Err() default: //Connect to the server and setting the connection settings - err = connectServer(mc) + err = connectServer(cxt, mc) if err != nil { return nil, err } - } - //Check to see if there was a canelation during creating the connection - select { - case <-cxt.Done(): - err = mc.Close() - if err != nil { - return nil, errors.New(cxt.Err().Error() + ":" + err.Error()) - } - - return nil, cxt.Err() - default: return mc, nil } } @@ -224,7 +217,7 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { } mc.parseTime = mc.cfg.ParseTime - err = connectServer(mc) + err = connectServer(nil, mc) // Connect to Server if err != nil { return nil, err From 46e7d3d740f2a8fb2726fa1bc0c441954b278b85 Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Thu, 22 Feb 2018 17:07:06 +0900 Subject: [PATCH 5/5] small codestyle fix --- driver.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/driver.go b/driver.go index 59c5999d8..18f77adb0 100644 --- a/driver.go +++ b/driver.go @@ -184,9 +184,8 @@ func (c MySQLConnector) Connect(cxt context.Context) (driver.Conn, error) { if err != nil { return nil, err } - - return mc, nil } + return mc, nil } //Driver returns a driver interface