generated from allenai/python-package-template
-
Notifications
You must be signed in to change notification settings - Fork 397
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Some crazy idea I had to simplify futures and memory limits
- Loading branch information
1 parent
f6ac591
commit a1a4798
Showing
3 changed files
with
220 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
import concurrent.futures | ||
import threading | ||
import queue | ||
|
||
class CappedFuture(concurrent.futures.Future): | ||
def __init__(self, semaphore): | ||
super().__init__() | ||
self._semaphore = semaphore | ||
self._result_retrieved = False | ||
self._underlying_future = None | ||
self._condition = threading.Condition() | ||
|
||
def set_underlying_future(self, underlying_future): | ||
with self._condition: | ||
self._underlying_future = underlying_future | ||
# Transfer the result when the underlying future completes | ||
underlying_future.add_done_callback(self._transfer_result) | ||
|
||
def _transfer_result(self, underlying_future): | ||
if underlying_future.cancelled(): | ||
self.set_cancelled() | ||
elif underlying_future.exception() is not None: | ||
self.set_exception(underlying_future.exception()) | ||
else: | ||
try: | ||
result = underlying_future.result() | ||
self.set_result(result) | ||
except Exception as e: | ||
self.set_exception(e) | ||
|
||
def result(self, timeout=None): | ||
res = super().result(timeout) | ||
self._release_semaphore() | ||
return res | ||
|
||
def exception(self, timeout=None): | ||
exc = super().exception(timeout) | ||
self._release_semaphore() | ||
return exc | ||
|
||
def _release_semaphore(self): | ||
if not self._result_retrieved: | ||
self._result_retrieved = True | ||
self._semaphore.release() | ||
|
||
def cancel(self): | ||
with self._condition: | ||
if self._underlying_future is not None: | ||
cancelled = self._underlying_future.cancel() | ||
if cancelled: | ||
super().cancel() | ||
return cancelled | ||
else: | ||
# Task has not been submitted yet; cancel directly | ||
return super().cancel() | ||
|
||
def cancelled(self): | ||
return super().cancelled() | ||
|
||
def running(self): | ||
with self._condition: | ||
if self._underlying_future is not None: | ||
return self._underlying_future.running() | ||
else: | ||
return False | ||
|
||
def done(self): | ||
return super().done() | ||
|
||
class CappedProcessPoolExecutor(concurrent.futures.Executor): | ||
def __init__(self, max_unprocessed=100, max_workers=None): | ||
self._max_unprocessed = max_unprocessed | ||
self._semaphore = threading.BoundedSemaphore(max_unprocessed) | ||
self._task_queue = queue.Queue() | ||
self._shutdown = threading.Event() | ||
self._shutdown_lock = threading.Lock() | ||
self._executor = concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) | ||
self._worker_thread = threading.Thread(target=self._worker) | ||
self._worker_thread.daemon = True | ||
self._worker_thread.start() | ||
|
||
def submit(self, fn, *args, **kwargs): | ||
if self._shutdown.is_set(): | ||
raise RuntimeError('Cannot submit new tasks after shutdown') | ||
# Create a CappedFuture to return to the user | ||
user_future = CappedFuture(self._semaphore) | ||
# Put the task in the queue | ||
self._task_queue.put((user_future, fn, args, kwargs)) | ||
return user_future | ||
|
||
def _worker(self): | ||
while True: | ||
if self._shutdown.is_set() and self._task_queue.empty(): | ||
break | ||
try: | ||
user_future, fn, args, kwargs = self._task_queue.get(timeout=0.1) | ||
except queue.Empty: | ||
continue | ||
self._semaphore.acquire() | ||
if user_future.cancelled(): | ||
self._semaphore.release() | ||
continue | ||
# Submit the task to the underlying executor | ||
try: | ||
underlying_future = self._executor.submit(fn, *args, **kwargs) | ||
user_future.set_underlying_future(underlying_future) | ||
except Exception as e: | ||
user_future.set_exception(e) | ||
self._semaphore.release() | ||
continue | ||
|
||
def shutdown(self, wait=True): | ||
with self._shutdown_lock: | ||
self._shutdown.set() | ||
self._worker_thread.join() | ||
self._executor.shutdown(wait=wait) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import unittest | ||
import time | ||
import concurrent.futures | ||
from concurrent.futures import TimeoutError | ||
|
||
# Assuming the CappedProcessPoolExecutor code is in a module named 'capped_executor' | ||
from pdelfin.cappedpool import CappedProcessPoolExecutor | ||
|
||
# Define functions at the top level to ensure they are picklable by multiprocessing | ||
|
||
def square(x): | ||
return x * x | ||
|
||
def raise_exception(): | ||
raise ValueError("Test exception") | ||
|
||
def sleep_and_return(x, sleep_time): | ||
time.sleep(sleep_time) | ||
return x | ||
|
||
def task(counter, max_counter, counter_lock): | ||
with counter_lock: | ||
counter.value += 1 | ||
print(f"Task incrementing counter to {counter.value}") | ||
if counter.value > max_counter.value: | ||
max_counter.value = counter.value | ||
time.sleep(0.5) | ||
with counter_lock: | ||
counter.value -= 1 | ||
return True | ||
|
||
class TestCappedProcessPoolExecutor(unittest.TestCase): | ||
|
||
def test_basic_functionality(self): | ||
"""Test that tasks are executed and results are correct.""" | ||
with CappedProcessPoolExecutor(max_unprocessed=10, max_workers=4) as executor: | ||
futures = [executor.submit(square, i) for i in range(10)] | ||
results = [f.result() for f in futures] | ||
expected = [i * i for i in range(10)] | ||
self.assertEqual(results, expected) | ||
|
||
def test_exception_handling(self): | ||
"""Test that exceptions in tasks are properly raised.""" | ||
with CappedProcessPoolExecutor(max_unprocessed=10, max_workers=4) as executor: | ||
future = executor.submit(raise_exception) | ||
with self.assertRaises(ValueError): | ||
future.result() | ||
|
||
def test_cancellation(self): | ||
"""Test that tasks can be cancelled before execution.""" | ||
with CappedProcessPoolExecutor(max_unprocessed=10, max_workers=4) as executor: | ||
future = executor.submit(time.sleep, 5) | ||
# Try to cancel immediately | ||
cancelled = future.cancel() | ||
self.assertTrue(cancelled) | ||
self.assertTrue(future.cancelled()) | ||
# Attempt to get result; should raise CancelledError | ||
with self.assertRaises(concurrent.futures.CancelledError): | ||
future.result() | ||
|
||
def test_shutdown(self): | ||
"""Test that the executor shuts down properly and does not accept new tasks.""" | ||
executor = CappedProcessPoolExecutor(max_unprocessed=10, max_workers=4) | ||
future = executor.submit(time.sleep, 1) | ||
executor.shutdown(wait=True) | ||
with self.assertRaises(RuntimeError): | ||
executor.submit(time.sleep, 1) | ||
|
||
def test_capping_behavior(self): | ||
"""Test that the number of concurrent tasks does not exceed max_unprocessed.""" | ||
max_unprocessed = 3 | ||
with CappedProcessPoolExecutor(max_unprocessed=max_unprocessed, max_workers=10) as executor: | ||
from multiprocessing import Manager | ||
|
||
manager = Manager() | ||
counter = manager.Value('i', 0) | ||
max_counter = manager.Value('i', 0) | ||
counter_lock = manager.Lock() | ||
|
||
futures = [executor.submit(task, counter, max_counter, counter_lock) for _ in range(10)] | ||
|
||
for index, f in enumerate(futures): | ||
print(f"Future {index} returned {f.result()}") | ||
|
||
time.sleep(1) | ||
|
||
print(max_counter.value) | ||
self.assertLessEqual(max_counter.value, max_unprocessed) | ||
|
||
def test_submit_after_shutdown(self): | ||
"""Test that submitting tasks after shutdown raises an error.""" | ||
executor = CappedProcessPoolExecutor(max_unprocessed=10, max_workers=4) | ||
executor.shutdown(wait=True) | ||
with self.assertRaises(RuntimeError): | ||
executor.submit(square, 2) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |