Skip to content

Commit

Permalink
Feat: Support concurrent batches processing (#11)
Browse files Browse the repository at this point in the history
* Feat: Support concurrent batches processing

* update the examples

* try to fix python 3.12 tests by updating safe time

* Ayncio task.cancelling is not available in python 3.10
  • Loading branch information
hussein-awala authored Feb 24, 2024
1 parent 6ee877d commit 3a6f576
Show file tree
Hide file tree
Showing 6 changed files with 234 additions and 49 deletions.
150 changes: 106 additions & 44 deletions async_batcher/batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,21 @@ def __init__(
*,
max_batch_size: int = -1,
max_queue_time: float = 0.01,
concurrency: int = 1,
):
super().__init__()
if max_batch_size is None or 0 <= max_batch_size <= 1:
raise ValueError("Valid max_batch_size value is greater than 1 or -1 for infinite")
if concurrency is None or concurrency == 0:
raise ValueError("Valid concurrency value is greater than 0 or -1 for infinite")
self.max_batch_size = max_batch_size
self.max_queue_time = max_queue_time
self.concurrency = concurrency
self._queue = asyncio.Queue()
self._current_task: asyncio.Task | None = None
self._should_stop = False
self._force_stop = False
self._running_batches: dict[int, asyncio.Task] = {}
self._concurrency_semaphore = asyncio.Semaphore(concurrency) if concurrency > 0 else None
self._stop = False
self._is_running = False

@abc.abstractmethod
Expand All @@ -59,23 +64,23 @@ async def process(self, item: T) -> S:
Returns:
S: The result of processing the item.
"""
if self._should_stop:
if self._stop:
raise RuntimeError("Batcher is stopped")
if self._current_task is None:
self._current_task = asyncio.get_running_loop().create_task(self.batch_run())
self._current_task = asyncio.get_running_loop().create_task(self.run())
logging.debug(item)
future = asyncio.get_running_loop().create_future()
await self._queue.put(self.QueueItem(item, future))
await future
return future.result()

async def _fill_batch_from_queue(self):
async def _fill_batch_from_queue(self, started_at: float | None) -> list[QueueItem]:
try:
batch = [await asyncio.wait_for(self._queue.get(), timeout=1.0)]
except asyncio.TimeoutError:
return []

started_at = asyncio.get_running_loop().time()
if started_at is None:
started_at = asyncio.get_event_loop().time()
while 1:
try:
max_wait = self.max_queue_time - (asyncio.get_running_loop().time() - started_at)
Expand All @@ -90,51 +95,108 @@ async def _fill_batch_from_queue(self):
break
return batch

async def batch_run(self):
async def _batch_run(self, task_id: int, batch: list[QueueItem]):
started_at = asyncio.get_event_loop().time()
try:
batch_items = [q_item.item for q_item in batch]
if asyncio.iscoroutinefunction(self.process_batch):
results = await self.process_batch(batch=batch_items)
else:
results = await asyncio.get_event_loop().run_in_executor(
None, self.process_batch, batch_items
)
if results is None:
results = [None] * len(batch)
if len(results) != len(batch):
raise ValueError(f"Expected to get {len(batch)} results, but got {len(results)}.")
except Exception as e:
self.logger.error("Error processing batch", exc_info=True)
for q_item in batch:
q_item.future.set_exception(e)
else:
for q_item, result in zip(batch, results, strict=True):
q_item.future.set_result(result)
elapsed_time = asyncio.get_event_loop().time() - started_at
self.logger.debug(f"Processed batch of {len(batch)} elements" f" in {elapsed_time} seconds.")
self._running_batches.pop(task_id)

async def _concurrent_batch_run(self, task_id: int, batch: list[QueueItem]):
async with self._concurrency_semaphore:
await self._batch_run(task_id, batch)

async def run(self):
"""Run the batcher asynchronously."""
self._is_running = True
while not self._should_stop or (not self._force_stop and self._queue.qsize() > 0):
batch = await self._fill_batch_from_queue()
if not batch:
continue

started_at = asyncio.get_event_loop().time()
try:
batch_items = [q_item.item for q_item in batch]
if asyncio.iscoroutinefunction(self.process_batch):
results = await self.process_batch(batch=batch_items)
else:
results = await asyncio.get_event_loop().run_in_executor(
None, self.process_batch, batch_items
task_id = 0
if self.concurrency > 0:
started_at = None
while not self._should_stop():
if started_at is None:
started_at = asyncio.get_event_loop().time()
semaphore_acquired = False
try:
# to check if the batcher should stop, we raise a timeout after 1 second
# if the semaphore is not acquired
await asyncio.wait_for(self._concurrency_semaphore.acquire(), timeout=1.0)
semaphore_acquired = True
# if the queue is empty, we need to let the batch filler create it
batch = await self._fill_batch_from_queue(
started_at=started_at if self._queue.qsize() > 0 else None
)
if results is None:
results = [None] * len(batch)
if len(results) != len(batch):
raise ValueError(f"Expected to get {len(batch)} results, but got {len(results)}.")
except Exception as e:
self.logger.error("Error processing batch", exc_info=True)
for q_item in batch:
q_item.future.set_exception(e)
else:
for q_item, result in zip(batch, results, strict=True):
q_item.future.set_result(result)
elapsed_time = asyncio.get_event_loop().time() - started_at
self.logger.debug(f"Processed batch of {len(batch)} elements" f" in {elapsed_time} seconds.")
if batch:
# create a new task to process the batch
self._running_batches[task_id] = asyncio.get_event_loop().create_task(
self._concurrent_batch_run(task_id, batch)
)
task_id += 1
started_at = None
except asyncio.TimeoutError:
pass
finally:
if semaphore_acquired:
self._concurrency_semaphore.release()
else:
while not self._should_stop():
batch = await self._fill_batch_from_queue(started_at=None)
if batch:
self._running_batches[task_id] = asyncio.get_event_loop().create_task(
self._batch_run(task_id, batch)
)
task_id += 1
self._is_running = False

def stop(self, force: bool = False):
def _should_stop(self):
return self._stop and self._queue.qsize() == 0

async def is_running(self):
"""Check if the batcher is running.
Returns:
bool: True if the batcher is running, False otherwise.
"""
return self._is_running

async def stop(self, force: bool = False, timeout: float | None = None):
"""Stop the batcher asyncio task.
Args:
force (bool, optional): Whether to force stop the batcher without waiting for processing
the remaining buffer items. Defaults to False.
the remaining buffer items. If True, it will cancel the current task and all running tasks.
Defaults to False.
timeout (float, optional): The time to wait for the batcher to stop. If None, it will wait
indefinitely. Defaults to None.
"""
if force:
self._force_stop = True
self._should_stop = True
if (
self._current_task
and not self._current_task.done()
and not self._current_task.get_loop().is_closed()
):
self._current_task.get_loop().run_until_complete(self._current_task)
if self._current_task and not self._current_task.done():
self._current_task.cancel()
for task in self._running_batches.values():
if not task.done():
task.cancel()
else:
self._stop = True
if (
self._current_task
and not self._current_task.done()
and not self._current_task.get_loop().is_closed()
):
await asyncio.wait_for(self._current_task, timeout=timeout)
7 changes: 5 additions & 2 deletions examples/dynamodb/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
from typing import Any

import aioboto3
Expand Down Expand Up @@ -41,8 +42,10 @@ class GetRequestModel(BaseModel):

@app.on_event("shutdown")
def shutdown_event():
get_batcher.stop()
write_batcher.stop()
async def stop_batchers():
await asyncio.gather(get_batcher.stop(), write_batcher.stop())

asyncio.run(stop_batchers())


@app.post("/put")
Expand Down
3 changes: 2 additions & 1 deletion examples/keras/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
import gc
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -34,7 +35,7 @@ async def startup_event():

@app.on_event("shutdown")
def shutdown_event():
batcher.stop()
asyncio.run(batcher.stop())


@app.post("/predict")
Expand Down
2 changes: 1 addition & 1 deletion examples/keras_with_grpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ async def serve() -> None:

async def server_graceful_shutdown():
logging.info("Starting graceful shutdown...")
predictor.batcher.stop()
await predictor.batcher.stop()
await server.stop(30)

_cleanup_coroutines.append(server_graceful_shutdown())
Expand Down
21 changes: 21 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
from __future__ import annotations

import asyncio
from unittest import mock

import pytest
from async_batcher.batcher import AsyncBatcher


@pytest.fixture(scope="session")
def event_loop():
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
yield loop
loop.close()


def pytest_runtest_setup(item):
def _has_marker(item, marker_name: str) -> bool:
return len(list(item.iter_markers(name=marker_name))) > 0
Expand All @@ -23,6 +34,16 @@ async def process_batch(self, *args, **kwargs):
return await self.mock_batch_processor(*args, **kwargs)


class SlowAsyncBatcher(MockAsyncBatcher):
def __init__(self, sleep_time: float = 1, **kwargs):
super().__init__(**kwargs)
self.sleep_time = sleep_time

async def process_batch(self, *args, **kwargs):
await asyncio.sleep(self.sleep_time)
return await super().process_batch(*args, **kwargs)


@pytest.fixture(scope="function")
def mock_async_batcher():
batcher = MockAsyncBatcher(
Expand Down
100 changes: 99 additions & 1 deletion tests/test_batcher.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

import asyncio
import sys

import pytest

from tests.conftest import MockAsyncBatcher
from tests.conftest import MockAsyncBatcher, SlowAsyncBatcher


class CallsMaker:
Expand Down Expand Up @@ -76,3 +77,100 @@ async def test_process_batch_with_short_buffering_time():
assert calls_maker1.result == [i * 2 for i in range(5)]
assert calls_maker2.result == [i * 2 for i in range(5, 20)]
assert calls_maker3.result == [i * 2 for i in range(20, 30)]


@pytest.mark.asyncio
@pytest.mark.parametrize(
"concurrency, expected_execution_time",
[
# batch 1: from 0 to 1
# batch 2: from 1 to 2
# batch 3: from 2 to 3
# batch 4: from 3.2 to 4.2 (< max_batch_size, extra 0.4)
(1, 4.2),
# batch 1: from 0 to 1
# batch 2: from 0.25 to 1.25
# batch 3: from 1 to 2
# batch 4: from 1.45 to 2.45 (< max_batch_size, extra 0.4)
(2, 2.45),
# batch 1: from 0 to 1
# batch 2: from 0.25 to 1.25
# batch 3: from 0.4 to 1.4
# batch 4: from 1.2 to 2.2 (< max_batch_size, extra 0.4)
(3, 2.2),
# batch 1: from 0 to 1
# batch 2: from 0.25 to 1.25
# batch 3: from 0.4 to 1.4
# batch 4: from 0.6 to 1.6 (< max_batch_size, extra 0.4)
(-1, 1.6),
],
)
async def test_concurrent_process_batch(concurrency, expected_execution_time):
batcher = SlowAsyncBatcher(
sleep_time=1,
max_batch_size=10,
max_queue_time=0.2,
concurrency=concurrency,
)
batcher.mock_batch_processor.reset_mock()
started_at = asyncio.get_event_loop().time()
calls_maker1 = CallsMaker(batcher, 0, 0, 5)
calls_maker2 = CallsMaker(batcher, 0.25, 5, 20)
calls_maker3 = CallsMaker(batcher, 0.4, 20, 30)
await asyncio.gather(calls_maker1.arun(), calls_maker2.arun(), calls_maker3.arun())
ended_at = asyncio.get_event_loop().time()

# we add 0.4 seconds to the expected time to account for the sleep time
# for Python 3.12, we need to add 1 second to the expected time because there
# is a slowness in some asyncio functions
if sys.version_info >= (3, 12):
safe_time = 1
else:
safe_time = 0.4
assert expected_execution_time < ended_at - started_at
assert ended_at - started_at < expected_execution_time + safe_time

assert batcher.mock_batch_processor.call_count == 4
# the first range of size 5 should be processed in a single batch
assert batcher.mock_batch_processor.mock_calls[0].kwargs["batch"] == list(range(5))
# the second range of size 15 should be processed in 2 batches
assert batcher.mock_batch_processor.mock_calls[1].kwargs["batch"] == list(range(5, 15))
# the second part of the second range should be processed with 5 items from the third range
assert batcher.mock_batch_processor.mock_calls[2].kwargs["batch"] == list(range(15, 25))
# the last 5 items should be processed in a single batch
assert batcher.mock_batch_processor.mock_calls[3].kwargs["batch"] == list(range(25, 30))
# the results should be correct regardless the number of needed batches
assert calls_maker1.result == [i * 2 for i in range(5)]
assert calls_maker2.result == [i * 2 for i in range(5, 20)]
assert calls_maker3.result == [i * 2 for i in range(20, 30)]
batcher.mock_batch_processor.reset_mock()


@pytest.mark.asyncio
async def test_stop_batcher(mock_async_batcher):
await asyncio.gather(*[mock_async_batcher.process(item=i) for i in range(10)])

assert await mock_async_batcher.is_running()
await mock_async_batcher.stop()
assert not await mock_async_batcher.is_running()
with pytest.raises(RuntimeError):
await mock_async_batcher.process(item=0)


@pytest.mark.asyncio
async def test_force_stop_batcher():
batcher = SlowAsyncBatcher(
sleep_time=1,
max_batch_size=10,
max_queue_time=0.2,
concurrency=1,
)
batcher.mock_batch_processor.reset_mock()
await asyncio.gather(*[batcher.process(item=i) for i in range(10)])
assert await batcher.is_running()
await batcher.stop(force=True)
if sys.version_info >= (3, 11):
assert batcher._current_task.cancelled() or batcher._current_task.cancelling()
for task in batcher._running_batches.values():
assert task.cancelled() or task.cancelling()
batcher.mock_batch_processor.reset_mock()

0 comments on commit 3a6f576

Please sign in to comment.