diff --git a/jupyverse_api/jupyverse_api/contents/__init__.py b/jupyverse_api/jupyverse_api/contents/__init__.py index c9af6975..30319178 100644 --- a/jupyverse_api/jupyverse_api/contents/__init__.py +++ b/jupyverse_api/jupyverse_api/contents/__init__.py @@ -1,7 +1,10 @@ +from __future__ import annotations + from abc import ABC, abstractmethod from pathlib import Path -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Union, cast +from anyio import Event from fastapi import APIRouter, Depends, Request, Response from jupyverse_api import Router @@ -36,9 +39,13 @@ def unwatch(self, path: str, watcher): class Contents(Router, ABC): + file_lock: FileLock + def __init__(self, app: App, auth: Auth): super().__init__(app=app) + self.file_lock = FileLock() + router = APIRouter() @router.post( @@ -194,3 +201,43 @@ async def rename_content( user: User, ) -> Content: ... + + +class FileLock: + """FileLock ensures that no file operation is done concurrently on the same file, + in order to prevent reading while writing or writing while reading (the same file). + Currently, this also prevents concurrent reading of the same file or directory, + which could be allowed. + """ + _locks: dict[Path, Event] + + def __init__(self): + self._locks = {} + + def __call__(self, path: Path | str): + if isinstance(path, str): + path = Path(path) + path = cast(Path, path) + return _FileLock(path, self._locks) + + +class _FileLock: + _path: Path + _locks: dict[Path, Event] + _lock: Event + + def __init__(self, path: Path, locks: dict[path, Event]): + self._path = path + self._locks = locks + + async def __aenter__(self): + while True: + if self._path in self._locks: + await self._locks[self._path].wait() + else: + break + self._locks[self._path] = self._lock = Event() + + async def __aexit__(self, exc_type, exc_value, exc_tb): + self._lock.set() + del self._locks[self._path] diff --git a/plugins/contents/fps_contents/routes.py b/plugins/contents/fps_contents/routes.py index bc7667ed..76743076 100644 --- a/plugins/contents/fps_contents/routes.py +++ b/plugins/contents/fps_contents/routes.py @@ -9,7 +9,7 @@ from pathlib import Path from typing import Dict, List, Optional, Union, cast -from anyio import open_file +from anyio import CancelScope, open_file from fastapi import HTTPException, Response from starlette.requests import Request @@ -145,107 +145,115 @@ async def rename_content( async def read_content( self, path: Union[str, Path], get_content: bool, file_format: Optional[str] = None ) -> Content: - if isinstance(path, str): - path = Path(path) - content: Optional[Union[str, Dict, List[Dict]]] = None - if get_content: + async with self.file_lock(path): + if isinstance(path, str): + path = Path(path) + content: Optional[Union[str, Dict, List[Dict]]] = None + if get_content: + if path.is_dir(): + content = [ + (await self.read_content(subpath, get_content=False)).model_dump() + for subpath in path.iterdir() + if not subpath.name.startswith(".") + ] + elif path.is_file() or path.is_symlink(): + try: + async with await open_file(path, mode="rb") as f: + content_bytes = await f.read() + if file_format == "base64": + content = base64.b64encode(content_bytes).decode("ascii") + elif file_format == "json": + content = json.loads(content_bytes) + else: + content = content_bytes.decode() + except Exception: + raise HTTPException(status_code=404, detail="Item not found") + format: Optional[str] = None if path.is_dir(): - content = [ - (await self.read_content(subpath, get_content=False)).model_dump() - for subpath in path.iterdir() - if not subpath.name.startswith(".") - ] - elif path.is_file() or path.is_symlink(): - try: - async with await open_file(path, mode="rb") as f: - content_bytes = await f.read() - if file_format == "base64": - content = base64.b64encode(content_bytes).decode("ascii") - elif file_format == "json": - content = json.loads(content_bytes) - else: - content = content_bytes.decode() - except Exception: - raise HTTPException(status_code=404, detail="Item not found") - format: Optional[str] = None - if path.is_dir(): - size = None - type = "directory" - format = "json" - mimetype = None - elif path.is_file() or path.is_symlink(): - size = get_file_size(path) - if path.suffix == ".ipynb": - type = "notebook" - format = None + size = None + type = "directory" + format = "json" mimetype = None - if content is not None: - nb: dict - if file_format == "json": - content = cast(Dict, content) - nb = content - else: - content = cast(str, content) - nb = json.loads(content) - for cell in nb["cells"]: - if "metadata" not in cell: - cell["metadata"] = {} - cell["metadata"].update({"trusted": False}) - if cell["cell_type"] == "code": - cell_source = cell["source"] - if not isinstance(cell_source, str): - cell["source"] = "".join(cell_source) - if file_format != "json": - content = json.dumps(nb) - elif path.suffix == ".json": - type = "json" - format = "text" - mimetype = "application/json" + elif path.is_file() or path.is_symlink(): + size = get_file_size(path) + if path.suffix == ".ipynb": + type = "notebook" + format = None + mimetype = None + if content is not None: + nb: dict + if file_format == "json": + content = cast(Dict, content) + nb = content + else: + content = cast(str, content) + nb = json.loads(content) + for cell in nb["cells"]: + if "metadata" not in cell: + cell["metadata"] = {} + cell["metadata"].update({"trusted": False}) + if cell["cell_type"] == "code": + cell_source = cell["source"] + if not isinstance(cell_source, str): + cell["source"] = "".join(cell_source) + if file_format != "json": + content = json.dumps(nb) + elif path.suffix == ".json": + type = "json" + format = "text" + mimetype = "application/json" + else: + type = "file" + format = None + mimetype = "text/plain" else: - type = "file" - format = None - mimetype = "text/plain" - else: - raise HTTPException(status_code=404, detail="Item not found") + raise HTTPException(status_code=404, detail="Item not found") - return Content( - **{ - "name": path.name, - "path": path.as_posix(), - "last_modified": get_file_modification_time(path), - "created": get_file_creation_time(path), - "content": content, - "format": format, - "mimetype": mimetype, - "size": size, - "writable": is_file_writable(path), - "type": type, - } - ) + return Content( + **{ + "name": path.name, + "path": path.as_posix(), + "last_modified": get_file_modification_time(path), + "created": get_file_creation_time(path), + "content": content, + "format": format, + "mimetype": mimetype, + "size": size, + "writable": is_file_writable(path), + "type": type, + } + ) async def write_content(self, content: Union[SaveContent, Dict]) -> None: - if not isinstance(content, SaveContent): - content = SaveContent(**content) - if content.format == "base64": - async with await open_file(content.path, "wb") as f: - content.content = cast(str, content.content) - content_bytes = content.content.encode("ascii") - await f.write(content_bytes) - else: - async with await open_file(content.path, "wt") as f: - if content.format == "json": - dict_content = cast(Dict, content.content) - if content.type == "notebook": - # see https://github.com/jupyterlab/jupyterlab/issues/11005 - if ( - "metadata" in dict_content - and "orig_nbformat" in dict_content["metadata"] - ): - del dict_content["metadata"]["orig_nbformat"] - await f.write(json.dumps(dict_content, indent=2)) - else: + async with create_task_group() as tg: + with CancelScope(shield=True) as scope: + # writing can never be cancelled, otherwise it would corrupt the file + tg.start_soon(self._write_content(content)) + + async def _write_content(self, content: Union[SaveContent, Dict]) -> None: + async with self.file_lock(content.path): + if not isinstance(content, SaveContent): + content = SaveContent(**content) + if content.format == "base64": + async with await open_file(content.path, "wb") as f: content.content = cast(str, content.content) - await f.write(content.content) + content_bytes = content.content.encode("ascii") + await f.write(content_bytes) + else: + async with await open_file(content.path, "wt") as f: + if content.format == "json": + dict_content = cast(Dict, content.content) + if content.type == "notebook": + # see https://github.com/jupyterlab/jupyterlab/issues/11005 + if ( + "metadata" in dict_content + and "orig_nbformat" in dict_content["metadata"] + ): + del dict_content["metadata"]["orig_nbformat"] + await f.write(json.dumps(dict_content, indent=2)) + else: + content.content = cast(str, content.content) + await f.write(content.content) @property def file_id_manager(self): diff --git a/plugins/yjs/fps_yjs/routes.py b/plugins/yjs/fps_yjs/routes.py index 14460686..ff57a323 100644 --- a/plugins/yjs/fps_yjs/routes.py +++ b/plugins/yjs/fps_yjs/routes.py @@ -6,7 +6,7 @@ from typing import Dict from uuid import uuid4 -from anyio import TASK_STATUS_IGNORED, Event, Lock, create_task_group, sleep +from anyio import TASK_STATUS_IGNORED, Event, create_task_group, sleep from anyio.abc import TaskGroup, TaskStatus from fastapi import ( HTTPException, @@ -166,7 +166,6 @@ class RoomManager: cleaners: Dict[YRoom, Task] last_modified: Dict[str, datetime] websocket_server: JupyterWebsocketServer - lock: Lock _task_group: TaskGroup def __init__(self, contents: Contents, lifespan: Lifespan, task_group: TaskGroup): @@ -179,7 +178,6 @@ def __init__(self, contents: Contents, lifespan: Lifespan, task_group: TaskGroup self.cleaners = {} # a dictionary of room:task self.last_modified = {} # a dictionary of file_id:last_modification_date self.websocket_server = JupyterWebsocketServer(rooms_ready=False, auto_clean_rooms=False) - self.lock = Lock() async def on_shutdown(self): await self.lifespan.shutdown_request.wait() @@ -203,8 +201,7 @@ async def serve(self, websocket: YWebsocket, permissions) -> None: document = YDOCS.get(file_type, YFILE)(room.ydoc) document.file_id = file_id self.documents[websocket.path] = document - async with self.lock: - model = await self.contents.read_content(file_path, True, file_format) + model = await self.contents.read_content(file_path, True, file_format) assert model.last_modified is not None self.last_modified[file_id] = to_datetime(model.last_modified) if not room.ready: @@ -294,14 +291,12 @@ async def watch_file(self, file_format: str, file_id: str, document: YBaseDoc) - await self.maybe_load_file(file_format, file_path, file_id) async def maybe_load_file(self, file_format: str, file_path: str, file_id: str) -> None: - async with self.lock: - model = await self.contents.read_content(file_path, False) + model = await self.contents.read_content(file_path, False) # do nothing if the file was saved by us assert model.last_modified is not None if self.last_modified[file_id] < to_datetime(model.last_modified): # the file was not saved by us, update the shared document(s) - async with self.lock: - model = await self.contents.read_content(file_path, True, file_format) + model = await self.contents.read_content(file_path, True, file_format) assert model.last_modified is not None documents = [v for k, v in self.documents.items() if k.split(":", 2)[2] == file_id] for document in documents: @@ -339,8 +334,7 @@ async def maybe_save_document( except Exception: return assert file_path is not None - async with self.lock: - model = await self.contents.read_content(file_path, True, file_format) + model = await self.contents.read_content(file_path, True, file_format) assert model.last_modified is not None if self.last_modified[file_id] < to_datetime(model.last_modified): # file changed on disk, let's revert @@ -357,9 +351,8 @@ async def maybe_save_document( "path": file_path, "type": file_type, } - async with self.lock: - await self.contents.write_content(content) - model = await self.contents.read_content(file_path, False) + await self.contents.write_content(content) + model = await self.contents.read_content(file_path, False) assert model.last_modified is not None self.last_modified[file_id] = to_datetime(model.last_modified) document.dirty = False