diff --git a/src/main/java/rx/Observable.java b/src/main/java/rx/Observable.java index 92e9ea56fa..041b497cdd 100644 --- a/src/main/java/rx/Observable.java +++ b/src/main/java/rx/Observable.java @@ -4061,7 +4061,7 @@ public final Observable concatMapDelayError(Func1ReactiveX operators documentation: FlatMap */ public final Observable concatMapIterable(Func1> collectionSelector) { - return concat(map(OperatorMapPair.convertSelector(collectionSelector))); + return OnSubscribeFlattenIterable.createFrom(this, collectionSelector, RxRingBuffer.SIZE); } /** @@ -5672,7 +5672,7 @@ public final Observable flatMap(final Func1ReactiveX operators documentation: FlatMap */ public final Observable flatMapIterable(Func1> collectionSelector) { - return merge(map(OperatorMapPair.convertSelector(collectionSelector))); + return flatMapIterable(collectionSelector, RxRingBuffer.SIZE); } /** @@ -5702,7 +5702,7 @@ public final Observable flatMapIterable(Func1 Observable flatMapIterable(Func1> collectionSelector, int maxConcurrent) { - return merge(map(OperatorMapPair.convertSelector(collectionSelector)), maxConcurrent); + return OnSubscribeFlattenIterable.createFrom(this, collectionSelector, maxConcurrent); } /** diff --git a/src/main/java/rx/internal/operators/OnSubscribeFlattenIterable.java b/src/main/java/rx/internal/operators/OnSubscribeFlattenIterable.java new file mode 100644 index 0000000000..fb3b9f97bd --- /dev/null +++ b/src/main/java/rx/internal/operators/OnSubscribeFlattenIterable.java @@ -0,0 +1,357 @@ +/** + * Copyright 2016 Netflix, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package rx.internal.operators; + +import java.util.*; +import java.util.concurrent.atomic.*; + +import rx.*; +import rx.Observable; +import rx.Observable.OnSubscribe; +import rx.exceptions.*; +import rx.functions.Func1; +import rx.internal.util.*; +import rx.internal.util.atomic.*; +import rx.internal.util.unsafe.*; + +/** + * Flattens a sequence if Iterable sources, generated via a function, into a single sequence. + * + * @param the input value type + * @param the output value type + */ +public final class OnSubscribeFlattenIterable implements OnSubscribe { + + final Observable source; + + final Func1> mapper; + + final int prefetch; + + /** Protected: use createFrom to handle source-dependent optimizations. */ + protected OnSubscribeFlattenIterable(Observable source, + Func1> mapper, int prefetch) { + this.source = source; + this.mapper = mapper; + this.prefetch = prefetch; + } + + @Override + public void call(Subscriber t) { + final FlattenIterableSubscriber parent = new FlattenIterableSubscriber(t, mapper, prefetch); + + t.add(parent); + t.setProducer(new Producer() { + @Override + public void request(long n) { + parent.requestMore(n); + } + }); + + source.unsafeSubscribe(parent); + } + + public static Observable createFrom(Observable source, + Func1> mapper, int prefetch) { + if (source instanceof ScalarSynchronousObservable) { + T scalar = ((ScalarSynchronousObservable) source).get(); + return Observable.create(new OnSubscribeScalarFlattenIterable(scalar, mapper)); + } + return Observable.create(new OnSubscribeFlattenIterable(source, mapper, prefetch)); + } + + static final class FlattenIterableSubscriber extends Subscriber { + final Subscriber actual; + + final Func1> mapper; + + final long limit; + + final Queue queue; + + final AtomicReference error; + + final AtomicLong requested; + + final AtomicInteger wip; + + final NotificationLite nl; + + volatile boolean done; + + long produced; + + Iterator active; + + public FlattenIterableSubscriber(Subscriber actual, + Func1> mapper, int prefetch) { + this.actual = actual; + this.mapper = mapper; + this.error = new AtomicReference(); + this.wip = new AtomicInteger(); + this.requested = new AtomicLong(); + this.nl = NotificationLite.instance(); + if (prefetch == Integer.MAX_VALUE) { + this.limit = Long.MAX_VALUE; + this.queue = new SpscLinkedArrayQueue(RxRingBuffer.SIZE); + } else { + // limit = prefetch * 75% rounded up + this.limit = prefetch - (prefetch >> 2); + if (UnsafeAccess.isUnsafeAvailable()) { + this.queue = new SpscArrayQueue(prefetch); + } else { + this.queue = new SpscAtomicArrayQueue(prefetch); + } + } + request(prefetch); + } + + @Override + public void onNext(T t) { + if (!queue.offer(nl.next(t))) { + unsubscribe(); + onError(new MissingBackpressureException()); + return; + } + drain(); + } + + @Override + public void onError(Throwable e) { + if (ExceptionsUtils.addThrowable(error, e)) { + done = true; + drain(); + } else { + RxJavaPluginUtils.handleException(e); + } + } + + @Override + public void onCompleted() { + done = true; + drain(); + } + + void requestMore(long n) { + if (n > 0) { + BackpressureUtils.getAndAddRequest(requested, n); + drain(); + } else if (n < 0) { + throw new IllegalStateException("n >= 0 required but it was " + n); + } + } + + void drain() { + if (wip.getAndIncrement() != 0) { + return; + } + + final Subscriber actual = this.actual; + final Queue queue = this.queue; + + int missed = 1; + + for (;;) { + + Iterator it = active; + + if (it == null) { + boolean d = done; + + Object v = queue.poll(); + + boolean empty = v == null; + + if (checkTerminated(d, empty, actual, queue)) { + return; + } + + if (!empty) { + + long p = produced + 1; + if (p == limit) { + produced = 0L; + request(p); + } else { + produced = p; + } + + boolean b; + + try { + Iterable iter = mapper.call(nl.getValue(v)); + + it = iter.iterator(); + + b = it.hasNext(); + } catch (Throwable ex) { + Exceptions.throwIfFatal(ex); + + it = null; + onError(ex); + + continue; + } + + if (!b) { + continue; + } + + active = it; + } + } + + if (it != null) { + long r = requested.get(); + long e = 0L; + + while (e != r) { + if (checkTerminated(done, false, actual, queue)) { + return; + } + + R v; + + try { + v = it.next(); + } catch (Throwable ex) { + Exceptions.throwIfFatal(ex); + it = null; + active = null; + onError(ex); + break; + } + + actual.onNext(v); + + if (checkTerminated(done, false, actual, queue)) { + return; + } + + e++; + + boolean b; + + try { + b = it.hasNext(); + } catch (Throwable ex) { + Exceptions.throwIfFatal(ex); + it = null; + active = null; + onError(ex); + break; + } + + if (!b) { + it = null; + active = null; + break; + } + } + + if (e == r) { + if (checkTerminated(done, queue.isEmpty() && it == null, actual, queue)) { + return; + } + } + + if (e != 0L) { + BackpressureUtils.produced(requested, e); + } + + if (it == null) { + continue; + } + } + + missed = wip.addAndGet(-missed); + if (missed == 0) { + break; + } + } + } + + boolean checkTerminated(boolean d, boolean empty, Subscriber a, Queue q) { + if (a.isUnsubscribed()) { + q.clear(); + active = null; + return true; + } + + if (d) { + Throwable ex = error.get(); + if (ex != null) { + ex = ExceptionsUtils.terminate(error); + unsubscribe(); + q.clear(); + active = null; + + a.onError(ex); + return true; + } else + if (empty) { + + a.onCompleted(); + return true; + } + } + + return false; + } + } + + /** + * A custom flattener that works from a scalar value and computes the iterable + * during subscription time. + * + * @param the scalar's value type + * @param the result value type + */ + static final class OnSubscribeScalarFlattenIterable implements OnSubscribe { + final T value; + + final Func1> mapper; + + public OnSubscribeScalarFlattenIterable(T value, Func1> mapper) { + this.value = value; + this.mapper = mapper; + } + + @Override + public void call(Subscriber t) { + Iterator itor; + boolean b; + try { + Iterable it = mapper.call(value); + + itor = it.iterator(); + + b = itor.hasNext(); + } catch (Throwable ex) { + Exceptions.throwOrReport(ex, t, value); + return; + } + + if (!b) { + t.onCompleted(); + return; + } + + t.setProducer(new OnSubscribeFromIterable.IterableProducer(t, itor)); + } + } +} diff --git a/src/main/java/rx/internal/operators/OnSubscribeFromIterable.java b/src/main/java/rx/internal/operators/OnSubscribeFromIterable.java index b94e35c35c..df91b7e67e 100644 --- a/src/main/java/rx/internal/operators/OnSubscribeFromIterable.java +++ b/src/main/java/rx/internal/operators/OnSubscribeFromIterable.java @@ -49,7 +49,7 @@ public void call(final Subscriber o) { o.setProducer(new IterableProducer(o, it)); } - private static final class IterableProducer extends AtomicLong implements Producer { + static final class IterableProducer extends AtomicLong implements Producer { /** */ private static final long serialVersionUID = -8730475647105475802L; private final Subscriber o; diff --git a/src/test/java/rx/internal/operators/OnSubscribeFlattenIterableTest.java b/src/test/java/rx/internal/operators/OnSubscribeFlattenIterableTest.java new file mode 100644 index 0000000000..7158456a35 --- /dev/null +++ b/src/test/java/rx/internal/operators/OnSubscribeFlattenIterableTest.java @@ -0,0 +1,477 @@ +/** + * Copyright 2016 Netflix, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package rx.internal.operators; + +import java.util.*; +import java.util.concurrent.atomic.AtomicInteger; + +import org.junit.*; + +import rx.Observable; +import rx.exceptions.TestException; +import rx.functions.Func1; +import rx.observers.TestSubscriber; +import rx.subjects.PublishSubject; + +public class OnSubscribeFlattenIterableTest { + + final Func1> mapper = new Func1>() { + @Override + public Iterable call(Integer v) { + return Arrays.asList(v, v + 1); + } + }; + + @Test + public void normal() { + TestSubscriber ts = new TestSubscriber(); + + Observable.range(1, 5).concatMapIterable(mapper) + .subscribe(ts); + + ts.assertValues(1, 2, 2, 3, 3, 4, 4, 5, 5, 6); + ts.assertNoErrors(); + ts.assertCompleted(); + } + + @Test + public void normalBackpressured() { + TestSubscriber ts = new TestSubscriber(0); + + Observable.range(1, 5).concatMapIterable(mapper) + .subscribe(ts); + + ts.assertNoValues(); + ts.assertNoErrors(); + ts.assertNotCompleted(); + + ts.requestMore(1); + + ts.assertValue(1); + ts.assertNoErrors(); + ts.assertNotCompleted(); + + ts.requestMore(2); + + ts.assertValues(1, 2, 2); + ts.assertNoErrors(); + ts.assertNotCompleted(); + + ts.requestMore(7); + + ts.assertValues(1, 2, 2, 3, 3, 4, 4, 5, 5, 6); + ts.assertNoErrors(); + ts.assertCompleted(); + } + + @Test + public void longRunning() { + TestSubscriber ts = new TestSubscriber(); + + int n = 1000 * 1000; + + Observable.range(1, n).concatMapIterable(mapper) + .subscribe(ts); + + ts.assertValueCount(n * 2); + ts.assertNoErrors(); + ts.assertCompleted(); + } + + @Test + public void asIntermediate() { + TestSubscriber ts = new TestSubscriber(); + + int n = 1000 * 1000; + + Observable.range(1, n).concatMapIterable(mapper).concatMap(new Func1>() { + @Override + public Observable call(Integer v) { + return Observable.just(v); + } + }) + .subscribe(ts); + + ts.assertValueCount(n * 2); + ts.assertNoErrors(); + ts.assertCompleted(); + } + + @Test + public void just() { + TestSubscriber ts = new TestSubscriber(); + + Observable.just(1).concatMapIterable(mapper) + .subscribe(ts); + + ts.assertValues(1, 2); + ts.assertNoErrors(); + ts.assertCompleted(); + } + + @Test + public void justHidden() { + TestSubscriber ts = new TestSubscriber(); + + Observable.just(1).asObservable().concatMapIterable(mapper) + .subscribe(ts); + + ts.assertValues(1, 2); + ts.assertNoErrors(); + ts.assertCompleted(); + } + + @Test + public void empty() { + TestSubscriber ts = new TestSubscriber(); + + Observable.empty().concatMapIterable(mapper) + .subscribe(ts); + + ts.assertNoValues(); + ts.assertNoErrors(); + ts.assertCompleted(); + } + + @Test + public void error() { + TestSubscriber ts = new TestSubscriber(); + + Observable.just(1).concatWith(Observable.error(new TestException())) + .concatMapIterable(mapper) + .subscribe(ts); + + ts.assertValues(1, 2); + ts.assertError(TestException.class); + ts.assertNotCompleted(); + } + + @Test + public void iteratorHasNextThrowsImmediately() { + TestSubscriber ts = new TestSubscriber(); + + final Iterable it = new Iterable() { + @Override + public Iterator iterator() { + return new Iterator() { + @Override + public boolean hasNext() { + throw new TestException(); + } + + @Override + public Integer next() { + return 1; + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + }; + } + }; + + Observable.range(1, 2) + .concatMapIterable(new Func1>() { + @Override + public Iterable call(Integer v) { + return it; + } + }) + .subscribe(ts); + + ts.assertNoValues(); + ts.assertError(TestException.class); + ts.assertNotCompleted(); + } + + @Test + public void iteratorHasNextThrowsImmediatelyJust() { + TestSubscriber ts = new TestSubscriber(); + + final Iterable it = new Iterable() { + @Override + public Iterator iterator() { + return new Iterator() { + @Override + public boolean hasNext() { + throw new TestException(); + } + + @Override + public Integer next() { + return 1; + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + }; + } + }; + + Observable.just(1) + .concatMapIterable(new Func1>() { + @Override + public Iterable call(Integer v) { + return it; + } + }) + .subscribe(ts); + + ts.assertNoValues(); + ts.assertError(TestException.class); + ts.assertNotCompleted(); + } + + @Test + public void iteratorHasNextThrowsSecondCall() { + TestSubscriber ts = new TestSubscriber(); + + final Iterable it = new Iterable() { + @Override + public Iterator iterator() { + return new Iterator() { + int count; + @Override + public boolean hasNext() { + if (++count >= 2) { + throw new TestException(); + } + return true; + } + + @Override + public Integer next() { + return 1; + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + }; + } + }; + + Observable.range(1, 2) + .concatMapIterable(new Func1>() { + @Override + public Iterable call(Integer v) { + return it; + } + }) + .subscribe(ts); + + ts.assertValue(1); + ts.assertError(TestException.class); + ts.assertNotCompleted(); + } + + @Test + public void iteratorNextThrows() { + TestSubscriber ts = new TestSubscriber(); + + final Iterable it = new Iterable() { + @Override + public Iterator iterator() { + return new Iterator() { + @Override + public boolean hasNext() { + return true; + } + + @Override + public Integer next() { + throw new TestException(); + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + }; + } + }; + + Observable.range(1, 2) + .concatMapIterable(new Func1>() { + @Override + public Iterable call(Integer v) { + return it; + } + }) + .subscribe(ts); + + ts.assertNoValues(); + ts.assertError(TestException.class); + ts.assertNotCompleted(); + } + + @Test + public void iteratorNextThrowsAndUnsubscribes() { + TestSubscriber ts = new TestSubscriber(); + + final Iterable it = new Iterable() { + @Override + public Iterator iterator() { + return new Iterator() { + @Override + public boolean hasNext() { + return true; + } + + @Override + public Integer next() { + throw new TestException(); + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + }; + } + }; + + PublishSubject ps = PublishSubject.create(); + + ps + .concatMapIterable(new Func1>() { + @Override + public Iterable call(Integer v) { + return it; + } + }) + .unsafeSubscribe(ts); + + ps.onNext(1); + + ts.assertNoValues(); + ts.assertError(TestException.class); + ts.assertNotCompleted(); + + Assert.assertFalse("PublishSubject has Observers?!", ps.hasObservers()); + } + + @Test + public void mixture() { + TestSubscriber ts = new TestSubscriber(); + + Observable.range(0, 1000) + .concatMapIterable(new Func1>() { + @Override + public Iterable call(Integer v) { + return (v % 2) == 0 ? Collections.singleton(1) : Collections.emptySet(); + } + }) + .subscribe(ts); + + ts.assertValueCount(500); + ts.assertNoErrors(); + ts.assertCompleted(); + } + + @Test + public void emptyInnerThenSingleBackpressured() { + TestSubscriber ts = new TestSubscriber(1); + + Observable.range(1, 2) + .concatMapIterable(new Func1>() { + @Override + public Iterable call(Integer v) { + return v == 2 ? Collections.singleton(1) : Collections.emptySet(); + } + }) + .subscribe(ts); + + ts.assertValue(1); + ts.assertNoErrors(); + ts.assertCompleted(); + } + + @Test + public void manyEmptyInnerThenSingleBackpressured() { + TestSubscriber ts = new TestSubscriber(1); + + Observable.range(1, 1000) + .concatMapIterable(new Func1>() { + @Override + public Iterable call(Integer v) { + return v == 1000 ? Collections.singleton(1) : Collections.emptySet(); + } + }) + .subscribe(ts); + + ts.assertValue(1); + ts.assertNoErrors(); + ts.assertCompleted(); + } + + @Test + public void hasNextIsNotCalledAfterChildUnsubscribedOnNext() { + TestSubscriber ts = new TestSubscriber(); + + final AtomicInteger counter = new AtomicInteger(); + + final Iterable it = new Iterable() { + @Override + public Iterator iterator() { + return new Iterator() { + @Override + public boolean hasNext() { + counter.getAndIncrement(); + return true; + } + + @Override + public Integer next() { + return 1; + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + }; + } + }; + + PublishSubject ps = PublishSubject.create(); + + ps + .concatMapIterable(new Func1>() { + @Override + public Iterable call(Integer v) { + return it; + } + }) + .take(1) + .unsafeSubscribe(ts); + + ps.onNext(1); + + ts.assertValue(1); + ts.assertNoErrors(); + ts.assertCompleted(); + + Assert.assertFalse("PublishSubject has Observers?!", ps.hasObservers()); + Assert.assertEquals(1, counter.get()); + } +}