Skip to content

Commit

Permalink
Some crazy idea I had to simplify futures and memory limits
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Oct 23, 2024
1 parent f6ac591 commit a1a4798
Show file tree
Hide file tree
Showing 3 changed files with 220 additions and 2 deletions.
116 changes: 116 additions & 0 deletions pdelfin/cappedpool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import concurrent.futures
import threading
import queue

class CappedFuture(concurrent.futures.Future):
def __init__(self, semaphore):
super().__init__()
self._semaphore = semaphore
self._result_retrieved = False
self._underlying_future = None
self._condition = threading.Condition()

def set_underlying_future(self, underlying_future):
with self._condition:
self._underlying_future = underlying_future
# Transfer the result when the underlying future completes
underlying_future.add_done_callback(self._transfer_result)

def _transfer_result(self, underlying_future):
if underlying_future.cancelled():
self.set_cancelled()
elif underlying_future.exception() is not None:
self.set_exception(underlying_future.exception())
else:
try:
result = underlying_future.result()
self.set_result(result)
except Exception as e:
self.set_exception(e)

def result(self, timeout=None):
res = super().result(timeout)
self._release_semaphore()
return res

def exception(self, timeout=None):
exc = super().exception(timeout)
self._release_semaphore()
return exc

def _release_semaphore(self):
if not self._result_retrieved:
self._result_retrieved = True
self._semaphore.release()

def cancel(self):
with self._condition:
if self._underlying_future is not None:
cancelled = self._underlying_future.cancel()
if cancelled:
super().cancel()
return cancelled
else:
# Task has not been submitted yet; cancel directly
return super().cancel()

def cancelled(self):
return super().cancelled()

def running(self):
with self._condition:
if self._underlying_future is not None:
return self._underlying_future.running()
else:
return False

def done(self):
return super().done()

class CappedProcessPoolExecutor(concurrent.futures.Executor):
def __init__(self, max_unprocessed=100, max_workers=None):
self._max_unprocessed = max_unprocessed
self._semaphore = threading.BoundedSemaphore(max_unprocessed)
self._task_queue = queue.Queue()
self._shutdown = threading.Event()
self._shutdown_lock = threading.Lock()
self._executor = concurrent.futures.ProcessPoolExecutor(max_workers=max_workers)
self._worker_thread = threading.Thread(target=self._worker)
self._worker_thread.daemon = True
self._worker_thread.start()

def submit(self, fn, *args, **kwargs):
if self._shutdown.is_set():
raise RuntimeError('Cannot submit new tasks after shutdown')
# Create a CappedFuture to return to the user
user_future = CappedFuture(self._semaphore)
# Put the task in the queue
self._task_queue.put((user_future, fn, args, kwargs))
return user_future

def _worker(self):
while True:
if self._shutdown.is_set() and self._task_queue.empty():
break
try:
user_future, fn, args, kwargs = self._task_queue.get(timeout=0.1)
except queue.Empty:
continue
self._semaphore.acquire()
if user_future.cancelled():
self._semaphore.release()
continue
# Submit the task to the underlying executor
try:
underlying_future = self._executor.submit(fn, *args, **kwargs)
user_future.set_underlying_future(underlying_future)
except Exception as e:
user_future.set_exception(e)
self._semaphore.release()
continue

def shutdown(self, wait=True):
with self._shutdown_lock:
self._shutdown.set()
self._worker_thread.join()
self._executor.shutdown(wait=wait)
7 changes: 5 additions & 2 deletions scripts/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,16 @@ def sample_mm_requests_qwen2vl(
return_tensors="np",
)

# print(inputs)

tokens = inputs["input_ids"][0]
prompt_len = len(tokens)

result.append((TokensPrompt(
dict(
prompt_token_ids=tokens,
multi_modal_data=dict(image=main_image),
multi_modal_data=dict(image=dict(image_embeds=torch.randn(1036, 3584), image_grid_thw=torch.tensor([[1, 74, 56]]))),
# multi_modal_data=dict(image=main_image)
)
), prompt_len, fixed_output_len))

