|
1 | 1 | use crate::connection::stream::PgStream; |
2 | 2 | use crate::error::Error; |
3 | | -use crate::message::{Authentication, AuthenticationSasl, SaslInitialResponse, SaslResponse}; |
| 3 | +use crate::message::{ |
| 4 | + Authentication, AuthenticationSasl, AuthenticationSaslContinue, SaslInitialResponse, |
| 5 | + SaslResponse, |
| 6 | +}; |
| 7 | +use crate::rt; |
4 | 8 | use crate::PgConnectOptions; |
5 | 9 | use hmac::{Hmac, Mac}; |
6 | 10 | use rand::Rng; |
@@ -85,12 +89,36 @@ pub(crate) async fn authenticate( |
85 | 89 | } |
86 | 90 | }; |
87 | 91 |
|
| 92 | + let password = options.password.clone().unwrap_or_default(); |
| 93 | + let (mac, client_final_message) = rt::spawn_blocking(move || { |
| 94 | + final_message(client_first_message_bare, channel_binding, cont, password) |
| 95 | + }) |
| 96 | + .await?; |
| 97 | + |
| 98 | + stream.send(SaslResponse(&client_final_message)).await?; |
| 99 | + |
| 100 | + let data = match stream.recv_expect().await? { |
| 101 | + Authentication::SaslFinal(data) => data, |
| 102 | + |
| 103 | + auth => { |
| 104 | + return Err(err_protocol!("expected SASLFinal but received {:?}", auth)); |
| 105 | + } |
| 106 | + }; |
| 107 | + |
| 108 | + // authentication is only considered valid if this verification passes |
| 109 | + mac.verify_slice(&data.verifier).map_err(Error::protocol)?; |
| 110 | + |
| 111 | + Ok(()) |
| 112 | +} |
| 113 | + |
| 114 | +fn final_message( |
| 115 | + client_first_message_bare: String, |
| 116 | + channel_binding: String, |
| 117 | + cont: AuthenticationSaslContinue, |
| 118 | + password: String, |
| 119 | +) -> Result<(Hmac<Sha256>, String), Error> { |
88 | 120 | // SaltedPassword := Hi(Normalize(password), salt, i) |
89 | | - let salted_password = hi( |
90 | | - options.password.as_deref().unwrap_or_default(), |
91 | | - &cont.salt, |
92 | | - cont.iterations, |
93 | | - )?; |
| 121 | + let salted_password = hi(&password, &cont.salt, cont.iterations)?; |
94 | 122 |
|
95 | 123 | // ClientKey := HMAC(SaltedPassword, "Client Key") |
96 | 124 | let mut mac = Hmac::<Sha256>::new_from_slice(&salted_password).map_err(Error::protocol)?; |
@@ -143,20 +171,7 @@ pub(crate) async fn authenticate( |
143 | 171 | let mut client_final_message = format!("{client_final_message_wo_proof},{CLIENT_PROOF_ATTR}="); |
144 | 172 | BASE64_STANDARD.encode_string(client_proof, &mut client_final_message); |
145 | 173 |
|
146 | | - stream.send(SaslResponse(&client_final_message)).await?; |
147 | | - |
148 | | - let data = match stream.recv_expect().await? { |
149 | | - Authentication::SaslFinal(data) => data, |
150 | | - |
151 | | - auth => { |
152 | | - return Err(err_protocol!("expected SASLFinal but received {:?}", auth)); |
153 | | - } |
154 | | - }; |
155 | | - |
156 | | - // authentication is only considered valid if this verification passes |
157 | | - mac.verify_slice(&data.verifier).map_err(Error::protocol)?; |
158 | | - |
159 | | - Ok(()) |
| 174 | + Ok((mac, client_final_message)) |
160 | 175 | } |
161 | 176 |
|
162 | 177 | // nonce is a sequence of random printable bytes |
@@ -223,3 +238,25 @@ fn bench_sasl_hi(b: &mut test::Bencher) { |
223 | 238 | ); |
224 | 239 | }); |
225 | 240 | } |
| 241 | + |
| 242 | +#[cfg(test)] |
| 243 | +mod tests { |
| 244 | + |
| 245 | + use super::*; |
| 246 | + |
| 247 | + use crate::io::ProtocolDecode; |
| 248 | + |
| 249 | + #[test] |
| 250 | + fn mac() { |
| 251 | + let start = std::time::Instant::now(); |
| 252 | + let (_mac, _client_final_message) = final_message( |
| 253 | + "first_message_bare".to_string(), |
| 254 | + "channel_binding".to_string(), |
| 255 | + AuthenticationSaslContinue::decode_with("r=/z+giZiTxAH7r8sNAeHr7cvpqV3uo7G/bJBIJO3pjVM7t3ng,s=4UV68bIkC8f9/X8xH7aPhg==,i=4096".as_bytes().into(), ()).unwrap(), |
| 256 | + "some-password".to_string(), |
| 257 | + ) |
| 258 | + .unwrap(); |
| 259 | + let duration = start.elapsed(); |
| 260 | + dbg!(duration); |
| 261 | + } |
| 262 | +} |
0 commit comments