diff --git a/rxjava-core/src/main/java/rx/concurrency/ExecutorScheduler.java b/rxjava-core/src/main/java/rx/concurrency/ExecutorScheduler.java index cee84eae441..1e35735c67d 100644 --- a/rxjava-core/src/main/java/rx/concurrency/ExecutorScheduler.java +++ b/rxjava-core/src/main/java/rx/concurrency/ExecutorScheduler.java @@ -57,7 +57,7 @@ public void run() { } }, initialDelay, period, unit); - subscriptions.add(Subscriptions.create(f)); + subscriptions.add(Subscriptions.from(f)); return subscriptions; } else { @@ -84,7 +84,7 @@ public void run() { } }, delayTime, unit); // add the ScheduledFuture as a subscription so we can cancel the scheduled action if an unsubscribe happens - subscription.add(Subscriptions.create(f)); + subscription.add(Subscriptions.from(f)); } else { // we are not a ScheduledExecutorService so can't directly schedule if (delayTime == 0) { @@ -106,7 +106,7 @@ public void run() { } }, delayTime, unit); // add the ScheduledFuture as a subscription so we can cancel the scheduled action if an unsubscribe happens - subscription.add(Subscriptions.create(f)); + subscription.add(Subscriptions.from(f)); } } return subscription; @@ -134,7 +134,7 @@ public void run() { // we are an ExecutorService so get a Future back that supports unsubscribe Future f = ((ExecutorService) executor).submit(r); // add the Future as a subscription so we can cancel the scheduled action if an unsubscribe happens - subscription.add(Subscriptions.create(f)); + subscription.add(Subscriptions.from(f)); } else { // we are the lowest common denominator so can't unsubscribe once we execute executor.execute(r); diff --git a/rxjava-core/src/main/java/rx/concurrency/NewThreadScheduler.java b/rxjava-core/src/main/java/rx/concurrency/NewThreadScheduler.java index c33918353b1..036ba621276 100644 --- a/rxjava-core/src/main/java/rx/concurrency/NewThreadScheduler.java +++ b/rxjava-core/src/main/java/rx/concurrency/NewThreadScheduler.java @@ -58,15 +58,22 @@ public Thread newThread(Runnable r) { } @Override - public Subscription schedule(final T state, final Func2 action) { + public Subscription schedule(T state, Func2 action) { + final DiscardableAction discardableAction = new DiscardableAction(state, action); + // all subscriptions that may need to be unsubscribed + final CompositeSubscription subscription = new CompositeSubscription(discardableAction); + final Scheduler _scheduler = this; - return Subscriptions.from(executor.submit(new Runnable() { + subscription.add(Subscriptions.from(executor.submit(new Runnable() { @Override public void run() { - action.call(_scheduler, state); + Subscription s = discardableAction.call(_scheduler); + subscription.add(s); } - })); + }))); + + return subscription; } @Override @@ -89,7 +96,7 @@ public void run() { }, delayTime, unit); // add the ScheduledFuture as a subscription so we can cancel the scheduled action if an unsubscribe happens - subscription.add(Subscriptions.create(f)); + subscription.add(Subscriptions.from(f)); return subscription; } @@ -97,7 +104,7 @@ public void run() { } @Override - public Subscription schedule(final T state, final Func2 action) { + public Subscription schedule(T state, Func2 action) { EventLoopScheduler s = new EventLoopScheduler(); return s.schedule(state, action); } @@ -122,7 +129,7 @@ public void run() { }, delay, unit); // add the ScheduledFuture as a subscription so we can cancel the scheduled action if an unsubscribe happens - subscription.add(Subscriptions.create(f)); + subscription.add(Subscriptions.from(f)); return subscription; } diff --git a/rxjava-core/src/test/java/rx/concurrency/SchedulerUnsubscribeTest.java b/rxjava-core/src/test/java/rx/concurrency/SchedulerUnsubscribeTest.java new file mode 100644 index 00000000000..e99a25eaf38 --- /dev/null +++ b/rxjava-core/src/test/java/rx/concurrency/SchedulerUnsubscribeTest.java @@ -0,0 +1,88 @@ +package rx.concurrency; + +import static org.junit.Assert.*; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import org.junit.Test; + +import rx.Observable; +import rx.Observer; +import rx.Scheduler; +import rx.operators.SafeObservableSubscription; +import rx.util.functions.Func1; + +public class SchedulerUnsubscribeTest { + + /** + * Bug report: https://github.com/Netflix/RxJava/issues/431 + */ + @Test + public void testUnsubscribeOfNewThread() throws InterruptedException { + testUnSubscribeForScheduler(Schedulers.newThread()); + } + + @Test + public void testUnsubscribeOfThreadPoolForIO() throws InterruptedException { + testUnSubscribeForScheduler(Schedulers.threadPoolForIO()); + } + + @Test + public void testUnsubscribeOfThreadPoolForComputation() throws InterruptedException { + testUnSubscribeForScheduler(Schedulers.threadPoolForComputation()); + } + + @Test + public void testUnsubscribeOfCurrentThread() throws InterruptedException { + testUnSubscribeForScheduler(Schedulers.currentThread()); + } + + public void testUnSubscribeForScheduler(Scheduler scheduler) throws InterruptedException { + + final AtomicInteger countReceived = new AtomicInteger(); + final AtomicInteger countGenerated = new AtomicInteger(); + final SafeObservableSubscription s = new SafeObservableSubscription(); + final CountDownLatch latch = new CountDownLatch(1); + + s.wrap(Observable.interval(50, TimeUnit.MILLISECONDS) + .map(new Func1() { + @Override + public Long call(Long aLong) { + System.out.println("generated " + aLong); + countGenerated.incrementAndGet(); + return aLong; + } + }) + .subscribeOn(scheduler) + .observeOn(Schedulers.currentThread()) + .subscribe(new Observer() { + @Override + public void onCompleted() { + System.out.println("--- completed"); + } + + @Override + public void onError(Throwable e) { + System.out.println("--- onError"); + } + + @Override + public void onNext(Long args) { + if (countReceived.incrementAndGet() == 2) { + s.unsubscribe(); + latch.countDown(); + } + System.out.println("==> Received " + args); + } + })); + + latch.await(1000, TimeUnit.MILLISECONDS); + + System.out.println("----------- it thinks it is finished ------------------ "); + Thread.sleep(100); + + assertEquals(2, countGenerated.get()); + } +}