diff --git a/Cargo.lock b/Cargo.lock index afdee3d8096..cb679f6de4e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2825,9 +2825,21 @@ dependencies = [ "once_cell", "pin-project-lite", "signal-hook-registry", + "tokio-macros", "winapi", ] +[[package]] +name = "tokio-macros" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "caf7b11a536f46a809a8a9f0bb4237020f70ecbf115b842360afb127ea2fda57" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "tokio-native-tls" version = "0.3.0" diff --git a/Cargo.toml b/Cargo.toml index e399bc47dfb..70484d4afad 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -81,7 +81,7 @@ sha2 = "0.9" swirl = { git = "https://github.com/sgrif/swirl.git", rev = "e87cf37" } tar = "0.4.16" tempfile = "3" -tokio = { version = "1", features = ["net", "signal", "io-std", "io-util", "rt-multi-thread"]} +tokio = { version = "1.5.0", features = ["net", "signal", "io-std", "io-util", "rt-multi-thread", "macros"]} toml = "0.5" tracing = "0.1" tracing-subscriber = "0.2" @@ -92,7 +92,7 @@ claim = "0.5" conduit-test = "0.9.0-alpha.4" hyper-tls = "0.5" lazy_static = "1.0" -tokio = "1" +tokio = "1.5.0" tower-service = "0.3.0" [build-dependencies] diff --git a/src/app.rs b/src/app.rs index 77ec5a23cae..13f2f4243a9 100644 --- a/src/app.rs +++ b/src/app.rs @@ -1,6 +1,7 @@ //! Application-wide components in a struct accessible from each request -use crate::{db, Config, Env}; +use crate::db::{ConnectionConfig, DieselPool}; +use crate::{Config, Env}; use std::{sync::Arc, time::Duration}; use crate::downloads_counter::DownloadsCounter; @@ -18,10 +19,10 @@ use scheduled_thread_pool::ScheduledThreadPool; #[allow(missing_debug_implementations)] pub struct App { /// The primary database connection pool - pub primary_database: db::DieselPool, + pub primary_database: DieselPool, /// The read-only replica database connection pool - pub read_only_replica_database: Option, + pub read_only_replica_database: Option, /// GitHub API client pub github: GitHubClient, @@ -103,38 +104,45 @@ impl App { _ => 30, }; - let primary_db_connection_config = db::ConnectionConfig { - statement_timeout: db_connection_timeout, - read_only: config.db_primary_config.read_only_mode, - }; - let thread_pool = Arc::new(ScheduledThreadPool::new(db_helper_threads)); - let primary_db_config = r2d2::Pool::builder() - .max_size(db_pool_size) - .min_idle(db_min_idle) - .connection_timeout(Duration::from_secs(db_connection_timeout)) - .connection_customizer(Box::new(primary_db_connection_config)) - .thread_pool(thread_pool.clone()); - - let primary_database = - db::diesel_pool(&config.db_primary_config.url, config.env, primary_db_config); - - let replica_database = if let Some(url) = config.db_replica_config.as_ref().map(|c| &c.url) - { - let replica_db_connection_config = db::ConnectionConfig { + let primary_database = if config.use_test_database_pool { + DieselPool::new_test(&config.db_primary_config.url) + } else { + let primary_db_connection_config = ConnectionConfig { statement_timeout: db_connection_timeout, - read_only: true, + read_only: config.db_primary_config.read_only_mode, }; - let replica_db_config = r2d2::Pool::builder() + let primary_db_config = r2d2::Pool::builder() .max_size(db_pool_size) .min_idle(db_min_idle) .connection_timeout(Duration::from_secs(db_connection_timeout)) - .connection_customizer(Box::new(replica_db_connection_config)) - .thread_pool(thread_pool); + .connection_customizer(Box::new(primary_db_connection_config)) + .thread_pool(thread_pool.clone()); + + DieselPool::new(&config.db_primary_config.url, primary_db_config) + }; - Some(db::diesel_pool(&url, config.env, replica_db_config)) + let replica_database = if let Some(url) = config.db_replica_config.as_ref().map(|c| &c.url) + { + if config.use_test_database_pool { + Some(DieselPool::new_test(url)) + } else { + let replica_db_connection_config = ConnectionConfig { + statement_timeout: db_connection_timeout, + read_only: true, + }; + + let replica_db_config = r2d2::Pool::builder() + .max_size(db_pool_size) + .min_idle(db_min_idle) + .connection_timeout(Duration::from_secs(db_connection_timeout)) + .connection_customizer(Box::new(replica_db_connection_config)) + .thread_pool(thread_pool); + + Some(DieselPool::new(&url, replica_db_config)) + } } else { None }; diff --git a/src/background_jobs.rs b/src/background_jobs.rs index daa34bf463d..0dec23b5f91 100644 --- a/src/background_jobs.rs +++ b/src/background_jobs.rs @@ -2,10 +2,9 @@ use reqwest::blocking::Client; use std::panic::AssertUnwindSafe; use std::sync::{Arc, Mutex, MutexGuard, PoisonError}; -use diesel::r2d2::PoolError; use swirl::PerformError; -use crate::db::{DieselPool, DieselPooledConn}; +use crate::db::{DieselPool, DieselPooledConn, PoolError}; use crate::git::Repository; use crate::uploaders::Uploader; diff --git a/src/config.rs b/src/config.rs index 99ba50e1515..1c0056c6b6a 100644 --- a/src/config.rs +++ b/src/config.rs @@ -22,6 +22,7 @@ pub struct Config { pub downloads_persist_interval_ms: usize, pub ownership_invitations_expiration_days: u64, pub metrics_authorization_token: Option, + pub use_test_database_pool: bool, } #[derive(Debug)] @@ -211,6 +212,7 @@ impl Default for Config { .unwrap_or(60_000), // 1 minute ownership_invitations_expiration_days: 30, metrics_authorization_token: dotenv::var("METRICS_AUTHORIZATION_TOKEN").ok(), + use_test_database_pool: false, } } } diff --git a/src/controllers/version/downloads.rs b/src/controllers/version/downloads.rs index 2fda4218187..25604ece264 100644 --- a/src/controllers/version/downloads.rs +++ b/src/controllers/version/downloads.rs @@ -2,53 +2,84 @@ //! //! Crate level functionality is located in `krate::downloads`. +use super::{extract_crate_name_and_semver, version_and_crate}; use crate::controllers::prelude::*; - -use chrono::{Duration, NaiveDate, Utc}; - +use crate::db::PoolError; use crate::models::{Crate, VersionDownload}; use crate::schema::*; use crate::views::EncodableVersionDownload; - -use super::{extract_crate_name_and_semver, version_and_crate}; +use chrono::{Duration, NaiveDate, Utc}; /// Handles the `GET /crates/:crate_id/:version/download` route. /// This returns a URL to the location where the crate is stored. pub fn download(req: &mut dyn RequestExt) -> EndpointResult { + let app = req.app().clone(); let recorder = req.timing_recorder(); - let crate_name = &req.params()["crate_id"]; - let version = &req.params()["version"]; - - let (version_id, canonical_crate_name): (_, String) = { - use self::versions::dsl::*; - - let conn = recorder.record("get_conn", || req.db_conn())?; - - // Returns the crate name as stored in the database, or an error if we could - // not load the version ID from the database. - recorder.record("get_version", || { - versions - .inner_join(crates::table) - .select((id, crates::name)) - .filter(Crate::with_name(crate_name)) - .filter(num.eq(version)) - .first(&*conn) - })? - }; - - // The increment does not happen instantly, but it's deferred to be executed in a batch - // along with other downloads. See crate::downloads_counter for the implementation. - req.app().downloads_counter.increment(version_id); + let mut crate_name = req.params()["crate_id"].clone(); + let version = req.params()["version"].as_str(); + + let mut log_metadata = None; + match recorder.record("get_conn", || req.db_conn()) { + Ok(conn) => { + use self::versions::dsl::*; + + // Returns the crate name as stored in the database, or an error if we could + // not load the version ID from the database. + let (version_id, canonical_crate_name) = recorder.record("get_version", || { + versions + .inner_join(crates::table) + .select((id, crates::name)) + .filter(Crate::with_name(&crate_name)) + .filter(num.eq(version)) + .first::<(i32, String)>(&*conn) + })?; + + if canonical_crate_name != crate_name { + app.instance_metrics + .downloads_non_canonical_crate_name_total + .inc(); + log_metadata = Some(("bot", "dl")); + } + crate_name = canonical_crate_name; + + // The increment does not happen instantly, but it's deferred to be executed in a batch + // along with other downloads. See crate::downloads_counter for the implementation. + app.downloads_counter.increment(version_id); + } + Err(PoolError::UnhealthyPool) => { + // The download endpoint is the most critical route in the whole crates.io application, + // as it's relied upon by users and automations to download crates. Keeping it working + // is the most important thing for us. + // + // The endpoint relies on the database to fetch the canonical crate name (with the + // right capitalization and hyphenation), but that's only needed to serve clients who + // don't call the endpoint with the crate's canonical name. + // + // Thankfully Cargo always uses the right name when calling the endpoint, and we can + // keep it working during a full database outage by unconditionally redirecting without + // checking whether the crate exists or the rigth name is used. Non-Cargo clients might + // get a 404 response instead of a 500, but that's worth it. + // + // Without a working database we also can't count downloads, but that's also less + // critical than keeping Cargo downloads operational. + + app.instance_metrics + .downloads_unconditional_redirects_total + .inc(); + log_metadata = Some(("unconditional_redirect", "true")); + } + Err(err) => return Err(err.into()), + } let redirect_url = req .app() .config .uploader - .crate_location(&canonical_crate_name, version); + .crate_location(&crate_name, &*version); - if &canonical_crate_name != crate_name { - req.log_metadata("bot", "dl"); + if let Some((key, value)) = log_metadata { + req.log_metadata(key, value); } if req.wants_json() { diff --git a/src/db.rs b/src/db.rs index 32822f4ef45..1ddd54492bc 100644 --- a/src/db.rs +++ b/src/db.rs @@ -2,12 +2,11 @@ use conduit::RequestExt; use diesel::prelude::*; use diesel::r2d2::{self, ConnectionManager, CustomizeConnection}; use parking_lot::{ReentrantMutex, ReentrantMutexGuard}; -use std::ops::Deref; use std::sync::Arc; +use std::{ops::Deref, time::Duration}; use url::Url; use crate::middleware::app::RequestApp; -use crate::Env; #[allow(missing_debug_implementations)] #[derive(Clone)] @@ -17,9 +16,33 @@ pub enum DieselPool { } impl DieselPool { - pub fn get(&self) -> Result, r2d2::PoolError> { + pub(crate) fn new( + url: &str, + config: r2d2::Builder>, + ) -> DieselPool { + let manager = ConnectionManager::new(connection_url(url)); + DieselPool::Pool(config.build(manager).unwrap()) + } + + pub(crate) fn new_test(url: &str) -> DieselPool { + let conn = + PgConnection::establish(&connection_url(url)).expect("failed to establish connection"); + conn.begin_test_transaction() + .expect("failed to begin test transaction"); + DieselPool::Test(Arc::new(ReentrantMutex::new(conn))) + } + + pub fn get(&self) -> Result, PoolError> { match self { - DieselPool::Pool(pool) => Ok(DieselPooledConn::Pool(pool.get()?)), + DieselPool::Pool(pool) => { + if let Some(conn) = pool.try_get() { + Ok(DieselPooledConn::Pool(conn)) + } else if !self.is_healthy() { + Err(PoolError::UnhealthyPool) + } else { + Ok(DieselPooledConn::Pool(pool.get().map_err(PoolError::R2D2)?)) + } + } DieselPool::Test(conn) => Ok(DieselPooledConn::Test(conn.lock())), } } @@ -40,8 +63,19 @@ impl DieselPool { } } - fn test_conn(conn: PgConnection) -> Self { - DieselPool::Test(Arc::new(ReentrantMutex::new(conn))) + pub fn wait_until_healthy(&self, timeout: Duration) -> Result<(), PoolError> { + match self { + DieselPool::Pool(pool) => match pool.get_timeout(timeout) { + Ok(_) => Ok(()), + Err(_) if !self.is_healthy() => Err(PoolError::UnhealthyPool), + Err(err) => Err(PoolError::R2D2(err)), + }, + DieselPool::Test(_) => Ok(()), + } + } + + fn is_healthy(&self) -> bool { + self.state().connections > 0 } } @@ -83,40 +117,25 @@ pub fn connection_url(url: &str) -> String { url.into_string() } -pub fn diesel_pool( - url: &str, - env: Env, - config: r2d2::Builder>, -) -> DieselPool { - let url = connection_url(url); - if env == Env::Test { - let conn = PgConnection::establish(&url).expect("failed to establish connection"); - DieselPool::test_conn(conn) - } else { - let manager = ConnectionManager::new(url); - DieselPool::Pool(config.build(manager).unwrap()) - } -} - pub trait RequestTransaction { /// Obtain a read/write database connection from the primary pool - fn db_conn(&self) -> Result, r2d2::PoolError>; + fn db_conn(&self) -> Result, PoolError>; /// Obtain a readonly database connection from the replica pool /// /// If there is no replica pool, the primary pool is used instead. - fn db_read_only(&self) -> Result, r2d2::PoolError>; + fn db_read_only(&self) -> Result, PoolError>; } impl RequestTransaction for T { - fn db_conn(&self) -> Result, r2d2::PoolError> { - self.app().primary_database.get().map_err(Into::into) + fn db_conn(&self) -> Result, PoolError> { + self.app().primary_database.get() } - fn db_read_only(&self) -> Result, r2d2::PoolError> { + fn db_read_only(&self) -> Result, PoolError> { match &self.app().read_only_replica_database { - Some(pool) => pool.get().map_err(Into::into), - None => self.app().primary_database.get().map_err(Into::into), + Some(pool) => pool.get(), + None => self.app().primary_database.get(), } } } @@ -152,3 +171,20 @@ pub(crate) fn test_conn() -> PgConnection { conn.begin_test_transaction().unwrap(); conn } + +#[derive(Debug)] +pub enum PoolError { + R2D2(r2d2::PoolError), + UnhealthyPool, +} + +impl std::error::Error for PoolError {} + +impl std::fmt::Display for PoolError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + PoolError::R2D2(err) => write!(f, "{}", err), + PoolError::UnhealthyPool => write!(f, "unhealthy database pool"), + } + } +} diff --git a/src/metrics/instance.rs b/src/metrics/instance.rs index 83a65c0fbe3..149e80ed6cc 100644 --- a/src/metrics/instance.rs +++ b/src/metrics/instance.rs @@ -32,6 +32,11 @@ metrics! { pub requests_total: IntCounter, /// Number of requests currently being processed pub requests_in_flight: IntGauge, + + /// Number of download requests that were served with an unconditional redirect. + pub downloads_unconditional_redirects_total: IntCounter, + /// Number of download requests with a non-canonical crate name. + pub downloads_non_canonical_crate_name_total: IntCounter, } // All instance metrics will be prefixed with this namespace. diff --git a/src/tests/all.rs b/src/tests/all.rs index 939523d59ed..5fbf6eec5b4 100644 --- a/src/tests/all.rs +++ b/src/tests/all.rs @@ -47,6 +47,7 @@ mod schema_details; mod server; mod team; mod token; +mod unhealthy_database; mod user; mod util; mod version; diff --git a/src/tests/dump_db.rs b/src/tests/dump_db.rs index 11187d6b16f..f3336644b6c 100644 --- a/src/tests/dump_db.rs +++ b/src/tests/dump_db.rs @@ -1,8 +1,5 @@ +use crate::util::FreshSchema; use cargo_registry::tasks::dump_db; -use diesel::{ - connection::{Connection, SimpleConnection}, - pg::PgConnection, -}; #[test] fn dump_db_and_reimport_dump() { @@ -13,59 +10,10 @@ fn dump_db_and_reimport_dump() { let directory = dump_db::DumpDirectory::create().unwrap(); directory.populate(&database_url).unwrap(); - let schema = TemporarySchema::create(database_url, "test_db_dump"); - schema.run_migrations(); + let schema = FreshSchema::new(&database_url); let import_script = directory.export_dir.join("import.sql"); - dump_db::run_psql(&import_script, &schema.database_url).unwrap(); + dump_db::run_psql(&import_script, schema.database_url()).unwrap(); // TODO: Consistency checks on the re-imported data? } - -struct TemporarySchema { - pub database_url: String, - pub schema_name: String, - pub connection: PgConnection, -} - -impl TemporarySchema { - pub fn create(database_url: String, schema_name: &str) -> Self { - let params = &[("options", format!("--search_path={},public", schema_name))]; - let database_url = url::Url::parse_with_params(&database_url, params) - .unwrap() - .into_string(); - let schema_name = schema_name.to_owned(); - let connection = PgConnection::establish(&database_url).unwrap(); - connection - .batch_execute(&format!( - r#"DROP SCHEMA IF EXISTS "{schema_name}" CASCADE; - CREATE SCHEMA "{schema_name}";"#, - schema_name = schema_name, - )) - .unwrap(); - Self { - database_url, - schema_name, - connection, - } - } - - pub fn run_migrations(&self) { - use diesel_migrations::{find_migrations_directory, run_pending_migrations_in_directory}; - let migrations_dir = find_migrations_directory().unwrap(); - run_pending_migrations_in_directory( - &self.connection, - &migrations_dir, - &mut std::io::sink(), - ) - .unwrap(); - } -} - -impl Drop for TemporarySchema { - fn drop(&mut self) { - self.connection - .batch_execute(&format!(r#"DROP SCHEMA "{}" CASCADE;"#, self.schema_name)) - .unwrap(); - } -} diff --git a/src/tests/unhealthy_database.rs b/src/tests/unhealthy_database.rs new file mode 100644 index 00000000000..7a2b23ef89b --- /dev/null +++ b/src/tests/unhealthy_database.rs @@ -0,0 +1,66 @@ +use crate::{ + builders::CrateBuilder, + util::{MockAnonymousUser, RequestHelper, TestApp}, +}; +use std::time::Duration; + +#[test] +fn download_crate_with_broken_networking_primary_database() { + let (app, anon, _, owner) = TestApp::init().with_slow_real_db_pool().with_token(); + app.db(|conn| { + CrateBuilder::new("crate_name", owner.as_model().user_id) + .version("1.0.0") + .expect_build(conn) + }); + + // When the database connection is healthy downloads are redirected with the proper + // capitalization, and missing crates or versions return a 404. + + assert_checked_redirects(&anon); + + // After networking breaks, preventing new database connections, the download endpoint should + // do an unconditional redirect to the CDN, without checking whether the crate exists or what + // the exact capitalization of crate name is. + + app.db_chaosproxy().break_networking(); + assert_unconditional_redirects(&anon); + + // After restoring the network and waiting for the database pool to get healthy again redirects + // should be checked again. + + app.db_chaosproxy().restore_networking(); + app.as_inner() + .primary_database + .wait_until_healthy(Duration::from_millis(500)) + .expect("the database did not return healthy"); + + assert_checked_redirects(&anon); +} + +fn assert_checked_redirects(anon: &MockAnonymousUser) { + anon.get::<()>("/api/v1/crates/crate_name/1.0.0/download") + .assert_redirect_ends_with("/crate_name/crate_name-1.0.0.crate"); + + anon.get::<()>("/api/v1/crates/Crate-Name/1.0.0/download") + .assert_redirect_ends_with("/crate_name/crate_name-1.0.0.crate"); + + anon.get::<()>("/api/v1/crates/crate_name/2.0.0/download") + .assert_not_found(); + + anon.get::<()>("/api/v1/crates/awesome-project/1.0.0/download") + .assert_not_found(); +} + +fn assert_unconditional_redirects(anon: &MockAnonymousUser) { + anon.get::<()>("/api/v1/crates/crate_name/1.0.0/download") + .assert_redirect_ends_with("/crate_name/crate_name-1.0.0.crate"); + + anon.get::<()>("/api/v1/crates/Crate-Name/1.0.0/download") + .assert_redirect_ends_with("/Crate-Name/Crate-Name-1.0.0.crate"); + + anon.get::<()>("/api/v1/crates/crate_name/2.0.0/download") + .assert_redirect_ends_with("/crate_name/crate_name-2.0.0.crate"); + + anon.get::<()>("/api/v1/crates/awesome-project/1.0.0/download") + .assert_redirect_ends_with("/awesome-project/awesome-project-1.0.0.crate"); +} diff --git a/src/tests/util.rs b/src/tests/util.rs index fdb1afae1ac..f9a379170c5 100644 --- a/src/tests/util.rs +++ b/src/tests/util.rs @@ -33,9 +33,12 @@ use conduit::header; use cookie::Cookie; use std::collections::HashMap; +mod chaosproxy; +mod fresh_schema; mod response; mod test_app; +pub(crate) use fresh_schema::FreshSchema; pub use response::Response; pub use test_app::TestApp; diff --git a/src/tests/util/chaosproxy.rs b/src/tests/util/chaosproxy.rs new file mode 100644 index 00000000000..0471ac59411 --- /dev/null +++ b/src/tests/util/chaosproxy.rs @@ -0,0 +1,129 @@ +use anyhow::Error; +use std::net::SocketAddr; +use std::sync::Arc; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::{ + tcp::{OwnedReadHalf, OwnedWriteHalf}, + TcpListener, TcpStream, + }, + runtime::Runtime, + sync::broadcast::Sender, +}; + +pub(crate) struct ChaosProxy { + address: SocketAddr, + backend_address: SocketAddr, + + runtime: Runtime, + listener: TcpListener, + + break_networking_send: Sender<()>, + restore_networking_send: Sender<()>, +} + +impl ChaosProxy { + pub(crate) fn new(backend_address: SocketAddr) -> Result, Error> { + let runtime = Runtime::new().expect("failed to create Tokio runtime"); + let listener = runtime.block_on(TcpListener::bind("127.0.0.1:0"))?; + + let (break_networking_send, _) = tokio::sync::broadcast::channel(16); + let (restore_networking_send, _) = tokio::sync::broadcast::channel(16); + + let instance = Arc::new(ChaosProxy { + address: listener.local_addr()?, + backend_address, + + listener, + runtime, + + break_networking_send, + restore_networking_send, + }); + + let instance_clone = instance.clone(); + instance.runtime.spawn(async move { + if let Err(err) = instance_clone.server_loop().await { + eprintln!("ChaosProxy server error: {}", err); + } + }); + + Ok(instance) + } + + pub(crate) fn address(&self) -> SocketAddr { + self.address + } + + pub(crate) fn break_networking(&self) { + self.break_networking_send + .send(()) + .expect("failed to send the break_networking message"); + } + + pub(crate) fn restore_networking(&self) { + self.restore_networking_send + .send(()) + .expect("failed to send the restore_networking message"); + } + + async fn server_loop(self: Arc) -> Result<(), Error> { + let mut break_networking_recv = self.break_networking_send.subscribe(); + let mut restore_networking_recv = self.restore_networking_send.subscribe(); + + loop { + let (client_read, client_write) = tokio::select! { + accepted = self.listener.accept() => accepted?.0.into_split(), + + // When networking is broken stop accepting connections until restore_networking() + _ = break_networking_recv.recv() => { + let _ = restore_networking_recv.recv().await; + continue; + }, + }; + let (backend_read, backend_write) = TcpStream::connect(&self.backend_address) + .await? + .into_split(); + + let self_clone = self.clone(); + self.runtime.spawn(async move { + if let Err(err) = self_clone.proxy_data(client_read, backend_write).await { + eprintln!("ChaosProxy connection error: {}", err); + } + }); + + let self_clone = self.clone(); + tokio::spawn(async move { + if let Err(err) = self_clone.proxy_data(backend_read, client_write).await { + eprintln!("ChaosProxy connection error: {}", err); + } + }); + } + } + + async fn proxy_data( + &self, + mut from: OwnedReadHalf, + mut to: OwnedWriteHalf, + ) -> Result<(), Error> { + let mut break_connections_recv = self.break_networking_send.subscribe(); + let mut buf = [0; 1024]; + + loop { + tokio::select! { + len = from.read(&mut buf) => { + let len = len?; + if len == 0 { + // EOF, the socket was closed + return Ok(()); + } + to.write(&buf[0..len]).await?; + } + _ = break_connections_recv.recv() => { + to.shutdown().await?; + return Ok(()); + } + } + } + } +} diff --git a/src/tests/util/fresh_schema.rs b/src/tests/util/fresh_schema.rs new file mode 100644 index 00000000000..3cd67ff7b63 --- /dev/null +++ b/src/tests/util/fresh_schema.rs @@ -0,0 +1,65 @@ +use diesel::connection::SimpleConnection; +use diesel::prelude::*; +use diesel_migrations::{find_migrations_directory, run_pending_migrations_in_directory}; +use rand::Rng; + +pub(crate) struct FreshSchema { + database_url: String, + schema_name: String, + management_conn: PgConnection, +} + +impl FreshSchema { + pub(crate) fn new(database_url: &str) -> Self { + let schema_name = generate_schema_name(); + + let conn = PgConnection::establish(&database_url).expect("can't connect to the test db"); + conn.batch_execute(&format!( + " + DROP SCHEMA IF EXISTS {schema_name} CASCADE; + CREATE SCHEMA {schema_name}; + SET search_path TO {schema_name}, public; + ", + schema_name = schema_name + )) + .expect("failed to initialize schema"); + + let migrations_dir = find_migrations_directory().unwrap(); + run_pending_migrations_in_directory(&conn, &migrations_dir, &mut std::io::sink()) + .expect("failed to run migrations on the test schema"); + + let database_url = url::Url::parse_with_params( + database_url, + &[("options", format!("--search_path={},public", schema_name))], + ) + .unwrap() + .to_string(); + + Self { + database_url, + schema_name, + management_conn: conn, + } + } + + pub(crate) fn database_url(&self) -> &str { + &self.database_url + } +} + +impl Drop for FreshSchema { + fn drop(&mut self) { + self.management_conn + .batch_execute(&format!("DROP SCHEMA {} CASCADE;", self.schema_name)) + .expect("failed to drop the test schema"); + } +} + +fn generate_schema_name() -> String { + let mut rng = rand::thread_rng(); + let random_string: String = std::iter::repeat(()) + .map(|_| rng.sample(rand::distributions::Alphanumeric) as char) + .take(16) + .collect(); + format!("cratesio_test_{}", random_string) +} diff --git a/src/tests/util/test_app.rs b/src/tests/util/test_app.rs index dec871a3c16..25b89b5b5fd 100644 --- a/src/tests/util/test_app.rs +++ b/src/tests/util/test_app.rs @@ -1,4 +1,5 @@ use super::{MockAnonymousUser, MockCookieUser, MockTokenUser}; +use crate::util::{chaosproxy::ChaosProxy, fresh_schema::FreshSchema}; use crate::{env, record}; use cargo_registry::{ background_jobs::Environment, @@ -9,7 +10,7 @@ use cargo_registry::{ use std::{rc::Rc, sync::Arc, time::Duration}; use cargo_registry::git::Repository as WorkerRepository; -use diesel::{Connection, PgConnection}; +use diesel::PgConnection; use git2::Repository as UpstreamRepository; use reqwest::{blocking::Client, Proxy}; use swirl::Runner; @@ -22,6 +23,10 @@ struct TestAppInner { middle: conduit_middleware::MiddlewareBuilder, index: Option, runner: Option>, + db_chaosproxy: Option>, + + // Must be the last field of the struct! + _fresh_schema: Option, } impl Drop for TestAppInner { @@ -167,6 +172,12 @@ impl TestApp { pub fn as_middleware(&self) -> &conduit_middleware::MiddlewareBuilder { &self.0.middle } + + pub(crate) fn db_chaosproxy(&self) -> Arc { + self.0.db_chaosproxy.clone().expect( + "ChaosProxy is not enabled on this test, call with_slow_real_pool during app init", + ) + } } pub struct TestAppBuilder { @@ -179,9 +190,36 @@ pub struct TestAppBuilder { impl TestAppBuilder { /// Create a `TestApp` with an empty database - pub fn empty(self) -> (TestApp, MockAnonymousUser) { + pub fn empty(mut self) -> (TestApp, MockAnonymousUser) { use crate::git; + // Run each test inside a fresh database schema, deleted at the end of the test, + // The schema will be cleared up once the app is dropped. + let (db_chaosproxy, fresh_schema) = if !self.config.use_test_database_pool { + let fresh_schema = FreshSchema::new(&self.config.db_primary_config.url); + self.config.db_primary_config.url = fresh_schema.database_url().into(); + + let mut db_url = + Url::parse(&self.config.db_primary_config.url).expect("invalid db url"); + let backend_addr = db_url + .socket_addrs(|| Some(5432)) + .expect("could not resolve database url") + .get(0) + .copied() + .expect("the database url does not point to any IP"); + + let db_chaosproxy = ChaosProxy::new(backend_addr).unwrap(); + db_url.set_ip_host(db_chaosproxy.address().ip()).unwrap(); + db_url + .set_port(Some(db_chaosproxy.address().port())) + .unwrap(); + self.config.db_primary_config.url = db_url.into_string(); + + (Some(db_chaosproxy), Some(fresh_schema)) + } else { + (None, None) + }; + let (app, middle) = build_app(self.config, self.proxy); let runner = if self.build_job_runner { @@ -211,10 +249,12 @@ impl TestAppBuilder { let test_app_inner = TestAppInner { app, + _fresh_schema: fresh_schema, _bomb: self.bomb, middle, index: self.index, runner, + db_chaosproxy, }; let test_app = TestApp(Rc::new(test_app_inner)); let anon = MockAnonymousUser { @@ -272,6 +312,11 @@ impl TestAppBuilder { self.build_job_runner = true; self } + + pub fn with_slow_real_db_pool(mut self) -> Self { + self.config.use_test_database_pool = false; + self + } } pub fn init_logger() { @@ -321,6 +366,7 @@ fn simple_config() -> Config { downloads_persist_interval_ms: 1000, ownership_invitations_expiration_days: 30, metrics_authorization_token: None, + use_test_database_pool: true, } } @@ -343,7 +389,6 @@ fn build_app( // the application. This will also prevent cluttering the filesystem. app.emails = Emails::new_in_memory(); - assert_ok!(assert_ok!(app.primary_database.get()).begin_test_transaction()); let app = Arc::new(app); let handler = cargo_registry::build_handler(Arc::clone(&app)); (app, handler)