Skip to content

Commit bd9a86c

Browse files
committed
Improve performance by modifying already created PollWakers instead of creating new
1 parent c2641d4 commit bd9a86c

File tree

1 file changed

+53
-28
lines changed

1 file changed

+53
-28
lines changed

futures-util/src/stream/stream/flat_map_unordered.rs

Lines changed: 53 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use futures_core::task::{Context, Poll, Waker};
1313
use futures_sink::Sink;
1414
use futures_task::{waker, ArcWake};
1515
use pin_utils::{unsafe_pinned, unsafe_unpinned};
16+
use core::cell::UnsafeCell;
1617

1718
/// Indicates that there is nothing to poll and stream isn't polled at the
1819
/// moment.
@@ -68,18 +69,24 @@ impl SharedPollState {
6869
/// Waker which will update `poll_state` with `need_to_poll` value on
6970
/// `wake_by_ref` call and then, if there is a need, call `inner_waker`.
7071
struct PollWaker {
71-
inner_waker: Waker,
72+
inner_waker: UnsafeCell<Option<Waker>>,
7273
poll_state: SharedPollState,
7374
need_to_poll: u8,
7475
}
7576

77+
unsafe impl Send for PollWaker {}
78+
79+
unsafe impl Sync for PollWaker {}
80+
7681
impl ArcWake for PollWaker {
7782
fn wake_by_ref(self_arc: &Arc<Self>) {
7883
let poll_state_value = self_arc.poll_state.set_or(self_arc.need_to_poll);
7984
// Only call waker if stream isn't polled because it will be called
8085
// at the end of polling if state was changed.
8186
if poll_state_value & POLLING == NONE {
82-
self_arc.inner_waker.wake_by_ref();
87+
if let Some(Some(inner_waker)) = unsafe { self_arc.inner_waker.get().as_ref() } {
88+
inner_waker.wake_by_ref();
89+
}
8390
}
8491
}
8592
}
@@ -135,6 +142,8 @@ pub struct FlatMapUnordered<St: Stream, U: Stream, F: FnMut(St::Item) -> U> {
135142
stream: Map<St, F>,
136143
limit: Option<NonZeroUsize>,
137144
is_stream_done: bool,
145+
futures_waker: Arc<PollWaker>,
146+
stream_waker: Arc<PollWaker>
138147
}
139148

140149
impl<St, U, F> Unpin for FlatMapUnordered<St, U, F>
@@ -173,16 +182,29 @@ where
173182
unsafe_unpinned!(is_stream_done: bool);
174183
unsafe_unpinned!(limit: Option<NonZeroUsize>);
175184
unsafe_unpinned!(poll_state: SharedPollState);
185+
unsafe_unpinned!(futures_waker: Arc<PollWaker>);
186+
unsafe_unpinned!(stream_waker: Arc<PollWaker>);
176187

177188
pub(super) fn new(stream: St, limit: Option<usize>, f: F) -> FlatMapUnordered<St, U, F> {
189+
// Because to create first future, it needs to get inner
190+
// stream from `stream`
191+
let poll_state = SharedPollState::new(NEED_TO_POLL_STREAM);
178192
FlatMapUnordered {
179-
// Because to create first future, it needs to get inner
180-
// stream from `stream`
181-
poll_state: SharedPollState::new(NEED_TO_POLL_STREAM),
182193
futures: FuturesUnordered::new(),
183194
stream: Map::new(stream, f),
184195
is_stream_done: false,
185196
limit: limit.and_then(NonZeroUsize::new),
197+
futures_waker: Arc::new(PollWaker {
198+
inner_waker: UnsafeCell::new(None),
199+
poll_state: poll_state.clone(),
200+
need_to_poll: NEED_TO_POLL_FUTURES,
201+
}),
202+
stream_waker: Arc::new(PollWaker {
203+
inner_waker: UnsafeCell::new(None),
204+
poll_state: poll_state.clone(),
205+
need_to_poll: NEED_TO_POLL_STREAM,
206+
}),
207+
poll_state
186208
}
187209
}
188210

@@ -218,28 +240,30 @@ where
218240
self.stream.into_inner()
219241
}
220242

