diff --git a/src/main/java/rx/internal/operators/OperatorSwitch.java b/src/main/java/rx/internal/operators/OperatorSwitch.java index eae4d3aa67..afd35e477d 100644 --- a/src/main/java/rx/internal/operators/OperatorSwitch.java +++ b/src/main/java/rx/internal/operators/OperatorSwitch.java @@ -94,15 +94,26 @@ public void request(long n) { synchronized (guard) { localSubscriber = currentSubscriber; if (currentSubscriber == null) { - initialRequested = n; + long r = initialRequested + n; + if (r < 0) { + infinite = true; + } else { + initialRequested = r; + } } else { - // If n == Long.MAX_VALUE, infinite will become true. Then currentSubscriber.requested won't be used. - // Therefore we don't need to worry about overflow. - currentSubscriber.requested += n; + long r = currentSubscriber.requested + n; + if (r < 0) { + infinite = true; + } else { + currentSubscriber.requested = r; + } } } if (localSubscriber != null) { - localSubscriber.requestMore(n); + if (infinite) + localSubscriber.requestMore(Long.MAX_VALUE); + else + localSubscriber.requestMore(n); } } }); @@ -167,7 +178,8 @@ void emit(T value, int id, InnerSubscriber innerSubscriber) { if (queue == null) { queue = new ArrayList(); } - innerSubscriber.requested--; + if (innerSubscriber.requested != Long.MAX_VALUE) + innerSubscriber.requested--; queue.add(value); return; } @@ -183,7 +195,8 @@ void emit(T value, int id, InnerSubscriber innerSubscriber) { if (once) { once = false; synchronized (guard) { - innerSubscriber.requested--; + if (innerSubscriber.requested != Long.MAX_VALUE) + innerSubscriber.requested--; } s.onNext(value); } diff --git a/src/test/java/rx/internal/operators/OperatorSwitchTest.java b/src/test/java/rx/internal/operators/OperatorSwitchTest.java index 0efc388db1..6b5d3a1f79 100644 --- a/src/test/java/rx/internal/operators/OperatorSwitchTest.java +++ b/src/test/java/rx/internal/operators/OperatorSwitchTest.java @@ -15,12 +15,20 @@ */ package rx.internal.operators; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyString; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.inOrder; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import java.util.ArrayList; import java.util.Arrays; +import java.util.List; +import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; @@ -36,6 +44,7 @@ import rx.Subscriber; import rx.exceptions.TestException; import rx.functions.Action0; +import rx.functions.Action1; import rx.functions.Func1; import rx.observers.TestSubscriber; import rx.schedulers.TestScheduler; @@ -574,4 +583,91 @@ public void onNext(String t) { Assert.assertEquals(250, ts.getOnNextEvents().size()); } + + @Test(timeout = 10000) + public void testInitialRequestsAreAdditive() { + TestSubscriber ts = new TestSubscriber(0); + Observable.switchOnNext( + Observable.interval(100, TimeUnit.MILLISECONDS) + .map( + new Func1>() { + @Override + public Observable call(Long t) { + return Observable.just(1L, 2L, 3L); + } + } + ).take(3)) + .subscribe(ts); + ts.requestMore(Long.MAX_VALUE - 100); + ts.requestMore(1); + ts.awaitTerminalEvent(); + } + + @Test(timeout = 10000) + public void testInitialRequestsDontOverflow() { + TestSubscriber ts = new TestSubscriber(0); + Observable.switchOnNext( + Observable.interval(100, TimeUnit.MILLISECONDS) + .map(new Func1>() { + @Override + public Observable call(Long t) { + return Observable.from(Arrays.asList(1L, 2L, 3L)); + } + }).take(3)).subscribe(ts); + ts.requestMore(Long.MAX_VALUE - 1); + ts.requestMore(2); + ts.awaitTerminalEvent(); + assertTrue(ts.getOnNextEvents().size() > 0); + } + + + @Test(timeout = 10000) + public void testSecondaryRequestsDontOverflow() throws InterruptedException { + TestSubscriber ts = new TestSubscriber(0); + Observable.switchOnNext( + Observable.interval(100, TimeUnit.MILLISECONDS) + .map(new Func1>() { + @Override + public Observable call(Long t) { + return Observable.from(Arrays.asList(1L, 2L, 3L)); + } + }).take(3)).subscribe(ts); + ts.requestMore(1); + //we will miss two of the first observable + Thread.sleep(250); + ts.requestMore(Long.MAX_VALUE - 1); + ts.requestMore(Long.MAX_VALUE - 1); + ts.awaitTerminalEvent(); + ts.assertValueCount(7); + } + + @Test(timeout = 10000) + public void testSecondaryRequestsAdditivelyAreMoreThanLongMaxValueInducesMaxValueRequestFromUpstream() throws InterruptedException { + final List requests = new CopyOnWriteArrayList(); + final Action1 addRequest = new Action1() { + + @Override + public void call(Long n) { + requests.add(n); + }}; + TestSubscriber ts = new TestSubscriber(0); + Observable.switchOnNext( + Observable.interval(100, TimeUnit.MILLISECONDS) + .map(new Func1>() { + @Override + public Observable call(Long t) { + return Observable.from(Arrays.asList(1L, 2L, 3L)).doOnRequest(addRequest); + } + }).take(3)).subscribe(ts); + ts.requestMore(1); + //we will miss two of the first observable + Thread.sleep(250); + ts.requestMore(Long.MAX_VALUE - 1); + ts.requestMore(Long.MAX_VALUE - 1); + ts.awaitTerminalEvent(); + assertTrue(ts.getOnNextEvents().size() > 0); + assertEquals(5, (int) requests.size()); + assertEquals(Long.MAX_VALUE, (long) requests.get(3)); + assertEquals(Long.MAX_VALUE, (long) requests.get(4)); + } }