Skip to content

Commit

Permalink
Properly close async generators
Browse files Browse the repository at this point in the history
Change the type of functions returning async generators to
`AsyncGenerator` and properly close async generators in the hope
that this will address various log errors about generators already
running.
  • Loading branch information
rra committed Feb 6, 2025
1 parent 480af79 commit f1673ff
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 40 deletions.
54 changes: 51 additions & 3 deletions src/mobu/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,61 @@
import asyncio
import contextlib
from asyncio import Task
from collections.abc import Awaitable, Callable, Coroutine
from collections.abc import AsyncIterator, Awaitable, Callable, Coroutine
from contextlib import AbstractAsyncContextManager
from datetime import timedelta
from typing import TypeVar
from types import TracebackType
from typing import Literal, TypeVar

T = TypeVar("T")

__all__ = ["schedule_periodic", "wait_first"]
__all__ = [
"aclosing_iter",
"schedule_periodic",
"wait_first",
]


class aclosing_iter[T: AsyncIterator](AbstractAsyncContextManager): # noqa: N801
"""Automatically close async iterators that are generators.
Python supports two ways of writing an async iterator: a true async
iterator, and an async generator. Generators support additional async
context, such as yielding from inside an async context manager, and
therefore require cleanup by calling their `aclose` method once the
generator is no longer needed. This step is done automatically by the
async loop implementation when the generator is garbage-collected, but
this may happen at an arbitrary point and produces pytest warnings
saying that the `aclose` method on the generator was never called.
This class provides a variant of `contextlib.aclosing` that can be
used to close generators masquerading as iterators. Many Python libraries
implement `__aiter__` by returning a generator rather than an iterator,
which is equivalent except for this cleanup behavior. Async iterators do
not require this explicit cleanup step because they don't support async
context managers inside the iteration. Since the library is free to change
from a generator to an iterator at any time, and async iterators don't
require this cleanup and don't have `aclose` methods, the `aclose` method
should be called only if it exists.
"""

def __init__(self, thing: T) -> None:
self.thing = thing

async def __aenter__(self) -> T:
return self.thing

async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> Literal[False]:
# Only call aclose if the method is defined, which we take to mean that
# this iterator is actually a generator.
if getattr(self.thing, "aclose", None):
await self.thing.aclose() # type: ignore[attr-defined]
return False


