@@ -13,6 +13,7 @@ use futures_core::task::{Context, Poll, Waker};
1313use futures_sink:: Sink ;
1414use futures_task:: { waker, ArcWake } ;
1515use pin_utils:: { unsafe_pinned, unsafe_unpinned} ;
16+ use core:: cell:: UnsafeCell ;
1617
1718/// Indicates that there is nothing to poll and stream isn't polled at the
1819/// moment.
@@ -68,18 +69,24 @@ impl SharedPollState {
6869/// Waker which will update `poll_state` with `need_to_poll` value on
6970/// `wake_by_ref` call and then, if there is a need, call `inner_waker`.
7071struct PollWaker {
71- inner_waker : Waker ,
72+ inner_waker : UnsafeCell < Option < Waker > > ,
7273 poll_state : SharedPollState ,
7374 need_to_poll : u8 ,
7475}
7576
77+ unsafe impl Send for PollWaker { }
78+
79+ unsafe impl Sync for PollWaker { }
80+
7681impl ArcWake for PollWaker {
7782 fn wake_by_ref ( self_arc : & Arc < Self > ) {
7883 let poll_state_value = self_arc. poll_state . set_or ( self_arc. need_to_poll ) ;
7984 // Only call waker if stream isn't polled because it will be called
8085 // at the end of polling if state was changed.
8186 if poll_state_value & POLLING == NONE {
82- self_arc. inner_waker . wake_by_ref ( ) ;
87+ if let Some ( Some ( inner_waker) ) = unsafe { self_arc. inner_waker . get ( ) . as_ref ( ) } {
88+ inner_waker. wake_by_ref ( ) ;
89+ }
8390 }
8491 }
8592}
@@ -135,6 +142,8 @@ pub struct FlatMapUnordered<St: Stream, U: Stream, F: FnMut(St::Item) -> U> {
135142 stream : Map < St , F > ,
136143 limit : Option < NonZeroUsize > ,
137144 is_stream_done : bool ,
145+ futures_waker : Arc < PollWaker > ,
146+ stream_waker : Arc < PollWaker >
138147}
139148
140149impl < St , U , F > Unpin for FlatMapUnordered < St , U , F >
@@ -173,16 +182,29 @@ where
173182 unsafe_unpinned ! ( is_stream_done: bool ) ;
174183 unsafe_unpinned ! ( limit: Option <NonZeroUsize >) ;
175184 unsafe_unpinned ! ( poll_state: SharedPollState ) ;
185+ unsafe_unpinned ! ( futures_waker: Arc <PollWaker >) ;
186+ unsafe_unpinned ! ( stream_waker: Arc <PollWaker >) ;
176187
177188 pub ( super ) fn new ( stream : St , limit : Option < usize > , f : F ) -> FlatMapUnordered < St , U , F > {
189+ // Because to create first future, it needs to get inner
190+ // stream from `stream`
191+ let poll_state = SharedPollState :: new ( NEED_TO_POLL_STREAM ) ;
178192 FlatMapUnordered {
179- // Because to create first future, it needs to get inner
180- // stream from `stream`
181- poll_state : SharedPollState :: new ( NEED_TO_POLL_STREAM ) ,
182193 futures : FuturesUnordered :: new ( ) ,
183194 stream : Map :: new ( stream, f) ,
184195 is_stream_done : false ,
185196 limit : limit. and_then ( NonZeroUsize :: new) ,
197+ futures_waker : Arc :: new ( PollWaker {
198+ inner_waker : UnsafeCell :: new ( None ) ,
199+ poll_state : poll_state. clone ( ) ,
200+ need_to_poll : NEED_TO_POLL_FUTURES ,
201+ } ) ,
202+ stream_waker : Arc :: new ( PollWaker {
203+ inner_waker : UnsafeCell :: new ( None ) ,
204+ poll_state : poll_state. clone ( ) ,
205+ need_to_poll : NEED_TO_POLL_STREAM ,
206+ } ) ,
207+ poll_state
186208 }
187209 }
188210
@@ -218,28 +240,30 @@ where
218240 self . stream . into_inner ( )
219241 }
220242
221- /// Creates waker with given `need_to_poll` value, which will be used to
222- /// update poll state on `wake_by_ref` call.
223- fn create_waker ( & self , inner_waker : Waker , need_to_poll : u8 ) -> Waker {
224- waker ( Arc :: new ( PollWaker {
225- inner_waker,
226- poll_state : self . poll_state . clone ( ) ,
227- need_to_poll,
228- } ) )
229- }
230-
231243 /// Creates special waker for polling stream which will set poll state
232244 /// to poll `stream` on `wake_by_ref` call. Use only if you need several
233245 /// contexts.
234- fn create_poll_stream_waker ( & self , ctx : & Context < ' _ > ) -> Waker {
235- self . create_waker ( ctx. waker ( ) . clone ( ) , NEED_TO_POLL_STREAM )
246+ ///
247+ /// ## Safety
248+ ///
249+ /// This function will modify current `stream_waker`'s `inner_waker`
250+ /// via `UnsafeCell`, so it should be used only in `POLLING` phase.
251+ unsafe fn create_poll_stream_waker ( mut self : Pin < & mut Self > , ctx : & Context < ' _ > ) -> Waker {
252+ * self . as_mut ( ) . stream_waker . inner_waker . get ( ) = ctx. waker ( ) . clone ( ) . into ( ) ;
253+ waker ( self . stream_waker . clone ( ) )
236254 }
237255
238256 /// Creates special waker for polling futures which willset poll state
239257 /// to poll `futures` on `wake_by_ref` call. Use only if you need several
240- /// contexts.
241- fn create_poll_futures_waker ( & self , ctx : & Context < ' _ > ) -> Waker {
242- self . create_waker ( ctx. waker ( ) . clone ( ) , NEED_TO_POLL_FUTURES )
258+ /// contexts.
259+ ///
260+ /// ## Safety
261+ ///
262+ /// This function will modify current `futures_waker`'s `inner_waker`
263+ /// via `UnsafeCell`, so it should be used only in `POLLING` phase.
264+ unsafe fn create_poll_futures_waker ( mut self : Pin < & mut Self > , ctx : & Context < ' _ > ) -> Waker {
265+ * self . as_mut ( ) . futures_waker . inner_waker . get ( ) = ctx. waker ( ) . clone ( ) . into ( ) ;
266+ waker ( self . futures_waker . clone ( ) )
243267 }
244268
245269 /// Checks if current `futures` size is less than optional limit.
@@ -273,15 +297,15 @@ where
273297 let mut poll_state_value = self . as_mut ( ) . poll_state ( ) . begin_polling ( ) ;
274298 let mut next_item = None ;
275299 let mut need_to_poll_next = NONE ;
276- let mut polling_with_two_wakers =
277- poll_state_value & NEED_TO_POLL == NEED_TO_POLL && self . not_exceeded_limit ( ) ;
278- let mut stream_will_be_woken = false ;
300+ let mut stream_will_be_woken_or_polled_later = !self . not_exceeded_limit ( ) ;
279301 let mut futures_will_be_woken = false ;
302+ let mut polling_with_two_wakers = poll_state_value & NEED_TO_POLL == NEED_TO_POLL && !stream_will_be_woken_or_polled_later;
280303
281304 if poll_state_value & NEED_TO_POLL_STREAM != NONE {
282- if self . not_exceeded_limit ( ) {
305+ if !stream_will_be_woken_or_polled_later {
283306 match if polling_with_two_wakers {
284- let waker = self . create_poll_stream_waker ( ctx) ;
307+ // Safety: now state is `POLLING`.
308+ let waker = unsafe { self . as_mut ( ) . create_poll_stream_waker ( ctx) } ;
285309 let mut ctx = Context :: from_waker ( & waker) ;
286310 self . as_mut ( ) . stream ( ) . poll_next ( & mut ctx)
287311 } else {
@@ -304,7 +328,7 @@ where
304328 polling_with_two_wakers = false ;
305329 }
306330 Poll :: Pending => {
307- stream_will_be_woken = true ;
331+ stream_will_be_woken_or_polled_later = true ;
308332 if !polling_with_two_wakers {
309333 need_to_poll_next |= NEED_TO_POLL_STREAM ;
310334 }
@@ -317,7 +341,8 @@ where
317341
318342 if poll_state_value & NEED_TO_POLL_FUTURES != NONE {
319343 match if polling_with_two_wakers {
320- let waker = self . create_poll_futures_waker ( ctx) ;
344+ // Safety: now state is `POLLING`.
345+ let waker = unsafe { self . as_mut ( ) . create_poll_futures_waker ( ctx) } ;
321346 let mut ctx = Context :: from_waker ( & waker) ;
322347 self . as_mut ( ) . futures ( ) . poll_next ( & mut ctx)
323348 } else {
@@ -348,7 +373,7 @@ where
348373 if poll_state_value & NEED_TO_POLL != NONE
349374 && ( polling_with_two_wakers
350375 || poll_state_value & NEED_TO_POLL_FUTURES != NONE && !futures_will_be_woken
351- || poll_state_value & NEED_TO_POLL_STREAM != NONE && !stream_will_be_woken )
376+ || poll_state_value & NEED_TO_POLL_STREAM != NONE && !stream_will_be_woken_or_polled_later )
352377 {
353378 ctx. waker ( ) . wake_by_ref ( ) ;
354379 }
0 commit comments