diff --git a/src/db.rs b/src/db.rs index 73436dcf75c..94237490974 100644 --- a/src/db.rs +++ b/src/db.rs @@ -31,9 +31,18 @@ impl DieselPool { } } - fn test_conn(conn: PgConnection) -> Self { + pub fn test_conn(conn: PgConnection) -> Self { DieselPool::Test(Arc::new(ReentrantMutex::new(conn))) } + + pub fn unwrap_test_conn(self) -> Result { + match self { + DieselPool::Test(shared_conn) => Arc::try_unwrap(shared_conn) + .map(|c| c.into_inner()) + .map_err(Self::Test), + other => Err(other), + } + } } #[allow(missing_debug_implementations)] diff --git a/src/tasks/update_downloads.rs b/src/tasks/update_downloads.rs index 1cd0d3a3704..ce0dfda578b 100644 --- a/src/tasks/update_downloads.rs +++ b/src/tasks/update_downloads.rs @@ -6,25 +6,34 @@ use crate::{ use diesel::prelude::*; use swirl::PerformError; +#[cfg(not(test))] +const ROWS_PER_BATCH: i64 = 1000; + +#[cfg(test)] +const ROWS_PER_BATCH: i64 = 1; + #[swirl::background_job] pub fn update_downloads(conn: &PgConnection) -> Result<(), PerformError> { - update(&conn)?; - Ok(()) -} - -fn update(conn: &PgConnection) -> QueryResult<()> { use self::version_downloads::dsl::*; use diesel::dsl::now; use diesel::select; - let rows = version_downloads - .filter(processed.eq(false)) - .filter(downloads.ne(counted)) - .load(conn)?; - - println!("Updating {} versions", rows.len()); - collect(conn, &rows)?; - println!("Finished updating versions"); + println!("Enqueuing jobs to count downloads"); + let mut last_id = Some(0); + while let Some(id) = last_id { + let rows = version_downloads + .filter(processed.eq(false)) + .filter(downloads.ne(counted)) + .filter(version_id.gt(id)) + .limit(ROWS_PER_BATCH) + .select(version_id) + .load(conn)?; + last_id = rows.last().copied(); + if let Some(max_id) = last_id { + update_downloads_batch(id, max_id).enqueue(&conn)?; + } + } + println!("Finished enqueuing jobs"); // Anything older than 24 hours ago will be frozen and will not be queried // against again. @@ -43,6 +52,23 @@ fn update(conn: &PgConnection) -> QueryResult<()> { Ok(()) } +#[swirl::background_job] +pub fn update_downloads_batch( + conn: &PgConnection, + min_version_id: i32, + max_version_id: i32, +) -> Result<(), PerformError> { + use self::version_downloads::dsl::*; + + let rows = version_downloads + .filter(processed.eq(false)) + .filter(downloads.ne(counted)) + .filter(version_id.between(min_version_id, max_version_id)) + .load(conn)?; + collect(conn, &rows)?; + Ok(()) +} + fn collect(conn: &PgConnection, rows: &[VersionDownload]) -> QueryResult<()> { use diesel::update; @@ -89,6 +115,24 @@ mod test { }; use std::collections::HashMap; + fn run_update(conn: PgConnection) -> PgConnection { + use crate::db::DieselPool; + use swirl::{Job, Runner}; + + super::update_downloads().enqueue(&conn).unwrap(); + let pool = DieselPool::test_conn(conn); + { + let runner = Runner::builder(()) + .thread_count(1) + .connection_pool(pool.clone()) + .build(); + runner.run_all_pending_jobs().unwrap(); + runner.check_for_failed_jobs().unwrap(); + } + pool.unwrap_test_conn() + .unwrap_or_else(|_| panic!("couldn't unwrap pool")) + } + fn conn() -> PgConnection { let conn = PgConnection::establish(&env("TEST_DATABASE_URL")).unwrap(); conn.begin_test_transaction().unwrap(); @@ -142,7 +186,7 @@ mod test { .execute(&conn) .unwrap(); - super::update(&conn).unwrap(); + let conn = run_update(conn); let version_downloads = versions::table .find(version.id) .select(versions::downloads) @@ -153,7 +197,7 @@ mod test { .select(crates::downloads) .first(&conn); assert_eq!(Ok(1), crate_downloads); - super::update(&conn).unwrap(); + let conn = run_update(conn); let version_downloads = versions::table .find(version.id) .select(versions::downloads) @@ -178,7 +222,7 @@ mod test { )) .execute(&conn) .unwrap(); - super::update(&conn).unwrap(); + let conn = run_update(conn); let processed = version_downloads::table .filter(version_downloads::version_id.eq(version.id)) .select(version_downloads::processed) @@ -202,7 +246,7 @@ mod test { )) .execute(&conn) .unwrap(); - super::update(&conn).unwrap(); + let conn = run_update(conn); let processed = version_downloads::table .filter(version_downloads::version_id.eq(version.id)) .select(version_downloads::processed) @@ -249,7 +293,7 @@ mod test { .filter(crates::id.eq(krate.id)) .first(&conn) .unwrap(); - super::update(&conn).unwrap(); + let conn = run_update(conn); let version2: Version = versions::table.find(version.id).first(&conn).unwrap(); assert_eq!(version2.downloads, 2); assert_eq!(version2.updated_at, version_before.updated_at); @@ -259,7 +303,7 @@ mod test { .unwrap(); assert_eq!(krate2.downloads, 2); assert_eq!(krate2.updated_at, krate_before.updated_at); - super::update(&conn).unwrap(); + let conn = run_update(conn); let version3: Version = versions::table.find(version.id).first(&conn).unwrap(); assert_eq!(version3.downloads, 2); } @@ -291,7 +335,7 @@ mod test { .execute(&conn) .unwrap(); - super::update(&conn).unwrap(); + let conn = run_update(conn); let versions_changed = versions::table .select(versions::updated_at.ne(now - 2.days())) .get_result(&conn);