Skip to content

Never store more than one StdWaker per live Future #2894

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lightning-background-processor/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -854,8 +854,8 @@ impl BackgroundProcessor {
peer_manager.onion_message_handler().process_pending_events(&event_handler),
gossip_sync, logger, scorer, stop_thread.load(Ordering::Acquire),
{ Sleeper::from_two_futures(
channel_manager.get_event_or_persistence_needed_future(),
chain_monitor.get_update_future()
&channel_manager.get_event_or_persistence_needed_future(),
&chain_monitor.get_update_future()
).wait_timeout(Duration::from_millis(100)); },
|_| Instant::now(), |time: &Instant, dur| time.elapsed().as_secs() > dur, false,
|| {
Expand Down
152 changes: 109 additions & 43 deletions lightning/src/util/wakers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,25 +56,33 @@ impl Notifier {
/// Gets a [`Future`] that will get woken up with any waiters
pub(crate) fn get_future(&self) -> Future {
let mut lock = self.notify_pending.lock().unwrap();
let mut self_idx = 0;
if let Some(existing_state) = &lock.1 {
if existing_state.lock().unwrap().callbacks_made {
let mut locked = existing_state.lock().unwrap();
if locked.callbacks_made {
// If the existing `FutureState` has completed and actually made callbacks,
// consider the notification flag to have been cleared and reset the future state.
mem::drop(locked);
lock.1.take();
lock.0 = false;
} else {
self_idx = locked.next_idx;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: 'fetch and add' could be a method so that (in future, no pun intended) we would never forget to increase the counter?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I played with a constructor a bit trying to make it more robust but didn't really see a decent way to do it without double-locking everywhere. Just defining a method to fetch-and-increment the index doesn't seem like it'll actually prevent a bug cause we'll just forget to use it :)

locked.next_idx += 1;
}
}
if let Some(existing_state) = &lock.1 {
Future { state: Arc::clone(&existing_state) }
Future { state: Arc::clone(&existing_state), self_idx }
} else {
let state = Arc::new(Mutex::new(FutureState {
callbacks: Vec::new(),
std_future_callbacks: Vec::new(),
callbacks_with_state: Vec::new(),
complete: lock.0,
callbacks_made: false,
next_idx: 1,
}));
lock.1 = Some(Arc::clone(&state));
Future { state }
Future { state, self_idx: 0 }
}
}

Expand Down Expand Up @@ -109,36 +117,39 @@ define_callback!(Send);
define_callback!();

pub(crate) struct FutureState {
// When we're tracking whether a callback counts as having woken the user's code, we check the
// first bool - set to false if we're just calling a Waker, and true if we're calling an actual
// user-provided function.
callbacks: Vec<(bool, Box<dyn FutureCallback>)>,
callbacks_with_state: Vec<(bool, Box<dyn Fn(&Arc<Mutex<FutureState>>) -> () + Send>)>,
// `callbacks` count as having woken the users' code (as they go direct to the user), but
// `std_future_callbacks` and `callbacks_with_state` do not (as the first just wakes a future,
// we only count it after another `poll()` and the second wakes a `Sleeper` which handles
// setting `callbacks_made` itself).
callbacks: Vec<Box<dyn FutureCallback>>,
std_future_callbacks: Vec<(usize, StdWaker)>,
callbacks_with_state: Vec<Box<dyn Fn(&Arc<Mutex<FutureState>>) -> () + Send>>,
complete: bool,
callbacks_made: bool,
next_idx: usize,
}

fn complete_future(this: &Arc<Mutex<FutureState>>) -> bool {
let mut state_lock = this.lock().unwrap();
let state = &mut *state_lock;
for (counts_as_call, callback) in state.callbacks.drain(..) {
for callback in state.callbacks.drain(..) {
callback.call();
state.callbacks_made |= counts_as_call;
state.callbacks_made = true;
}
for (counts_as_call, callback) in state.callbacks_with_state.drain(..) {
for (_, waker) in state.std_future_callbacks.drain(..) {
waker.0.wake_by_ref();
}
for callback in state.callbacks_with_state.drain(..) {
(callback)(this);
state.callbacks_made |= counts_as_call;
}
state.complete = true;
state.callbacks_made
}

/// A simple future which can complete once, and calls some callback(s) when it does so.
///
/// Clones can be made and all futures cloned from the same source will complete at the same time.
#[derive(Clone)]
pub struct Future {
state: Arc<Mutex<FutureState>>,
self_idx: usize,
}

impl Future {
Expand All @@ -153,7 +164,7 @@ impl Future {
mem::drop(state);
callback.call();
} else {
state.callbacks.push((true, callback));
state.callbacks.push(callback);
}
}

Expand All @@ -169,16 +180,16 @@ impl Future {

/// Waits until this [`Future`] completes.
#[cfg(feature = "std")]
pub fn wait(self) {
Sleeper::from_single_future(self).wait();
pub fn wait(&self) {
Sleeper::from_single_future(&self).wait();
}

/// Waits until this [`Future`] completes or the given amount of time has elapsed.
///
/// Returns true if the [`Future`] completed, false if the time elapsed.
#[cfg(feature = "std")]
pub fn wait_timeout(self, max_wait: Duration) -> bool {
Sleeper::from_single_future(self).wait_timeout(max_wait)
pub fn wait_timeout(&self, max_wait: Duration) -> bool {
Sleeper::from_single_future(&self).wait_timeout(max_wait)
}

#[cfg(test)]
Expand All @@ -191,11 +202,14 @@ impl Future {
}
}

impl Drop for Future {
fn drop(&mut self) {
self.state.lock().unwrap().std_future_callbacks.retain(|(idx, _)| *idx != self.self_idx);
}
}

use core::task::Waker;
struct StdWaker(pub Waker);
impl FutureCallback for StdWaker {
fn call(&self) { self.0.wake_by_ref() }
}

/// This is not exported to bindings users as Rust Futures aren't usable in language bindings.
impl<'a> StdFuture for Future {
Expand All @@ -208,7 +222,8 @@ impl<'a> StdFuture for Future {
Poll::Ready(())
} else {
let waker = cx.waker().clone();
state.callbacks.push((false, Box::new(StdWaker(waker))));
state.std_future_callbacks.retain(|(idx, _)| *idx != self.self_idx);
state.std_future_callbacks.push((self.self_idx, StdWaker(waker)));
Poll::Pending
}
}
Expand All @@ -224,17 +239,17 @@ pub struct Sleeper {
#[cfg(feature = "std")]
impl Sleeper {
/// Constructs a new sleeper from one future, allowing blocking on it.
pub fn from_single_future(future: Future) -> Self {
Self { notifiers: vec![future.state] }
pub fn from_single_future(future: &Future) -> Self {
Self { notifiers: vec![Arc::clone(&future.state)] }
}
/// Constructs a new sleeper from two futures, allowing blocking on both at once.
// Note that this is the common case - a ChannelManager and ChainMonitor.
pub fn from_two_futures(fut_a: Future, fut_b: Future) -> Self {
Self { notifiers: vec![fut_a.state, fut_b.state] }
pub fn from_two_futures(fut_a: &Future, fut_b: &Future) -> Self {
Self { notifiers: vec![Arc::clone(&fut_a.state), Arc::clone(&fut_b.state)] }
}
/// Constructs a new sleeper on many futures, allowing blocking on all at once.
pub fn new(futures: Vec<Future>) -> Self {
Self { notifiers: futures.into_iter().map(|f| f.state).collect() }
Self { notifiers: futures.into_iter().map(|f| Arc::clone(&f.state)).collect() }
}
/// Prepares to go into a wait loop body, creating a condition variable which we can block on
/// and an `Arc<Mutex<Option<_>>>` which gets set to the waking `Future`'s state prior to the
Expand All @@ -251,10 +266,10 @@ impl Sleeper {
*notified_fut_mtx.lock().unwrap() = Some(Arc::clone(&notifier_mtx));
break;
}
notifier.callbacks_with_state.push((false, Box::new(move |notifier_ref| {
notifier.callbacks_with_state.push(Box::new(move |notifier_ref| {
*notified_fut_ref.lock().unwrap() = Some(Arc::clone(notifier_ref));
cv_ref.notify_all();
})));
}));
}
}
(cv, notified_fut_mtx)
Expand Down Expand Up @@ -439,13 +454,15 @@ mod tests {

// Wait on the other thread to finish its sleep, note that the leak only happened if we
// actually have to sleep here, not if we immediately return.
Sleeper::from_two_futures(future_a, future_b).wait();
Sleeper::from_two_futures(&future_a, &future_b).wait();

join_handle.join().unwrap();

// then drop the notifiers and make sure the future states are gone.
mem::drop(notifier_a);
mem::drop(notifier_b);
mem::drop(future_a);
mem::drop(future_b);

assert!(future_state_a.upgrade().is_none() && future_state_b.upgrade().is_none());
}
Expand All @@ -455,10 +472,13 @@ mod tests {
let future = Future {
state: Arc::new(Mutex::new(FutureState {
callbacks: Vec::new(),
std_future_callbacks: Vec::new(),
callbacks_with_state: Vec::new(),
complete: false,
callbacks_made: false,
}))
next_idx: 1,
})),
self_idx: 0,
};
let callback = Arc::new(AtomicBool::new(false));
let callback_ref = Arc::clone(&callback);
Expand All @@ -475,10 +495,13 @@ mod tests {
let future = Future {
state: Arc::new(Mutex::new(FutureState {
callbacks: Vec::new(),
std_future_callbacks: Vec::new(),
callbacks_with_state: Vec::new(),
complete: false,
callbacks_made: false,
}))
next_idx: 1,
})),
self_idx: 0,
};
complete_future(&future.state);

Expand Down Expand Up @@ -514,12 +537,15 @@ mod tests {
let mut future = Future {
state: Arc::new(Mutex::new(FutureState {
callbacks: Vec::new(),
std_future_callbacks: Vec::new(),
callbacks_with_state: Vec::new(),
complete: false,
callbacks_made: false,
}))
next_idx: 2,
})),
self_idx: 0,
};
let mut second_future = Future { state: Arc::clone(&future.state) };
let mut second_future = Future { state: Arc::clone(&future.state), self_idx: 1 };

let (woken, waker) = create_waker();
assert_eq!(Pin::new(&mut future).poll(&mut Context::from_waker(&waker)), Poll::Pending);
Expand Down Expand Up @@ -638,18 +664,18 @@ mod tests {
// Set both notifiers as woken without sleeping yet.
notifier_a.notify();
notifier_b.notify();
Sleeper::from_two_futures(notifier_a.get_future(), notifier_b.get_future()).wait();
Sleeper::from_two_futures(&notifier_a.get_future(), &notifier_b.get_future()).wait();

// One future has woken us up, but the other should still have a pending notification.
Sleeper::from_two_futures(notifier_a.get_future(), notifier_b.get_future()).wait();
Sleeper::from_two_futures(&notifier_a.get_future(), &notifier_b.get_future()).wait();

// However once we've slept twice, we should no longer have any pending notifications
assert!(!Sleeper::from_two_futures(notifier_a.get_future(), notifier_b.get_future())
assert!(!Sleeper::from_two_futures(&notifier_a.get_future(), &notifier_b.get_future())
.wait_timeout(Duration::from_millis(10)));

// Test ordering somewhat more.
notifier_a.notify();
Sleeper::from_two_futures(notifier_a.get_future(), notifier_b.get_future()).wait();
Sleeper::from_two_futures(&notifier_a.get_future(), &notifier_b.get_future()).wait();
}

#[test]
Expand All @@ -667,7 +693,7 @@ mod tests {

// After sleeping one future (not guaranteed which one, however) will have its notification
// bit cleared.
Sleeper::from_two_futures(notifier_a.get_future(), notifier_b.get_future()).wait();
Sleeper::from_two_futures(&notifier_a.get_future(), &notifier_b.get_future()).wait();

// By registering a callback on the futures for both notifiers, one will complete
// immediately, but one will remain tied to the notifier, and will complete once the
Expand All @@ -686,8 +712,48 @@ mod tests {
notifier_b.notify();

assert!(callback_a.load(Ordering::SeqCst) && callback_b.load(Ordering::SeqCst));
Sleeper::from_two_futures(notifier_a.get_future(), notifier_b.get_future()).wait();
assert!(!Sleeper::from_two_futures(notifier_a.get_future(), notifier_b.get_future())
Sleeper::from_two_futures(&notifier_a.get_future(), &notifier_b.get_future()).wait();
assert!(!Sleeper::from_two_futures(&notifier_a.get_future(), &notifier_b.get_future())
.wait_timeout(Duration::from_millis(10)));
}

#[test]
#[cfg(feature = "std")]
fn multi_poll_stores_single_waker() {
// When a `Future` is `poll()`ed multiple times, only the last `Waker` should be called,
// but previously we'd store all `Waker`s until they're all woken at once. This tests a few
// cases to ensure `Future`s avoid storing an endless set of `Waker`s.
let notifier = Notifier::new();
let future_state = Arc::clone(&notifier.get_future().state);
assert_eq!(future_state.lock().unwrap().std_future_callbacks.len(), 0);

// Test that simply polling a future twice doesn't result in two pending `Waker`s.
let mut future_a = notifier.get_future();
assert_eq!(Pin::new(&mut future_a).poll(&mut Context::from_waker(&create_waker().1)), Poll::Pending);
assert_eq!(future_state.lock().unwrap().std_future_callbacks.len(), 1);
assert_eq!(Pin::new(&mut future_a).poll(&mut Context::from_waker(&create_waker().1)), Poll::Pending);
assert_eq!(future_state.lock().unwrap().std_future_callbacks.len(), 1);

// If we poll a second future, however, that will store a second `Waker`.
let mut future_b = notifier.get_future();
assert_eq!(Pin::new(&mut future_b).poll(&mut Context::from_waker(&create_waker().1)), Poll::Pending);
assert_eq!(future_state.lock().unwrap().std_future_callbacks.len(), 2);

// but when we drop the `Future`s, the pending Wakers will also be dropped.
mem::drop(future_a);
assert_eq!(future_state.lock().unwrap().std_future_callbacks.len(), 1);
mem::drop(future_b);
assert_eq!(future_state.lock().unwrap().std_future_callbacks.len(), 0);

// Further, after polling a future twice, if the notifier is woken all Wakers are dropped.
let mut future_a = notifier.get_future();
assert_eq!(Pin::new(&mut future_a).poll(&mut Context::from_waker(&create_waker().1)), Poll::Pending);
assert_eq!(future_state.lock().unwrap().std_future_callbacks.len(), 1);
assert_eq!(Pin::new(&mut future_a).poll(&mut Context::from_waker(&create_waker().1)), Poll::Pending);
assert_eq!(future_state.lock().unwrap().std_future_callbacks.len(), 1);
notifier.notify();
assert_eq!(future_state.lock().unwrap().std_future_callbacks.len(), 0);
assert_eq!(Pin::new(&mut future_a).poll(&mut Context::from_waker(&create_waker().1)), Poll::Ready(()));
assert_eq!(future_state.lock().unwrap().std_future_callbacks.len(), 0);
}
}