@@ -13,6 +13,7 @@ use futures_core::task::{Context, Poll, Waker};
1313use futures_sink:: Sink ;
1414use futures_task:: { waker, ArcWake } ;
1515use pin_utils:: { unsafe_pinned, unsafe_unpinned} ;
16+ use core:: cell:: UnsafeCell ;
1617
1718/// Indicates that there is nothing to poll and stream isn't polled at the
1819/// moment.
@@ -68,18 +69,24 @@ impl SharedPollState {
6869/// Waker which will update `poll_state` with `need_to_poll` value on
6970/// `wake_by_ref` call and then, if there is a need, call `inner_waker`.
7071struct PollWaker {
71- inner_waker : Waker ,
72+ inner_waker : UnsafeCell < Option < Waker > > ,
7273 poll_state : SharedPollState ,
7374 need_to_poll : u8 ,
7475}
7576
77+ unsafe impl Send for PollWaker { }
78+
79+ unsafe impl Sync for PollWaker { }
80+
7681impl ArcWake for PollWaker {
7782 fn wake_by_ref ( self_arc : & Arc < Self > ) {
7883 let poll_state_value = self_arc. poll_state . set_or ( self_arc. need_to_poll ) ;
7984 // Only call waker if stream isn't polled because it will be called
8085 // at the end of polling if state was changed.
8186 if poll_state_value & POLLING == NONE {
82- self_arc. inner_waker . wake_by_ref ( ) ;
87+ if let Some ( Some ( inner_waker) ) = unsafe { self_arc. inner_waker . get ( ) . as_ref ( ) } {
88+ inner_waker. wake_by_ref ( ) ;
89+ }
8390 }
8491 }
8592}
@@ -135,6 +142,8 @@ pub struct FlatMapUnordered<St: Stream, U: Stream, F: FnMut(St::Item) -> U> {
135142 stream : Map < St , F > ,
136143 limit : Option < NonZeroUsize > ,
137144 is_stream_done : bool ,
145+ futures_waker : Arc < PollWaker > ,
146+ stream_waker : Arc < PollWaker >
138147}
139148
140149impl < St , U , F > Unpin for FlatMapUnordered < St , U , F >
@@ -173,16 +182,29 @@ where
173182 unsafe_unpinned ! ( is_stream_done: bool ) ;
174183 unsafe_unpinned ! ( limit: Option <NonZeroUsize >) ;
175184 unsafe_unpinned ! ( poll_state: SharedPollState ) ;
185+ unsafe_unpinned ! ( futures_waker: Arc <PollWaker >) ;
186+ unsafe_unpinned ! ( stream_waker: Arc <PollWaker >) ;
176187
177188 pub ( super ) fn new ( stream : St , limit : Option < usize > , f : F ) -> FlatMapUnordered < St , U , F > {
189+ // Because to create first future, it needs to get inner
190+ // stream from `stream`
191+ let poll_state = SharedPollState :: new ( NEED_TO_POLL_STREAM ) ;
178192 FlatMapUnordered {
179- // Because to create first future, it needs to get inner
180- // stream from `stream`
181- poll_state : SharedPollState :: new ( NEED_TO_POLL_STREAM ) ,
182193 futures : FuturesUnordered :: new ( ) ,
183194 stream : Map :: new ( stream, f) ,
184195 is_stream_done : false ,
185196 limit : limit. and_then ( NonZeroUsize :: new) ,
197+ futures_waker : Arc :: new ( PollWaker {
198+ inner_waker : UnsafeCell :: new ( None ) ,
199+ poll_state : poll_state. clone ( ) ,
200+ need_to_poll : NEED_TO_POLL_FUTURES ,
201+ } ) ,
202+ stream_waker : Arc :: new ( PollWaker {
203+ inner_waker : UnsafeCell :: new ( None ) ,
204+ poll_state : poll_state. clone ( ) ,
205+ need_to_poll : NEED_TO_POLL_STREAM ,
206+ } ) ,
207+ poll_state
186208 }
187209 }
188210
@@ -218,28 +240,30 @@ where
218240 self . stream . into_inner ( )
219241 }
220242
221- /// Creates waker with given `need_to_poll` value, which will be used to
222- /// update poll state on `wake_by_ref` call.
223- fn create_waker ( & self , inner_waker : Waker , need_to_poll : u8 ) -> Waker {
224- waker ( Arc :: new ( PollWaker {
225- inner_waker,
226- poll_state : self . poll_state . clone ( ) ,
227- need_to_poll,
228- } ) )
229- }
230-
231243 /// Creates special waker for polling stream which will set poll state
232244 /// to poll `stream` on `wake_by_ref` call. Use only if you need several
233245 /// contexts.
234- fn create_poll_stream_waker ( & self , ctx : & Context < ' _ > ) -> Waker {
235- self . create_waker ( ctx. waker ( ) . clone ( ) , NEED_TO_POLL_STREAM )
246+ ///
247+ /// ## Safety
248+ ///
249+ /// This function will modify current `stream_waker`'s `inner_waker`
250+ /// via `UnsafeCell`, so it should be used only in `POLLING` phase.
251+ unsafe fn create_poll_stream_waker ( mut self : Pin < & mut Self > , ctx : & Context < ' _ > ) -> Waker {
252+ * self . as_mut ( ) . stream_waker . inner_waker . get ( ) = ctx. waker ( ) . clone ( ) . into ( ) ;
253+ waker ( self . stream_waker . clone ( ) )
236254 }
237255
238256 /// Creates special waker for polling futures which willset poll state
239257 /// to poll `futures` on `wake_by_ref` call. Use only if you need several
240- /// contexts.
241- fn create_poll_futures_waker ( & self , ctx : & Context < ' _ > ) -> Waker {
242- self . create_waker ( ctx. waker ( ) . clone ( ) , NEED_TO_POLL_FUTURES )
258+ /// contexts.
259+ ///
260+ /// ## Safety
261+ ///
262+ /// This function will modify current `futures_waker`'s `inner_waker`
263+ /// via `UnsafeCell`, so it should be used only in `POLLING` phase.
264+ unsafe fn create_poll_futures_waker ( mut self : Pin < & mut Self > , ctx : & Context < ' _ > ) -> Waker {
265+ * self . as_mut ( ) . futures_waker . inner_waker . get ( ) = ctx. waker ( ) . clone ( ) . into ( ) ;
266+ waker ( self . futures_waker . clone ( ) )
243267 }
244268
245269 /// Checks if current `futures` size is less than optional limit.
@@ -281,7 +305,8 @@ where
281305 if poll_state_value & NEED_TO_POLL_STREAM != NONE {
282306 if self . not_exceeded_limit ( ) {
283307 match if polling_with_two_wakers {
284- let waker = self . create_poll_stream_waker ( ctx) ;
308+ // Safety: now state is `POLLING`.
309+ let waker = unsafe { self . as_mut ( ) . create_poll_stream_waker ( ctx) } ;
285310 let mut ctx = Context :: from_waker ( & waker) ;
286311 self . as_mut ( ) . stream ( ) . poll_next ( & mut ctx)
287312 } else {
@@ -317,7 +342,8 @@ where
317342
318343 if poll_state_value & NEED_TO_POLL_FUTURES != NONE {
319344 match if polling_with_two_wakers {
320- let waker = self . create_poll_futures_waker ( ctx) ;
345+ // Safety: now state is `POLLING`.
346+ let waker = unsafe { self . as_mut ( ) . create_poll_futures_waker ( ctx) } ;
321347 let mut ctx = Context :: from_waker ( & waker) ;
322348 self . as_mut ( ) . futures ( ) . poll_next ( & mut ctx)
323349 } else {
0 commit comments