diff --git a/src/main/java/rx/internal/operators/OnSubscribeRefCount.java b/src/main/java/rx/internal/operators/OnSubscribeRefCount.java index c5898ac1b6..4a34663fbb 100644 --- a/src/main/java/rx/internal/operators/OnSubscribeRefCount.java +++ b/src/main/java/rx/internal/operators/OnSubscribeRefCount.java @@ -129,7 +129,13 @@ void cleanup() { // and set the subscriptionCount to 0 lock.lock(); try { + if (baseSubscription == currentBase) { + // backdoor into the ConnectableObservable to cleanup and reset its state + if (source instanceof Subscription) { + ((Subscription)source).unsubscribe(); + } + baseSubscription.unsubscribe(); baseSubscription = new CompositeSubscription(); subscriptionCount.set(0); @@ -148,7 +154,13 @@ public void call() { lock.lock(); try { if (baseSubscription == current) { + if (subscriptionCount.decrementAndGet() == 0) { + // backdoor into the ConnectableObservable to cleanup and reset its state + if (source instanceof Subscription) { + ((Subscription)source).unsubscribe(); + } + baseSubscription.unsubscribe(); // need a new baseSubscription because once // unsubscribed stays that way diff --git a/src/main/java/rx/internal/operators/OperatorReplay.java b/src/main/java/rx/internal/operators/OperatorReplay.java index a89ac32bd7..6447b2c737 100644 --- a/src/main/java/rx/internal/operators/OperatorReplay.java +++ b/src/main/java/rx/internal/operators/OperatorReplay.java @@ -28,7 +28,7 @@ import rx.schedulers.Timestamped; import rx.subscriptions.Subscriptions; -public final class OperatorReplay extends ConnectableObservable { +public final class OperatorReplay extends ConnectableObservable implements Subscription { /** The source observable. */ final Observable source; /** Holds the current subscriber that is, will be or just was subscribed to the source observable. */ @@ -254,6 +254,17 @@ private OperatorReplay(OnSubscribe onSubscribe, Observable sourc this.bufferFactory = bufferFactory; } + @Override + public void unsubscribe() { + current.lazySet(null); + } + + @Override + public boolean isUnsubscribed() { + ReplaySubscriber ps = current.get(); + return ps == null || ps.isUnsubscribed(); + } + @Override public void connect(Action1 connection) { boolean doConnect; diff --git a/src/test/java/rx/internal/operators/OnSubscribeRefCountTest.java b/src/test/java/rx/internal/operators/OnSubscribeRefCountTest.java index a64f808e78..a824f3191b 100644 --- a/src/test/java/rx/internal/operators/OnSubscribeRefCountTest.java +++ b/src/test/java/rx/internal/operators/OnSubscribeRefCountTest.java @@ -19,6 +19,7 @@ import static org.mockito.Matchers.any; import static org.mockito.Mockito.*; +import java.lang.management.ManagementFactory; import java.util.*; import java.util.concurrent.*; import java.util.concurrent.atomic.*; @@ -31,6 +32,7 @@ import rx.Observable.OnSubscribe; import rx.Observer; import rx.functions.*; +import rx.observables.ConnectableObservable; import rx.observers.*; import rx.schedulers.*; import rx.subjects.ReplaySubject; @@ -611,4 +613,155 @@ public void call(Throwable t) { assertNotNull("First subscriber didn't get the error", err1); assertNotNull("Second subscriber didn't get the error", err2); } + + Observable source; + + @Test + public void replayNoLeak() throws Exception { + System.gc(); + Thread.sleep(100); + + long start = ManagementFactory.getMemoryMXBean().getHeapMemoryUsage().getUsed(); + + source = Observable.fromCallable(new Callable() { + @Override + public Object call() throws Exception { + return new byte[100 * 1000 * 1000]; + } + }) + .replay(1) + .refCount(); + + source.subscribe(); + + System.gc(); + Thread.sleep(100); + + long after = ManagementFactory.getMemoryMXBean().getHeapMemoryUsage().getUsed(); + + source = null; + assertTrue(String.format("%,3d -> %,3d%n", start, after), start + 20 * 1000 * 1000 > after); + } + + @Test + public void replayNoLeak2() throws Exception { + System.gc(); + Thread.sleep(100); + + long start = ManagementFactory.getMemoryMXBean().getHeapMemoryUsage().getUsed(); + + source = Observable.fromCallable(new Callable() { + @Override + public Object call() throws Exception { + return new byte[100 * 1000 * 1000]; + } + }).concatWith(Observable.never()) + .replay(1) + .refCount(); + + Subscription s1 = source.subscribe(); + Subscription s2 = source.subscribe(); + + s1.unsubscribe(); + s2.unsubscribe(); + + s1 = null; + s2 = null; + + System.gc(); + Thread.sleep(100); + + long after = ManagementFactory.getMemoryMXBean().getHeapMemoryUsage().getUsed(); + + source = null; + assertTrue(String.format("%,3d -> %,3d%n", start, after), start + 20 * 1000 * 1000 > after); + } + + static final class ExceptionData extends Exception { + private static final long serialVersionUID = -6763898015338136119L; + + public final Object data; + + public ExceptionData(Object data) { + this.data = data; + } + } + + @Test + public void publishNoLeak() throws Exception { + System.gc(); + Thread.sleep(100); + + long start = ManagementFactory.getMemoryMXBean().getHeapMemoryUsage().getUsed(); + + source = Observable.fromCallable(new Callable() { + @Override + public Object call() throws Exception { + throw new ExceptionData(new byte[100 * 1000 * 1000]); + } + }) + .publish() + .refCount(); + + Action1 err = Actions.empty(); + source.subscribe(Actions.empty(), err); + + System.gc(); + Thread.sleep(100); + + long after = ManagementFactory.getMemoryMXBean().getHeapMemoryUsage().getUsed(); + + source = null; + assertTrue(String.format("%,3d -> %,3d%n", start, after), start + 20 * 1000 * 1000 > after); + } + + @Test + public void publishNoLeak2() throws Exception { + System.gc(); + Thread.sleep(100); + + long start = ManagementFactory.getMemoryMXBean().getHeapMemoryUsage().getUsed(); + + source = Observable.fromCallable(new Callable() { + @Override + public Object call() throws Exception { + return new byte[100 * 1000 * 1000]; + } + }).concatWith(Observable.never()) + .publish() + .refCount(); + + Subscription s1 = source.test(0); + Subscription s2 = source.test(0); + + s1.unsubscribe(); + s2.unsubscribe(); + + s1 = null; + s2 = null; + + System.gc(); + Thread.sleep(100); + + long after = ManagementFactory.getMemoryMXBean().getHeapMemoryUsage().getUsed(); + + source = null; + assertTrue(String.format("%,3d -> %,3d%n", start, after), start + 20 * 1000 * 1000 > after); + } + + @Test + public void replayIsUnsubscribed() { + ConnectableObservable co = Observable.just(1) + .replay(); + + assertTrue(((Subscription)co).isUnsubscribed()); + + Subscription s = co.connect(); + + assertFalse(((Subscription)co).isUnsubscribed()); + + s.unsubscribe(); + + assertTrue(((Subscription)co).isUnsubscribed()); + } }