diff --git a/src/guidellm/scheduler/scheduler.py b/src/guidellm/scheduler/scheduler.py index bf08255..98146fe 100644 --- a/src/guidellm/scheduler/scheduler.py +++ b/src/guidellm/scheduler/scheduler.py @@ -48,15 +48,18 @@ def run(self) -> TextGenerationBenchmark: def _run_sync(self) -> TextGenerationBenchmark: benchmark = TextGenerationBenchmark(mode=self._load_gen_mode.value, rate=None) start_time = time.time() - counter = 0 + requests_counter = 0 for task in self._task_iterator(): benchmark.request_started() res = task.run_sync() benchmark.request_completed(res) - counter += 1 + requests_counter += 1 - if (self._max_requests is not None and counter >= self._max_requests) or ( + if ( + self._max_requests is not None + and requests_counter >= self._max_requests + ) or ( self._max_duration is not None and time.time() - start_time >= self._max_duration ): @@ -65,16 +68,17 @@ def _run_sync(self) -> TextGenerationBenchmark: return benchmark async def _run_async(self) -> TextGenerationBenchmark: - benchmark = TextGenerationBenchmark( + benchmark: TextGenerationBenchmark = TextGenerationBenchmark( mode=self._load_gen_mode.value, rate=self._load_gen_rate ) if not self._load_gen_rate: raise ValueError("Invalid empty value for self._load_gen_rate") load_gen = LoadGenerator(self._load_gen_mode, self._load_gen_rate) - tasks: List[asyncio.tasks.Task] = [] - start_time = time.time() - counter = 0 + tasks: List[Task] = [] + asyncio_tasks: List[asyncio.tasks.Task] = [] + start_time: float = time.time() + requests_counter = 0 try: for _task, task_start_time in zip(self._task_iterator(), load_gen.times()): @@ -83,13 +87,15 @@ async def _run_async(self) -> TextGenerationBenchmark: if pending_time > 0: await asyncio.sleep(pending_time) - tasks.append( + asyncio_tasks.append( asyncio.create_task(self._run_task_async(_task, benchmark)) ) - counter += 1 + tasks.append(_task) + requests_counter += 1 if ( - self._max_requests is not None and counter >= self._max_requests + self._max_requests is not None + and requests_counter >= self._max_requests ) or ( self._max_duration is not None and time.time() - start_time >= self._max_duration @@ -102,23 +108,26 @@ async def _run_async(self) -> TextGenerationBenchmark: if pending_duration > 0: await asyncio.sleep(pending_duration) - breakpoint() # TODO: remove raise asyncio.CancelledError() - await asyncio.gather(*tasks) except asyncio.CancelledError: # Cancel all pending asyncio.Tasks instances - for task in tasks: + for task in asyncio_tasks: if not task.done(): task.cancel() - return benchmark + # Return not fully filled benchmark on error + return benchmark + + else: + # Ensure all the tasks are done + await asyncio.gather(*asyncio_tasks) + return benchmark - async def _run_task_async(self, task: Task, result_set: TextGenerationBenchmark): - result_set.request_started() + async def _run_task_async(self, task: Task, benchmark: TextGenerationBenchmark): + benchmark.request_started() res = await task.run_async() - breakpoint() # TODO: remove - result_set.request_completed(res) + benchmark.request_completed(res) def _task_iterator(self) -> Iterable[Task]: for request in self._request_generator: diff --git a/src/guidellm/scheduler/task.py b/src/guidellm/scheduler/task.py index 88d1e6d..e27454f 100644 --- a/src/guidellm/scheduler/task.py +++ b/src/guidellm/scheduler/task.py @@ -1,6 +1,6 @@ import asyncio import functools -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Coroutine, Dict, Optional from loguru import logger @@ -29,13 +29,23 @@ def __init__( self._func: Callable[..., Any] = func self._params: Dict[str, Any] = params or {} self._err_container: Optional[Callable] = err_container - self._cancel_event: asyncio.Event = asyncio.Event() logger.info( f"Task created with function: {self._func.__name__} and " f"params: {self._params}" ) + async def _check_future_is_done(self, future: asyncio.Future): + """ + Check if the coroutine is done and then release. + """ + + while True: + if future.done(): + return + else: + await asyncio.sleep(0) + async def run_async(self) -> Any: """ Run the task asynchronously. @@ -44,31 +54,33 @@ async def run_async(self) -> Any: :rtype: Any """ logger.info(f"Running task asynchronously with function: {self._func.__name__}") + try: loop = asyncio.get_running_loop() - result = await asyncio.gather( - loop.run_in_executor( - None, - functools.partial( - self._func, - **self._params, - ), + executable: asyncio.futures.Future = loop.run_in_executor( + None, + functools.partial( + self._func, + **self._params, ), - self._cancel_event.wait(), + ) + result = await asyncio.gather( + executable, + self._check_future_is_done(executable), return_exceptions=True, ) + if isinstance(result[0], Exception): raise result[0] - if self._cancel_event.is_set() is True: - raise asyncio.CancelledError("Task was cancelled") - logger.info(f"Task completed with result: {result[0]}") return result[0] + except asyncio.CancelledError as cancel_err: logger.warning("Task was cancelled") + return ( cancel_err if not self._err_container @@ -101,10 +113,3 @@ def run_sync(self) -> Any: if not self._err_container else self._err_container(**self._params, error=err) ) - - def cancel(self) -> None: - """ - Cancel the task. - """ - logger.info("Cancelling task") - self._cancel_event.set() diff --git a/tests/unit/executor/test_report_generation.py b/tests/unit/executor/test_report_generation.py index 7e0f1c7..4d29104 100644 --- a/tests/unit/executor/test_report_generation.py +++ b/tests/unit/executor/test_report_generation.py @@ -54,7 +54,7 @@ def test_executor_openai_single_report_generation_constant_mode( profile_mode=profile_generation_mode, profile_args=profile_generator_kwargs, max_requests=1, - max_duration=120, + max_duration=None, ) report: TextGenerationBenchmarkReport = executor.run()