Skip to content

Commit 7627b0f

Browse files
committed
Make the initial Waker provided by block_on_future Sync
As it was pointed out in the review, the initial `Waker` for `block_on_future` was not holding up the `Sync` guarantees. If the `Waker` reference had been passed to another thread and cloned there, the cloned threadsafe `Waker` would have captured the wrong `Thread` handle. This change removes the optimization. A threadsafe `Waker` is now immediately created.
1 parent dcd2a01 commit 7627b0f

File tree

2 files changed

+111
-130
lines changed

2 files changed

+111
-130
lines changed

src/libstd/tests/block_on_future.rs

Lines changed: 70 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -75,23 +75,6 @@ impl Future for Yield {
7575
}
7676
}
7777

78-
struct NeverReady {
79-
}
80-
81-
impl NeverReady {
82-
fn new() -> Self {
83-
NeverReady {}
84-
}
85-
}
86-
87-
impl Future for NeverReady {
88-
type Output = ();
89-
90-
fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
91-
Poll::Pending
92-
}
93-
}
94-
9578
struct WakerStore {
9679
waker: Option<Waker>,
9780
}
@@ -203,14 +186,6 @@ fn returns_result_from_task() {
203186
assert_eq!(42, result);
204187
}
205188

206-
#[test]
207-
#[should_panic]
208-
fn panics_if_waker_was_not_cloned_and_task_is_not_ready() {
209-
block_on_future(async {
210-
NeverReady::new().await;
211-
});
212-
}
213-
214189
#[test]
215190
fn does_not_panic_if_waker_is_cloned_and_used_a_lot_later() {
216191
let store = Arc::new(Mutex::new(WakerStore {
@@ -227,3 +202,73 @@ fn does_not_panic_if_waker_is_cloned_and_used_a_lot_later() {
227202
WakeFromPreviouslyStoredWakerFuture::new(store).await;
228203
});
229204
}
205+
206+
struct WakeSynchronouslyFromOtherThreadFuture {
207+
was_polled: bool,
208+
use_clone: bool,
209+
}
210+
211+
impl WakeSynchronouslyFromOtherThreadFuture {
212+
fn new(use_clone: bool) -> Self {
213+
WakeSynchronouslyFromOtherThreadFuture {
214+
was_polled: false,
215+
use_clone,
216+
}
217+
}
218+
}
219+
220+
/// This is just a helper to transfer a waker by reference/pointer
221+
/// to another thread without the availability of scoped threads.
222+
struct WakerBox {
223+
waker: *const Waker,
224+
}
225+
226+
unsafe impl Send for WakerBox {}
227+
228+
impl Future for WakeSynchronouslyFromOtherThreadFuture {
229+
type Output = ();
230+
231+
fn poll(mut self: core::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
232+
if !self.was_polled {
233+
self.was_polled = true;
234+
// This captures the waker by pointer and passes it to the other thread,
235+
// since we don't have a scoped thread API available here.
236+
// The pointer is however guaranteed to be alive when we call it, due to
237+
// joining the thread in this scope.
238+
let waker_box = WakerBox {
239+
waker: cx.waker() as *const Waker,
240+
};
241+
let use_clone = self.use_clone;
242+
spawn(move ||{
243+
let x = waker_box;
244+
unsafe {
245+
if !use_clone {
246+
(*(x.waker as *mut Waker)).wake_by_ref();
247+
} else {
248+
let cloned_waker = (*(x.waker as *mut Waker)).clone();
249+
cloned_waker.wake_by_ref();
250+
}
251+
}
252+
}).join().unwrap();
253+
Poll::Pending
254+
} else {
255+
Poll::Ready(())
256+
}
257+
}
258+
}
259+
260+
#[test]
261+
fn wake_synchronously_by_ref_from_other_thread() {
262+
block_on_future(async {
263+
WakeSynchronouslyFromOtherThreadFuture::new(false).await;
264+
Yield::new(10).await;
265+
})
266+
}
267+
268+
#[test]
269+
fn clone_and_wake_synchronously_from_other_thread() {
270+
block_on_future(async {
271+
WakeSynchronouslyFromOtherThreadFuture::new(true).await;
272+
Yield::new(10).await;
273+
})
274+
}

src/libstd/thread/block_on_future.rs

Lines changed: 41 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -9,44 +9,21 @@ use crate::pin::Pin;
99
use crate::sync::Arc;
1010
use super::{current, park, Inner, Thread};
1111

12-
/// Carries a flag that is used to wakeup the executor.
13-
/// A pointer to this struct is passed to the thread-local waker.
14-
struct LocalWakeState {
15-
is_woken: bool,
16-
waker_was_cloned: bool,
17-
}
18-
19-
/// Returns the vtable that is used for waking up the executor
20-
/// from another thread.
21-
fn threadsafe_waker_vtable() -> &'static RawWakerVTable {
22-
&RawWakerVTable::new(
23-
clone_threadsafe_waker,
24-
wake_threadsafe_waker,
25-
wake_threadsafe_waker_by_ref,
26-
drop_threadsafe_waker,
27-
)
28-
}
29-
3012
/// Returns the vtable that is used for waking up the executor
31-
/// from inside it's execution on the current thread.
32-
fn current_thread_waker_vtable() -> &'static RawWakerVTable {
13+
/// from any thread.
14+
fn waker_vtable() -> &'static RawWakerVTable {
3315
&RawWakerVTable::new(
34-
create_threadsafe_waker,
35-
wake_current_thread,
36-
wake_current_thread_by_ref,
37-
|_| {},
16+
clone_waker,
17+
wake_waker,
18+
wake_waker_by_ref,
19+
drop_waker,
3820
)
3921
}
4022

