diff --git a/lightning/src/ln/peer_handler.rs b/lightning/src/ln/peer_handler.rs index b4c24262737..6b03b365457 100644 --- a/lightning/src/ln/peer_handler.rs +++ b/lightning/src/ln/peer_handler.rs @@ -3452,48 +3452,44 @@ mod tests { #[cfg(feature = "std")] fn test_process_events_multithreaded() { use std::time::{Duration, Instant}; - // Test that `process_events` getting called on multiple threads doesn't generate too many - // loop iterations. + // `process_events` shouldn't block on another thread processing events and instead should + // simply signal the currently processing thread to go around the loop again. + // Here we test that this happens by spawning a few threads and checking that we see one go + // around again at least once. + // // Each time `process_events` goes around the loop we call - // `get_and_clear_pending_msg_events`, which we count using the `TestMessageHandler`. - // Because the loop should go around once more after a call which fails to take the - // single-threaded lock, if we write zero to the counter before calling `process_events` we - // should never observe there having been more than 2 loop iterations. - // Further, because the last thread to exit will call `process_events` before returning, we - // should always have at least one count at the end. + // `get_and_clear_pending_msg_events`, which we count using the `TestMessageHandler`. Thus, + // to test we simply write zero to the counter before calling `process_events` and make + // sure we observe a value greater than one at least once. let cfg = Arc::new(create_peermgr_cfgs(1)); // Until we have std::thread::scoped we have to unsafe { turn off the borrow checker }. let peer = Arc::new(create_network(1, unsafe { &*(&*cfg as *const _) as &'static _ }).pop().unwrap()); - let exit_flag = Arc::new(AtomicBool::new(false)); - macro_rules! spawn_thread { () => { { - let thread_cfg = Arc::clone(&cfg); + let end_time = Instant::now() + Duration::from_millis(100); + let observed_loop = Arc::new(AtomicBool::new(false)); + let thread_fn = || { let thread_peer = Arc::clone(&peer); - let thread_exit = Arc::clone(&exit_flag); - std::thread::spawn(move || { - while !thread_exit.load(Ordering::Acquire) { - thread_cfg[0].chan_handler.message_fetch_counter.store(0, Ordering::Release); + let thread_observed_loop = Arc::clone(&observed_loop); + move || { + while Instant::now() < end_time || !thread_observed_loop.load(Ordering::Acquire) { + test_utils::TestChannelMessageHandler::MESSAGE_FETCH_COUNTER.with(|val| val.store(0, Ordering::Relaxed)); thread_peer.process_events(); + if test_utils::TestChannelMessageHandler::MESSAGE_FETCH_COUNTER.with(|val| val.load(Ordering::Relaxed)) > 1 { + thread_observed_loop.store(true, Ordering::Release); + return; + } std::thread::sleep(Duration::from_micros(1)); } - }) - } } } - - let thread_a = spawn_thread!(); - let thread_b = spawn_thread!(); - let thread_c = spawn_thread!(); - - let start_time = Instant::now(); - while start_time.elapsed() < Duration::from_millis(100) { - let val = cfg[0].chan_handler.message_fetch_counter.load(Ordering::Acquire); - assert!(val <= 2); - std::thread::yield_now(); // Winblowz seemingly doesn't ever interrupt threads?! - } + } + }; - exit_flag.store(true, Ordering::Release); + let thread_a = std::thread::spawn(thread_fn()); + let thread_b = std::thread::spawn(thread_fn()); + let thread_c = std::thread::spawn(thread_fn()); + thread_fn()(); thread_a.join().unwrap(); thread_b.join().unwrap(); thread_c.join().unwrap(); - assert!(cfg[0].chan_handler.message_fetch_counter.load(Ordering::Acquire) >= 1); + assert!(observed_loop.load(Ordering::Acquire)); } } diff --git a/lightning/src/util/test_utils.rs b/lightning/src/util/test_utils.rs index 3f2226d741d..050c2a3e007 100644 --- a/lightning/src/util/test_utils.rs +++ b/lightning/src/util/test_utils.rs @@ -759,17 +759,21 @@ pub struct TestChannelMessageHandler { pub pending_events: Mutex>, expected_recv_msgs: Mutex>>>, connected_peers: Mutex>, - pub message_fetch_counter: AtomicUsize, chain_hash: ChainHash, } +impl TestChannelMessageHandler { + thread_local! { + pub static MESSAGE_FETCH_COUNTER: AtomicUsize = AtomicUsize::new(0); + } +} + impl TestChannelMessageHandler { pub fn new(chain_hash: ChainHash) -> Self { TestChannelMessageHandler { pending_events: Mutex::new(Vec::new()), expected_recv_msgs: Mutex::new(None), connected_peers: Mutex::new(new_hash_set()), - message_fetch_counter: AtomicUsize::new(0), chain_hash, } } @@ -940,7 +944,7 @@ impl msgs::ChannelMessageHandler for TestChannelMessageHandler { impl events::MessageSendEventsProvider for TestChannelMessageHandler { fn get_and_clear_pending_msg_events(&self) -> Vec { - self.message_fetch_counter.fetch_add(1, Ordering::AcqRel); + Self::MESSAGE_FETCH_COUNTER.with(|val| val.fetch_add(1, Ordering::AcqRel)); let mut pending_events = self.pending_events.lock().unwrap(); let mut ret = Vec::new(); mem::swap(&mut ret, &mut *pending_events);