diff --git a/rxjava-core/src/main/java/rx/Observable.java b/rxjava-core/src/main/java/rx/Observable.java index d24e48106c..9ced2e1415 100644 --- a/rxjava-core/src/main/java/rx/Observable.java +++ b/rxjava-core/src/main/java/rx/Observable.java @@ -37,6 +37,7 @@ import rx.observables.GroupedObservable; import rx.operators.OperationAll; +import rx.operators.OperationCache; import rx.operators.OperationConcat; import rx.operators.OperationDefer; import rx.operators.OperationDematerialize; @@ -2428,6 +2429,28 @@ public R call(T0 t0, T1 t1, T2 t2, T3 t3) { }); } + /** + * Returns an Observable that repeats the original Observable sequence to all subscribers. + * The source Observable is subscribed to at most once. + * + * @param source + * the source Observable + * @return an Observable that repeats the original Observable sequence to all subscribers. + */ + public static Observable cache(Observable source) { + return _create(OperationCache.cache(source)); + } + + /** + * Returns an Observable that repeats the original Observable sequence to all subscribers. + * The source Observable is subscribed to at most once. + * + * @return an Observable that repeats the original Observable sequence to all subscribers. + */ + public Observable cache() { + return cache(this); + } + /** * Filters an Observable by discarding any of its emissions that do not meet some test. *

diff --git a/rxjava-core/src/main/java/rx/operators/OperationCache.java b/rxjava-core/src/main/java/rx/operators/OperationCache.java new file mode 100644 index 0000000000..197b7dd2be --- /dev/null +++ b/rxjava-core/src/main/java/rx/operators/OperationCache.java @@ -0,0 +1,447 @@ +package rx.operators; + +import org.junit.Test; +import org.mockito.Mockito; +import rx.Observable; +import rx.Observer; +import rx.Subscription; +import rx.subscriptions.Subscriptions; +import rx.util.functions.Func1; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.junit.Assert.assertFalse; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +public final class OperationCache +{ + public static Func1, Subscription> cache(Observable source) { + return new Cache(source); + } + + private static class Cache implements Func1, Subscription>, Observer + { + + private Observable source; + private boolean isSubscribed = false; + private final Map> subscriptions = new HashMap>(); + private Exception exception = null; + private final List history = Collections.synchronizedList(new ArrayList()); + + public Cache(Observable source) + { + this.source = source; + } + + @Override + public Subscription call(Observer observer) + { + int item = 0; + Subscription subscription; + boolean needSubscribe; + + for (;;) { + while (item < history.size()) { + observer.onNext(history.get(item++)); + } + + synchronized (subscriptions) { + if (item < history.size()) { + continue; + } + + if (exception != null) { + observer.onError(exception); + return Subscriptions.empty(); + } + if (source == null) { + observer.onCompleted(); + return Subscriptions.empty(); + } + + subscription = new CacheSubscription(); + subscriptions.put(subscription, observer); + needSubscribe = !isSubscribed; + if (needSubscribe) { + isSubscribed = true; + } + break; + } + } + + if (needSubscribe) { + source.subscribe(this); + } + + return subscription; + } + + @Override + public void onCompleted() + { + synchronized (subscriptions) { + source = null; + for (Observer observer : new ArrayList>(subscriptions.values())) { + observer.onCompleted(); + } + subscriptions.clear(); + } + } + + @Override + public void onError(Exception e) + { + synchronized (subscriptions) { + source = null; + exception = e; + for (Observer observer : new ArrayList>(subscriptions.values())) { + observer.onError(e); + } + subscriptions.clear(); + } + } + + @Override + public void onNext(T args) + { + synchronized (subscriptions) { + history.add(args); + for (Observer observer : new ArrayList>(subscriptions.values())) { + observer.onNext(args); + } + } + } + + private class CacheSubscription implements Subscription + { + @Override + public void unsubscribe() + { + synchronized (subscriptions) { + subscriptions.remove(this); + } + } + } + } + + public static class UnitTest { + + private final Exception testException = new Exception(); + + @Test + public void testNoSubscription() { + final SynchronousObservableFunc synchronousObservableFunc = new SynchronousObservableFunc(); + Observable cache = Observable.create(cache(Observable.create(synchronousObservableFunc))); + assertFalse("Source observer subscribed", synchronousObservableFunc.isSubscribed.get()); + } + + @Test + public void testSynchronous() { + Observable observable = Observable.create(cache(Observable.create(new SynchronousObservableFunc()))); + + Observer aObserver = mock(Observer.class); + observable.subscribe(aObserver); + assertCompletedObserver(aObserver); + + aObserver = mock(Observer.class); + observable.subscribe(aObserver); + assertCompletedObserver(aObserver); + } + + private void assertCompletedObserver(Observer aObserver) + { + verify(aObserver, times(1)).onNext("one"); + verify(aObserver, times(1)).onNext("two"); + verify(aObserver, times(1)).onNext("three"); + verify(aObserver, Mockito.never()).onError(any(Exception.class)); + verify(aObserver, times(1)).onCompleted(); + } + + private static class SynchronousObservableFunc implements Func1, Subscription> + { + private AtomicBoolean isSubscribed = new AtomicBoolean(false); + + @Override + public Subscription call(Observer observer) + { + assertFalse("Source observer subscribed twice", isSubscribed.getAndSet(true)); + observer.onNext("one"); + observer.onNext("two"); + observer.onNext("three"); + observer.onCompleted(); + return Subscriptions.empty(); + } + } + + @Test + public void testSynchronousError() { + Observable observable = Observable.create(cache(Observable.create(new SynchronousObservableErrorFunc()))); + + Observer aObserver = mock(Observer.class); + observable.subscribe(aObserver); + assertErrorObserver(aObserver); + + aObserver = mock(Observer.class); + observable.subscribe(aObserver); + assertErrorObserver(aObserver); + } + + private void assertErrorObserver(Observer aObserver) + { + verify(aObserver, times(1)).onNext("one"); + verify(aObserver, times(1)).onNext("two"); + verify(aObserver, times(1)).onNext("three"); + verify(aObserver, times(1)).onError(testException); + verify(aObserver, Mockito.never()).onCompleted(); + } + + private class SynchronousObservableErrorFunc implements Func1, Subscription> + { + private AtomicBoolean isSubscribed = new AtomicBoolean(false); + + @Override + public Subscription call(Observer observer) + { + assertFalse("Source observer subscribed twice", isSubscribed.getAndSet(true)); + observer.onNext("one"); + observer.onNext("two"); + observer.onNext("three"); + observer.onError(testException); + return Subscriptions.empty(); + } + } + + @Test + public void testAsync() { + AsyncObservableFunc asyncObservableFunc = new AsyncObservableFunc(); + Observable observable = Observable.create(cache(Observable.create(asyncObservableFunc))); + + Observer aObserver = mock(Observer.class); + observable.subscribe(aObserver); + asyncObservableFunc.waitToFinish(); + assertCompletedObserver(aObserver); + + aObserver = mock(Observer.class); + observable.subscribe(aObserver); + assertCompletedObserver(aObserver); + } + + private static class AsyncObservableFunc implements Func1, Subscription> + { + private AtomicBoolean isSubscribed = new AtomicBoolean(false); + Thread t; + + @Override + public Subscription call(final Observer observer) + { + assertFalse("Source observer subscribed twice", isSubscribed.getAndSet(true)); + t = new Thread(new Runnable() { + @Override + public void run() + { + try { + Thread.sleep(10); + } + catch (InterruptedException e) { + e.printStackTrace(); + } + observer.onNext("one"); + observer.onNext("two"); + observer.onNext("three"); + observer.onCompleted(); + } + }); + t.start(); + + return Subscriptions.empty(); + } + + public void waitToFinish() { + try { + t.join(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + } + + @Test + public void testAsyncError() { + AsyncObservableErrorFunc asyncObservableErrorFunc = new AsyncObservableErrorFunc(); + Observable observable = Observable.create(cache(Observable.create(asyncObservableErrorFunc))); + + Observer aObserver = mock(Observer.class); + observable.subscribe(aObserver); + asyncObservableErrorFunc.waitToFinish(); + assertErrorObserver(aObserver); + + aObserver = mock(Observer.class); + observable.subscribe(aObserver); + assertErrorObserver(aObserver); + } + + private class AsyncObservableErrorFunc implements Func1, Subscription> + { + private AtomicBoolean isSubscribed = new AtomicBoolean(false); + Thread t; + + @Override + public Subscription call(final Observer observer) + { + assertFalse("Source observer subscribed twice", isSubscribed.getAndSet(true)); + t = new Thread(new Runnable() { + @Override + public void run() + { + try { + Thread.sleep(10); + } + catch (InterruptedException e) { + e.printStackTrace(); + } + observer.onNext("one"); + observer.onNext("two"); + observer.onNext("three"); + observer.onError(testException); + } + }); + t.start(); + + return Subscriptions.empty(); + } + + public void waitToFinish() { + try { + t.join(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + } + + @Test + public void testSubscribeMidSequence() { + LatchedObservableFunc latchedObservableFunc = new LatchedObservableFunc(); + Observable observable = Observable.create(cache(Observable.create(latchedObservableFunc))); + + Observer aObserver = mock(Observer.class); + observable.subscribe(aObserver); + + latchedObservableFunc.waitToTwo(); + assertObservedUntilTwo(aObserver); + + Observer anotherObserver = mock(Observer.class); + observable.subscribe(anotherObserver); + assertObservedUntilTwo(anotherObserver); + + latchedObservableFunc.waitToFinish(); + assertCompletedObserver(aObserver); + assertCompletedObserver(anotherObserver); + } + + @Test + public void testUnsubscribeFirstObserver() { + LatchedObservableFunc latchedObservableFunc = new LatchedObservableFunc(); + Observable observable = Observable.create(cache(Observable.create(latchedObservableFunc))); + + Observer aObserver = mock(Observer.class); + Subscription subscription = observable.subscribe(aObserver); + + latchedObservableFunc.waitToTwo(); + + subscription.unsubscribe(); + assertObservedUntilTwo(aObserver); + + Observer anotherObserver = mock(Observer.class); + observable.subscribe(anotherObserver); + assertObservedUntilTwo(anotherObserver); + + latchedObservableFunc.waitToFinish(); + assertObservedUntilTwo(aObserver); + assertCompletedObserver(anotherObserver); + } + + private void assertObservedUntilTwo(Observer aObserver) + { + verify(aObserver, times(1)).onNext("one"); + verify(aObserver, times(1)).onNext("two"); + verify(aObserver, Mockito.never()).onNext("three"); + verify(aObserver, Mockito.never()).onError(any(Exception.class)); + verify(aObserver, Mockito.never()).onCompleted(); + } + + private static class LatchedObservableFunc implements Func1, Subscription> + { + private AtomicBoolean isSubscribed = new AtomicBoolean(false); + private Thread t; + private final Object latch = new Object(); + + + @Override + public Subscription call(final Observer observer) + { + assertFalse("Source observer subscribed twice", isSubscribed.getAndSet(true)); + observer.onNext("one"); + t = new Thread(new Runnable() { + @Override + public void run() + { + try { + Thread.sleep(10); + } + catch (InterruptedException e) { + e.printStackTrace(); + } + + observer.onNext("two"); + synchronized (latch) { + latch.notifyAll(); + try { + latch.wait(); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + observer.onNext("three"); + observer.onCompleted(); + } + }); + t.start(); + + return Subscriptions.empty(); + } + + public void waitToTwo() + { + try { + synchronized (latch) { + latch.wait(); + } + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + + public void waitToFinish() { + try { + synchronized (latch) { + latch.notifyAll(); + } + t.join(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + } + } +}