Skip to content

Commit

Permalink
🚧 WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
Dmytro Parfeniuk committed Jul 15, 2024
1 parent 2959efc commit cc8e9a4
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 39 deletions.
45 changes: 27 additions & 18 deletions src/guidellm/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,18 @@ def run(self) -> TextGenerationBenchmark:
def _run_sync(self) -> TextGenerationBenchmark:
benchmark = TextGenerationBenchmark(mode=self._load_gen_mode.value, rate=None)
start_time = time.time()
counter = 0
requests_counter = 0

for task in self._task_iterator():
benchmark.request_started()
res = task.run_sync()
benchmark.request_completed(res)
counter += 1
requests_counter += 1

if (self._max_requests is not None and counter >= self._max_requests) or (
if (
self._max_requests is not None
and requests_counter >= self._max_requests
) or (
self._max_duration is not None
and time.time() - start_time >= self._max_duration
):
Expand All @@ -65,16 +68,17 @@ def _run_sync(self) -> TextGenerationBenchmark:
return benchmark

async def _run_async(self) -> TextGenerationBenchmark:
benchmark = TextGenerationBenchmark(
benchmark: TextGenerationBenchmark = TextGenerationBenchmark(
mode=self._load_gen_mode.value, rate=self._load_gen_rate
)
if not self._load_gen_rate:
raise ValueError("Invalid empty value for self._load_gen_rate")
load_gen = LoadGenerator(self._load_gen_mode, self._load_gen_rate)

tasks: List[asyncio.tasks.Task] = []
start_time = time.time()
counter = 0
tasks: List[Task] = []
asyncio_tasks: List[asyncio.tasks.Task] = []
start_time: float = time.time()
requests_counter = 0

try:
for _task, task_start_time in zip(self._task_iterator(), load_gen.times()):
Expand All @@ -83,13 +87,15 @@ async def _run_async(self) -> TextGenerationBenchmark:
if pending_time > 0:
await asyncio.sleep(pending_time)

tasks.append(
asyncio_tasks.append(
asyncio.create_task(self._run_task_async(_task, benchmark))
)
counter += 1
tasks.append(_task)
requests_counter += 1

if (
self._max_requests is not None and counter >= self._max_requests
self._max_requests is not None
and requests_counter >= self._max_requests
) or (
self._max_duration is not None
and time.time() - start_time >= self._max_duration
Expand All @@ -102,23 +108,26 @@ async def _run_async(self) -> TextGenerationBenchmark:
if pending_duration > 0:
await asyncio.sleep(pending_duration)

breakpoint() # TODO: remove
raise asyncio.CancelledError()

await asyncio.gather(*tasks)
except asyncio.CancelledError:
# Cancel all pending asyncio.Tasks instances
for task in tasks:
for task in asyncio_tasks:
if not task.done():
task.cancel()

return benchmark
# Return not fully filled benchmark on error
return benchmark

else:
# Ensure all the tasks are done
await asyncio.gather(*asyncio_tasks)
return benchmark

async def _run_task_async(self, task: Task, result_set: TextGenerationBenchmark):
result_set.request_started()
async def _run_task_async(self, task: Task, benchmark: TextGenerationBenchmark):
benchmark.request_started()
res = await task.run_async()
breakpoint() # TODO: remove
result_set.request_completed(res)
benchmark.request_completed(res)

def _task_iterator(self) -> Iterable[Task]:
for request in self._request_generator:
Expand Down
45 changes: 25 additions & 20 deletions src/guidellm/scheduler/task.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import functools
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Coroutine, Dict, Optional

from loguru import logger

Expand Down Expand Up @@ -29,13 +29,23 @@ def __init__(
self._func: Callable[..., Any] = func
self._params: Dict[str, Any] = params or {}
self._err_container: Optional[Callable] = err_container
self._cancel_event: asyncio.Event = asyncio.Event()

logger.info(
f"Task created with function: {self._func.__name__} and "
f"params: {self._params}"
)

async def _check_future_is_done(self, future: asyncio.Future):
"""
Check if the coroutine is done and then release.
"""

while True:
if future.done():
return
else:
await asyncio.sleep(0)

async def run_async(self) -> Any:
"""
Run the task asynchronously.
Expand All @@ -44,31 +54,33 @@ async def run_async(self) -> Any:
:rtype: Any
"""
logger.info(f"Running task asynchronously with function: {self._func.__name__}")

try:
loop = asyncio.get_running_loop()

result = await asyncio.gather(
loop.run_in_executor(
None,
functools.partial(
self._func,
**self._params,
),
executable: asyncio.futures.Future = loop.run_in_executor(
None,
functools.partial(
self._func,
**self._params,
),
self._cancel_event.wait(),
)
result = await asyncio.gather(
executable,
self._check_future_is_done(executable),
return_exceptions=True,
)

if isinstance(result[0], Exception):
raise result[0]

if self._cancel_event.is_set() is True:
raise asyncio.CancelledError("Task was cancelled")

logger.info(f"Task completed with result: {result[0]}")

return result[0]

except asyncio.CancelledError as cancel_err:
logger.warning("Task was cancelled")

return (
cancel_err
if not self._err_container
Expand Down Expand Up @@ -101,10 +113,3 @@ def run_sync(self) -> Any:
if not self._err_container
else self._err_container(**self._params, error=err)
)

def cancel(self) -> None:
"""
Cancel the task.
"""
logger.info("Cancelling task")
self._cancel_event.set()
2 changes: 1 addition & 1 deletion tests/unit/executor/test_report_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_executor_openai_single_report_generation_constant_mode(
profile_mode=profile_generation_mode,
profile_args=profile_generator_kwargs,
max_requests=1,
max_duration=120,
max_duration=None,
)

report: TextGenerationBenchmarkReport = executor.run()
Expand Down

0 comments on commit cc8e9a4

Please sign in to comment.