Skip to content

Commit b39e690

Browse files
committed
Make panic handling better + significantly improve performance by polling stream in a loop
1 parent 9bc4e4e commit b39e690

File tree

3 files changed

+159
-90
lines changed

3 files changed

+159
-90
lines changed

futures-util/src/stream/stream/flatten_unordered.rs

Lines changed: 111 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -50,28 +50,47 @@ struct SharedPollState {
5050
}
5151

5252
impl SharedPollState {
53-
/// Constructs new `SharedPollState` with given state.
54-
fn new(state: u8) -> SharedPollState {
55-
SharedPollState { state: Arc::new(AtomicU8::new(state)) }
53+
/// Constructs new `SharedPollState` with the given state.
54+
fn new(value: u8) -> SharedPollState {
55+
SharedPollState { state: Arc::new(AtomicU8::new(value)) }
5656
}
5757

5858
/// Attempts to start polling, returning stored state in case of success.
5959
/// Returns `None` if state some waker is waking at the moment.
60-
fn start_polling(&self) -> Option<u8> {
60+
fn start_polling(&self) -> Option<(u8, PollStateBomb<'_, impl FnOnce(&SharedPollState)>)> {
6161
self.state
62-
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |state| {
63-
if state & WAKING_ANYTHING == NONE {
62+
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |value| {
63+
if value & WAKING_ANYTHING == NONE {
6464
Some(POLLING)
6565
} else {
6666
None
6767
}
6868
})
6969
.ok()
70+
.map(|value| {
71+
(
72+
value,
73+
PollStateBomb::new(self, move |state| {
74+
state.stop_polling(NEED_TO_POLL_ALL);
75+
}),
76+
)
77+
})
7078
}
7179

7280
/// Starts the waking process and performs bitwise or with the given value.
73-
fn start_waking(&self, to_poll: u8, waking: u8) -> u8 {
74-
self.state.fetch_or(to_poll | waking, Ordering::SeqCst)
81+
fn start_waking(
82+
&self,
83+
to_poll: u8,
84+
waking: u8,
85+
) -> (u8, PollStateBomb<'_, impl FnOnce(&SharedPollState)>) {
86+
let value = self.state.fetch_or(to_poll | waking, Ordering::SeqCst);
87+
88+
(
89+
value,
90+
PollStateBomb::new(self, move |state| {
91+
state.stop_waking(waking);
92+
}),
93+
)
7594
}
7695

7796
/// Toggles state to non-waking, allowing to start polling.
@@ -82,8 +101,8 @@ impl SharedPollState {
82101
/// Sets current state to `!POLLING`, allowing to use wakers.
83102
fn stop_polling(&self, to_poll: u8) -> u8 {
84103
self.state
85-
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |state| {
86-
Some((state | to_poll) & !POLLING)
104+
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |value| {
105+
Some((value | to_poll) & !POLLING)
87106
})
88107
.unwrap()
89108
}
@@ -98,7 +117,6 @@ struct InnerWaker {
98117
}
99118

100119
unsafe impl Send for InnerWaker {}
101-
102120
unsafe impl Sync for InnerWaker {}
103121

104122
impl InnerWaker {
@@ -115,20 +133,39 @@ impl InnerWaker {
115133
waker(self_arc.clone())
116134
}
117135

118-
// Flags state that walking is started for the walker with the given value.
119-
fn start_waking(&self) -> u8 {
136+
// Flags state that waking is started for the waker with the given value.
137+
fn start_waking(&self) -> (u8, PollStateBomb<'_, impl FnOnce(&SharedPollState)>) {
120138
self.poll_state.start_waking(self.need_to_poll, self.need_to_poll << 3)
121139
}
140+
}
141+
142+
///
143+
struct PollStateBomb<'a, F: FnOnce(&SharedPollState)> {
144+
state: &'a SharedPollState,
145+
drop: Option<F>,
146+
}
147+
148+
impl<'a, F: FnOnce(&SharedPollState)> PollStateBomb<'a, F> {
149+
fn new(state: &'a SharedPollState, drop: F) -> Self {
150+
Self { state, drop: Some(drop) }
151+
}
152+
153+
fn deactivate(mut self) {
154+
self.drop.take();
155+
}
156+
}
122157

123-
// Flags state that walking is finished for the walker with the given value.
124-
fn stop_waking(&self) -> u8 {
125-
self.poll_state.stop_waking(self.need_to_poll << 3)
158+
impl<F: FnOnce(&SharedPollState)> Drop for PollStateBomb<'_, F> {
159+
fn drop(&mut self) {
160+
if let Some(drop) = self.drop.take() {
161+
(drop)(&self.state);
162+
}
126163
}
127164
}
128165

129166
impl ArcWake for InnerWaker {
130167
fn wake_by_ref(self_arc: &Arc<Self>) {
131-
let poll_state_value = self_arc.start_waking();
168+
let (poll_state_value, state_bomb) = self_arc.start_waking();
132169

133170
// Only call waker if stream isn't being polled because of safety reasons.
134171
// Waker will be called at the end of polling if state was changed.
@@ -137,14 +174,11 @@ impl ArcWake for InnerWaker {
137174
unsafe { self_arc.inner_waker.get().as_ref().cloned().flatten() }
138175
{
139176
// First, stop waking to allow polling stream
140-
self_arc.stop_waking();
177+
drop(state_bomb);
141178
// Wake inner waker
142179
inner_waker.wake();
143-
return;
144180
}
145181
}
146-
147-
self_arc.stop_waking();
148182
}
149183
}
150184

@@ -168,30 +202,27 @@ impl<St> PollStreamFut<St> {
168202
}
169203
}
170204

171-
impl<St: Stream> Future for PollStreamFut<St> {
205+
impl<St: Stream + Unpin> Future for PollStreamFut<St> {
172206
type Output = Option<(St::Item, PollStreamFut<St>)>;
173207

174208
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
175209
let mut stream = self.project().stream;
176-
177210
let item = if let Some(stream) = stream.as_mut().as_pin_mut() {
178211
ready!(stream.poll_next(cx))
179212
} else {
180213
None
181214
};
215+
let out = item.map(|item| (item, PollStreamFut::new(stream.get_mut().take())));
182216

183-
Poll::Ready(
184-
item.map(|item| {
185-
(item, PollStreamFut::new(unsafe { stream.get_unchecked_mut().take() }))
186-
}),
187-
)
217+
Poll::Ready(out)
188218
}
189219
}
190220

191221
pin_project! {
192222
/// Stream for the [`flatten_unordered`](super::StreamExt::flatten_unordered)
193223
/// method.
194224
#[must_use = "streams do nothing unless polled"]
225+
#[project = FlattenUnorderedProj]
195226
pub struct FlattenUnordered<St, U> {
196227
#[pin]
197228
inner_streams: FuturesUnordered<PollStreamFut<U>>,
@@ -224,7 +255,7 @@ where
224255
impl<St> FlattenUnordered<St, St::Item>
225256
where
226257
St: Stream,
227-
St::Item: Stream,
258+
St::Item: Stream + Unpin,
228259
{
229260
pub(super) fn new(stream: St, limit: Option<usize>) -> FlattenUnordered<St, St::Item> {
230261
let poll_state = SharedPollState::new(NEED_TO_POLL_STREAM);
@@ -248,18 +279,24 @@ where
248279
}
249280
}
250281

282+
delegate_access_inner!(stream, St, ());
283+
}
284+
285+
impl<St> FlattenUnorderedProj<'_, St, St::Item>
286+
where
287+
St: Stream,
288+
{
251289
/// Checks if current `inner_streams` size is less than optional limit.
252290
fn is_exceeded_limit(&self) -> bool {
253-
self.limit.map(|limit| self.inner_streams.len() >= limit.get()).unwrap_or(false)
291+
self.limit.map_or(false, |limit| self.inner_streams.len() >= limit.get())
254292
}
255-
256-
delegate_access_inner!(stream, St, ());
257293
}
258294

259295
impl<St> FusedStream for FlattenUnordered<St, St::Item>
260296
where
261297
St: FusedStream,
262-
St::Item: FusedStream,
298+
St::Item: FusedStream + Unpin,
299+
<St::Item as Stream>::Item: core::fmt::Debug,
263300
{
264301
fn is_terminated(&self) -> bool {
265302
self.stream.is_terminated() && self.inner_streams.is_empty()
@@ -269,18 +306,18 @@ where
269306
impl<St> Stream for FlattenUnordered<St, St::Item>
270307
where
271308
St: Stream,
272-
St::Item: Stream,
309+
St::Item: Stream + Unpin,
310+
<St::Item as Stream>::Item: core::fmt::Debug,
273311
{
274312
type Item = <St::Item as Stream>::Item;
275313

276314
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
277315
let mut next_item = None;
278316
let mut need_to_poll_next = NONE;
279-
let limit_exceeded = self.is_exceeded_limit();
280317

281318
let mut this = self.as_mut().project();
282319

283-
let mut poll_state_value = match this.poll_state.start_polling() {
320+
let (mut poll_state_value, state_bomb) = match this.poll_state.start_polling() {
284321
Some(value) => value,
285322
_ => {
286323
// Waker was called, just wait for the next poll
@@ -289,44 +326,48 @@ where
289326
};
290327

291328
let mut polling_with_two_wakers =
292-
!limit_exceeded && poll_state_value & NEED_TO_POLL_ALL == NEED_TO_POLL_ALL;
329+
!this.is_exceeded_limit() && poll_state_value & NEED_TO_POLL_ALL == NEED_TO_POLL_ALL;
293330

294331
if poll_state_value & NEED_TO_POLL_STREAM != NONE {
295-
if !limit_exceeded && !*this.is_stream_done {
296-
match if polling_with_two_wakers {
297-
// Safety: now state is `POLLING`.
298-
let waker = unsafe { InnerWaker::replace_waker(this.stream_waker, cx) };
299-
let mut cx = Context::from_waker(&waker);
300-
this.stream.as_mut().poll_next(&mut cx)
332+
loop {
333+
if this.is_exceeded_limit() || *this.is_stream_done {
334+
polling_with_two_wakers = false;
335+
need_to_poll_next |= NEED_TO_POLL_STREAM;
336+
337+
break;
301338
} else {
302-
this.stream.as_mut().poll_next(cx)
303-
} {
304-
Poll::Ready(Some(inner_stream)) => {
305-
this.inner_streams.as_mut().push(PollStreamFut::new(inner_stream));
306-
need_to_poll_next |= NEED_TO_POLL_STREAM;
307-
// Polling inner streams in current iteration with the same context
308-
// is ok because we already received `Poll::Ready` from
309-
// stream
310-
poll_state_value |= NEED_TO_POLL_INNER_STREAMS;
311-
polling_with_two_wakers = false;
312-
*this.is_stream_done = false;
313-
}
314-
Poll::Ready(None) => {
315-
// Polling inner streams in current iteration with the same context
316-
// is ok because we already received `Poll::Ready` from
317-
// stream
318-
polling_with_two_wakers = false;
319-
*this.is_stream_done = true;
320-
}
321-
Poll::Pending => {
322-
if !polling_with_two_wakers {
339+
match if polling_with_two_wakers {
340+
// Safety: now state is `POLLING`.
341+
let waker = unsafe { InnerWaker::replace_waker(this.stream_waker, cx) };
342+
let mut cx = Context::from_waker(&waker);
343+
this.stream.as_mut().poll_next(&mut cx)
344+
} else {
345+
this.stream.as_mut().poll_next(cx)
346+
} {
347+
Poll::Ready(Some(inner_stream)) => {
348+
this.inner_streams.as_mut().push(PollStreamFut::new(inner_stream));
323349
need_to_poll_next |= NEED_TO_POLL_STREAM;
350+
// Polling inner streams in current iteration with the same context
351+
// is ok because we already received `Poll::Ready` from
352+
// stream
353+
poll_state_value |= NEED_TO_POLL_INNER_STREAMS;
354+
*this.is_stream_done = false;
355+
}
356+
Poll::Ready(None) => {
357+
// Polling inner streams in current iteration with the same context
358+
// is ok because we already received `Poll::Ready` from
359+
// stream
360+
*this.is_stream_done = true;
361+
}
362+
Poll::Pending => {
363+
if !polling_with_two_wakers {
364+
need_to_poll_next |= NEED_TO_POLL_STREAM;
365+
}
366+
*this.is_stream_done = false;
367+
break;
324368
}
325-
*this.is_stream_done = false;
326369
}
327370
}
328-
} else {
329-
need_to_poll_next |= NEED_TO_POLL_STREAM;
330371
}
331372
}
332373

@@ -345,7 +386,7 @@ where
345386
need_to_poll_next |= NEED_TO_POLL_INNER_STREAMS;
346387
}
347388
Poll::Ready(Some(None)) => {
348-
need_to_poll_next |= NEED_TO_POLL_ALL;
389+
need_to_poll_next |= NEED_TO_POLL_INNER_STREAMS;
349390
}
350391
Poll::Pending => {
351392
if !polling_with_two_wakers {
@@ -358,14 +399,15 @@ where
358399
}
359400
}
360401

402+
state_bomb.deactivate();
361403
poll_state_value = this.poll_state.stop_polling(need_to_poll_next);
362404
let is_done = *this.is_stream_done && this.inner_streams.is_empty();
363405

364406
if next_item.is_some() || is_done {
365407
Poll::Ready(next_item)
366408
} else {
367409
if poll_state_value & NEED_TO_POLL_ALL != NONE
368-
|| !self.is_exceeded_limit() && need_to_poll_next & NEED_TO_POLL_STREAM != NONE
410+
|| !this.is_exceeded_limit() && need_to_poll_next & NEED_TO_POLL_STREAM != NONE
369411
{
370412
cx.waker().wake_by_ref();
371413
}

futures-util/src/stream/stream/mod.rs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ delegate_all!(
210210
FlattenUnordered<St>(
211211
flatten_unordered::FlattenUnordered<St, St::Item>
212212
): Debug + Sink + Stream + FusedStream + AccessInner[St, (.)] + New[|x: St, limit: Option<usize>| flatten_unordered::FlattenUnordered::new(x, limit)]
213-
where St: Stream, St::Item: Stream
213+
where St: Stream, St::Item: Stream, St::Item: Unpin
214214
);
215215

216216
#[cfg(not(futures_no_atomic_cas))]
@@ -220,7 +220,7 @@ delegate_all!(
220220
FlatMapUnordered<St, U, F>(
221221
FlattenUnordered<Map<St, F>>
222222
): Debug + Sink + Stream + FusedStream + AccessInner[St, (. .)] + New[|x: St, limit: Option<usize>, f: F| FlattenUnordered::new(Map::new(x, f), limit)]
223-
where St: Stream, U: Stream, F: FnMut(St::Item) -> U
223+
where St: Stream, U: Stream, U: Unpin, F: FnMut(St::Item) -> U
224224
);
225225

226226
#[cfg(not(futures_no_atomic_cas))]
@@ -790,11 +790,11 @@ pub trait StreamExt: Stream {
790790
/// assert_eq!(output, vec![1, 2, 3, 4]);
791791
/// # });
792792
/// ```
793-
#[cfg_attr(feature = "cfg-target-has-atomic", cfg(target_has_atomic = "ptr"))]
793+
#[cfg(not(futures_no_atomic_cas))]
794794
#[cfg(feature = "alloc")]
795795
fn flatten_unordered(self, limit: impl Into<Option<usize>>) -> FlattenUnordered<Self>
796796
where
797-
Self::Item: Stream,
797+
Self::Item: Stream + Unpin,
798798
Self: Sized,
799799
{
800800
FlattenUnordered::new(self, limit.into())
@@ -871,15 +871,16 @@ pub trait StreamExt: Stream {
871871
///
872872
/// assert_eq!(vec![1usize, 2, 2, 3, 3, 3, 4, 4, 4, 4], values);
873873
/// # });
874-
#[cfg_attr(feature = "cfg-target-has-atomic", cfg(target_has_atomic = "ptr"))]
874+
/// ```
875+
#[cfg(not(futures_no_atomic_cas))]
875876
#[cfg(feature = "alloc")]
876877
fn flat_map_unordered<U, F>(
877878
self,
878879
limit: impl Into<Option<usize>>,
879880
f: F,
880881
) -> FlatMapUnordered<Self, U, F>
881882
where
882-
U: Stream,
883+
U: Stream + Unpin,
883884
F: FnMut(Self::Item) -> U,
884885
Self: Sized,
885886
{

0 commit comments

Comments
 (0)