def schedule_periodic(
Expand Down
4 changes: 2 additions & 2 deletions src/mobu/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from __future__ import annotations

import json
from collections.abc import AsyncIterator
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from datetime import timedelta
from importlib.metadata import metadata, version
Expand Down Expand Up @@ -41,7 +41,7 @@


@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
async def lifespan(app: FastAPI) -> AsyncGenerator[None]:
"""Set up and tear down the the base application."""
config = config_dependency.config
if not config.environment_url:
Expand Down
27 changes: 14 additions & 13 deletions src/mobu/services/business/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import asyncio
from abc import ABCMeta, abstractmethod
from asyncio import Queue, QueueEmpty
from collections.abc import AsyncIterable, AsyncIterator
from collections.abc import AsyncGenerator, AsyncIterable
from datetime import timedelta
from enum import Enum
from typing import Generic, TypedDict, TypeVar
Expand All @@ -14,7 +14,7 @@
from safir.datetime import current_datetime
from structlog.stdlib import BoundLogger

from ...asyncio import wait_first
from ...asyncio import aclosing_iter, wait_first
from ...events import Events
from ...models.business.base import BusinessData, BusinessOptions
from ...models.user import AuthenticatedUser
Expand Down Expand Up @@ -267,7 +267,7 @@ async def pause(self, interval: timedelta) -> bool:

async def iter_with_timeout(
self, iterable: AsyncIterable[U], timeout: timedelta
) -> AsyncIterator[U]:
) -> AsyncGenerator[U]:
"""Run an iterator with a timeout.
Returns the next element of the iterator on success and ends the
Expand Down Expand Up @@ -316,16 +316,17 @@ async def iter_next() -> U:
return await iterator.__anext__()

start = current_datetime(microseconds=True)
while True:
now = current_datetime(microseconds=True)
remaining = timeout - (now - start)
if remaining < timedelta(seconds=0):
break
pause = self._pause_no_return(timeout)
result = await wait_first(iter_next(), pause)
if result is None or self.stopping:
break
yield result
async with aclosing_iter(iterator):
while True:
now = current_datetime(microseconds=True)
remaining = timeout - (now - start)
if remaining < timedelta(seconds=0):
break
pause = self._pause_no_return(timeout)
result = await wait_first(iter_next(), pause)
if result is None or self.stopping:
break
yield result

def dump(self) -> BusinessData:
return BusinessData(
Expand Down
4 changes: 2 additions & 2 deletions src/mobu/services/business/notebookrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import json
import random
import shutil
from collections.abc import AsyncIterator, Iterator
from collections.abc import AsyncGenerator, Iterator
from contextlib import asynccontextmanager
from datetime import timedelta
from pathlib import Path
Expand Down Expand Up @@ -308,7 +308,7 @@ def read_notebook(self, notebook: Path) -> list[dict[str, Any]]:
@asynccontextmanager
async def open_session(
self, notebook_name: str | None = None
) -> AsyncIterator[JupyterLabSession]:
) -> AsyncGenerator[JupyterLabSession]:
"""Override to add the notebook name."""
if not notebook_name:
notebook_name = self._notebook.name if self._notebook else None
Expand Down
32 changes: 17 additions & 15 deletions src/mobu/services/business/nublado.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import re
from abc import ABCMeta, abstractmethod
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from collections.abc import AsyncGenerator
from contextlib import aclosing, asynccontextmanager
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from random import SystemRandom
Expand Down Expand Up @@ -268,18 +268,20 @@ async def _spawn_lab(self, span: Span) -> bool:
# Watch the progress API until the lab has spawned.
log_messages = []
progress = self._client.watch_spawn_progress()
try:
async for message in self.iter_with_timeout(progress, timeout):
log_messages.append(ProgressLogMessage(message.message))
if message.ready:
return True
except:
log = "\n".join([str(m) for m in log_messages])
sentry_sdk.get_current_scope().add_attachment(
filename="spawn_log.txt",
bytes=self.remove_ansi_escapes(log).encode(),
)
raise
progress_generator = self.iter_with_timeout(progress, timeout)
async with aclosing(progress_generator):
try:
async for message in progress_generator:
log_messages.append(ProgressLogMessage(message.message))
if message.ready:
return True
except:
log = "\n".join([str(m) for m in log_messages])
sentry_sdk.get_current_scope().add_attachment(
filename="spawn_log.txt",
bytes=self.remove_ansi_escapes(log).encode(),
)
raise

# We only fall through if the spawn failed, timed out, or if we're
# stopping the business.
Expand All @@ -305,7 +307,7 @@ async def lab_login(self) -> None:
@asynccontextmanager
async def open_session(
self, notebook: str | None = None
) -> AsyncIterator[JupyterLabSession]:
) -> AsyncGenerator[JupyterLabSession]:
self.logger.info("Creating lab session")
opts = {"max_websocket_size": self.options.max_websocket_message_size}
create_session_cm = capturing_start_span(op="create_session")
Expand Down
10 changes: 5 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from collections.abc import AsyncIterator, Generator, Iterator
from collections.abc import AsyncGenerator, Generator, Iterator
from contextlib import asynccontextmanager
from pathlib import Path
from tempfile import TemporaryDirectory
Expand Down Expand Up @@ -134,7 +134,7 @@ def _enable_github_refresh_app(


@pytest_asyncio.fixture
async def app(jupyter: MockJupyter) -> AsyncIterator[FastAPI]:
async def app(jupyter: MockJupyter) -> AsyncGenerator[FastAPI]:
"""Return a configured test application.
Wraps the application in a lifespan manager so that startup and shutdown
Expand Down Expand Up @@ -184,7 +184,7 @@ async def client(
app: FastAPI,
test_user: User,
jupyter: MockJupyter,
) -> AsyncIterator[AsyncClient]:
) -> AsyncGenerator[AsyncClient]:
"""Return an ``httpx.AsyncClient`` configured to talk to the test app."""
async with AsyncClient(
transport=ASGITransport(app=app),
Expand All @@ -198,7 +198,7 @@ async def client(


@pytest_asyncio.fixture
async def anon_client(app: FastAPI) -> AsyncIterator[AsyncClient]:
async def anon_client(app: FastAPI) -> AsyncGenerator[AsyncClient]:
"""Return an anonymous ``httpx.AsyncClient`` configured to talk to the test
app.
"""
Expand Down Expand Up @@ -228,7 +228,7 @@ async def mock_connect(
extra_headers: dict[str, str],
max_size: int | None,
open_timeout: int,
) -> AsyncIterator[MockJupyterWebSocket]:
) -> AsyncGenerator[MockJupyterWebSocket]:
yield mock_jupyter_websocket(url, extra_headers, jupyter_mock)

with patch("rubin.nublado.client.nubladoclient.websocket_connect") as mock:
Expand Down

0 comments on commit f1673ff

Please sign in to comment.