221-
/// Creates waker with given `need_to_poll` value, which will be used to
222-
/// update poll state on `wake_by_ref` call.
223-
fn create_waker(&self, inner_waker: Waker, need_to_poll: u8) -> Waker {
224-
waker(Arc::new(PollWaker {
225-
inner_waker,
226-
poll_state: self.poll_state.clone(),
227-
need_to_poll,
228-
}))
229-
}
230-
231243
/// Creates special waker for polling stream which will set poll state
232244
/// to poll `stream` on `wake_by_ref` call. Use only if you need several
233245
/// contexts.
234-
fn create_poll_stream_waker(&self, ctx: &Context<'_>) -> Waker {
235-
self.create_waker(ctx.waker().clone(), NEED_TO_POLL_STREAM)
246+
///
247+
/// ## Safety
248+
///
249+
/// This function will modify current `stream_waker`'s `inner_waker`
250+
/// via `UnsafeCell`, so it should be used only in `POLLING` phase.
251+
unsafe fn create_poll_stream_waker(mut self: Pin<&mut Self>, ctx: &Context<'_>) -> Waker {
252+
*self.as_mut().stream_waker.inner_waker.get() = ctx.waker().clone().into();
253+
waker(self.stream_waker.clone())
236254
}
237255

238256
/// Creates special waker for polling futures which willset poll state
239257
/// to poll `futures` on `wake_by_ref` call. Use only if you need several
240-
/// contexts.
241-
fn create_poll_futures_waker(&self, ctx: &Context<'_>) -> Waker {
242-
self.create_waker(ctx.waker().clone(), NEED_TO_POLL_FUTURES)
258+
/// contexts.
259+
///
260+
/// ## Safety
261+
///
262+
/// This function will modify current `futures_waker`'s `inner_waker`
263+
/// via `UnsafeCell`, so it should be used only in `POLLING` phase.
264+
unsafe fn create_poll_futures_waker(mut self: Pin<&mut Self>, ctx: &Context<'_>) -> Waker {
265+
*self.as_mut().futures_waker.inner_waker.get() = ctx.waker().clone().into();
266+
waker(self.futures_waker.clone())
243267
}
244268

245269
/// Checks if current `futures` size is less than optional limit.
@@ -273,15 +297,15 @@ where
273297
let mut poll_state_value = self.as_mut().poll_state().begin_polling();
274298
let mut next_item = None;
275299
let mut need_to_poll_next = NONE;
276-
let mut polling_with_two_wakers =
277-
poll_state_value & NEED_TO_POLL == NEED_TO_POLL && self.not_exceeded_limit();
278-
let mut stream_will_be_woken = false;
300+
let mut stream_will_be_woken_or_polled_later = !self.not_exceeded_limit();
279301
let mut futures_will_be_woken = false;
302+
let mut polling_with_two_wakers = poll_state_value & NEED_TO_POLL == NEED_TO_POLL && !stream_will_be_woken_or_polled_later;
280303

281304
if poll_state_value & NEED_TO_POLL_STREAM != NONE {
282-
if self.not_exceeded_limit() {
305+
if !stream_will_be_woken_or_polled_later {
283306
match if polling_with_two_wakers {
284-
let waker = self.create_poll_stream_waker(ctx);
307+
// Safety: now state is `POLLING`.
308+
let waker = unsafe { self.as_mut().create_poll_stream_waker(ctx) };
285309
let mut ctx = Context::from_waker(&waker);
286310
self.as_mut().stream().poll_next(&mut ctx)
287311
} else {
@@ -304,7 +328,7 @@ where
304328
polling_with_two_wakers = false;
305329
}
306330
Poll::Pending => {
307-
stream_will_be_woken = true;
331+
stream_will_be_woken_or_polled_later = true;
308332
if !polling_with_two_wakers {
309333
need_to_poll_next |= NEED_TO_POLL_STREAM;
310334
}
@@ -317,7 +341,8 @@ where
317341

318342
if poll_state_value & NEED_TO_POLL_FUTURES != NONE {
319343
match if polling_with_two_wakers {
320-
let waker = self.create_poll_futures_waker(ctx);
344+
// Safety: now state is `POLLING`.
345+
let waker = unsafe { self.as_mut().create_poll_futures_waker(ctx) };
321346
let mut ctx = Context::from_waker(&waker);
322347
self.as_mut().futures().poll_next(&mut ctx)
323348
} else {
@@ -348,7 +373,7 @@ where
348373
if poll_state_value & NEED_TO_POLL != NONE
349374
&& (polling_with_two_wakers
350375
|| poll_state_value & NEED_TO_POLL_FUTURES != NONE && !futures_will_be_woken
351-
|| poll_state_value & NEED_TO_POLL_STREAM != NONE && !stream_will_be_woken)
376+
|| poll_state_value & NEED_TO_POLL_STREAM != NONE && !stream_will_be_woken_or_polled_later)
352377
{
353378
ctx.waker().wake_by_ref();
354379
}

0 commit comments

Comments
 (0)