41-
/// This method will be called when the waker reference gets cloned,
42-
/// which makes it possible to transfer it to another thread. In this
43-
/// case we have to create a threadsafe `Waker`. In order to to this
44-
/// we retain the thread handle and store it in the new `RawWaker`s
45-
/// data pointer.
46-
unsafe fn create_threadsafe_waker(data: *const()) -> RawWaker {
47-
let wake_state = data as *mut LocalWakeState;
48-
(*wake_state).waker_was_cloned = true;
49-
23+
/// Creates a [`RawWaker`] which captures the current thread handle
24+
/// and allows to wake up the [`block_on_future`] executor from any
25+
/// thread by calling [`Thread::unpark()`].
26+
fn create_threadsafe_raw_waker() -> RawWaker {
5027
// Get the `Arc<Inner>` of a current thread handle and store into in
5128
// the type erased pointer.
5229
//
@@ -61,32 +38,23 @@ unsafe fn create_threadsafe_waker(data: *const()) -> RawWaker {
6138
// `let arc_thread = Arc::new(current());`
6239
let arc_thread_inner = current().inner;
6340
let ptr = Arc::into_raw(arc_thread_inner) as *const ();
64-
RawWaker::new(ptr, threadsafe_waker_vtable())
41+
RawWaker::new(ptr, waker_vtable())
6542
}
6643

67-
unsafe fn clone_threadsafe_waker(data: *const()) -> RawWaker {
44+
unsafe fn clone_waker(data: *const()) -> RawWaker {
6845
increase_refcount(data);
69-
RawWaker::new(data, threadsafe_waker_vtable())
70-
}
71-
72-
fn wake_current_thread(_data: *const()) {
73-
unreachable!("A current thread waker can only be woken by reference");
46+
RawWaker::new(data, waker_vtable())
7447
}
7548

76-
unsafe fn wake_current_thread_by_ref(data: *const()) {
77-
let wake_state = data as *mut LocalWakeState;
78-
(*wake_state).is_woken = true;
79-
}
80-
81-
unsafe fn wake_threadsafe_waker(data: *const ()) {
49+
unsafe fn wake_waker(data: *const ()) {
8250
let arc_thread_inner = Arc::from_raw(data as *const Inner);
8351
let thread = Thread {
8452
inner: arc_thread_inner,
8553
};
8654
thread.unpark();
8755
}
8856

89-
unsafe fn wake_threadsafe_waker_by_ref(data: *const ()) {
57+
unsafe fn wake_waker_by_ref(data: *const ()) {
9058
// Retain `Arc`, but don't touch refcount by wrapping in `ManuallyDrop`
9159
let arc_thread_inner = Arc::from_raw(data as *const Inner);
9260
let thread = mem::ManuallyDrop::new(Thread {
@@ -95,7 +63,7 @@ unsafe fn wake_threadsafe_waker_by_ref(data: *const ()) {
9563
thread.unpark();
9664
}
9765

98-
unsafe fn drop_threadsafe_waker(data: *const ()) {
66+
unsafe fn drop_waker(data: *const ()) {
9967
drop(Thread {
10068
inner: Arc::from_raw(data as *const Inner),
10169
})
@@ -138,52 +106,23 @@ pub fn block_on_future<F: Future>(mut future: F) -> F::Output {
138106
// out of this function again.
139107
let mut future = unsafe { Pin::new_unchecked(&mut future) };
140108

141-
let mut waker_state = LocalWakeState {
142-
is_woken: true,
143-
waker_was_cloned: false,
109+
// Safety: The `Waker` that we create upholds all guarantees that are expected
110+
// from a `Waker`
111+
let waker = unsafe {
112+
Waker::from_raw(create_threadsafe_raw_waker())
144113
};
145114

146-
// Safety: The `Waker` that we create here is references data on the current
147-
// callstack. This is safe, since the polled `Future` only gets a reference
148-
// to this `Waker`. When it tries to clone the `Waker`, a threadsafe and owned
149-
// version is created instead.
150-
unsafe {
151-
let waker = Waker::from_raw(RawWaker::new(
152-
&waker_state as *const LocalWakeState as *const (),
153-
current_thread_waker_vtable()));
154-
155-
let mut cx = Context::from_waker(&waker);
156-
loop {
157-
while waker_state.is_woken {
158-
// Reset is_woken, so that we do not spin if the poll does not
159-
// directly wake us up.
160-
waker_state.is_woken = false;
161-
if let Poll::Ready(task_result) = future.as_mut().poll(&mut cx) {
162-
return task_result;
163-
}
164-
}
165-
166-
// The task is not ready, and the `Waker` had not been woken from the
167-
// current thread. In order for us to proceed we wait until the
168-
// thread gets unparked by another thread. If the `Waker` has not been
169-
// cloned this will never happen and represents a deadlock, which
170-
// gets reported here.
171-
if !waker_state.waker_was_cloned {
172-
panic!("Deadlock: Task is not ready, but the Waker had not been cloned");
173-
// Note: This flag is never reset, since a `Waker` that had been cloned
174-
// once can be cloned more often to wakeup this executor. We don't
175-
// have knowledge on how many clones are around - therefore the
176-
// deadlock detection only works for the case the `Waker` never
177-
// gets cloned.
178-
}
179-
park();
180-
// If thread::park has returned, we have been notified by another
181-
// thread. Therefore we are woken.
182-
// Remark: This flag can not be set by the other thread directly,
183-
// because it may no longer be alive at the point of time when
184-
// wake() is called.
185-
waker_state.is_woken = true;
115+
let mut cx = Context::from_waker(&waker);
116+
loop {
117+
if let Poll::Ready(task_result) = future.as_mut().poll(&mut cx) {
118+
return task_result;
186119
}
120+
121+
// The task is not ready. In order for us to proceed we wait until the
122+
// thread gets unparked. If the `Waker` had been woken inside `.poll()`,
123+
// then `park()` will immediately return, and we will call `.poll()`
124+
// again without any wait period.
125+
park();
187126
}
188127
}
189128

@@ -201,31 +140,28 @@ mod tests {
201140
fn check_refcounts() {
202141
let original = current_thread_refcount();
203142

204-
let waker_state = LocalWakeState {
205-
is_woken: true,
206-
waker_was_cloned: false,
207-
};
208-
209-
let waker = unsafe { Waker::from_raw(RawWaker::new(
210-
&waker_state as *const LocalWakeState as *const (),
211-
current_thread_waker_vtable())) };
143+
let waker = unsafe { Waker::from_raw(create_threadsafe_raw_waker()) };
144+
assert_eq!(original + 1, current_thread_refcount());
212145

213146
waker.wake_by_ref();
214-
assert_eq!(original, current_thread_refcount());
147+
assert_eq!(original + 1, current_thread_refcount());
215148

216149
let clone1 = waker.clone();
217-
assert_eq!(original + 1, current_thread_refcount());
218-
let clone2 = waker.clone();
219150
assert_eq!(original + 2, current_thread_refcount());
220-
let clone3 = clone1.clone();
151+
let clone2 = waker.clone();
221152
assert_eq!(original + 3, current_thread_refcount());
153+
let clone3 = clone1.clone();
154+
assert_eq!(original + 4, current_thread_refcount());
222155

223156
drop(clone1);
224-
assert_eq!(original + 2, current_thread_refcount());
157+
assert_eq!(original + 3, current_thread_refcount());
225158

226159
clone2.wake_by_ref();
227-
assert_eq!(original + 2, current_thread_refcount());
160+
assert_eq!(original + 3, current_thread_refcount());
228161
clone2.wake();
162+
assert_eq!(original + 2, current_thread_refcount());
163+
164+
drop(waker);
229165
assert_eq!(original + 1, current_thread_refcount());
230166

231167
clone3.wake_by_ref();

0 commit comments

Comments
 (0)