Skip to content

Commit 7dc39b0

Browse files
committed
Basic StreamExt::flat_map_unordered impl
1 parent c19d43f commit 7dc39b0

File tree

3 files changed

+445
-1
lines changed

3 files changed

+445
-1
lines changed
Lines changed: 376 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,376 @@
1+
use super::Map;
2+
use crate::stream::FuturesUnordered;
3+
use core::fmt;
4+
use core::num::NonZeroUsize;
5+
use core::pin::Pin;
6+
use futures_core::future::Future;
7+
use futures_core::stream::FusedStream;
8+
use futures_core::stream::Stream;
9+
use futures_core::task::{Context, Poll, Waker};
10+
#[cfg(feature = "sink")]
11+
use futures_sink::Sink;
12+
use futures_task::{waker, ArcWake};
13+
use pin_utils::{unsafe_pinned, unsafe_unpinned};
14+
use std::sync::atomic::*;
15+
use std::sync::Arc;
16+
17+
/// Indicates that there is nothing to poll and stream isn't polled at the
18+
/// moment.
19+
const NONE: u8 = 0;
20+
21+
/// Indicates that `futures` need to be polled.
22+
const NEED_TO_POLL_FUTURES: u8 = 0b1;
23+
24+
/// Indicates that `stream` needs to be polled.
25+
const NEED_TO_POLL_STREAM: u8 = 0b10;
26+
27+
/// Indicates that we need to poll something.
28+
const NEED_TO_POLL: u8 = NEED_TO_POLL_FUTURES | NEED_TO_POLL_STREAM;
29+
30+
/// Indicates that current stream is polled at the moment.
31+
const POLLING: u8 = 0b100;
32+
33+
/// State which used to determine what needs to be polled,
34+
/// and are we polling stream at the moment or not.
35+
#[derive(Clone, Debug)]
36+
struct SharedPollState {
37+
state: Arc<AtomicU8>,
38+
}
39+
40+
impl SharedPollState {
41+
/// Constructs new `SharedPollState` with given state.
42+
fn new(state: u8) -> Self {
43+
Self {
44+
state: Arc::new(AtomicU8::new(state)),
45+
}
46+
}
47+
48+
/// Swaps state with `POLLING`, returning previous state.
49+
fn begin_polling(&self) -> u8 {
50+
self.state.swap(POLLING, Ordering::AcqRel)
51+
}
52+
53+
/// Performs bitwise or with `to_poll` and given state, returning
54+
/// previous state.
55+
fn set_or(&self, to_poll: u8) -> u8 {
56+
self.state.fetch_or(to_poll, Ordering::AcqRel)
57+
}
58+
59+
/// Performs bitwise or with `to_poll` and current state, stores result
60+
/// with non-`POLLING` state, and returns disjunction result.
61+
fn end_polling(&self, to_poll: u8) -> u8 {
62+
let to_poll = to_poll | self.state.load(Ordering::Acquire);
63+
self.state.store(to_poll & !POLLING, Ordering::Release);
64+
to_poll
65+
}
66+
}
67+
68+
/// Waker which will update `poll_state` with `need_to_poll` value on
69+
/// `wake_by_ref` call and then, if there is a need, call `inner_waker`.
70+
struct PollWaker {
71+
inner_waker: Waker,
72+
poll_state: SharedPollState,
73+
need_to_poll: u8,
74+
}
75+
76+
impl ArcWake for PollWaker {
77+
fn wake_by_ref(self_arc: &Arc<Self>) {
78+
let poll_state_value = self_arc.poll_state.set_or(self_arc.need_to_poll);
79+
// Only call waker if we're not polling because we will call it at the end
80+
// of polling if it needs to poll something.
81+
if poll_state_value & POLLING == NONE {
82+
self_arc.inner_waker.wake_by_ref();
83+
}
84+
}
85+
}
86+
87+
/// Future which contains optional stream. If it's `Some`, it will attempt
88+
/// to call `poll_next` on it, returning `Some((item, stream))` in case of
89+
/// `Poll::Ready(Some(...))` or `None` in case of `Poll::Ready(None)`.
90+
/// If `poll_next` will return `Poll::Pending`, it will be forwared to
91+
/// the future, and current task will be notified by waker.
92+
#[must_use = "futures do nothing unless you `.await` or poll them"]
93+
struct StreamFut<St> {
94+
stream: Option<St>,
95+
}
96+
97+
impl<St> StreamFut<St> {
98+
unsafe_pinned!(stream: Option<St>);
99+
}
100+
101+
impl<St: Stream + Unpin> Unpin for StreamFut<St> {}
102+
103+
impl<St: Stream> Future for StreamFut<St> {
104+
type Output = Option<(St::Item, St)>;
105+
106+
fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
107+
let item = if let Some(stream) = self.as_mut().stream().as_pin_mut() {
108+
ready!(stream.poll_next(ctx))
109+
} else {
110+
None
111+
};
112+
113+
Poll::Ready(item.map(|item| {
114+
(item, unsafe {
115+
self.get_unchecked_mut().stream.take().unwrap()
116+
})
117+
}))
118+
}
119+
}
120+
121+
/// Stream for the [`flat_map_unordered`](super::StreamExt::flat_map_unordered)
122+
/// method.
123+
#[must_use = "streams do nothing unless polled"]
124+
pub struct FlatMapUnordered<St: Stream, U: Stream, F: FnMut(St::Item) -> U> {
125+
poll_state: SharedPollState,
126+
futures: FuturesUnordered<StreamFut<U>>,
127+
stream: Map<St, F>,
128+
limit: Option<NonZeroUsize>,
129+
is_stream_done: bool,
130+
}
131+
132+
impl<St, U, F> Unpin for FlatMapUnordered<St, U, F>
133+
where
134+
St: Stream + Unpin,
135+
U: Stream + Unpin,
136+
F: FnMut(St::Item) -> U,
137+
{
138+
}
139+
140+
impl<St, U, F> fmt::Debug for FlatMapUnordered<St, U, F>
141+
where
142+
St: Stream + fmt::Debug,
143+
U: Stream + fmt::Debug,
144+
F: FnMut(St::Item) -> U,
145+
{
146+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
147+
f.debug_struct("FlatMapUnordered")
148+
.field("poll_state", &self.poll_state)
149+
.field("futures", &self.futures)
150+
.field("limit", &self.limit)
151+
.field("stream", &self.stream)
152+
.field("is_stream_done", &self.is_stream_done)
153+
.finish()
154+
}
155+
}
156+
157+
impl<St, U, F> FlatMapUnordered<St, U, F>
158+
where
159+
St: Stream,
160+
U: Stream,
161+
F: FnMut(St::Item) -> U,
162+
{
163+
unsafe_pinned!(futures: FuturesUnordered<StreamFut<U>>);
164+
unsafe_pinned!(stream: Map<St, F>);
165+
unsafe_unpinned!(is_stream_done: bool);
166+
unsafe_unpinned!(limit: Option<NonZeroUsize>);
167+
unsafe_unpinned!(poll_state: SharedPollState);
168+
169+
pub(super) fn new(stream: St, limit: Option<usize>, f: F) -> FlatMapUnordered<St, U, F> {
170+
FlatMapUnordered {
171+
// Because to create first future, we need to get inner
172+
// stream from `stream`
173+
poll_state: SharedPollState::new(NEED_TO_POLL_STREAM),
174+
futures: FuturesUnordered::new(),
175+
stream: Map::new(stream, f),
176+
is_stream_done: false,
177+
limit: limit.and_then(NonZeroUsize::new),
178+
}
179+
}
180+
181+
/// Acquires a reference to the underlying stream that this combinator is
182+
/// pulling from.
183+
pub fn get_ref(&self) -> &St {
184+
self.stream.get_ref()
185+
}
186+
187+
/// Acquires a mutable reference to the underlying stream that this
188+
/// combinator is pulling from.
189+
///
190+
/// Note that care must be taken to avoid tampering with the state of the
191+
/// stream which may otherwise confuse this combinator.
192+
pub fn get_mut(&mut self) -> &mut St {
193+
self.stream.get_mut()
194+
}
195+
196+
/// Acquires a pinned mutable reference to the underlying stream that this
197+
/// combinator is pulling from.
198+
///
199+
/// Note that care must be taken to avoid tampering with the state of the
200+
/// stream which may otherwise confuse this combinator.
201+
pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut St> {
202+
self.stream().get_pin_mut()
203+
}
204+
205+
/// Consumes this combinator, returning the underlying stream.
206+
///
207+
/// Note that this may discard intermediate state of this combinator, so
208+
/// care should be taken to avoid losing resources when this is called.
209+
pub fn into_inner(self) -> St {
210+
self.stream.into_inner()
211+
}
212+
213+
/// Creates waker with given `need_to_poll` value, which will be used to
214+
/// update poll state on `wake_by_ref` call.
215+
fn create_waker(&self, inner_waker: Waker, need_to_poll: u8) -> Waker {
216+
waker(Arc::new(PollWaker {
217+
inner_waker,
218+
poll_state: self.poll_state.clone(),
219+
need_to_poll,
220+
}))
221+
}
222+
223+
/// Creates special waker for polling stream which will set poll state
224+
/// to poll `stream` on `wake_by_ref` call. Use only if you need several
225+
/// contexts.
226+
fn create_poll_stream_waker(&self, ctx: &Context<'_>) -> Waker {
227+
self.create_waker(ctx.waker().clone(), NEED_TO_POLL_STREAM)
228+
}
229+
230+
/// Creates special waker for polling futures which willset poll state
231+
/// to poll `futures` on `wake_by_ref` call. Use only if you need several
232+
/// contexts.
233+
fn create_poll_futures_waker(&self, ctx: &Context<'_>) -> Waker {
234+
self.create_waker(ctx.waker().clone(), NEED_TO_POLL_FUTURES)
235+
}
236+
237+
/// Checks if current `futures` size is less than optional limit.
238+
fn not_exceeded_limit(&self) -> bool {
239+
self.limit
240+
.map(|limit| self.futures.len() < limit.get())
241+
.unwrap_or(true)
242+
}
243+
}
244+
245+
impl<St, U, F> FusedStream for FlatMapUnordered<St, U, F>
246+
where
247+
St: FusedStream,
248+
U: Unpin + FusedStream,
249+
F: FnMut(St::Item) -> U,
250+
{
251+
fn is_terminated(&self) -> bool {
252+
self.futures.is_empty() && self.stream.is_terminated()
253+
}
254+
}
255+
256+
impl<St, U, F> Stream for FlatMapUnordered<St, U, F>
257+
where
258+
St: Stream,
259+
U: Stream,
260+
F: FnMut(St::Item) -> U,
261+
{
262+
type Item = U::Item;
263+
264+
fn poll_next(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
265+
let mut poll_state_value = self.as_mut().poll_state().begin_polling();
266+
267+
let mut next_item = None;
268+
let mut need_to_poll_next = NONE;
269+
let mut polling_with_two_wakers =
270+
poll_state_value & NEED_TO_POLL == NEED_TO_POLL && self.not_exceeded_limit();
271+
let mut polled_stream = false;
272+
let mut polled_futures = false;
273+
274+
if poll_state_value & NEED_TO_POLL_STREAM != NONE {
275+
if self.not_exceeded_limit() {
276+
polled_stream = true;
277+
match if polling_with_two_wakers {
278+
let waker = self.create_poll_stream_waker(ctx);
279+
let mut ctx = Context::from_waker(&waker);
280+
self.as_mut().stream().poll_next(&mut ctx)
281+
} else {
282+
self.as_mut().stream().poll_next(ctx)
283+
} {
284+
Poll::Ready(Some(inner_stream)) => {
285+
self.as_mut().futures().push(StreamFut {
286+
stream: Some(inner_stream),
287+
});
288+
need_to_poll_next |= NEED_TO_POLL_STREAM;
289+
// Polling futures in current iteration with the same context
290+
// is ok because we already received `Poll::Ready` from
291+
// stream
292+
poll_state_value |= NEED_TO_POLL_FUTURES;
293+
polling_with_two_wakers = false;
294+
}
295+
Poll::Ready(None) => {
296+
*self.as_mut().is_stream_done() = true;
297+
// Polling futures in current iteration with the same context
298+
// is ok because we already received `Poll::Ready` from
299+
// stream
300+
polling_with_two_wakers = false;
301+
}
302+
Poll::Pending => {
303+
if !polling_with_two_wakers {
304+
need_to_poll_next |= NEED_TO_POLL_STREAM;
305+
}
306+
}
307+
}
308+
} else {
309+
need_to_poll_next |= NEED_TO_POLL_STREAM;
310+
}
311+
}
312+
313+
if poll_state_value & NEED_TO_POLL_FUTURES != NONE {
314+
polled_futures = true;
315+
match if polling_with_two_wakers {
316+
let waker = self.create_poll_futures_waker(ctx);
317+
let mut ctx = Context::from_waker(&waker);
318+
self.as_mut().futures().poll_next(&mut ctx)
319+
} else {
320+
self.as_mut().futures().poll_next(ctx)
321+
} {
322+
Poll::Ready(Some(Some((item, stream)))) => {
323+
self.as_mut().futures().push(StreamFut {
324+
stream: Some(stream),
325+
});
326+
next_item = Some(item);
327+
need_to_poll_next |= NEED_TO_POLL_FUTURES;
328+
}
329+
Poll::Ready(Some(None)) => {
330+
need_to_poll_next |= NEED_TO_POLL_FUTURES;
331+
}
332+
Poll::Pending => {
333+
if !polling_with_two_wakers {
334+
need_to_poll_next |= NEED_TO_POLL_FUTURES;
335+
}
336+
}
337+
_ => {
338+
need_to_poll_next &= !NEED_TO_POLL_FUTURES;
339+
}
340+
}
341+
}
342+
343+
let poll_state_value = self.as_mut().poll_state().end_polling(need_to_poll_next);
344+
345+
if poll_state_value & NEED_TO_POLL != NONE {
346+
if !polling_with_two_wakers {
347+
if poll_state_value & NEED_TO_POLL_FUTURES != NONE && !polled_futures
348+
|| poll_state_value & NEED_TO_POLL_STREAM != NONE && !polled_stream
349+
{
350+
ctx.waker().wake_by_ref();
351+
}
352+
} else {
353+
ctx.waker().wake_by_ref();
354+
}
355+
}
356+
357+
if self.futures.is_empty() && self.is_stream_done || next_item.is_some() {
358+
Poll::Ready(next_item)
359+
} else {
360+
Poll::Pending
361+
}
362+
}
363+
}
364+
365+
// Forwarding impl of Sink from the underlying stream
366+
#[cfg(feature = "sink")]
367+
impl<S, U, F, Item> Sink<Item> for FlatMapUnordered<S, U, F>
368+
where
369+
S: Stream + Sink<Item>,
370+
U: Stream,
371+
F: FnMut(S::Item) -> U,
372+
{
373+
type Error = S::Error;
374+
375+
delegate_sink!(stream, Item);
376+
}

0 commit comments

Comments
 (0)