From 3a6f576602e6995f3bc6e00fc67d1e6b49d3645c Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Sat, 24 Feb 2024 20:32:01 +0100 Subject: [PATCH] Feat: Support concurrent batches processing (#11) * Feat: Support concurrent batches processing * update the examples * try to fix python 3.12 tests by updating safe time * Ayncio task.cancelling is not available in python 3.10 --- async_batcher/batcher.py | 150 ++++++++++++++++++++--------- examples/dynamodb/main.py | 7 +- examples/keras/main.py | 3 +- examples/keras_with_grpc/server.py | 2 +- tests/conftest.py | 21 ++++ tests/test_batcher.py | 100 ++++++++++++++++++- 6 files changed, 234 insertions(+), 49 deletions(-) diff --git a/async_batcher/batcher.py b/async_batcher/batcher.py index 24d1b44..78f050a 100644 --- a/async_batcher/batcher.py +++ b/async_batcher/batcher.py @@ -31,16 +31,21 @@ def __init__( *, max_batch_size: int = -1, max_queue_time: float = 0.01, + concurrency: int = 1, ): super().__init__() if max_batch_size is None or 0 <= max_batch_size <= 1: raise ValueError("Valid max_batch_size value is greater than 1 or -1 for infinite") + if concurrency is None or concurrency == 0: + raise ValueError("Valid concurrency value is greater than 0 or -1 for infinite") self.max_batch_size = max_batch_size self.max_queue_time = max_queue_time + self.concurrency = concurrency self._queue = asyncio.Queue() self._current_task: asyncio.Task | None = None - self._should_stop = False - self._force_stop = False + self._running_batches: dict[int, asyncio.Task] = {} + self._concurrency_semaphore = asyncio.Semaphore(concurrency) if concurrency > 0 else None + self._stop = False self._is_running = False @abc.abstractmethod @@ -59,23 +64,23 @@ async def process(self, item: T) -> S: Returns: S: The result of processing the item. """ - if self._should_stop: + if self._stop: raise RuntimeError("Batcher is stopped") if self._current_task is None: - self._current_task = asyncio.get_running_loop().create_task(self.batch_run()) + self._current_task = asyncio.get_running_loop().create_task(self.run()) logging.debug(item) future = asyncio.get_running_loop().create_future() await self._queue.put(self.QueueItem(item, future)) await future return future.result() - async def _fill_batch_from_queue(self): + async def _fill_batch_from_queue(self, started_at: float | None) -> list[QueueItem]: try: batch = [await asyncio.wait_for(self._queue.get(), timeout=1.0)] except asyncio.TimeoutError: return [] - - started_at = asyncio.get_running_loop().time() + if started_at is None: + started_at = asyncio.get_event_loop().time() while 1: try: max_wait = self.max_queue_time - (asyncio.get_running_loop().time() - started_at) @@ -90,51 +95,108 @@ async def _fill_batch_from_queue(self): break return batch - async def batch_run(self): + async def _batch_run(self, task_id: int, batch: list[QueueItem]): + started_at = asyncio.get_event_loop().time() + try: + batch_items = [q_item.item for q_item in batch] + if asyncio.iscoroutinefunction(self.process_batch): + results = await self.process_batch(batch=batch_items) + else: + results = await asyncio.get_event_loop().run_in_executor( + None, self.process_batch, batch_items + ) + if results is None: + results = [None] * len(batch) + if len(results) != len(batch): + raise ValueError(f"Expected to get {len(batch)} results, but got {len(results)}.") + except Exception as e: + self.logger.error("Error processing batch", exc_info=True) + for q_item in batch: + q_item.future.set_exception(e) + else: + for q_item, result in zip(batch, results, strict=True): + q_item.future.set_result(result) + elapsed_time = asyncio.get_event_loop().time() - started_at + self.logger.debug(f"Processed batch of {len(batch)} elements" f" in {elapsed_time} seconds.") + self._running_batches.pop(task_id) + + async def _concurrent_batch_run(self, task_id: int, batch: list[QueueItem]): + async with self._concurrency_semaphore: + await self._batch_run(task_id, batch) + + async def run(self): """Run the batcher asynchronously.""" self._is_running = True - while not self._should_stop or (not self._force_stop and self._queue.qsize() > 0): - batch = await self._fill_batch_from_queue() - if not batch: - continue - - started_at = asyncio.get_event_loop().time() - try: - batch_items = [q_item.item for q_item in batch] - if asyncio.iscoroutinefunction(self.process_batch): - results = await self.process_batch(batch=batch_items) - else: - results = await asyncio.get_event_loop().run_in_executor( - None, self.process_batch, batch_items + task_id = 0 + if self.concurrency > 0: + started_at = None + while not self._should_stop(): + if started_at is None: + started_at = asyncio.get_event_loop().time() + semaphore_acquired = False + try: + # to check if the batcher should stop, we raise a timeout after 1 second + # if the semaphore is not acquired + await asyncio.wait_for(self._concurrency_semaphore.acquire(), timeout=1.0) + semaphore_acquired = True + # if the queue is empty, we need to let the batch filler create it + batch = await self._fill_batch_from_queue( + started_at=started_at if self._queue.qsize() > 0 else None ) - if results is None: - results = [None] * len(batch) - if len(results) != len(batch): - raise ValueError(f"Expected to get {len(batch)} results, but got {len(results)}.") - except Exception as e: - self.logger.error("Error processing batch", exc_info=True) - for q_item in batch: - q_item.future.set_exception(e) - else: - for q_item, result in zip(batch, results, strict=True): - q_item.future.set_result(result) - elapsed_time = asyncio.get_event_loop().time() - started_at - self.logger.debug(f"Processed batch of {len(batch)} elements" f" in {elapsed_time} seconds.") + if batch: + # create a new task to process the batch + self._running_batches[task_id] = asyncio.get_event_loop().create_task( + self._concurrent_batch_run(task_id, batch) + ) + task_id += 1 + started_at = None + except asyncio.TimeoutError: + pass + finally: + if semaphore_acquired: + self._concurrency_semaphore.release() + else: + while not self._should_stop(): + batch = await self._fill_batch_from_queue(started_at=None) + if batch: + self._running_batches[task_id] = asyncio.get_event_loop().create_task( + self._batch_run(task_id, batch) + ) + task_id += 1 self._is_running = False - def stop(self, force: bool = False): + def _should_stop(self): + return self._stop and self._queue.qsize() == 0 + + async def is_running(self): + """Check if the batcher is running. + + Returns: + bool: True if the batcher is running, False otherwise. + """ + return self._is_running + + async def stop(self, force: bool = False, timeout: float | None = None): """Stop the batcher asyncio task. Args: force (bool, optional): Whether to force stop the batcher without waiting for processing - the remaining buffer items. Defaults to False. + the remaining buffer items. If True, it will cancel the current task and all running tasks. + Defaults to False. + timeout (float, optional): The time to wait for the batcher to stop. If None, it will wait + indefinitely. Defaults to None. """ if force: - self._force_stop = True - self._should_stop = True - if ( - self._current_task - and not self._current_task.done() - and not self._current_task.get_loop().is_closed() - ): - self._current_task.get_loop().run_until_complete(self._current_task) + if self._current_task and not self._current_task.done(): + self._current_task.cancel() + for task in self._running_batches.values(): + if not task.done(): + task.cancel() + else: + self._stop = True + if ( + self._current_task + and not self._current_task.done() + and not self._current_task.get_loop().is_closed() + ): + await asyncio.wait_for(self._current_task, timeout=timeout) diff --git a/examples/dynamodb/main.py b/examples/dynamodb/main.py index 4554105..1f2ec7c 100644 --- a/examples/dynamodb/main.py +++ b/examples/dynamodb/main.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio from typing import Any import aioboto3 @@ -41,8 +42,10 @@ class GetRequestModel(BaseModel): @app.on_event("shutdown") def shutdown_event(): - get_batcher.stop() - write_batcher.stop() + async def stop_batchers(): + await asyncio.gather(get_batcher.stop(), write_batcher.stop()) + + asyncio.run(stop_batchers()) @app.post("/put") diff --git a/examples/keras/main.py b/examples/keras/main.py index 4fe1482..aae6985 100644 --- a/examples/keras/main.py +++ b/examples/keras/main.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import gc from typing import TYPE_CHECKING @@ -34,7 +35,7 @@ async def startup_event(): @app.on_event("shutdown") def shutdown_event(): - batcher.stop() + asyncio.run(batcher.stop()) @app.post("/predict") diff --git a/examples/keras_with_grpc/server.py b/examples/keras_with_grpc/server.py index 3b0a4be..75bc111 100644 --- a/examples/keras_with_grpc/server.py +++ b/examples/keras_with_grpc/server.py @@ -49,7 +49,7 @@ async def serve() -> None: async def server_graceful_shutdown(): logging.info("Starting graceful shutdown...") - predictor.batcher.stop() + await predictor.batcher.stop() await server.stop(30) _cleanup_coroutines.append(server_graceful_shutdown()) diff --git a/tests/conftest.py b/tests/conftest.py index e187a7d..0ad5193 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,11 +1,22 @@ from __future__ import annotations +import asyncio from unittest import mock import pytest from async_batcher.batcher import AsyncBatcher +@pytest.fixture(scope="session") +def event_loop(): + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + yield loop + loop.close() + + def pytest_runtest_setup(item): def _has_marker(item, marker_name: str) -> bool: return len(list(item.iter_markers(name=marker_name))) > 0 @@ -23,6 +34,16 @@ async def process_batch(self, *args, **kwargs): return await self.mock_batch_processor(*args, **kwargs) +class SlowAsyncBatcher(MockAsyncBatcher): + def __init__(self, sleep_time: float = 1, **kwargs): + super().__init__(**kwargs) + self.sleep_time = sleep_time + + async def process_batch(self, *args, **kwargs): + await asyncio.sleep(self.sleep_time) + return await super().process_batch(*args, **kwargs) + + @pytest.fixture(scope="function") def mock_async_batcher(): batcher = MockAsyncBatcher( diff --git a/tests/test_batcher.py b/tests/test_batcher.py index b91b515..1752543 100644 --- a/tests/test_batcher.py +++ b/tests/test_batcher.py @@ -1,10 +1,11 @@ from __future__ import annotations import asyncio +import sys import pytest -from tests.conftest import MockAsyncBatcher +from tests.conftest import MockAsyncBatcher, SlowAsyncBatcher class CallsMaker: @@ -76,3 +77,100 @@ async def test_process_batch_with_short_buffering_time(): assert calls_maker1.result == [i * 2 for i in range(5)] assert calls_maker2.result == [i * 2 for i in range(5, 20)] assert calls_maker3.result == [i * 2 for i in range(20, 30)] + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "concurrency, expected_execution_time", + [ + # batch 1: from 0 to 1 + # batch 2: from 1 to 2 + # batch 3: from 2 to 3 + # batch 4: from 3.2 to 4.2 (< max_batch_size, extra 0.4) + (1, 4.2), + # batch 1: from 0 to 1 + # batch 2: from 0.25 to 1.25 + # batch 3: from 1 to 2 + # batch 4: from 1.45 to 2.45 (< max_batch_size, extra 0.4) + (2, 2.45), + # batch 1: from 0 to 1 + # batch 2: from 0.25 to 1.25 + # batch 3: from 0.4 to 1.4 + # batch 4: from 1.2 to 2.2 (< max_batch_size, extra 0.4) + (3, 2.2), + # batch 1: from 0 to 1 + # batch 2: from 0.25 to 1.25 + # batch 3: from 0.4 to 1.4 + # batch 4: from 0.6 to 1.6 (< max_batch_size, extra 0.4) + (-1, 1.6), + ], +) +async def test_concurrent_process_batch(concurrency, expected_execution_time): + batcher = SlowAsyncBatcher( + sleep_time=1, + max_batch_size=10, + max_queue_time=0.2, + concurrency=concurrency, + ) + batcher.mock_batch_processor.reset_mock() + started_at = asyncio.get_event_loop().time() + calls_maker1 = CallsMaker(batcher, 0, 0, 5) + calls_maker2 = CallsMaker(batcher, 0.25, 5, 20) + calls_maker3 = CallsMaker(batcher, 0.4, 20, 30) + await asyncio.gather(calls_maker1.arun(), calls_maker2.arun(), calls_maker3.arun()) + ended_at = asyncio.get_event_loop().time() + + # we add 0.4 seconds to the expected time to account for the sleep time + # for Python 3.12, we need to add 1 second to the expected time because there + # is a slowness in some asyncio functions + if sys.version_info >= (3, 12): + safe_time = 1 + else: + safe_time = 0.4 + assert expected_execution_time < ended_at - started_at + assert ended_at - started_at < expected_execution_time + safe_time + + assert batcher.mock_batch_processor.call_count == 4 + # the first range of size 5 should be processed in a single batch + assert batcher.mock_batch_processor.mock_calls[0].kwargs["batch"] == list(range(5)) + # the second range of size 15 should be processed in 2 batches + assert batcher.mock_batch_processor.mock_calls[1].kwargs["batch"] == list(range(5, 15)) + # the second part of the second range should be processed with 5 items from the third range + assert batcher.mock_batch_processor.mock_calls[2].kwargs["batch"] == list(range(15, 25)) + # the last 5 items should be processed in a single batch + assert batcher.mock_batch_processor.mock_calls[3].kwargs["batch"] == list(range(25, 30)) + # the results should be correct regardless the number of needed batches + assert calls_maker1.result == [i * 2 for i in range(5)] + assert calls_maker2.result == [i * 2 for i in range(5, 20)] + assert calls_maker3.result == [i * 2 for i in range(20, 30)] + batcher.mock_batch_processor.reset_mock() + + +@pytest.mark.asyncio +async def test_stop_batcher(mock_async_batcher): + await asyncio.gather(*[mock_async_batcher.process(item=i) for i in range(10)]) + + assert await mock_async_batcher.is_running() + await mock_async_batcher.stop() + assert not await mock_async_batcher.is_running() + with pytest.raises(RuntimeError): + await mock_async_batcher.process(item=0) + + +@pytest.mark.asyncio +async def test_force_stop_batcher(): + batcher = SlowAsyncBatcher( + sleep_time=1, + max_batch_size=10, + max_queue_time=0.2, + concurrency=1, + ) + batcher.mock_batch_processor.reset_mock() + await asyncio.gather(*[batcher.process(item=i) for i in range(10)]) + assert await batcher.is_running() + await batcher.stop(force=True) + if sys.version_info >= (3, 11): + assert batcher._current_task.cancelled() or batcher._current_task.cancelling() + for task in batcher._running_batches.values(): + assert task.cancelled() or task.cancelling() + batcher.mock_batch_processor.reset_mock()