diff --git a/rxjava-core/src/main/java/rx/Observable.java b/rxjava-core/src/main/java/rx/Observable.java index 53327b371d..9f5f399ed0 100644 --- a/rxjava-core/src/main/java/rx/Observable.java +++ b/rxjava-core/src/main/java/rx/Observable.java @@ -51,6 +51,7 @@ import rx.operators.OperationFirstOrDefault; import rx.operators.OperationGroupBy; import rx.operators.OperationInterval; +import rx.operators.OperationJoin; import rx.operators.OperationLast; import rx.operators.OperationMap; import rx.operators.OperationMaterialize; @@ -5662,5 +5663,25 @@ private boolean isInternalImplementation(Object o) { return isInternal; } } - + /** + * Correlates the elements of two sequences based on overlapping durations. + * @param right The right observable sequence to join elements for. + * @param leftDurationSelector A function to select the duration of each + * element of this observable sequence, used to + * determine overlap. + * @param rightDurationSelector A function to select the duration of each + * element of the right observable sequence, + * used to determine overlap. + * @param resultSelector A function invoked to compute a result element + * for any two overlapping elements of the left and + * right observable sequences. + * @return An observable sequence that contains result elements computed + * from source elements that have an overlapping duration. + * @see MSDN: Observable.Join + */ + public Observable join(Observable right, Func1> leftDurationSelector, + Func1> rightDurationSelector, + Func2 resultSelector) { + return create(new OperationJoin(this, right, leftDurationSelector, rightDurationSelector, resultSelector)); + } } diff --git a/rxjava-core/src/main/java/rx/operators/OperationJoin.java b/rxjava-core/src/main/java/rx/operators/OperationJoin.java new file mode 100644 index 0000000000..b75b8498b0 --- /dev/null +++ b/rxjava-core/src/main/java/rx/operators/OperationJoin.java @@ -0,0 +1,277 @@ +/** + * Copyright 2013 Netflix, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package rx.operators; + +import java.util.HashMap; +import java.util.Map; +import rx.Observable; +import rx.Observable.OnSubscribeFunc; +import rx.Observer; +import rx.Subscription; +import rx.subscriptions.CompositeSubscription; +import rx.subscriptions.SerialSubscription; +import rx.util.functions.Func1; +import rx.util.functions.Func2; + +/** + * Correlates the elements of two sequences based on overlapping durations. + */ +public class OperationJoin implements OnSubscribeFunc { + final Observable left; + final Observable right; + final Func1> leftDurationSelector; + final Func1> rightDurationSelector; + final Func2 resultSelector; + public OperationJoin( + Observable left, + Observable right, + Func1> leftDurationSelector, + Func1> rightDurationSelector, + Func2 resultSelector) { + this.left = left; + this.right = right; + this.leftDurationSelector = leftDurationSelector; + this.rightDurationSelector = rightDurationSelector; + this.resultSelector = resultSelector; + } + + @Override + public Subscription onSubscribe(Observer t1) { + SerialSubscription cancel = new SerialSubscription(); + ResultSink result = new ResultSink(t1, cancel); + cancel.setSubscription(result.run()); + return cancel; + } + /** Manage the left and right sources. */ + class ResultSink { + final Object gate = new Object(); + final CompositeSubscription group = new CompositeSubscription(); + boolean leftDone; + int leftId; + final Map leftMap = new HashMap(); + boolean rightDone; + int rightId; + final Map rightMap = new HashMap(); + final Observer observer; + final Subscription cancel; + public ResultSink(Observer observer, Subscription cancel) { + this.observer = observer; + this.cancel = cancel; + } + public Subscription run() { + SerialSubscription leftCancel = new SerialSubscription(); + SerialSubscription rightCancel = new SerialSubscription(); + + group.add(leftCancel); + group.add(rightCancel); + + leftCancel.setSubscription(left.subscribe(new LeftObserver(leftCancel))); + rightCancel.setSubscription(right.subscribe(new RightObserver(rightCancel))); + + return group; + } + /** Observes the left values. */ + class LeftObserver implements Observer { + final Subscription self; + public LeftObserver(Subscription self) { + this.self = self; + } + protected void expire(int id, Subscription resource) { + synchronized (gate) { + if (leftMap.remove(id) != null && leftMap.isEmpty() && leftDone) { + observer.onCompleted(); + cancel.unsubscribe(); + } + } + group.remove(resource); + } + @Override + public void onNext(TLeft args) { + int id; + synchronized (gate) { + id = leftId++; + leftMap.put(id, args); + } + SerialSubscription md = new SerialSubscription(); + group.add(md); + + Observable duration; + try { + duration = leftDurationSelector.call(args); + } catch (Throwable t) { + observer.onError(t); + cancel.unsubscribe(); + return; + } + + md.setSubscription(duration.subscribe(new LeftDurationObserver(id, md))); + + synchronized (gate) { + for (TRight r : rightMap.values()) { + R result; + try { + result = resultSelector.call(args, r); + } catch (Throwable t) { + observer.onError(t); + cancel.unsubscribe(); + return; + } + observer.onNext(result); + } + } + } + @Override + public void onError(Throwable e) { + synchronized (gate) { + observer.onError(e); + cancel.unsubscribe(); + } + } + @Override + public void onCompleted() { + synchronized (gate) { + leftDone = true; + if (rightDone || leftMap.isEmpty()) { + observer.onCompleted(); + cancel.unsubscribe(); + } else { + self.unsubscribe(); + } + } + } + /** Observes the left duration. */ + class LeftDurationObserver implements Observer { + final int id; + final Subscription handle; + public LeftDurationObserver(int id, Subscription handle) { + this.id = id; + this.handle = handle; + } + + @Override + public void onNext(TLeftDuration args) { + expire(id, handle); + } + + @Override + public void onError(Throwable e) { + LeftObserver.this.onError(e); + } + + @Override + public void onCompleted() { + expire(id, handle); + } + + } + } + /** Observes the right values. */ + class RightObserver implements Observer { + final Subscription self; + public RightObserver(Subscription self) { + this.self = self; + } + void expire(int id, Subscription resource) { + synchronized (gate) { + if (rightMap.remove(id) != null && rightMap.isEmpty() && rightDone) { + observer.onCompleted(); + cancel.unsubscribe(); + } + } + group.remove(resource); + } + @Override + public void onNext(TRight args) { + int id = 0; + synchronized (gate) { + id = rightId++; + rightMap.put(id, args); + } + SerialSubscription md = new SerialSubscription(); + group.add(md); + + Observable duration; + try { + duration = rightDurationSelector.call(args); + } catch (Throwable t) { + observer.onError(t); + cancel.unsubscribe(); + return; + } + + md.setSubscription(duration.subscribe(new RightDurationObserver(id, md))); + + synchronized (gate) { + for (TLeft lv : leftMap.values()) { + R result; + try { + result = resultSelector.call(lv, args); + } catch (Throwable t) { + observer.onError(t); + cancel.unsubscribe(); + return; + } + observer.onNext(result); + } + } + } + @Override + public void onError(Throwable e) { + synchronized (gate) { + observer.onError(e); + cancel.unsubscribe(); + } + } + @Override + public void onCompleted() { + synchronized (gate) { + rightDone = true; + if (leftDone || rightMap.isEmpty()) { + observer.onCompleted(); + cancel.unsubscribe(); + } else { + self.unsubscribe(); + } + } + } + /** Observe the right duration. */ + class RightDurationObserver implements Observer { + final int id; + final Subscription handle; + public RightDurationObserver(int id, Subscription handle) { + this.id = id; + this.handle = handle; + } + + @Override + public void onNext(TRightDuration args) { + expire(id, handle); + } + + @Override + public void onError(Throwable e) { + RightObserver.this.onError(e); + } + + @Override + public void onCompleted() { + expire(id, handle); + } + + } + } + } +} diff --git a/rxjava-core/src/test/java/rx/operators/OperationJoinTest.java b/rxjava-core/src/test/java/rx/operators/OperationJoinTest.java new file mode 100644 index 0000000000..8f841ebb50 --- /dev/null +++ b/rxjava-core/src/test/java/rx/operators/OperationJoinTest.java @@ -0,0 +1,302 @@ +/** + * Copyright 2013 Netflix, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package rx.operators; + +import java.util.Collection; +import org.junit.Before; +import org.junit.Test; +import static org.mockito.Matchers.any; +import org.mockito.Mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import org.mockito.MockitoAnnotations; +import rx.Observable; +import rx.Observer; +import rx.subjects.PublishSubject; +import rx.util.functions.Action1; +import rx.util.functions.Func1; +import rx.util.functions.Func2; + +public class OperationJoinTest { + @Mock + Observer observer; + + Func2 add = new Func2() { + @Override + public Integer call(Integer t1, Integer t2) { + return t1 + t2; + } + }; + Func1> just(final Observable observable) { + return new Func1>() { + @Override + public Observable call(Integer t1) { + return observable; + } + }; + } + @Before + public void before() { + MockitoAnnotations.initMocks(this); + } + @Test + public void normal1() { + PublishSubject source1 = PublishSubject.create(); + PublishSubject source2 = PublishSubject.create(); + + Observable m = source1.join(source2, + just(Observable.never()), + just(Observable.never()), add); + + m.subscribe(observer); + + source1.onNext(1); + source1.onNext(2); + source1.onNext(4); + + source2.onNext(16); + source2.onNext(32); + source2.onNext(64); + + source1.onCompleted(); + source2.onCompleted(); + + verify(observer, times(1)).onNext(17); + verify(observer, times(1)).onNext(18); + verify(observer, times(1)).onNext(20); + verify(observer, times(1)).onNext(33); + verify(observer, times(1)).onNext(34); + verify(observer, times(1)).onNext(36); + verify(observer, times(1)).onNext(65); + verify(observer, times(1)).onNext(66); + verify(observer, times(1)).onNext(68); + + verify(observer, times(1)).onCompleted(); + verify(observer, never()).onError(any(Throwable.class)); + } + @Test + public void normal1WithDuration() { + PublishSubject source1 = PublishSubject.create(); + PublishSubject source2 = PublishSubject.create(); + + PublishSubject duration1 = PublishSubject.create(); + + Observable m = source1.join(source2, + just(duration1), + just(Observable.never()), add); + m.subscribe(observer); + + source1.onNext(1); + source1.onNext(2); + source2.onNext(16); + + duration1.onNext(1); + + source1.onNext(4); + source1.onNext(8); + + source1.onCompleted(); + source2.onCompleted(); + + verify(observer, times(1)).onNext(17); + verify(observer, times(1)).onNext(18); + verify(observer, times(1)).onNext(20); + verify(observer, times(1)).onNext(24); + + verify(observer, times(1)).onCompleted(); + verify(observer, never()).onError(any(Throwable.class)); + + } + @Test + public void normal2() { + PublishSubject source1 = PublishSubject.create(); + PublishSubject source2 = PublishSubject.create(); + + Observable m = source1.join(source2, + just(Observable.never()), + just(Observable.never()), add); + + m.subscribe(observer); + + source1.onNext(1); + source1.onNext(2); + source1.onCompleted(); + + source2.onNext(16); + source2.onNext(32); + source2.onNext(64); + + source2.onCompleted(); + + verify(observer, times(1)).onNext(17); + verify(observer, times(1)).onNext(18); + verify(observer, times(1)).onNext(33); + verify(observer, times(1)).onNext(34); + verify(observer, times(1)).onNext(65); + verify(observer, times(1)).onNext(66); + + verify(observer, times(1)).onCompleted(); + verify(observer, never()).onError(any(Throwable.class)); + } + @Test + public void leftThrows() { + PublishSubject source1 = PublishSubject.create(); + PublishSubject source2 = PublishSubject.create(); + + Observable m = source1.join(source2, + just(Observable.never()), + just(Observable.never()), add); + + m.subscribe(observer); + + source2.onNext(1); + source1.onError(new RuntimeException("Forced failure")); + + verify(observer, times(1)).onError(any(Throwable.class)); + verify(observer, never()).onCompleted(); + verify(observer, never()).onNext(any()); + } + @Test + public void rightThrows() { + PublishSubject source1 = PublishSubject.create(); + PublishSubject source2 = PublishSubject.create(); + + Observable m = source1.join(source2, + just(Observable.never()), + just(Observable.never()), add); + + m.subscribe(observer); + + source1.onNext(1); + source2.onError(new RuntimeException("Forced failure")); + + verify(observer, times(1)).onError(any(Throwable.class)); + verify(observer, never()).onCompleted(); + verify(observer, never()).onNext(any()); + } + @Test + public void leftDurationThrows() { + PublishSubject source1 = PublishSubject.create(); + PublishSubject source2 = PublishSubject.create(); + + Observable duration1 = Observable.error(new RuntimeException("Forced failure")); + + Observable m = source1.join(source2, + just(duration1), + just(Observable.never()), add); + m.subscribe(observer); + + source1.onNext(1); + + + verify(observer, times(1)).onError(any(Throwable.class)); + verify(observer, never()).onCompleted(); + verify(observer, never()).onNext(any()); + } + @Test + public void rightDurationThrows() { + PublishSubject source1 = PublishSubject.create(); + PublishSubject source2 = PublishSubject.create(); + + Observable duration1 = Observable.error(new RuntimeException("Forced failure")); + + Observable m = source1.join(source2, + just(Observable.never()), + just(duration1), add); + m.subscribe(observer); + + source2.onNext(1); + + + verify(observer, times(1)).onError(any(Throwable.class)); + verify(observer, never()).onCompleted(); + verify(observer, never()).onNext(any()); + } + @Test + public void leftDurationSelectorThrows() { + PublishSubject source1 = PublishSubject.create(); + PublishSubject source2 = PublishSubject.create(); + + Func1> fail = new Func1>() { + @Override + public Observable call(Integer t1) { + throw new RuntimeException("Forced failure"); + } + }; + + Observable m = source1.join(source2, + fail, + just(Observable.never()), add); + m.subscribe(observer); + + source1.onNext(1); + + + verify(observer, times(1)).onError(any(Throwable.class)); + verify(observer, never()).onCompleted(); + verify(observer, never()).onNext(any()); + } + @Test + public void rightDurationSelectorThrows() { + PublishSubject source1 = PublishSubject.create(); + PublishSubject source2 = PublishSubject.create(); + + Func1> fail = new Func1>() { + @Override + public Observable call(Integer t1) { + throw new RuntimeException("Forced failure"); + } + }; + + Observable m = source1.join(source2, + just(Observable.never()), + fail, add); + m.subscribe(observer); + + source2.onNext(1); + + + verify(observer, times(1)).onError(any(Throwable.class)); + verify(observer, never()).onCompleted(); + verify(observer, never()).onNext(any()); + } + @Test + public void resultSelectorThrows() { + PublishSubject source1 = PublishSubject.create(); + PublishSubject source2 = PublishSubject.create(); + + Func2 fail = new Func2() { + @Override + public Integer call(Integer t1, Integer t2) { + throw new RuntimeException("Forced failure"); + } + }; + + Observable m = source1.join(source2, + just(Observable.never()), + just(Observable.never()), fail); + m.subscribe(observer); + + source1.onNext(1); + source2.onNext(2); + + + verify(observer, times(1)).onError(any(Throwable.class)); + verify(observer, never()).onCompleted(); + verify(observer, never()).onNext(any()); + } +}