Expand Down Expand Up @@ -467,7 +470,7 @@ def main(args: argparse.Namespace):
else:
# requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
# args.output_len)
requests = sample_mm_requests_molmo(args.dataset, args.num_prompts, tokenizer,
requests = sample_mm_requests_qwen2vl(args.dataset, args.num_prompts, tokenizer,
args.output_len)

if args.backend == "vllm":
Expand Down
99 changes: 99 additions & 0 deletions tests/test_cappedpool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import unittest
import time
import concurrent.futures
from concurrent.futures import TimeoutError

# Assuming the CappedProcessPoolExecutor code is in a module named 'capped_executor'
from pdelfin.cappedpool import CappedProcessPoolExecutor

# Define functions at the top level to ensure they are picklable by multiprocessing

def square(x):
return x * x

def raise_exception():
raise ValueError("Test exception")

def sleep_and_return(x, sleep_time):
time.sleep(sleep_time)
return x

def task(counter, max_counter, counter_lock):
with counter_lock:
counter.value += 1
print(f"Task incrementing counter to {counter.value}")
if counter.value > max_counter.value:
max_counter.value = counter.value
time.sleep(0.5)
with counter_lock:
counter.value -= 1
return True

class TestCappedProcessPoolExecutor(unittest.TestCase):

def test_basic_functionality(self):
"""Test that tasks are executed and results are correct."""
with CappedProcessPoolExecutor(max_unprocessed=10, max_workers=4) as executor:
futures = [executor.submit(square, i) for i in range(10)]
results = [f.result() for f in futures]
expected = [i * i for i in range(10)]
self.assertEqual(results, expected)

def test_exception_handling(self):
"""Test that exceptions in tasks are properly raised."""
with CappedProcessPoolExecutor(max_unprocessed=10, max_workers=4) as executor:
future = executor.submit(raise_exception)
with self.assertRaises(ValueError):
future.result()

def test_cancellation(self):
"""Test that tasks can be cancelled before execution."""
with CappedProcessPoolExecutor(max_unprocessed=10, max_workers=4) as executor:
future = executor.submit(time.sleep, 5)
# Try to cancel immediately
cancelled = future.cancel()
self.assertTrue(cancelled)
self.assertTrue(future.cancelled())
# Attempt to get result; should raise CancelledError
with self.assertRaises(concurrent.futures.CancelledError):
future.result()

def test_shutdown(self):
"""Test that the executor shuts down properly and does not accept new tasks."""
executor = CappedProcessPoolExecutor(max_unprocessed=10, max_workers=4)
future = executor.submit(time.sleep, 1)
executor.shutdown(wait=True)
with self.assertRaises(RuntimeError):
executor.submit(time.sleep, 1)

def test_capping_behavior(self):
"""Test that the number of concurrent tasks does not exceed max_unprocessed."""
max_unprocessed = 3
with CappedProcessPoolExecutor(max_unprocessed=max_unprocessed, max_workers=10) as executor:
from multiprocessing import Manager

manager = Manager()
counter = manager.Value('i', 0)
max_counter = manager.Value('i', 0)
counter_lock = manager.Lock()

futures = [executor.submit(task, counter, max_counter, counter_lock) for _ in range(10)]

for index, f in enumerate(futures):
print(f"Future {index} returned {f.result()}")

time.sleep(1)

print(max_counter.value)
self.assertLessEqual(max_counter.value, max_unprocessed)

def test_submit_after_shutdown(self):
"""Test that submitting tasks after shutdown raises an error."""
executor = CappedProcessPoolExecutor(max_unprocessed=10, max_workers=4)
executor.shutdown(wait=True)
with self.assertRaises(RuntimeError):
executor.submit(square, 2)


if __name__ == '__main__':
unittest.main()

0 comments on commit a1a4798

Please sign in to comment.