From 9824b74ab1b4fc03b5f35767a29fa5a9712f3e35 Mon Sep 17 00:00:00 2001 From: lishuode Date: Wed, 16 Aug 2017 07:26:19 +0000 Subject: [PATCH] Support SSL in windows --- Cargo.toml | 6 ++- README.md | 2 +- appveyor.yml | 2 + src/conn/mod.rs | 22 +++++---- src/conn/opts.rs | 18 ++++---- src/io/mod.rs | 113 +++++++++++++++++++++++++++++++++++++++++++++-- 6 files changed, 139 insertions(+), 24 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 48f84635..bee9df89 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,7 +43,7 @@ lto = true default = ['mysql_common'] nightly = ['mysql_common'] rustc_serialize = ['mysql_common/rustc_serialize', 'rustc-serialize'] -ssl = ['mysql_common', "openssl", "security-framework"] +ssl = ['mysql_common', "openssl", "security-framework", "schannel"] [dev-dependencies] serde_derive = "1" @@ -79,6 +79,10 @@ version = "~0.2" optional = true features = ["OSX_10_9"] +[target.'cfg(target_os = "windows")'.dependencies.schannel] +version = "~0.1" +optional = true + [target.'cfg(target_os = "windows")'.dependencies] named_pipe = "~0.3" winapi = "~0.3" diff --git a/README.md b/README.md index fad55d39..2cd7cbc0 100644 --- a/README.md +++ b/README.md @@ -41,4 +41,4 @@ features = ["rustc-serialize"] ``` ### Windows support (since 0.18.0) -Windows is supported but currently rust-mysql-simple has no support for SSL on Windows. +Windows is supported. diff --git a/appveyor.yml b/appveyor.yml index 9257dda9..dd8a1e68 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -46,6 +46,8 @@ before_test: $newText = ([System.IO.File]::ReadAllText($iniPath)).Replace("# enable-named-pipe", "enable-named-pipe") + $newText = $newText + "`nssl-ca=c:/clone/tests/ca-cert.pem`nssl-cert=c:/clone/tests/server-cert.pem`nssl-key=c:/clone/tests/server-key.pem" + [System.IO.File]::WriteAllText($iniPath, $newText) Restart-Service MySQL57 diff --git a/src/conn/mod.rs b/src/conn/mod.rs index 9d640f3a..97395cae 100644 --- a/src/conn/mod.rs +++ b/src/conn/mod.rs @@ -779,7 +779,12 @@ impl Conn { } } - #[cfg(all(feature = "ssl", any(unix, target_os = "macos")))] + #[cfg(not(feature = "ssl"))] + fn switch_to_ssl(&mut self) -> MyResult<()> { + unimplemented!(); + } + + #[cfg(all(feature = "ssl", any(unix, target_os = "macos", target_os = "windows")))] fn switch_to_ssl(&mut self) -> MyResult<()> { match self.stream.take() { Some(ConnStream::Plain(stream)) => { @@ -798,11 +803,6 @@ impl Conn { Ok(()) } - #[cfg(any(not(feature = "ssl"), target_os = "windows"))] - fn switch_to_ssl(&mut self) -> MyResult<()> { - unimplemented!(); - } - fn connect_stream(&mut self) -> MyResult<()> { let read_timeout = self.opts.get_read_timeout().cloned(); let write_timeout = self.opts.get_write_timeout().cloned(); @@ -2075,7 +2075,11 @@ mod test { builder.into() } - #[cfg(all(feature = "ssl", not(target_os = "macos"), unix))] + #[cfg(all( + feature = "ssl", + not(target_os = "macos"), + any(unix, target_os = "windows") + ))] pub fn get_opts() -> Opts { let pwd: String = env::var("MYSQL_SERVER_PASS").unwrap_or(PASS.to_string()); let port: u16 = env::var("MYSQL_SERVER_PORT") @@ -2099,7 +2103,7 @@ mod test { builder.into() } - #[cfg(any(not(feature = "ssl"), target_os = "windows"))] + #[cfg(not(feature = "ssl"))] pub fn get_opts() -> Opts { let pwd: String = env::var("MYSQL_SERVER_PASS").unwrap_or(PASS.to_string()); let port: u16 = env::var("MYSQL_SERVER_PORT") @@ -2556,7 +2560,7 @@ mod test { } #[test] - #[cfg(all(feature = "ssl", any(target_os = "macos", unix)))] + #[cfg(all(feature = "ssl", any(target_os = "macos", target_os = "windows", unix)))] fn should_connect_via_ssl() { let mut opts = OptsBuilder::from_opts(get_opts()); opts.prefer_socket(false); diff --git a/src/conn/opts.rs b/src/conn/opts.rs index 2c51fb2a..c0a0f982 100644 --- a/src/conn/opts.rs +++ b/src/conn/opts.rs @@ -1,7 +1,7 @@ use crate::consts::CapabilityFlags; use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr}; -#[cfg(all(feature = "ssl", not(target_os = "windows")))] +#[cfg(all(feature = "ssl"))] use std::path; use std::str::FromStr; @@ -29,8 +29,8 @@ pub type SslOpts = Option)>>; pub type SslOpts = Option<(path::PathBuf, Option<(path::PathBuf, path::PathBuf)>)>; #[cfg(all(feature = "ssl", target_os = "windows"))] -/// Not implemented on Windows -pub type SslOpts = Option<()>; +/// Ssl options: Option<(pem_ca_cert, Option<(pem_client_cert, pem_client_key)>)>.` +pub type SslOpts = Option<(path::PathBuf, Option<(path::PathBuf, path::PathBuf)>)>; #[cfg(not(feature = "ssl"))] /// Requires `ssl` feature @@ -445,7 +445,11 @@ impl OptsBuilder { self } - #[cfg(all(feature = "ssl", not(target_os = "macos"), unix))] + #[cfg(all( + feature = "ssl", + not(target_os = "macos"), + any(unix, target_os = "windows") + ))] /// SSL certificates and keys in pem format. /// /// If not None, then ssl connection implied. @@ -487,12 +491,6 @@ impl OptsBuilder { self } - /// Not implemented on windows - #[cfg(all(feature = "ssl", target_os = "windows"))] - pub fn ssl_opts(&mut self, _: Option) -> &mut Self { - panic!("OptsBuilder::ssl_opts is not implemented on Windows"); - } - /// Requires `ssl` feature #[cfg(not(feature = "ssl"))] pub fn ssl_opts(&mut self, _: Option) -> &mut Self { diff --git a/src/io/mod.rs b/src/io/mod.rs index 1ee0e4d7..733b53f0 100644 --- a/src/io/mod.rs +++ b/src/io/mod.rs @@ -9,7 +9,7 @@ use std::net::SocketAddr; use std::slice::Chunks; use std::time::Duration; -#[cfg(all(feature = "ssl", not(target_os = "windows")))] +#[cfg(all(feature = "ssl"))] use crate::conn::SslOpts; use super::consts; @@ -32,6 +32,14 @@ use flate2::{read::ZlibDecoder, write::ZlibEncoder, Compression}; use named_pipe as np; #[cfg(all(feature = "ssl", all(unix, not(target_os = "macos"))))] use openssl::ssl::{self, SslContext, SslStream}; +#[cfg(all(feature = "ssl", target_os = "windows"))] +use schannel::cert_context::CertContext; +#[cfg(all(feature = "ssl", target_os = "windows"))] +use schannel::cert_store; +#[cfg(all(feature = "ssl", target_os = "windows"))] +use schannel::schannel_cred; +#[cfg(all(feature = "ssl", target_os = "windows"))] +use schannel::tls_stream; #[cfg(all(feature = "ssl", target_os = "macos"))] use security_framework::certificate::SecCertificate; #[cfg(all(feature = "ssl", target_os = "macos"))] @@ -763,6 +771,103 @@ impl Stream { } } +#[cfg(all(feature = "ssl", target_os = "windows"))] +impl Stream { + pub fn make_secure( + mut self, + verify_peer: bool, + ip_or_hostname: Option<&str>, + ssl_opts: &SslOpts, + ) -> MyResult { + use std::path::Path; + + fn load_cert_data(path: &Path) -> MyResult { + let mut client_file = ::std::fs::File::open(path)?; + let mut client_data = String::new(); + client_file.read_to_string(&mut client_data)?; + Ok(client_data) + } + + fn load_client_cert(path: &Path) -> MyResult { + let cert_data = load_cert_data(path)?; + let cert = CertContext::from_pem(&cert_data)?; + Ok(cert) + } + + fn load_client_cert_with_key(cert_path: &Path, key_path: &Path) -> MyResult { + let mut cert_data = load_cert_data(cert_path)?; + let cert = CertContext::from_pem(&cert_data)?; + let key_data = load_cert_data(key_path)?; + cert_data.push_str(&key_data); + Ok(cert) + } + + fn load_ca_store(path: &Path) -> MyResult { + let ca_cert = load_client_cert(path)?; + let mut cert_store = cert_store::Memory::new().unwrap().into_store(); + cert_store.add_cert(&ca_cert, cert_store::CertAdd::Always)?; + Ok(cert_store) + } + + if self.is_insecure() { + let mut stream_builder = tls_stream::Builder::new(); + let mut cred_builder = schannel_cred::Builder::default(); + cred_builder.enabled_protocols(&[ + schannel_cred::Protocol::Tls10, + schannel_cred::Protocol::Tls11, + ]); + cred_builder.supported_algorithms(&[ + schannel_cred::Algorithm::DhEphem, + schannel_cred::Algorithm::RsaSign, + schannel_cred::Algorithm::Aes256, + schannel_cred::Algorithm::Sha1, + ]); + if verify_peer { + stream_builder.domain(ip_or_hostname.as_ref().unwrap_or(&("localhost".into()))); + } + + match *ssl_opts { + Some((ref ca_cert, None)) => { + stream_builder.cert_store(load_ca_store(&ca_cert)?); + } + Some((ref ca_cert, Some((ref client_cert, ref client_key)))) => { + cred_builder.cert(load_client_cert_with_key(&client_cert, &client_key)?); + stream_builder.cert_store(load_ca_store(&ca_cert)?); + } + _ => unreachable!(), + } + + let cred = cred_builder.acquire(schannel_cred::Direction::Outbound)?; + match self { + Stream::TcpStream(ref mut opt_stream) if opt_stream.is_some() => { + let stream = opt_stream.take().unwrap(); + match stream { + TcpStream::Insecure(mut stream) => { + stream.flush()?; + let s_stream = match stream_builder + .connect(cred, stream.into_inner().unwrap()) + { + Ok(s_stream) => s_stream, + Err(tls_stream::HandshakeError::Failure(err)) => { + return Err(err.into()); + } + Err(tls_stream::HandshakeError::Interrupted(_)) => unreachable!(), + }; + Ok(Stream::TcpStream(Some(TcpStream::Secure(BufStream::new( + s_stream, + ))))) + } + _ => unreachable!(), + } + } + _ => unreachable!(), + } + } else { + Ok(self) + } + } +} + #[cfg(all(feature = "ssl", not(target_os = "macos"), unix))] impl Stream { pub fn make_secure( @@ -838,13 +943,15 @@ impl Drop for Stream { pub enum TcpStream { #[cfg(all(feature = "ssl", any(unix, target_os = "macos")))] Secure(BufStream>), + #[cfg(all(feature = "ssl", target_os = "windows"))] + Secure(BufStream>), Insecure(BufStream), } impl AsMut for TcpStream { fn as_mut(&mut self) -> &mut dyn IoPack { match *self { - #[cfg(all(feature = "ssl", any(unix, target_os = "macos")))] + #[cfg(all(feature = "ssl", any(unix, target_os = "macos", target_os = "windows")))] TcpStream::Secure(ref mut stream) => stream, TcpStream::Insecure(ref mut stream) => stream, } @@ -854,7 +961,7 @@ impl AsMut for TcpStream { impl fmt::Debug for TcpStream { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match *self { - #[cfg(all(feature = "ssl", any(unix, target_os = "macos")))] + #[cfg(all(feature = "ssl", any(unix, target_os = "macos", target_os = "windows")))] TcpStream::Secure(_) => write!(f, "Secure stream"), TcpStream::Insecure(ref s) => write!(f, "Insecure stream {:?}", s), }