From 3d7c0c92b90d6a4de97b459fc9ab4c3a307ed0b5 Mon Sep 17 00:00:00 2001 From: akarnokd Date: Wed, 6 Apr 2016 15:44:44 +0200 Subject: [PATCH] 1.x: fix switchMap/switchOnNext producer retention and backpressure --- .../rx/internal/operators/OperatorSwitch.java | 382 ++++++++++-------- .../util/atomic/SpscLinkedArrayQueue.java | 31 +- .../operators/OperatorSwitchTest.java | 91 ++++- 3 files changed, 325 insertions(+), 179 deletions(-) diff --git a/src/main/java/rx/internal/operators/OperatorSwitch.java b/src/main/java/rx/internal/operators/OperatorSwitch.java index 7d706f2a95..bbbcd9879d 100644 --- a/src/main/java/rx/internal/operators/OperatorSwitch.java +++ b/src/main/java/rx/internal/operators/OperatorSwitch.java @@ -16,14 +16,17 @@ package rx.internal.operators; import java.util.*; +import java.util.concurrent.atomic.AtomicLong; import rx.*; import rx.Observable; import rx.Observable.Operator; import rx.exceptions.CompositeException; -import rx.internal.producers.ProducerArbiter; +import rx.functions.Action0; +import rx.internal.util.RxRingBuffer; +import rx.internal.util.atomic.SpscLinkedArrayQueue; import rx.plugins.RxJavaPlugins; -import rx.subscriptions.SerialSubscription; +import rx.subscriptions.*; /** * Transforms an Observable that emits Observables into a single Observable that @@ -45,6 +48,9 @@ private static final class HolderDelayError { static final OperatorSwitch INSTANCE = new OperatorSwitch(true); } /** + * Returns a singleton instance of the operator based on the delayError parameter. + * @param the value type + * @param delayError should the errors of the inner sources delayed until the main sequence completes? * @return a singleton instance of this stateless operator. */ @SuppressWarnings({ "unchecked" }) @@ -72,51 +78,80 @@ public Subscriber> call(final Subscriber extends Subscriber> { final Subscriber child; final SerialSubscription ssub; - final ProducerArbiter arbiter; - final boolean delayError; + final AtomicLong index; + final SpscLinkedArrayQueue queue; + final NotificationLite nl; + + boolean emitting; - long index; + boolean missed; - Throwable error; + long requested; - boolean mainDone; + Producer producer; - List queue; + volatile boolean mainDone; + + Throwable error; boolean innerActive; - boolean emitting; - - boolean missed; + static final Throwable TERMINAL_ERROR = new Throwable("Terminal error"); SwitchSubscriber(Subscriber child, boolean delayError) { this.child = child; - this.arbiter = new ProducerArbiter(); this.ssub = new SerialSubscription(); this.delayError = delayError; + this.index = new AtomicLong(); + this.queue = new SpscLinkedArrayQueue(RxRingBuffer.SIZE); + this.nl = NotificationLite.instance(); } void init() { child.add(ssub); + child.add(Subscriptions.create(new Action0() { + @Override + public void call() { + clearProducer(); + } + })); child.setProducer(new Producer(){ @Override public void request(long n) { - if (n > 0) { - arbiter.request(n); + if (n > 0L) { + childRequested(n); + } else + if (n < 0L) { + throw new IllegalArgumentException("n >= 0 expected but it was " + n); } } }); } - + + void clearProducer() { + synchronized (this) { + producer = null; + } + } + @Override public void onNext(Observable t) { + long id = index.incrementAndGet(); + + Subscription s = ssub.get(); + if (s != null) { + s.unsubscribe(); + } + InnerSubscriber inner; + synchronized (this) { - long id = ++index; inner = new InnerSubscriber(id, this); + innerActive = true; + producer = null; } ssub.set(inner); @@ -125,201 +160,228 @@ public void onNext(Observable t) { @Override public void onError(Throwable e) { + boolean success; + synchronized (this) { - e = updateError(e); + success = updateError(e); + } + if (success) { mainDone = true; - - if (emitting) { - missed = true; - return; - } - if (delayError && innerActive) { - return; - } - emitting = true; + drain(); + } else { + pluginError(e); } - - child.onError(e); } + boolean updateError(Throwable next) { + Throwable e = error; + if (e == TERMINAL_ERROR) { + return false; + } else + if (e == null) { + error = next; + } else + if (e instanceof CompositeException) { + List list = new ArrayList(((CompositeException)e).getExceptions()); + list.add(next); + error = new CompositeException(list); + } else { + error = new CompositeException(e, next); + } + return true; + } + @Override public void onCompleted() { - Throwable ex; + mainDone = true; + drain(); + } + + void emit(T value, InnerSubscriber inner) { synchronized (this) { - mainDone = true; - if (emitting) { - missed = true; + if (index.get() != inner.id) { return; } - if (innerActive) { - return; + + queue.offer(inner, nl.next(value)); + } + drain(); + } + + void error(Throwable e, long id) { + boolean success; + synchronized (this) { + if (index.get() == id) { + success = updateError(e); + innerActive = false; + producer = null; + } else { + success = true; } - emitting = true; - ex = error; } - if (ex == null) { - child.onCompleted(); + if (success) { + drain(); } else { - child.onError(ex); + pluginError(e); } } - Throwable updateError(Throwable e) { - Throwable ex = error; - if (ex == null) { - error = e; - } else - if (ex instanceof CompositeException) { - CompositeException ce = (CompositeException) ex; - List list = new ArrayList(ce.getExceptions()); - list.add(e); - e = new CompositeException(list); - error = e; - } else { - e = new CompositeException(Arrays.asList(ex, e)); - error = e; + void complete(long id) { + synchronized (this) { + if (index.get() != id) { + return; + } + innerActive = false; + producer = null; } - return e; + drain(); + } + + void pluginError(Throwable e) { + RxJavaPlugins.getInstance().getErrorHandler().handleError(e); } - void emit(T value, long id) { + void innerProducer(Producer p, long id) { + long n; synchronized (this) { - if (id != index) { + if (index.get() != id) { return; } - + n = requested; + producer = p; + } + + p.request(n); + } + + void childRequested(long n) { + Producer p; + synchronized (this) { + p = producer; + requested = BackpressureUtils.addCap(requested, n); + } + if (p != null) { + p.request(n); + } + drain(); + } + + void drain() { + boolean localMainDone = mainDone; + boolean localInnerActive; + long localRequested; + Throwable localError; + synchronized (this) { if (emitting) { - List q = queue; - if (q == null) { - q = new ArrayList(4); - queue = q; - } - q.add(value); missed = true; return; } - emitting = true; + localInnerActive = innerActive; + localRequested = requested; + localError = error; + if (localError != null && localError != TERMINAL_ERROR && !delayError) { + error = TERMINAL_ERROR; + } } - - child.onNext(value); - - arbiter.produced(1); - + + final SpscLinkedArrayQueue localQueue = queue; + final AtomicLong localIndex = index; + final Subscriber localChild = child; + for (;;) { - if (child.isUnsubscribed()) { - return; - } - - Throwable localError; - boolean localMainDone; - boolean localActive; - List localQueue; - synchronized (this) { - if (!missed) { - emitting = false; + + long localEmission = 0L; + + while (localEmission != localRequested) { + if (localChild.isUnsubscribed()) { return; } + + boolean empty = localQueue.isEmpty(); - localError = error; - localMainDone = mainDone; - localQueue = queue; - localActive = innerActive; - } - - if (!delayError && localError != null) { - child.onError(localError); - return; - } - - if (localQueue == null && !localActive && localMainDone) { - if (localError != null) { - child.onError(localError); - } else { - child.onCompleted(); + if (checkTerminated(localMainDone, localInnerActive, localError, + localQueue, localChild, empty)) { + return; + } + + if (empty) { + break; + } + + @SuppressWarnings("unchecked") + InnerSubscriber inner = (InnerSubscriber)localQueue.poll(); + T value = nl.getValue(localQueue.poll()); + + if (localIndex.get() == inner.id) { + localChild.onNext(value); + localEmission++; } - return; } - if (localQueue != null) { - int n = 0; - for (T v : localQueue) { - if (child.isUnsubscribed()) { - return; - } - - child.onNext(v); - n++; + if (localEmission == localRequested) { + if (localChild.isUnsubscribed()) { + return; } - arbiter.produced(n); + if (checkTerminated(mainDone, localInnerActive, localError, localQueue, + localChild, localQueue.isEmpty())) { + return; + } } - } - } - - void error(Throwable e, long id) { - boolean drop; - synchronized (this) { - if (id == index) { - innerActive = false; - - e = updateError(e); + + + synchronized (this) { - if (emitting) { - missed = true; - return; + localRequested = requested; + if (localRequested != Long.MAX_VALUE) { + localRequested -= localEmission; + requested = localRequested; } - if (delayError && !mainDone) { + + if (!missed) { + emitting = false; return; } - emitting = true; + missed = false; - drop = false; - } else { - drop = true; + localMainDone = mainDone; + localInnerActive = innerActive; + localError = error; + if (localError != null && localError != TERMINAL_ERROR && !delayError) { + error = TERMINAL_ERROR; + } } } - - if (drop) { - pluginError(e); - } else { - child.onError(e); - } } - - void complete(long id) { - Throwable ex; - synchronized (this) { - if (id != index) { - return; - } - innerActive = false; - - if (emitting) { - missed = true; - return; - } - - ex = error; - if (!mainDone) { - return; + protected boolean checkTerminated(boolean localMainDone, boolean localInnerActive, Throwable localError, + final SpscLinkedArrayQueue localQueue, final Subscriber localChild, boolean empty) { + if (delayError) { + if (localMainDone && !localInnerActive && empty) { + if (localError != null) { + localChild.onError(localError); + } else { + localChild.onCompleted(); + } + return true; } - } - - if (ex != null) { - child.onError(ex); } else { - child.onCompleted(); + if (localError != null) { + localQueue.clear(); + localChild.onError(localError); + return true; + } else + if (localMainDone && !localInnerActive && empty) { + localChild.onCompleted(); + return true; + } } - } - - void pluginError(Throwable e) { - RxJavaPlugins.getInstance().getErrorHandler().handleError(e); + return false; } } - private static final class InnerSubscriber extends Subscriber { + static final class InnerSubscriber extends Subscriber { private final long id; @@ -332,12 +394,12 @@ private static final class InnerSubscriber extends Subscriber { @Override public void setProducer(Producer p) { - parent.arbiter.setProducer(p); + parent.innerProducer(p, id); } @Override public void onNext(T t) { - parent.emit(t, id); + parent.emit(t, this); } @Override diff --git a/src/main/java/rx/internal/util/atomic/SpscLinkedArrayQueue.java b/src/main/java/rx/internal/util/atomic/SpscLinkedArrayQueue.java index 33472a40da..23d8ad7c9e 100644 --- a/src/main/java/rx/internal/util/atomic/SpscLinkedArrayQueue.java +++ b/src/main/java/rx/internal/util/atomic/SpscLinkedArrayQueue.java @@ -30,23 +30,19 @@ /** * A single-producer single-consumer array-backed queue which can allocate new arrays in case the consumer is slower * than the producer. + * + * @param the element type, not null */ public final class SpscLinkedArrayQueue implements Queue { static final int MAX_LOOK_AHEAD_STEP = Integer.getInteger("jctools.spsc.max.lookahead.step", 4096); - protected volatile long producerIndex; - @SuppressWarnings("rawtypes") - static final AtomicLongFieldUpdater PRODUCER_INDEX = - AtomicLongFieldUpdater.newUpdater(SpscLinkedArrayQueue.class, "producerIndex"); + protected final AtomicLong producerIndex; protected int producerLookAheadStep; protected long producerLookAhead; protected int producerMask; protected AtomicReferenceArray producerBuffer; protected int consumerMask; protected AtomicReferenceArray consumerBuffer; - protected volatile long consumerIndex; - @SuppressWarnings("rawtypes") - static final AtomicLongFieldUpdater CONSUMER_INDEX = - AtomicLongFieldUpdater.newUpdater(SpscLinkedArrayQueue.class, "consumerIndex"); + protected final AtomicLong consumerIndex; private static final Object HAS_NEXT = new Object(); public SpscLinkedArrayQueue(final int bufferSize) { @@ -59,7 +55,8 @@ public SpscLinkedArrayQueue(final int bufferSize) { consumerBuffer = buffer; consumerMask = mask; producerLookAhead = mask - 1; // we know it's all empty to start with - soProducerIndex(0L); + producerIndex = new AtomicLong(); + consumerIndex = new AtomicLong(); } /** @@ -219,27 +216,27 @@ private void adjustLookAheadStep(int capacity) { } private long lvProducerIndex() { - return producerIndex; + return producerIndex.get(); } private long lvConsumerIndex() { - return consumerIndex; + return consumerIndex.get(); } private long lpProducerIndex() { - return producerIndex; + return producerIndex.get(); } private long lpConsumerIndex() { - return consumerIndex; + return consumerIndex.get(); } private void soProducerIndex(long v) { - PRODUCER_INDEX.lazySet(this, v); + producerIndex.lazySet(v); } private void soConsumerIndex(long v) { - CONSUMER_INDEX.lazySet(this, v); + consumerIndex.lazySet(v); } private static int calcWrappedOffset(long index, int mask) { @@ -321,11 +318,11 @@ public T element() { *

Don't use the regular offer() with this at all! * @param first * @param second - * @return + * @return always true */ public boolean offer(T first, T second) { final AtomicReferenceArray buffer = producerBuffer; - final long p = producerIndex; + final long p = lvProducerIndex(); final int m = producerMask; int pi = calcWrappedOffset(p + 2, m); diff --git a/src/test/java/rx/internal/operators/OperatorSwitchTest.java b/src/test/java/rx/internal/operators/OperatorSwitchTest.java index 55170ab9ff..b673b56949 100644 --- a/src/test/java/rx/internal/operators/OperatorSwitchTest.java +++ b/src/test/java/rx/internal/operators/OperatorSwitchTest.java @@ -19,6 +19,7 @@ import static org.mockito.Matchers.*; import static org.mockito.Mockito.*; +import java.lang.ref.WeakReference; import java.util.*; import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicBoolean; @@ -32,7 +33,7 @@ import rx.exceptions.*; import rx.functions.*; import rx.observers.TestSubscriber; -import rx.schedulers.TestScheduler; +import rx.schedulers.*; import rx.subjects.PublishSubject; public class OperatorSwitchTest { @@ -654,7 +655,7 @@ public Observable call(Long t) { ts.requestMore(Long.MAX_VALUE - 1); ts.awaitTerminalEvent(); assertTrue(ts.getOnNextEvents().size() > 0); - assertEquals(5, requests.size()); + assertEquals(4, requests.size()); // depends on the request pattern assertEquals(Long.MAX_VALUE, (long) requests.get(requests.size()-1)); } @@ -790,4 +791,90 @@ public Observable call(Integer v) { ts.assertCompleted(); } + Object ref; + + @Test + public void producerIsNotRetained() throws Exception { + ref = new Object(); + + WeakReference wr = new WeakReference(ref); + + PublishSubject> ps = PublishSubject.create(); + + Subscriber observer = new Subscriber() { + @Override + public void onCompleted() { + } + + @Override + public void onError(Throwable e) { + } + + @Override + public void onNext(Object t) { + } + }; + + Observable.switchOnNext(ps).subscribe(observer); + + ps.onNext(Observable.just(ref)); + + ref = null; + + System.gc(); + + Thread.sleep(500); + + Assert.assertNotNull(observer); // retain every other referenec in the pipeline + Assert.assertNotNull(ps); + Assert.assertNull("Object retained!", wr.get()); + } + + @Test + public void switchAsyncHeavily() { + for (int i = 1; i < 1024; i *= 2) { + System.out.println("switchAsyncHeavily >> " + i); + + final Queue q = new ConcurrentLinkedQueue(); + + final int j = i; + TestSubscriber ts = new TestSubscriber(i) { + int count; + @Override + public void onNext(Integer t) { + super.onNext(t); + if (++count == j) { + count = 0; + requestMore(j); + } + } + }; + + Observable.range(1, 25000) + .observeOn(Schedulers.computation(), i) + .switchMap(new Func1>() { + @Override + public Observable call(Integer v) { + return Observable.range(1, 1000).observeOn(Schedulers.computation(), j) + .doOnError(new Action1() { + @Override + public void call(Throwable e) { + q.add(e); + } + }); + } + }) + .timeout(10, TimeUnit.SECONDS) + .subscribe(ts); + + ts.awaitTerminalEvent(30, TimeUnit.SECONDS); + if (!q.isEmpty()) { + throw new AssertionError("Dropped exceptions", new CompositeException(q)); + } + ts.assertNoErrors(); + if (ts.getOnCompletedEvents().size() == 0) { + fail("switchAsyncHeavily timed out @ " + j + " (" + ts.getOnNextEvents().size() + " onNexts received)"); + } + } + } }