From 6729a145b80f011a8cb0824302c1279855c7c9ca Mon Sep 17 00:00:00 2001 From: Dan Fuchs Date: Thu, 13 Feb 2025 13:10:46 -0600 Subject: [PATCH] DM-48941: Only check out notebook repos once Add a reference-counting process-global notebook repo cache. Use this in the NotebookRunner business to only clone notebook repos once per process, not once per monkey. --- .../20250217_173118_danfuchs_DM_48941.md | 3 + src/mobu/dependencies/context.py | 5 + src/mobu/dependencies/github.py | 1 + src/mobu/factory.py | 13 +- src/mobu/handlers/github_refresh_app.py | 5 +- src/mobu/models/repo.py | 13 ++ src/mobu/services/business/notebookrunner.py | 54 +++--- src/mobu/services/flock.py | 6 + src/mobu/services/github_ci/ci_manager.py | 4 + .../services/github_ci/ci_notebook_job.py | 4 + src/mobu/services/manager.py | 6 + src/mobu/services/monkey.py | 6 + src/mobu/services/repo.py | 169 ++++++++++++++++++ src/mobu/services/solitary.py | 6 + tests/business/notebookrunner_test.py | 34 ++-- tests/services/ci_manager_test.py | 3 + tests/services/repo_cache_test.py | 128 +++++++++++++ 17 files changed, 412 insertions(+), 48 deletions(-) create mode 100644 changelog.d/20250217_173118_danfuchs_DM_48941.md create mode 100644 src/mobu/services/repo.py create mode 100644 tests/services/repo_cache_test.py diff --git a/changelog.d/20250217_173118_danfuchs_DM_48941.md b/changelog.d/20250217_173118_danfuchs_DM_48941.md new file mode 100644 index 00000000..37271448 --- /dev/null +++ b/changelog.d/20250217_173118_danfuchs_DM_48941.md @@ -0,0 +1,3 @@ +### Bug fixes + +- Notebook repos are only cloned once per process (and once per refresh request), instead of once per monkey. This should speed up how fast NotebookRunner flocks start, especially in load testing usecases. diff --git a/src/mobu/dependencies/context.py b/src/mobu/dependencies/context.py index 06bd0ebc..63c403df 100644 --- a/src/mobu/dependencies/context.py +++ b/src/mobu/dependencies/context.py @@ -19,6 +19,7 @@ from ..events import Events from ..factory import Factory, ProcessContext from ..services.manager import FlockManager +from ..services.repo import RepoManager __all__ = [ "ContextDependency", @@ -41,6 +42,9 @@ class RequestContext: manager: FlockManager """Global singleton flock manager.""" + repo_manager: RepoManager + """Global singleton git repo manager.""" + factory: Factory """Component factory.""" @@ -80,6 +84,7 @@ async def __call__( request=request, logger=logger, manager=self._process_context.manager, + repo_manager=self._process_context.repo_manager, factory=Factory(self._process_context, logger), ) diff --git a/src/mobu/dependencies/github.py b/src/mobu/dependencies/github.py index 40ef2161..8e1f38b6 100644 --- a/src/mobu/dependencies/github.py +++ b/src/mobu/dependencies/github.py @@ -41,6 +41,7 @@ def initialize( scopes=scopes, http_client=base_context.process_context.http_client, events=base_context.process_context.events, + repo_manager=base_context.process_context.repo_manager, gafaelfawr_storage=base_context.process_context.gafaelfawr, logger=base_context.process_context.logger, ) diff --git a/src/mobu/factory.py b/src/mobu/factory.py index 0e13a686..027ea59a 100644 --- a/src/mobu/factory.py +++ b/src/mobu/factory.py @@ -11,6 +11,7 @@ from .events import Events from .models.solitary import SolitaryConfig from .services.manager import FlockManager +from .services.repo import RepoManager from .services.solitary import Solitary from .storage.gafaelfawr import GafaelfawrStorage @@ -38,17 +39,25 @@ class ProcessContext: Manager for all running flocks. events Object with attributes for all metrics event publishers. + repo_manager + For efficiently cloning git repos. """ - def __init__(self, http_client: AsyncClient, events: Events) -> None: + def __init__( + self, + http_client: AsyncClient, + events: Events, + ) -> None: self.http_client = http_client self.logger = structlog.get_logger("mobu") self.gafaelfawr = GafaelfawrStorage(self.http_client, self.logger) self.events = events + self.repo_manager = RepoManager(self.logger) self.manager = FlockManager( gafaelfawr_storage=self.gafaelfawr, http_client=self.http_client, logger=self.logger, + repo_manager=self.repo_manager, events=self.events, ) @@ -58,6 +67,7 @@ async def aclose(self) -> None: Called before shutdown to free resources. """ await self.manager.aclose() + self.repo_manager.close() class Factory: @@ -114,6 +124,7 @@ def create_solitary(self, solitary_config: SolitaryConfig) -> Solitary: ), http_client=self._context.http_client, events=self._context.events, + repo_manager=self._context.repo_manager, logger=self._logger, ) diff --git a/src/mobu/handlers/github_refresh_app.py b/src/mobu/handlers/github_refresh_app.py index fbcba713..e9263e08 100644 --- a/src/mobu/handlers/github_refresh_app.py +++ b/src/mobu/handlers/github_refresh_app.py @@ -71,10 +71,7 @@ async def post_webhook( context.logger.debug("Received GitHub webhook", payload=event.data) # Give GitHub some time to reach internal consistency. await asyncio.sleep(GITHUB_WEBHOOK_WAIT_SECONDS) - await gidgethub_router.dispatch( - event=event, - context=context, - ) + await gidgethub_router.dispatch(event=event, context=context) @gidgethub_router.register("push") diff --git a/src/mobu/models/repo.py b/src/mobu/models/repo.py index e261d4ae..c2fdda32 100644 --- a/src/mobu/models/repo.py +++ b/src/mobu/models/repo.py @@ -2,10 +2,14 @@ from __future__ import annotations +from dataclasses import dataclass from pathlib import Path +from tempfile import TemporaryDirectory from pydantic import BaseModel, ConfigDict, Field +__all__ = ["ClonedRepoInfo", "RepoConfig"] + class RepoConfig(BaseModel): """In-repo configuration for mobu behavior. @@ -25,3 +29,12 @@ class RepoConfig(BaseModel): ), examples=["some-dir", "some-dir/some-other-dir"], ) + + +@dataclass(frozen=True) +class ClonedRepoInfo: + """Information about a cloned git repo.""" + + dir: TemporaryDirectory + path: Path + hash: str diff --git a/src/mobu/services/business/notebookrunner.py b/src/mobu/services/business/notebookrunner.py index 3b8e77b6..a2cd5d9d 100644 --- a/src/mobu/services/business/notebookrunner.py +++ b/src/mobu/services/business/notebookrunner.py @@ -9,12 +9,10 @@ import contextlib import json import random -import shutil from collections.abc import AsyncGenerator, Iterator from contextlib import asynccontextmanager from datetime import timedelta from pathlib import Path -from tempfile import TemporaryDirectory from typing import Any, override import sentry_sdk @@ -46,7 +44,7 @@ from ...models.user import AuthenticatedUser from ...sentry import capturing_start_span, start_transaction from ...services.business.base import CommonEventAttrs -from ...storage.git import Git +from ...services.repo import RepoManager from .nublado import NubladoBusiness __all__ = ["NotebookRunner"] @@ -83,6 +81,7 @@ def __init__( *, options: NotebookRunnerOptions | ListNotebookRunnerOptions, user: AuthenticatedUser, + repo_manager: RepoManager, events: Events, logger: BoundLogger, flock: str | None, @@ -97,13 +96,13 @@ def __init__( self._config = config_dependency.config self._notebook: Path | None = None self._notebook_paths: list[Path] | None = None - self._repo_dir: Path | None = None + self._repo_path: Path | None = None self._repo_hash: str | None = None self._exclude_paths: set[Path] = set() self._running_code: str | None = None - self._git = Git(logger=logger) self._max_executions: int | None = None self._notebooks_to_run: list[Path] | None = None + self._repo_manager = repo_manager match options: case NotebookRunnerOptions(max_executions=max_executions): @@ -117,22 +116,30 @@ async def startup(self) -> None: await super().startup() async def cleanup(self) -> None: - shutil.rmtree(str(self._repo_dir)) - self._repo_dir = None + if self._repo_hash is not None: + await self._repo_manager.invalidate( + url=self.options.repo_url, + ref=self.options.repo_ref, + repo_hash=self._repo_hash, + ) + self._repo_path = None + self._repo_hash = None self._notebook_filter_results = None async def initialize(self) -> None: """Prepare to run the business. - * Check out the repository + * Get notebook repo files from the repo manager * Parse the in-repo config * Filter the notebooks """ - if self._repo_dir is None: - self._repo_dir = Path(TemporaryDirectory(delete=False).name) - await self.clone_repo() + info = await self._repo_manager.clone( + url=self.options.repo_url, ref=self.options.repo_ref + ) + self._repo_path = info.path + self._repo_hash = info.hash - repo_config_path = self._repo_dir / GITHUB_REPO_CONFIG_PATH + repo_config_path = self._repo_path / GITHUB_REPO_CONFIG_PATH set_context( "repo_info", { @@ -155,7 +162,7 @@ async def initialize(self) -> None: repo_config = RepoConfig() exclude_dirs = repo_config.exclude_dirs - self._exclude_paths = {self._repo_dir / path for path in exclude_dirs} + self._exclude_paths = {self._repo_path / path for path in exclude_dirs} self._notebooks = self.find_notebooks() set_context( "notebook_filter_info", self._notebooks.model_dump(mode="json") @@ -168,20 +175,11 @@ async def shutdown(self) -> None: await super().shutdown() async def refresh(self) -> None: - self.logger.info("Recloning notebooks and forcing new execution") + self.logger.info("Getting new notebooks and forcing new execution") await self.cleanup() await self.initialize() self.refreshing = False - async def clone_repo(self) -> None: - url = self.options.repo_url - ref = self.options.repo_ref - with capturing_start_span(op="clone_repo"): - self._git.repo = self._repo_dir - await self._git.clone(url, str(self._repo_dir)) - await self._git.checkout(ref) - self._repo_hash = await self._git.repo_hash() - def is_excluded(self, notebook: Path) -> bool: # A notebook is excluded if any of its parent directories are excluded return bool(set(notebook.parents) & self._exclude_paths) @@ -207,12 +205,12 @@ def missing_services(self, notebook: Path) -> bool: def find_notebooks(self) -> NotebookFilterResults: with capturing_start_span(op="find_notebooks"): - if self._repo_dir is None: + if self._repo_path is None: raise NotebookRepositoryError( "Repository directory must be set", self.user.username ) - all_notebooks = set(self._repo_dir.glob("**/*.ipynb")) + all_notebooks = set(self._repo_path.glob("**/*.ipynb")) if not all_notebooks: msg = "No notebooks found in {self._repo_dir}" raise NotebookRepositoryError(msg, self.user.username) @@ -227,14 +225,14 @@ def find_notebooks(self) -> NotebookFilterResults: if self._notebooks_to_run: requested = { - self._repo_dir / notebook + self._repo_path / notebook for notebook in self._notebooks_to_run } not_found = requested - filter_results.all if not_found: msg = ( "Requested notebooks do not exist in" - f" {self._repo_dir}: {not_found}" + f" {self._repo_path}: {not_found}" ) raise NotebookRepositoryError(msg, self.user.username) filter_results.excluded_by_requested = ( @@ -348,7 +346,7 @@ async def execute_notebook( ) -> None: self._notebook = self.next_notebook() relative_notebook = str( - self._notebook.relative_to(self._repo_dir or "/") + self._notebook.relative_to(self._repo_path or "/") ) iteration = f"{count + 1}/{num_executions}" msg = f"Notebook {self._notebook.name} iteration {iteration}" diff --git a/src/mobu/services/flock.py b/src/mobu/services/flock.py index 75b8ecff..35099ee5 100644 --- a/src/mobu/services/flock.py +++ b/src/mobu/services/flock.py @@ -18,6 +18,7 @@ ) from ..models.flock import FlockConfig, FlockData, FlockSummary from ..models.user import AuthenticatedUser, User, UserSpec +from ..services.repo import RepoManager from ..storage.gafaelfawr import GafaelfawrStorage from .monkey import Monkey @@ -39,6 +40,8 @@ class Flock: Shared HTTP client. events Event publishers. + repo_manager + For efficiently cloning git repos. logger Global logger. """ @@ -51,6 +54,7 @@ def __init__( gafaelfawr_storage: GafaelfawrStorage, http_client: AsyncClient, events: Events, + repo_manager: RepoManager, logger: BoundLogger, ) -> None: self.name = flock_config.name @@ -59,6 +63,7 @@ def __init__( self._gafaelfawr = gafaelfawr_storage self._http_client = http_client self._events = events + self._repo_manager = repo_manager self._logger = logger.bind(flock=self.name) self._monkeys: dict[str, Monkey] = {} self._start_time: datetime | None = None @@ -166,6 +171,7 @@ def _create_monkey(self, user: AuthenticatedUser) -> Monkey: user=user, http_client=self._http_client, events=self._events, + repo_manager=self._repo_manager, logger=self._logger, ) diff --git a/src/mobu/services/github_ci/ci_manager.py b/src/mobu/services/github_ci/ci_manager.py index 9e983b16..04b9f5ef 100644 --- a/src/mobu/services/github_ci/ci_manager.py +++ b/src/mobu/services/github_ci/ci_manager.py @@ -16,6 +16,7 @@ from ...events import Events from ...models.ci_manager import CiManagerSummary, CiWorkerSummary from ...models.user import User +from ...services.repo import RepoManager from ...storage.gafaelfawr import GafaelfawrStorage from ...storage.github import GitHubStorage from .ci_notebook_job import CiNotebookJob @@ -79,6 +80,7 @@ def __init__( users: list[User], http_client: AsyncClient, events: Events, + repo_manager: RepoManager, gafaelfawr_storage: GafaelfawrStorage, logger: BoundLogger, ) -> None: @@ -88,6 +90,7 @@ def __init__( self._gafaelfawr = gafaelfawr_storage self._http_client = http_client self._events = events + self._repo_manager = repo_manager self._logger = logger.bind(ci_manager=True) self._scheduler: Scheduler = Scheduler() self._queue: Queue[QueueItem] = Queue() @@ -257,6 +260,7 @@ async def enqueue( check_run=check_run, http_client=self._http_client, events=self._events, + repo_manager=self._repo_manager, logger=self._logger, gafaelfawr_storage=self._gafaelfawr, ) diff --git a/src/mobu/services/github_ci/ci_notebook_job.py b/src/mobu/services/github_ci/ci_notebook_job.py index 53e8e805..d1ed754f 100644 --- a/src/mobu/services/github_ci/ci_notebook_job.py +++ b/src/mobu/services/github_ci/ci_notebook_job.py @@ -13,6 +13,7 @@ from ...models.ci_manager import CiJobSummary from ...models.solitary import SolitaryConfig from ...models.user import User +from ...services.repo import RepoManager from ...services.solitary import Solitary from ...storage.gafaelfawr import GafaelfawrStorage from ...storage.github import CheckRun, GitHubStorage @@ -44,6 +45,7 @@ def __init__( check_run: CheckRun, http_client: AsyncClient, events: Events, + repo_manager: RepoManager, gafaelfawr_storage: GafaelfawrStorage, logger: BoundLogger, ) -> None: @@ -51,6 +53,7 @@ def __init__( self.check_run = check_run self._http_client = http_client self._events = events + self._repo_manager = repo_manager self._gafaelfawr = gafaelfawr_storage self._logger = logger.bind(ci_job_type="NotebookJob") self._notebooks: list[Path] = [] @@ -101,6 +104,7 @@ async def run(self, user: User, scopes: list[str]) -> None: gafaelfawr_storage=self._gafaelfawr, http_client=self._http_client, events=self._events, + repo_manager=self._repo_manager, logger=self._logger, ) diff --git a/src/mobu/services/manager.py b/src/mobu/services/manager.py index db46ec36..bd1fa970 100644 --- a/src/mobu/services/manager.py +++ b/src/mobu/services/manager.py @@ -12,6 +12,7 @@ from ..events import Events from ..exceptions import FlockNotFoundError from ..models.flock import FlockConfig, FlockSummary +from ..services.repo import RepoManager from ..storage.gafaelfawr import GafaelfawrStorage from .flock import Flock @@ -33,6 +34,8 @@ class FlockManager: Shared HTTP client. events Event publishers. + repo_manager + For efficiently cloning git repos. logger Global logger to use for process-wide (not monkey) logging. """ @@ -43,12 +46,14 @@ def __init__( gafaelfawr_storage: GafaelfawrStorage, http_client: AsyncClient, events: Events, + repo_manager: RepoManager, logger: BoundLogger, ) -> None: self._config = config_dependency.config self._gafaelfawr = gafaelfawr_storage self._http_client = http_client self._events = events + self._repo_manager = repo_manager self._logger = logger self._flocks: dict[str, Flock] = {} self._scheduler = Scheduler(limit=None, pending_limit=0) @@ -87,6 +92,7 @@ async def start_flock(self, flock_config: FlockConfig) -> Flock: gafaelfawr_storage=self._gafaelfawr, http_client=self._http_client, events=self._events, + repo_manager=self._repo_manager, logger=self._logger, ) if flock.name in self._flocks: diff --git a/src/mobu/services/monkey.py b/src/mobu/services/monkey.py index fa7d8928..12a7a608 100644 --- a/src/mobu/services/monkey.py +++ b/src/mobu/services/monkey.py @@ -25,6 +25,7 @@ from ..models.business.tapquerysetrunner import TAPQuerySetRunnerConfig from ..models.monkey import MonkeyData, MonkeyState from ..models.user import AuthenticatedUser +from ..services.repo import RepoManager from .business.base import Business from .business.empty import EmptyLoop from .business.gitlfs import GitLFSBusiness @@ -56,6 +57,8 @@ class Monkey: Shared HTTP client. events Event publishers. + repo_manager + For efficiently cloning git repos. logger Global logger. """ @@ -69,6 +72,7 @@ def __init__( user: AuthenticatedUser, http_client: AsyncClient, events: Events, + repo_manager: RepoManager, logger: BoundLogger, ) -> None: self._config = config_dependency.config @@ -78,6 +82,7 @@ def __init__( self._log_level = business_config.options.log_level self._http_client = http_client self._events = events + self._repo_manager = repo_manager self._user = user self._state = MonkeyState.IDLE @@ -125,6 +130,7 @@ def __init__( options=business_config.options, user=user, events=self._events, + repo_manager=self._repo_manager, logger=self._logger, flock=self._flock, ) diff --git a/src/mobu/services/repo.py b/src/mobu/services/repo.py new file mode 100644 index 00000000..d80eeab6 --- /dev/null +++ b/src/mobu/services/repo.py @@ -0,0 +1,169 @@ +"""Helpers for cloning and filtering notebook repos.""" + +import asyncio +from dataclasses import dataclass +from pathlib import Path +from tempfile import TemporaryDirectory + +from structlog.stdlib import BoundLogger + +from ..models.repo import ClonedRepoInfo +from ..sentry import capturing_start_span +from ..storage.git import Git + +__all__ = ["RepoManager"] + + +@dataclass(frozen=True) +class _Key: + """Information to hash a repo clone.""" + + url: str + ref: str + + +@dataclass(frozen=True) +class _Reference: + """Information to hash a repo clone.""" + + url: str + ref: str + hash: str + + +@dataclass +class _ReferenceCount: + """A count and a directory to remove when the count reaches 0.""" + + count: int + dir: TemporaryDirectory + + +class RepoManager: + """A reference-counting caching repo cloner. + + Only the first call to ``clone`` for a given repo url and ref will clone + the repo. Subsequent calls will return the location and hash of the + already-cloned repo. + + A call to ``invalidate`` will make it so that the next call to ``clone`` + will re-clone the repo to a different path. + + A call to ``clone`` also increases a reference counter for the url + ref + + hash combo of the cloned repo. A call to ``invalidate`` for that combo + decreases the counter. ``Invalidate`` will only delete the files from the + cloned repo if the reference count drops to 0. + + Parameters + ---------- + logger + A logger + """ + + def __init__(self, logger: BoundLogger, *, testing: bool = False) -> None: + self._dir = TemporaryDirectory(delete=False, prefix="mobu-notebooks-") + self._cache: dict[_Key, ClonedRepoInfo] = {} + self._lock = asyncio.Lock() + self._logger = logger + self._references: dict[_Reference, _ReferenceCount] = {} + self._testing = testing + + # This is just for testing + self._cloned: list[_Key] = [] + + async def clone(self, url: str, ref: str) -> ClonedRepoInfo: + """Clone a git repo or return cached info by url + ref. + + Increase the reference count for the url + ref + hash combo. + + Parameters + ---------- + url + The URL of the repo to clone + ref + The git ref to checkout after the repo is cloned + """ + logger = self._logger.bind(url=url, ref=ref) + key = _Key(url=url, ref=ref) + + async with self._lock: + # If the notebook repo has already been cloned, return the info + if info := self._cache.get(key): + logger.info("Notebook repo cached") + reference = _Reference(url=url, ref=ref, hash=info.hash) + count = self._references[reference] + count.count += 1 + return info + + # If not, clone the repo + logger.info("Cloning notebook repo") + repo_dir = TemporaryDirectory(delete=False, dir=self._dir.name) + with capturing_start_span(op="clone_repo"): + git = Git(logger=self._logger) + git.repo = Path(repo_dir.name) + await git.clone(url, repo_dir.name) + await git.checkout(ref) + repo_hash = await git.repo_hash() + + # If we're in testing mode, record that we actually did a clone + if self._testing: + self._cloned.append(key) + + info = ClonedRepoInfo( + dir=repo_dir, path=Path(repo_dir.name), hash=repo_hash + ) + + # Update the cache with the cloned repo's info + self._cache[key] = info + + # Update the reference count + reference = _Reference(url=url, ref=ref, hash=info.hash) + self._references[reference] = _ReferenceCount( + count=1, dir=info.dir + ) + return info + + async def invalidate(self, url: str, ref: str, repo_hash: str) -> None: + """Invalidate a git repo in the cache by url + ref. + + Decrease the url + ref + hash reference count. If it drops to zero, + delete the files from that clone. + + Parameters + ---------- + url + The URL of the repo to clone + ref + The git ref to checkout after the repo is cloned + hash + The hash of the cloned repo to remove + """ + logger = self._logger.bind(url=url, ref=ref, hash=repo_hash) + key = _Key(url=url, ref=ref) + reference = _Reference(url=url, ref=ref, hash=repo_hash) + + # This theoretically doesn't need a lock, but if we add any awaits here + # in the future, we'd need to add a lock, and it would be pretty easy + # to forget to do it. + async with self._lock: + info = self._cache.get(key) + if info: + logger.info("Invalidating repo") + + # Note that this could force an unnecessary clone if any monkey + # is calling invalidate in its shutdown method. This would + # force reclone for other monkeys that are using the same repo + # at the same hash. + del self._cache[key] + + count = self._references.get(reference) + if count: + count.count -= 1 + if count.count == 0: + logger.info(f"0 references, deleting: {count.dir.name}") + count.dir.cleanup() + del self._references[reference] + + def close(self) -> None: + """Delete all cloned repos and containing directory.""" + self._dir.cleanup() diff --git a/src/mobu/services/solitary.py b/src/mobu/services/solitary.py index cdd6d174..334ba3b6 100644 --- a/src/mobu/services/solitary.py +++ b/src/mobu/services/solitary.py @@ -9,6 +9,7 @@ from ..events import Events from ..models.solitary import SolitaryConfig, SolitaryResult +from ..services.repo import RepoManager from ..storage.gafaelfawr import GafaelfawrStorage from .monkey import Monkey @@ -28,6 +29,8 @@ class Solitary: Shared HTTP client. events Event publishers. + repo_manager + For efficiently cloning git repos. logger Global logger. """ @@ -39,12 +42,14 @@ def __init__( gafaelfawr_storage: GafaelfawrStorage, http_client: AsyncClient, events: Events, + repo_manager: RepoManager, logger: BoundLogger, ) -> None: self._config = solitary_config self._gafaelfawr = gafaelfawr_storage self._http_client = http_client self._events = events + self._repo_manager = repo_manager self._logger = logger async def run(self) -> SolitaryResult: @@ -64,6 +69,7 @@ async def run(self) -> SolitaryResult: user=user, http_client=self._http_client, events=self._events, + repo_manager=self._repo_manager, logger=self._logger, ) error = await monkey.run_once() diff --git a/tests/business/notebookrunner_test.py b/tests/business/notebookrunner_test.py index a8cb6f97..243444ea 100644 --- a/tests/business/notebookrunner_test.py +++ b/tests/business/notebookrunner_test.py @@ -551,6 +551,7 @@ async def test_refresh( # Set up git repo await setup_git_repo(repo_path) + num_monkeys = 5 # Start a monkey. We have to do this in a try/finally block since the # runner will change working directories, which because working # directories are process-global may mess up future tests. @@ -559,7 +560,7 @@ async def test_refresh( "/mobu/flocks", json={ "name": "test", - "count": 1, + "count": num_monkeys, "user_spec": {"username_prefix": "bot-mobu-testuser"}, "scopes": ["exec:notebook"], "business": { @@ -578,10 +579,11 @@ async def test_refresh( ) assert r.status_code == 201 - # We should see a message from the notebook execution in the logs. - assert await wait_for_log_message( - client, "bot-mobu-testuser1", msg="This is a test" - ) + # We should see messages from the notebook execution in the logs. + for i in range(num_monkeys): + assert await wait_for_log_message( + client, f"bot-mobu-testuser{i + 1}", msg="This is a test" + ) # Change the notebook and git commit it notebook = repo_path / "test-notebook.ipynb" @@ -596,19 +598,21 @@ async def test_refresh( jupyter.expected_session_name = "test-notebook.ipynb" jupyter.expected_session_type = "notebook" - # Refresh the notebook + # Refresh the flock r = await client.post("/mobu/flocks/test/refresh") assert r.status_code == 202 - # The refresh should have forced a new execution - assert await wait_for_log_message( - client, "bot-mobu-testuser1", msg="Deleting lab" - ) - - # We should see a message from the updated notebook. - assert await wait_for_log_message( - client, "bot-mobu-testuser1", msg="This is a NEW test" - ) + # The refresh should have forced new executions + for i in range(num_monkeys): + assert await wait_for_log_message( + client, f"bot-mobu-testuser{i + 1}", msg="Deleting lab" + ) + + # We should see messages from the updated notebook. + for i in range(num_monkeys): + assert await wait_for_log_message( + client, f"bot-mobu-testuser{i + 1}", msg="This is a NEW test" + ) finally: os.chdir(cwd) diff --git a/tests/services/ci_manager_test.py b/tests/services/ci_manager_test.py index f3d79a2d..77dbfde1 100644 --- a/tests/services/ci_manager_test.py +++ b/tests/services/ci_manager_test.py @@ -18,6 +18,7 @@ from mobu.models.user import User from mobu.services.business.base import Business from mobu.services.github_ci.ci_manager import CiManager +from mobu.services.repo import RepoManager from mobu.storage.gafaelfawr import GafaelfawrStorage from tests.support.constants import TEST_GITHUB_CI_APP_PRIVATE_KEY @@ -41,11 +42,13 @@ def create_ci_manager(respx_mock: respx.Router, events: Events) -> CiManager: http_client = AsyncClient() logger = structlog.get_logger() gafaelfawr = GafaelfawrStorage(http_client=http_client, logger=logger) + repo_manager = RepoManager(logger=logger) return CiManager( http_client=http_client, gafaelfawr_storage=gafaelfawr, events=events, + repo_manager=repo_manager, logger=logger, scopes=scopes, github_app_id=123, diff --git a/tests/services/repo_cache_test.py b/tests/services/repo_cache_test.py new file mode 100644 index 00000000..cda618ae --- /dev/null +++ b/tests/services/repo_cache_test.py @@ -0,0 +1,128 @@ +"""Tests for the RepoCache.""" + +import shutil +from asyncio import gather +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + +from mobu.services.repo import RepoManager +from mobu.storage.git import Git + +from ..support.constants import TEST_DATA_DIR + + +async def setup_git_repo(repo_path: Path) -> str: + """Initialize and populate a git repo at `repo_path`. + + Returns + ------- + str + Commit hash of the cloned repo + """ + git = Git(repo=repo_path) + await git.init("--initial-branch=main") + await git.config("user.email", "gituser@example.com") + await git.config("user.name", "Git User") + for path in repo_path.iterdir(): + if not path.name.startswith("."): + await git.add(str(path)) + await git.commit("-m", "Initial commit") + return await git.repo_hash() + + +@pytest.mark.asyncio +async def test_cache( + tmp_path: Path, +) -> None: + # Set up a notebook repository. + source_path = TEST_DATA_DIR / "notebooks" + repo_path = tmp_path / "notebooks" + repo_ref = "main" + + shutil.copytree(str(source_path), str(repo_path)) + await setup_git_repo(repo_path) + + mock_logger = MagicMock() + manager = RepoManager(logger=mock_logger, testing=True) + + # Clone the same repo and ref a bunch of times concurrently + clone_tasks = [ + manager.clone(url=str(repo_path), ref=repo_ref) for _ in range(100) + ] + infos = await gather(*clone_tasks) + + # The same info should be returned for every call + assert len(set(infos)) == 1 + original_info = infos[0] + + # The repo should have been cloned + contents = (original_info.path / "test-notebook.ipynb").read_text() + assert "This is a test" in contents + assert "This is a NEW test" not in contents + + # ...once + assert len(manager._cloned) == 1 + manager._cloned = [] + + # Change the notebook and git commit it + notebook = repo_path / "test-notebook.ipynb" + contents = notebook.read_text() + new_contents = contents.replace("This is a test", "This is a NEW test") + notebook.write_text(new_contents) + + git = Git(repo=repo_path) + await git.add(str(notebook)) + await git.commit("-m", "Updating notebook") + + # The repo should be cached (this makes the reference count 101) + cached_info = await manager.clone(url=str(repo_path), ref=repo_ref) + assert cached_info == original_info + contents = (cached_info.path / "test-notebook.ipynb").read_text() + assert "This is a test" in contents + assert "This is a NEW test" not in contents + + # Invalidate this URL and ref. This should make the next clone call clone + # the repo again, but it should not delete the directory of the old + # checkout because there are still 100 references to it. + await manager.invalidate( + url=str(repo_path), ref=repo_ref, repo_hash=original_info.hash + ) + + # Clone it again and verify stuff + clone_tasks = [ + manager.clone(url=str(repo_path), ref=repo_ref) for _ in range(100) + ] + infos = await gather(*clone_tasks) + assert len(set(infos)) == 1 + assert len(manager._cloned) == 1 + updated_info = infos[0] + + # We should get different info because the repo should have been recloned + assert updated_info != original_info + + # The repo should be updated + contents = (updated_info.path / "test-notebook.ipynb").read_text() + assert "This is a test" not in contents + assert "This is a NEW test" in contents + + # The original dir should NOT be deleted + assert Path(original_info.dir.name).exists() + + # invalidate the other references + remove_tasks = [ + manager.invalidate( + url=str(repo_path), ref=repo_ref, repo_hash=original_info.hash + ) + for _ in range(100) + ] + await gather(*remove_tasks) + + # The original dir should be deleted + assert not Path(original_info.dir.name).exists() + + # The cache should clean up after itself + manager.close() + assert not original_info.path.exists() + assert not updated_info.path.exists()