-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Implement Connector interface #705
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
59cb281
39714cd
ad4a9e9
bc89eae
46e7d3d
965ecef
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,19 +17,33 @@ | |
package mysql | ||
|
||
import ( | ||
"context" | ||
"database/sql" | ||
"database/sql/driver" | ||
"errors" | ||
"net" | ||
) | ||
|
||
var ( | ||
errInvalidUser = errors.New("invalid Connection: User is not set or longer than 32 chars") | ||
errInvalidAddr = errors.New("invalid Connection: Addr config is missing") | ||
errInvalidNet = errors.New("invalid Connection: Only tcp is valid for Net") | ||
errInvalidDBName = errors.New("invalid Connection: DBName config is missing") | ||
) | ||
|
||
// watcher interface is used for context support (From Go 1.8) | ||
type watcher interface { | ||
startWatcher() | ||
} | ||
|
||
// MySQLDriver is exported to make the driver directly accessible. | ||
// In general the driver is used via the database/sql package. | ||
type MySQLDriver struct{} | ||
type MySQLDriver struct { | ||
} | ||
|
||
type MySQLConnector struct { | ||
Cfg *Config | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I prefer private type and constructor function. type mysqlConnector *Config
func NewConnector(cfg *Config) (driver.Connector, error) {
if err := cfg.normalize(); err != nil {
return nil, err
}
// any additional validation here.
return mysqlConnector(cfg)
} Then, we can move validation from |
||
|
||
// DialFunc is a function which can be used to establish the network connection. | ||
// Custom dial functions must be registered with RegisterDial | ||
|
@@ -47,33 +61,22 @@ func RegisterDial(net string, dial DialFunc) { | |
dials[net] = dial | ||
} | ||
|
||
// Open new Connection. | ||
// See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how | ||
// the DSN string is formated | ||
func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { | ||
//Open a new Connection | ||
func connectServer(cxt context.Context, mc *mysqlConn) error { | ||
var err error | ||
|
||
// New mysqlConn | ||
mc := &mysqlConn{ | ||
maxAllowedPacket: maxPacketSize, | ||
maxWriteSize: maxPacketSize - 1, | ||
closech: make(chan struct{}), | ||
} | ||
mc.cfg, err = ParseDSN(dsn) | ||
if err != nil { | ||
return nil, err | ||
} | ||
mc.parseTime = mc.cfg.ParseTime | ||
|
||
// Connect to Server | ||
if dial, ok := dials[mc.cfg.Net]; ok { | ||
mc.netConn, err = dial(mc.cfg.Addr) | ||
} else { | ||
nd := net.Dialer{Timeout: mc.cfg.Timeout} | ||
mc.netConn, err = nd.Dial(mc.cfg.Net, mc.cfg.Addr) | ||
if cxt == nil { | ||
mc.netConn, err = nd.Dial(mc.cfg.Net, mc.cfg.Addr) | ||
} else { | ||
mc.netConn, err = nd.DialContext(cxt, mc.cfg.Net, mc.cfg.Addr) | ||
} | ||
} | ||
if err != nil { | ||
return nil, err | ||
return err | ||
} | ||
|
||
// Enable TCP Keepalives on TCP connections | ||
|
@@ -82,7 +85,7 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { | |
// Don't send COM_QUIT before handshake. | ||
mc.netConn.Close() | ||
mc.netConn = nil | ||
return nil, err | ||
return err | ||
} | ||
} | ||
|
||
|
@@ -101,13 +104,13 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { | |
cipher, err := mc.readInitPacket() | ||
if err != nil { | ||
mc.cleanup() | ||
return nil, err | ||
return err | ||
} | ||
|
||
// Send Client Authentication Packet | ||
if err = mc.writeAuthPacket(cipher); err != nil { | ||
mc.cleanup() | ||
return nil, err | ||
return err | ||
} | ||
|
||
// Handle response to auth packet, switch methods if possible | ||
|
@@ -116,7 +119,7 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { | |
// (https://dev.mysql.com/doc/internals/en/authentication-fails.html). | ||
// Do not send COM_QUIT, just cleanup and return the error. | ||
mc.cleanup() | ||
return nil, err | ||
return err | ||
} | ||
|
||
if mc.cfg.MaxAllowedPacket > 0 { | ||
|
@@ -126,14 +129,99 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { | |
maxap, err := mc.getSystemVar("max_allowed_packet") | ||
if err != nil { | ||
mc.Close() | ||
return nil, err | ||
return err | ||
} | ||
mc.maxAllowedPacket = stringToInt(maxap) - 1 | ||
} | ||
if mc.maxAllowedPacket < maxPacketSize { | ||
mc.maxWriteSize = mc.maxAllowedPacket | ||
} | ||
|
||
return err | ||
} | ||
|
||
//Connect opens a new connection without using a DSN | ||
func (c MySQLConnector) Connect(cxt context.Context) (driver.Conn, error) { | ||
var err error | ||
|
||
//Validate the connection parameters | ||
//the following are required User,Pass,Net,Addr,DBName | ||
//Pass may be blank | ||
//The other optional parameters are not checks | ||
//as GO will automatically enforce proper bool types on the options | ||
if len(c.Cfg.User) > 32 || len(c.Cfg.User) <= 0 { | ||
return nil, errInvalidUser | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this check really needed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As far as I am aware an empty username will never authenticate. I guess it is not required press as you would just get a return from the server as a bad connection attempt however it does short cut a connection attempt thus using less resources/connections to the mysql server if the user forgets to provide a username. It also gives a very clear error message as to what is wrong. |
||
} | ||
|
||
if len(c.Cfg.Addr) <= 0 { | ||
return nil, errInvalidAddr | ||
} | ||
|
||
if len(c.Cfg.DBName) <= 0 { | ||
return nil, errInvalidDBName | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Empty DB name should be supported. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I concur. I had not thought of the SQL queries that do not require a database. I will correct. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I use it when I create monitoring tools. (e.g. |
||
|
||
if c.Cfg.Net != "tcp" { | ||
return nil, errInvalidNet | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No "unix"? |
||
|
||
//New mysqlConn | ||
mc := &mysqlConn{ | ||
maxAllowedPacket: maxPacketSize, | ||
maxWriteSize: maxPacketSize - 1, | ||
closech: make(chan struct{}), | ||
cfg: c.Cfg, | ||
parseTime: c.Cfg.ParseTime, | ||
} | ||
|
||
//Check if the there is a canelation before creating the connection | ||
select { | ||
case <-cxt.Done(): | ||
return nil, cxt.Err() | ||
default: | ||
//Connect to the server and setting the connection settings | ||
err = connectServer(cxt, mc) | ||
if err != nil { | ||
return nil, err | ||
} | ||
} | ||
return mc, nil | ||
} | ||
|
||
//Driver returns a driver interface | ||
func (d MySQLDriver) Driver() driver.Driver { | ||
return MySQLDriver{} | ||
} | ||
|
||
//Driver returns a driver interface | ||
func (c MySQLConnector) Driver() driver.Driver { | ||
return MySQLDriver{} | ||
} | ||
|
||
// Open new Connection using a DSN. | ||
// See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how | ||
// the DSN string is formated | ||
func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { | ||
var err error | ||
|
||
// New mysqlConn | ||
mc := &mysqlConn{ | ||
maxAllowedPacket: maxPacketSize, | ||
maxWriteSize: maxPacketSize - 1, | ||
closech: make(chan struct{}), | ||
} | ||
mc.cfg, err = ParseDSN(dsn) | ||
if err != nil { | ||
return nil, err | ||
} | ||
mc.parseTime = mc.cfg.ParseTime | ||
|
||
err = connectServer(nil, mc) | ||
// Connect to Server | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
// Handle DSN Params | ||
err = mc.handleParams() | ||
if err != nil { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package | ||
// | ||
// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved. | ||
// | ||
// This Source Code Form is subject to the terms of the Mozilla Public | ||
// License, v. 2.0. If a copy of the MPL was not distributed with this file, | ||
// You can obtain one at http://mozilla.org/MPL/2.0/. | ||
|
||
// +build go1.10 | ||
|
||
package mysql | ||
|
||
import ( | ||
"context" | ||
"database/sql" | ||
"database/sql/driver" | ||
"sync" | ||
"testing" | ||
) | ||
|
||
type Connector struct { | ||
m sync.Mutex | ||
mysql *MySQLConnector | ||
} | ||
|
||
func (c *Connector) Connect(cxt context.Context) (driver.Conn, error) { | ||
var err error | ||
|
||
if c.mysql == nil { | ||
c.mysql = c.init() | ||
} | ||
|
||
//Just use the global DSN because we just want to test the connector | ||
//interface and we do not care about any custom functionality in the Connector | ||
c.m.Lock() | ||
c.mysql.Cfg, err = ParseDSN(dsn) | ||
c.m.Unlock() | ||
if err != nil { | ||
println(err) | ||
return nil, err | ||
} | ||
|
||
return c.mysql.Connect(cxt) | ||
} | ||
|
||
func (c *Connector) Driver() driver.Driver { | ||
return c.mysql.Driver() | ||
} | ||
|
||
func (c *Connector) init() *MySQLConnector { | ||
return &MySQLConnector{} | ||
} | ||
|
||
func runtestsWithConnector(t *testing.T, tests ...func(dbt *DBTest)) { | ||
if !available { | ||
t.Skipf("MySQL server not running on %s", netAddr) | ||
} | ||
|
||
connector := &Connector{} | ||
|
||
db := sql.OpenDB(connector) | ||
if err := db.Ping(); err != nil { | ||
db.Close() | ||
t.Fatalf("error connecting: %s", err.Error()) | ||
} | ||
defer db.Close() | ||
|
||
dbt := &DBTest{t, db} | ||
for _, test := range tests { | ||
test(dbt) | ||
dbt.db.Exec("DROP TABLE IF EXISTS test") | ||
} | ||
|
||
} | ||
|
||
func TestPingWithConnector(t *testing.T) { | ||
runtestsWithConnector(t, func(dbt *DBTest) { | ||
if err := dbt.db.Ping(); err != nil { | ||
dbt.fail("Ping With Connector", "Ping With Connector", err) | ||
} | ||
}) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe, "invalid config" is better than "invalid Connection".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would concur and will make that change.