diff --git a/client/auth.go b/client/auth.go index 952accec5..b7a52a0e9 100644 --- a/client/auth.go +++ b/client/auth.go @@ -154,9 +154,9 @@ func (c *Conn) genAuthResponse(authData []byte) ([]byte, bool, error) { // password hashing switch c.authPluginName { case mysql.AUTH_NATIVE_PASSWORD: - return mysql.CalcPassword(authData[:20], []byte(c.password)), false, nil + return mysql.CalcNativePassword(authData[:20], []byte(c.password)), false, nil case mysql.AUTH_CACHING_SHA2_PASSWORD: - return mysql.CalcCachingSha2Password(authData, c.password), false, nil + return mysql.CalcCachingSha2Password(authData, []byte(c.password)), false, nil case mysql.AUTH_CLEAR_PASSWORD: return []byte(c.password), true, nil case mysql.AUTH_SHA256_PASSWORD: diff --git a/mysql/util.go b/mysql/util.go index 2e426262f..961998630 100644 --- a/mysql/util.go +++ b/mysql/util.go @@ -9,7 +9,9 @@ import ( "crypto/sha1" "crypto/sha256" "crypto/sha512" + "crypto/subtle" "encoding/binary" + "encoding/hex" "fmt" "io" mrand "math/rand" @@ -29,7 +31,7 @@ func Pstack() string { return string(buf[0:n]) } -func CalcPassword(scramble, password []byte) []byte { +func CalcNativePassword(scramble, password []byte) []byte { if len(password) == 0 { return nil } @@ -39,27 +41,93 @@ func CalcPassword(scramble, password []byte) []byte { crypt.Write(password) stage1 := crypt.Sum(nil) - // scrambleHash = SHA1(scramble + SHA1(stage1Hash)) - // inner Hash + // stage2Hash = SHA1(stage1Hash) crypt.Reset() crypt.Write(stage1) - hash := crypt.Sum(nil) + stage2 := crypt.Sum(nil) - // outer Hash + // scrambleHash = SHA1(scramble + stage2Hash) crypt.Reset() crypt.Write(scramble) - crypt.Write(hash) - scramble = crypt.Sum(nil) + crypt.Write(stage2) + scrambleHash := crypt.Sum(nil) // token = scrambleHash XOR stage1Hash - for i := range scramble { - scramble[i] ^= stage1[i] + return Xor(scrambleHash, stage1) +} + +// Xor hash1 modified in-place with XOR against hash2 +func Xor(hash1 []byte, hash2 []byte) []byte { + for i := range hash1 { + hash1[i] ^= hash2[i] + } + return hash1 +} + +// hash_stage1 = xor(reply, sha1(public_seed, hash_stage2)) +func stage1FromReply(scramble []byte, seed []byte, stage2 []byte) []byte { + crypt := sha1.New() + crypt.Write(seed) + crypt.Write(stage2) + seededHash := crypt.Sum(nil) + + return Xor(scramble, seededHash) +} + +// FROM vitess.io/vitess/go/mysql/auth_server.go +// DecodePasswordHex decodes the standard format used by MySQL +// Password hashes in the 4.1 format always begin with a * character +// see https://dev.mysql.com/doc/mysql-security-excerpt/5.7/en/password-hashing.html +func DecodePasswordHex(hexEncodedPassword string) ([]byte, error) { + if hexEncodedPassword[0] == '*' { + hexEncodedPassword = hexEncodedPassword[1:] + } + return hex.DecodeString(hexEncodedPassword) +} + +// EncodePasswordHex encodes to the standard format used by MySQL +// adds the optionally leading * to the hashed password +func EncodePasswordHex(passwordHash []byte) string { + hexstr := strings.ToUpper(hex.EncodeToString(passwordHash)) + return "*" + hexstr +} + +// NativePasswordHash = sha1(sha1(password)) +func NativePasswordHash(password []byte) []byte { + if len(password) == 0 { + return nil } - return scramble + + // stage1Hash = SHA1(password) + crypt := sha1.New() + crypt.Write(password) + stage1 := crypt.Sum(nil) + + // stage2Hash = SHA1(stage1Hash) + crypt.Reset() + crypt.Write(stage1) + return crypt.Sum(nil) +} + +func CompareNativePassword(reply []byte, stored []byte, seed []byte) bool { + if len(stored) == 0 { + return false + } + + // hash_stage1 = xor(reply, sha1(public_seed, hash_stage2)) + stage1 := stage1FromReply(reply, seed, stored) + // andidate_hash2 = sha1(hash_stage1) + crypt := sha1.New() + crypt.Write(stage1) + stage2 := crypt.Sum(nil) + + // check(candidate_hash2 == hash_stage2) + // use ConstantTimeCompare to mitigate timing based attacks + return subtle.ConstantTimeCompare(stage2, stored) == 1 } // CalcCachingSha2Password: Hash password using MySQL 8+ method (SHA256) -func CalcCachingSha2Password(scramble []byte, password string) []byte { +func CalcCachingSha2Password(scramble []byte, password []byte) []byte { if len(password) == 0 { return nil } @@ -67,7 +135,7 @@ func CalcCachingSha2Password(scramble []byte, password string) []byte { // XOR(SHA256(password), SHA256(SHA256(SHA256(password)), scramble)) crypt := sha256.New() - crypt.Write([]byte(password)) + crypt.Write(password) message1 := crypt.Sum(nil) crypt.Reset() @@ -79,11 +147,7 @@ func CalcCachingSha2Password(scramble []byte, password string) []byte { crypt.Write(scramble) message2 := crypt.Sum(nil) - for i := range message1 { - message1[i] ^= message2[i] - } - - return message1 + return Xor(message1, message2) } // Taken from https://github.com/go-sql-driver/mysql/pull/1518 @@ -135,6 +199,89 @@ func EncryptPassword(password string, seed []byte, pub *rsa.PublicKey) ([]byte, return rsa.EncryptOAEP(sha1v, rand.Reader, pub, plain, nil) } +const ( + SALT_LENGTH = 16 + ITERATION_MULTIPLIER = 1000 + SHA256_PASSWORD_ITERATIONS = 5 +) + +// generateUserSalt generate salt of given length for sha256_password hash +func generateUserSalt(length int) ([]byte, error) { + // Generate a random salt of the given length + // Implement this function for your project + salt := make([]byte, length) + _, err := rand.Read(salt) + if err != nil { + return []byte(""), err + } + + // Restrict to 7-bit to avoid multi-byte UTF-8 + for i := range salt { + salt[i] = salt[i] &^ 128 + for salt[i] == 36 || salt[i] == 0 { // '$' or NUL + newval := make([]byte, 1) + _, err := rand.Read(newval) + if err != nil { + return []byte(""), err + } + salt[i] = newval[0] &^ 128 + } + } + return salt, nil +} + +// hashCrypt256 salt and hash a password the given number of iterations +func hashCrypt256(source, salt string, iterations uint64) (string, error) { + actualIterations := iterations * ITERATION_MULTIPLIER + hashInput := []byte(source + salt) + var hash [32]byte + for i := uint64(0); i < actualIterations; i++ { + hash = sha256.Sum256(hashInput) + hashInput = hash[:] + } + + hashHex := hex.EncodeToString(hash[:]) + digest := fmt.Sprintf("$%d$%s$%s", iterations, salt, hashHex) + return digest, nil +} + +// Check256HashingPassword compares a password to a hash for sha256_password +// rather than trying to recreate just the hash we recreate the full hash +// and use that for comparison +func Check256HashingPassword(pwhash []byte, password string) (bool, error) { + pwHashParts := bytes.Split(pwhash, []byte("$")) + if len(pwHashParts) != 4 { + return false, errors.New("failed to decode hash parts") + } + + iterationsPart := pwHashParts[1] + if len(iterationsPart) == 0 { + return false, errors.New("iterations part is empty") + } + + iterations, err := strconv.ParseUint(string(iterationsPart), 10, 64) + if err != nil { + return false, errors.New("failed to decode iterations") + } + salt := pwHashParts[2][:SALT_LENGTH] + + newHash, err := hashCrypt256(password, string(salt), iterations) + if err != nil { + return false, err + } + + return subtle.ConstantTimeCompare(pwhash, []byte(newHash)) == 1, nil +} + +// NewSha256PasswordHash creates a new password hash for sha256_password +func NewSha256PasswordHash(pwd string) (string, error) { + salt, err := generateUserSalt(SALT_LENGTH) + if err != nil { + return "", err + } + return hashCrypt256(pwd, string(salt), SHA256_PASSWORD_ITERATIONS) +} + func DecompressMariadbData(data []byte) ([]byte, error) { // algorithm always 0=zlib // algorithm := (data[pos] & 0x07) >> 4 diff --git a/server/auth.go b/server/auth.go index 9d1f8101c..13402dd45 100644 --- a/server/auth.go +++ b/server/auth.go @@ -1,11 +1,11 @@ package server import ( - "bytes" "crypto/rand" "crypto/rsa" "crypto/sha1" "crypto/sha256" + "crypto/subtle" "crypto/tls" "fmt" @@ -19,26 +19,32 @@ var ( ) func (c *Conn) compareAuthData(authPluginName string, clientAuthData []byte) error { - switch authPluginName { - case mysql.AUTH_NATIVE_PASSWORD: - if err := c.acquirePassword(); err != nil { + if authPluginName != c.credential.AuthPluginName { + err := c.writeAuthSwitchRequest(c.credential.AuthPluginName) + if err != nil { return err } - return c.compareNativePasswordAuthData(clientAuthData, c.password) + return c.handleAuthSwitchResponse() + } + + switch authPluginName { + case mysql.AUTH_NATIVE_PASSWORD: + return c.compareNativePasswordAuthData(clientAuthData, c.credential) case mysql.AUTH_CACHING_SHA2_PASSWORD: - if err := c.compareCacheSha2PasswordAuthData(clientAuthData); err != nil { - return err - } - if c.cachingSha2FullAuth { - return c.handleAuthSwitchResponse() + if !c.cachingSha2FullAuth { + if err := c.compareCacheSha2PasswordAuthData(clientAuthData); err != nil { + return err + } + if c.cachingSha2FullAuth { + return c.handleAuthSwitchResponse() + } + return nil } - return nil + // AuthMoreData packet already sent, do full auth + return c.handleCachingSha2PasswordFullAuth(clientAuthData) case mysql.AUTH_SHA256_PASSWORD: - if err := c.acquirePassword(); err != nil { - return err - } cont, err := c.handlePublicKeyRetrieval(clientAuthData) if err != nil { return err @@ -46,7 +52,7 @@ func (c *Conn) compareAuthData(authPluginName string, clientAuthData []byte) err if !cont { return nil } - return c.compareSha256PasswordAuthData(clientAuthData, c.password) + return c.compareSha256PasswordAuthData(clientAuthData, c.credential) default: return errors.Errorf("unknown authentication plugin name '%s'", authPluginName) @@ -54,19 +60,22 @@ func (c *Conn) compareAuthData(authPluginName string, clientAuthData []byte) err } func (c *Conn) acquirePassword() error { - password, found, err := c.credentialProvider.GetCredential(c.user) + if c.credential.Password != "" { + return nil + } + credential, found, err := c.credentialProvider.GetCredential(c.user) if err != nil { return err } if !found { return mysql.NewDefaultError(mysql.ER_NO_SUCH_USER, c.user, c.RemoteAddr().String()) } - c.password = password + c.credential = credential return nil } -func errAccessDenied(password string) error { - if password == "" { +func errAccessDenied(credential Credential) error { + if credential.Password == "" { return ErrAccessDeniedNoPassword } @@ -90,20 +99,24 @@ func scrambleValidation(cached, nonce, scramble []byte) bool { crypt.Reset() crypt.Write(message2) m := crypt.Sum(nil) - return bytes.Equal(m, cached) + return subtle.ConstantTimeCompare(m, cached) == 1 } -func (c *Conn) compareNativePasswordAuthData(clientAuthData []byte, password string) error { - if bytes.Equal(mysql.CalcPassword(c.salt, []byte(password)), clientAuthData) { +func (c *Conn) compareNativePasswordAuthData(clientAuthData []byte, credential Credential) error { + password, err := mysql.DecodePasswordHex(c.credential.Password) + if err != nil { + return errAccessDenied(credential) + } + if mysql.CompareNativePassword(clientAuthData, password, c.salt) { return nil } - return errAccessDenied(password) + return errAccessDenied(credential) } -func (c *Conn) compareSha256PasswordAuthData(clientAuthData []byte, password string) error { +func (c *Conn) compareSha256PasswordAuthData(clientAuthData []byte, credential Credential) error { // Empty passwords are not hashed, but sent as empty string if len(clientAuthData) == 0 { - if password == "" { + if credential.Password == "" { return nil } return ErrAccessDenied @@ -117,10 +130,6 @@ func (c *Conn) compareSha256PasswordAuthData(clientAuthData []byte, password str if l := len(clientAuthData); l != 0 && clientAuthData[l-1] == 0x00 { clientAuthData = clientAuthData[:l-1] } - if bytes.Equal(clientAuthData, []byte(password)) { - return nil - } - return errAccessDenied(password) } else { // client should send encrypted password // decrypt @@ -128,46 +137,31 @@ func (c *Conn) compareSha256PasswordAuthData(clientAuthData []byte, password str if err != nil { return err } - plain := make([]byte, len(password)+1) - copy(plain, password) - for i := range plain { - j := i % len(c.salt) - plain[i] ^= c.salt[j] - } - if bytes.Equal(plain, dbytes) { - return nil + clientAuthData = mysql.Xor(dbytes, c.salt) + if l := len(clientAuthData); l != 0 && clientAuthData[l-1] == 0x00 { + clientAuthData = clientAuthData[:l-1] } - return errAccessDenied(password) } + check, err := mysql.Check256HashingPassword([]byte(credential.Password), string(clientAuthData)) + if err != nil { + return err + } + if check { + return nil + } + return ErrAccessDenied } func (c *Conn) compareCacheSha2PasswordAuthData(clientAuthData []byte) error { // Empty passwords are not hashed, but sent as empty string if len(clientAuthData) == 0 { - if err := c.acquirePassword(); err != nil { - return err - } - if c.password == "" { + if c.credential.Password == "" { return nil } return ErrAccessDenied } // the caching of 'caching_sha2_password' in MySQL, see: https://dev.mysql.com/worklog/task/?id=9591 - if _, ok := c.credentialProvider.(*InMemoryProvider); ok { - // since we have already kept the password in memory and calculate the scramble is not that high of cost, we eliminate - // the caching part. So our server will never ask the client to do a full authentication via RSA key exchange and it appears - // like the auth will always hit the cache. - if err := c.acquirePassword(); err != nil { - return err - } - if bytes.Equal(mysql.CalcCachingSha2Password(c.salt, c.password), clientAuthData) { - // 'fast' auth: write "More data" packet (first byte == 0x01) with the second byte = 0x03 - return c.writeAuthMoreDataFastAuth() - } - - return errAccessDenied(c.password) - } - // other type of credential provider, we use the cache + // check if we have a cached value cached, ok := c.serverConf.cacheShaPassword.Load(fmt.Sprintf("%s@%s", c.user, c.LocalAddr())) if ok { // Scramble validation @@ -176,7 +170,7 @@ func (c *Conn) compareCacheSha2PasswordAuthData(clientAuthData []byte) error { return c.writeAuthMoreDataFastAuth() } - return errAccessDenied(c.password) + return errAccessDenied(c.credential) } // cache miss, do full auth if err := c.writeAuthMoreDataFullAuth(); err != nil { diff --git a/server/auth_switch_response.go b/server/auth_switch_response.go index 9baccc03e..5de841acb 100644 --- a/server/auth_switch_response.go +++ b/server/auth_switch_response.go @@ -1,7 +1,6 @@ package server import ( - "bytes" "crypto/rand" "crypto/rsa" "crypto/sha1" @@ -9,6 +8,8 @@ import ( "crypto/tls" "fmt" + "github.com/pingcap/tidb/pkg/parser/auth" + "github.com/go-mysql-org/go-mysql/mysql" "github.com/pingcap/errors" ) @@ -19,48 +20,7 @@ func (c *Conn) handleAuthSwitchResponse() error { return err } - switch c.authPluginName { - case mysql.AUTH_NATIVE_PASSWORD: - if err := c.acquirePassword(); err != nil { - return err - } - return c.compareNativePasswordAuthData(authData, c.password) - - case mysql.AUTH_CACHING_SHA2_PASSWORD: - if !c.cachingSha2FullAuth { - // Switched auth method but no MoreData packet send yet - if err := c.compareCacheSha2PasswordAuthData(authData); err != nil { - return err - } else { - if c.cachingSha2FullAuth { - return c.handleAuthSwitchResponse() - } - return nil - } - } - // AuthMoreData packet already sent, do full auth - if err := c.handleCachingSha2PasswordFullAuth(authData); err != nil { - return err - } - c.writeCachingSha2Cache() - return nil - - case mysql.AUTH_SHA256_PASSWORD: - cont, err := c.handlePublicKeyRetrieval(authData) - if err != nil { - return err - } - if !cont { - return nil - } - if err := c.acquirePassword(); err != nil { - return err - } - return c.compareSha256PasswordAuthData(authData, c.password) - - default: - return errors.Errorf("unknown authentication plugin name '%s'", c.authPluginName) - } + return c.compareAuthData(c.authPluginName, authData) } func (c *Conn) handleCachingSha2PasswordFullAuth(authData []byte) error { @@ -76,10 +36,6 @@ func (c *Conn) handleCachingSha2PasswordFullAuth(authData []byte) error { if l := len(authData); l != 0 && authData[l-1] == 0x00 { authData = authData[:l-1] } - if bytes.Equal(authData, []byte(c.password)) { - return nil - } - return errAccessDenied(c.password) } else { // client either request for the public key or send the encrypted password if len(authData) == 1 && authData[0] == 0x02 { @@ -99,27 +55,41 @@ func (c *Conn) handleCachingSha2PasswordFullAuth(authData []byte) error { if err != nil { return err } - plain := make([]byte, len(c.password)+1) - copy(plain, c.password) - for i := range plain { - j := i % len(c.salt) - plain[i] ^= c.salt[j] - } - if bytes.Equal(plain, dbytes) { - return nil + authData = mysql.Xor(dbytes, c.salt) + if l := len(authData); l != 0 && authData[l-1] == 0x00 { + authData = authData[:l-1] } - return errAccessDenied(c.password) } + err := c.checkSha2CacheCredentials(authData, c.credential) + if err != nil { + return err + } + // write cache on successful auth - needs to be here as we have the decrypted password + // and we need to store an unsalted hashed version of the plaintext password in the cache + c.writeCachingSha2Cache(authData) + return nil +} + +func (c *Conn) checkSha2CacheCredentials(clientAuthData []byte, credential Credential) error { + match, err := auth.CheckHashingPassword([]byte(credential.Password), string(clientAuthData), mysql.AUTH_CACHING_SHA2_PASSWORD) + if match && err == nil { + return nil + } + return errAccessDenied(credential) } -func (c *Conn) writeCachingSha2Cache() { +func (c *Conn) writeCachingSha2Cache(authData []byte) { // write cache - if c.password == "" { + if authData == nil { return } + + if l := len(authData); l != 0 && authData[l-1] == 0x00 { + authData = authData[:l-1] + } // SHA256(PASSWORD) crypt := sha256.New() - crypt.Write([]byte(c.password)) + crypt.Write(authData) m1 := crypt.Sum(nil) // SHA256(SHA256(PASSWORD)) crypt.Reset() diff --git a/server/caching_sha2_cache_test.go b/server/caching_sha2_cache_test.go index 9c9530811..9d53a7105 100644 --- a/server/caching_sha2_cache_test.go +++ b/server/caching_sha2_cache_test.go @@ -58,7 +58,7 @@ type RemoteThrottleProvider struct { getCredCallCount atomic.Int64 } -func (m *RemoteThrottleProvider) GetCredential(username string) (password string, found bool, err error) { +func (m *RemoteThrottleProvider) GetCredential(username string) (credential Credential, found bool, err error) { m.getCredCallCount.Add(1) return m.InMemoryProvider.GetCredential(username) } @@ -107,7 +107,7 @@ func (s *cacheTestSuite) onAccept() { func (s *cacheTestSuite) onConn(conn net.Conn) { // co, err := NewConn(conn, *testUser, *testPassword, &testHandler{s}) - co, err := NewCustomizedConn(conn, s.server, s.credProvider, &testCacheHandler{s}) + co, err := s.server.NewCustomizedConn(conn, s.credProvider, &testCacheHandler{s}) require.NoError(s.T(), err) for { err = co.HandleCommand() @@ -147,7 +147,7 @@ func (s *cacheTestSuite) TestCache() { s.db.SetMaxIdleConns(4) s.runSelect() got = s.credProvider.(*RemoteThrottleProvider).getCredCallCount.Load() - require.Equal(s.T(), int64(1), got) + require.Equal(s.T(), int64(2), got) if s.db != nil { s.db.Close() diff --git a/server/conn.go b/server/conn.go index ad3509aea..1f99bdfa2 100644 --- a/server/conn.go +++ b/server/conn.go @@ -26,7 +26,7 @@ type Conn struct { credentialProvider CredentialProvider user string - password string + credential Credential cachingSha2FullAuth bool h Handler diff --git a/server/credential_provider.go b/server/credential_provider.go index 11014d916..552b62b9a 100644 --- a/server/credential_provider.go +++ b/server/credential_provider.go @@ -1,6 +1,12 @@ package server -import "sync" +import ( + "sync" + + "github.com/go-mysql-org/go-mysql/mysql" + "github.com/pingcap/errors" + "github.com/pingcap/tidb/pkg/parser/auth" +) // interface for user credential provider // hint: can be extended for more functionality @@ -13,18 +19,62 @@ type CredentialProvider interface { // check if the user exists CheckUsername(username string) (bool, error) // get user credential - GetCredential(username string) (password string, found bool, err error) + GetCredential(username string) (credential Credential, found bool, err error) } -func NewInMemoryProvider() *InMemoryProvider { +func NewInMemoryProvider(defaultAuthMethod ...string) *InMemoryProvider { + d := mysql.AUTH_CACHING_SHA2_PASSWORD + if len(defaultAuthMethod) > 0 { + d = defaultAuthMethod[0] + } return &InMemoryProvider{ - userPool: sync.Map{}, + userPool: sync.Map{}, + defaultAuthMethod: d, + } +} + +type Credential struct { + Password string + AuthPluginName string +} + +func NewCredential(password string, authPluginName string) (Credential, error) { + c := Credential{ + AuthPluginName: authPluginName, + } + + if password == "" { + c.Password = "" + return c, nil + } + + switch c.AuthPluginName { + case mysql.AUTH_NATIVE_PASSWORD: + c.Password = mysql.EncodePasswordHex(mysql.NativePasswordHash([]byte(password))) + + case mysql.AUTH_CACHING_SHA2_PASSWORD: + c.Password = auth.NewHashPassword(password, mysql.AUTH_CACHING_SHA2_PASSWORD) + + case mysql.AUTH_SHA256_PASSWORD: + hash, err := mysql.NewSha256PasswordHash(password) + if err != nil { + return c, err + } + c.Password = hash + + case mysql.AUTH_CLEAR_PASSWORD: + c.Password = password + + default: + return c, errors.Errorf("unknown authentication plugin name '%s'", c.AuthPluginName) } + return c, nil } -// implements a in memory credential provider +// implements an in memory credential provider type InMemoryProvider struct { - userPool sync.Map // username -> password + userPool sync.Map // username -> password + defaultAuthMethod string } func (m *InMemoryProvider) CheckUsername(username string) (found bool, err error) { @@ -32,16 +82,31 @@ func (m *InMemoryProvider) CheckUsername(username string) (found bool, err error return ok, nil } -func (m *InMemoryProvider) GetCredential(username string) (password string, found bool, err error) { +func (m *InMemoryProvider) GetCredential(username string) (credential Credential, found bool, err error) { v, ok := m.userPool.Load(username) if !ok { - return "", false, nil + return Credential{}, false, nil } - return v.(string), true, nil + c, valid := v.(Credential) + if !valid { + return Credential{}, true, errors.Errorf("invalid credential") + } + return c, true, nil } -func (m *InMemoryProvider) AddUser(username, password string) { - m.userPool.Store(username, password) +func (m *InMemoryProvider) AddUser(username, password string, optionalAuthPluginName ...string) error { + authPluginName := m.defaultAuthMethod + if len(optionalAuthPluginName) > 0 { + authPluginName = optionalAuthPluginName[0] + } + + c, err := NewCredential(password, authPluginName) + if err != nil { + return err + } + + m.userPool.Store(username, c) + return nil } type Provider InMemoryProvider diff --git a/server/handshake_resp.go b/server/handshake_resp.go index 68cade122..c50e2a612 100644 --- a/server/handshake_resp.go +++ b/server/handshake_resp.go @@ -200,12 +200,14 @@ func (c *Conn) handlePublicKeyRetrieval(authData []byte) (bool, error) { func (c *Conn) handleAuthMatch() (bool, error) { // if the client responds the handshake with a different auth method, the server will send the AuthSwitchRequest packet // to the client to ask the client to switch. + if err := c.acquirePassword(); err != nil { + return false, err + } - if c.authPluginName != c.serverConf.defaultAuthMethod { - if err := c.writeAuthSwitchRequest(c.serverConf.defaultAuthMethod); err != nil { + if c.authPluginName != c.credential.AuthPluginName { + if err := c.writeAuthSwitchRequest(c.credential.AuthPluginName); err != nil { return false, err } - c.authPluginName = c.serverConf.defaultAuthMethod // handle AuthSwitchResponse return false, c.handleAuthSwitchResponse() } diff --git a/server/resp.go b/server/resp.go index c6f13fe9e..e4b21aad2 100644 --- a/server/resp.go +++ b/server/resp.go @@ -66,6 +66,7 @@ func (c *Conn) writeEOF() error { // see: https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_auth_switch_request.html func (c *Conn) writeAuthSwitchRequest(newAuthPluginName string) error { + c.authPluginName = newAuthPluginName data := make([]byte, 4) data = append(data, mysql.EOF_HEADER) data = append(data, []byte(newAuthPluginName)...) diff --git a/server/server_test.go b/server/server_test.go index 2ad8f4e7b..c1e827f05 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -61,11 +61,13 @@ func prepareServerConf() []*Server { func Test(t *testing.T) { // general tests inMemProvider := NewInMemoryProvider() - inMemProvider.AddUser(*testUser, *testPassword) - servers := prepareServerConf() // no TLS for _, svr := range servers { + inMemProvider.userPool.Clear() + err := inMemProvider.AddUser(*testUser, *testPassword, svr.defaultAuthMethod) + require.NoError(t, err) + suite.Run(t, &serverTestSuite{ server: svr, credProvider: inMemProvider, @@ -76,6 +78,10 @@ func Test(t *testing.T) { // TLS if server supports for _, svr := range servers { if svr.tlsConfig != nil { + inMemProvider.userPool.Clear() + err := inMemProvider.AddUser(*testUser, *testPassword, svr.defaultAuthMethod) + require.NoError(t, err) + suite.Run(t, &serverTestSuite{ server: svr, credProvider: inMemProvider, @@ -138,7 +144,7 @@ func (s *serverTestSuite) onAccept() { func (s *serverTestSuite) onConn(conn net.Conn) { // co, err := NewConn(conn, *testUser, *testPassword, &testHandler{s}) - co, err := NewCustomizedConn(conn, s.server, s.credProvider, &testHandler{s}) + co, err := s.server.NewCustomizedConn(conn, s.credProvider, &testHandler{s}) require.NoError(s.T(), err) // set SSL if defined for {