diff --git a/auth.go b/auth.go index 658259b24..74e1bd03e 100644 --- a/auth.go +++ b/auth.go @@ -338,7 +338,7 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) { return authEd25519(authData, mc.cfg.Passwd) default: - mc.cfg.Logger.Print("unknown auth plugin:", plugin) + mc.log("unknown auth plugin:", plugin) return nil, ErrUnknownPlugin } } diff --git a/connection.go b/connection.go index 55e42eb18..5061b69ca 100644 --- a/connection.go +++ b/connection.go @@ -45,6 +45,11 @@ type mysqlConn struct { closed atomic.Bool // set when conn is closed, before closech is closed } +// Helper function to call per-connection logger. +func (mc *mysqlConn) log(v ...any) { + mc.cfg.Logger.Print(v...) +} + // Handles parameters set in DSN after the connection is established func (mc *mysqlConn) handleParams() (err error) { var cmdSet strings.Builder @@ -110,7 +115,7 @@ func (mc *mysqlConn) Begin() (driver.Tx, error) { func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) { if mc.closed.Load() { - mc.cfg.Logger.Print(ErrInvalidConn) + mc.log(ErrInvalidConn) return nil, driver.ErrBadConn } var q string @@ -152,7 +157,7 @@ func (mc *mysqlConn) cleanup() { return } if err := mc.netConn.Close(); err != nil { - mc.cfg.Logger.Print(err) + mc.log(err) } mc.clearResult() } @@ -169,14 +174,14 @@ func (mc *mysqlConn) error() error { func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { if mc.closed.Load() { - mc.cfg.Logger.Print(ErrInvalidConn) + mc.log(ErrInvalidConn) return nil, driver.ErrBadConn } // Send command err := mc.writeCommandPacketStr(comStmtPrepare, query) if err != nil { // STMT_PREPARE is safe to retry. So we can return ErrBadConn here. - mc.cfg.Logger.Print(err) + mc.log(err) return nil, driver.ErrBadConn } @@ -210,7 +215,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin buf, err := mc.buf.takeCompleteBuffer() if err != nil { // can not take the buffer. Something must be wrong with the connection - mc.cfg.Logger.Print(err) + mc.log(err) return "", ErrInvalidConn } buf = buf[:0] @@ -302,7 +307,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { if mc.closed.Load() { - mc.cfg.Logger.Print(ErrInvalidConn) + mc.log(ErrInvalidConn) return nil, driver.ErrBadConn } if len(args) != 0 { @@ -362,7 +367,7 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) handleOk := mc.clearResult() if mc.closed.Load() { - mc.cfg.Logger.Print(ErrInvalidConn) + mc.log(ErrInvalidConn) return nil, driver.ErrBadConn } if len(args) != 0 { @@ -457,7 +462,7 @@ func (mc *mysqlConn) finish() { // Ping implements driver.Pinger interface func (mc *mysqlConn) Ping(ctx context.Context) (err error) { if mc.closed.Load() { - mc.cfg.Logger.Print(ErrInvalidConn) + mc.log(ErrInvalidConn) return driver.ErrBadConn } @@ -666,7 +671,7 @@ func (mc *mysqlConn) ResetSession(ctx context.Context) error { err = connCheck(conn) } if err != nil { - mc.cfg.Logger.Print("closing bad idle connection: ", err) + mc.log("closing bad idle connection: ", err) return driver.ErrBadConn } } diff --git a/packets.go b/packets.go index 3d6e5308c..d727f00fe 100644 --- a/packets.go +++ b/packets.go @@ -34,7 +34,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { if cerr := mc.canceled.Value(); cerr != nil { return nil, cerr } - mc.cfg.Logger.Print(err) + mc.log(err) mc.Close() return nil, ErrInvalidConn } @@ -57,7 +57,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { if pktLen == 0 { // there was no previous packet if prevData == nil { - mc.cfg.Logger.Print(ErrMalformPkt) + mc.log(ErrMalformPkt) mc.Close() return nil, ErrInvalidConn } @@ -71,7 +71,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { if cerr := mc.canceled.Value(); cerr != nil { return nil, cerr } - mc.cfg.Logger.Print(err) + mc.log(err) mc.Close() return nil, ErrInvalidConn } @@ -134,7 +134,7 @@ func (mc *mysqlConn) writePacket(data []byte) error { // Handle error if err == nil { // n != len(data) mc.cleanup() - mc.cfg.Logger.Print(ErrMalformPkt) + mc.log(ErrMalformPkt) } else { if cerr := mc.canceled.Value(); cerr != nil { return cerr @@ -144,7 +144,7 @@ func (mc *mysqlConn) writePacket(data []byte) error { return errBadConnNoWrite } mc.cleanup() - mc.cfg.Logger.Print(err) + mc.log(err) } return ErrInvalidConn } @@ -302,7 +302,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string data, err := mc.buf.takeBuffer(pktLen + 4) if err != nil { // cannot take the buffer. Something must be wrong with the connection - mc.cfg.Logger.Print(err) + mc.log(err) return errBadConnNoWrite } @@ -392,7 +392,7 @@ func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error { data, err := mc.buf.takeSmallBuffer(pktLen) if err != nil { // cannot take the buffer. Something must be wrong with the connection - mc.cfg.Logger.Print(err) + mc.log(err) return errBadConnNoWrite } @@ -412,7 +412,7 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error { data, err := mc.buf.takeSmallBuffer(4 + 1) if err != nil { // cannot take the buffer. Something must be wrong with the connection - mc.cfg.Logger.Print(err) + mc.log(err) return errBadConnNoWrite } @@ -431,7 +431,7 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { data, err := mc.buf.takeBuffer(pktLen + 4) if err != nil { // cannot take the buffer. Something must be wrong with the connection - mc.cfg.Logger.Print(err) + mc.log(err) return errBadConnNoWrite } @@ -452,7 +452,7 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { data, err := mc.buf.takeSmallBuffer(4 + 1 + 4) if err != nil { // cannot take the buffer. Something must be wrong with the connection - mc.cfg.Logger.Print(err) + mc.log(err) return errBadConnNoWrite } @@ -994,7 +994,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { } if err != nil { // cannot take the buffer. Something must be wrong with the connection - mc.cfg.Logger.Print(err) + mc.log(err) return errBadConnNoWrite } @@ -1193,7 +1193,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { if valuesCap != cap(paramValues) { data = append(data[:pos], paramValues...) if err = mc.buf.store(data); err != nil { - mc.cfg.Logger.Print(err) + mc.log(err) return errBadConnNoWrite } } diff --git a/statement.go b/statement.go index 31e7799c4..860c6588b 100644 --- a/statement.go +++ b/statement.go @@ -51,7 +51,7 @@ func (stmt *mysqlStmt) CheckNamedValue(nv *driver.NamedValue) (err error) { func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { if stmt.mc.closed.Load() { - stmt.mc.cfg.Logger.Print(ErrInvalidConn) + stmt.mc.log(ErrInvalidConn) return nil, driver.ErrBadConn } // Send command @@ -95,7 +95,7 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { if stmt.mc.closed.Load() { - stmt.mc.cfg.Logger.Print(ErrInvalidConn) + stmt.mc.log(ErrInvalidConn) return nil, driver.ErrBadConn } // Send command