Skip to content

Commit

Permalink
Add contents' file lock
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Apr 17, 2024
1 parent 029aacd commit 01eb0ce
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 110 deletions.
49 changes: 48 additions & 1 deletion jupyverse_api/jupyverse_api/contents/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Union, cast

from anyio import Event
from fastapi import APIRouter, Depends, Request, Response

from jupyverse_api import Router
Expand Down Expand Up @@ -36,9 +39,13 @@ def unwatch(self, path: str, watcher):


class Contents(Router, ABC):
file_lock: FileLock

def __init__(self, app: App, auth: Auth):
super().__init__(app=app)

self.file_lock = FileLock()

router = APIRouter()

@router.post(
Expand Down Expand Up @@ -194,3 +201,43 @@ async def rename_content(
user: User,
) -> Content:
...


class FileLock:
"""FileLock ensures that no file operation is done concurrently on the same file,
in order to prevent reading while writing or writing while reading (the same file).
Currently, this also prevents concurrent reading of the same file or directory,
which could be allowed.
"""
_locks: dict[Path, Event]

def __init__(self):
self._locks = {}

def __call__(self, path: Path | str):
if isinstance(path, str):
path = Path(path)
path = cast(Path, path)
return _FileLock(path, self._locks)


class _FileLock:
_path: Path
_locks: dict[Path, Event]
_lock: Event

def __init__(self, path: Path, locks: dict[path, Event]):
self._path = path
self._locks = locks

async def __aenter__(self):
while True:
if self._path in self._locks:
await self._locks[self._path].wait()
else:
break
self._locks[self._path] = self._lock = Event()

async def __aexit__(self, exc_type, exc_value, exc_tb):
self._lock.set()
del self._locks[self._path]
198 changes: 103 additions & 95 deletions plugins/contents/fps_contents/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pathlib import Path
from typing import Dict, List, Optional, Union, cast

from anyio import open_file
from anyio import CancelScope, open_file
from fastapi import HTTPException, Response
from starlette.requests import Request

Expand Down Expand Up @@ -145,107 +145,115 @@ async def rename_content(
async def read_content(
self, path: Union[str, Path], get_content: bool, file_format: Optional[str] = None
) -> Content:
if isinstance(path, str):
path = Path(path)
content: Optional[Union[str, Dict, List[Dict]]] = None
if get_content:
async with self.file_lock(path):
if isinstance(path, str):
path = Path(path)
content: Optional[Union[str, Dict, List[Dict]]] = None
if get_content:
if path.is_dir():
content = [
(await self.read_content(subpath, get_content=False)).model_dump()
for subpath in path.iterdir()
if not subpath.name.startswith(".")
]
elif path.is_file() or path.is_symlink():
try:
async with await open_file(path, mode="rb") as f:
content_bytes = await f.read()
if file_format == "base64":
content = base64.b64encode(content_bytes).decode("ascii")
elif file_format == "json":
content = json.loads(content_bytes)
else:
content = content_bytes.decode()
except Exception:
raise HTTPException(status_code=404, detail="Item not found")
format: Optional[str] = None
if path.is_dir():
content = [
(await self.read_content(subpath, get_content=False)).model_dump()
for subpath in path.iterdir()
if not subpath.name.startswith(".")
]
elif path.is_file() or path.is_symlink():
try:
async with await open_file(path, mode="rb") as f:
content_bytes = await f.read()
if file_format == "base64":
content = base64.b64encode(content_bytes).decode("ascii")
elif file_format == "json":
content = json.loads(content_bytes)
else:
content = content_bytes.decode()
except Exception:
raise HTTPException(status_code=404, detail="Item not found")
format: Optional[str] = None
if path.is_dir():
size = None
type = "directory"
format = "json"
mimetype = None
elif path.is_file() or path.is_symlink():
size = get_file_size(path)
if path.suffix == ".ipynb":
type = "notebook"
format = None
size = None
type = "directory"
format = "json"
mimetype = None
if content is not None:
nb: dict
if file_format == "json":
content = cast(Dict, content)
nb = content
else:
content = cast(str, content)
nb = json.loads(content)
for cell in nb["cells"]:
if "metadata" not in cell:
cell["metadata"] = {}
cell["metadata"].update({"trusted": False})
if cell["cell_type"] == "code":
cell_source = cell["source"]
if not isinstance(cell_source, str):
cell["source"] = "".join(cell_source)
if file_format != "json":
content = json.dumps(nb)
elif path.suffix == ".json":
type = "json"
format = "text"
mimetype = "application/json"
elif path.is_file() or path.is_symlink():
size = get_file_size(path)
if path.suffix == ".ipynb":
type = "notebook"
format = None
mimetype = None
if content is not None:
nb: dict
if file_format == "json":
content = cast(Dict, content)
nb = content
else:
content = cast(str, content)
nb = json.loads(content)
for cell in nb["cells"]:
if "metadata" not in cell:
cell["metadata"] = {}
cell["metadata"].update({"trusted": False})
if cell["cell_type"] == "code":
cell_source = cell["source"]
if not isinstance(cell_source, str):
cell["source"] = "".join(cell_source)
if file_format != "json":
content = json.dumps(nb)
elif path.suffix == ".json":
type = "json"
format = "text"
mimetype = "application/json"
else:
type = "file"
format = None
mimetype = "text/plain"
else:
type = "file"
format = None
mimetype = "text/plain"
else:
raise HTTPException(status_code=404, detail="Item not found")
raise HTTPException(status_code=404, detail="Item not found")

return Content(
**{
"name": path.name,
"path": path.as_posix(),
"last_modified": get_file_modification_time(path),
"created": get_file_creation_time(path),
"content": content,
"format": format,
"mimetype": mimetype,
"size": size,
"writable": is_file_writable(path),
"type": type,
}
)
return Content(
**{
"name": path.name,
"path": path.as_posix(),
"last_modified": get_file_modification_time(path),
"created": get_file_creation_time(path),
"content": content,
"format": format,
"mimetype": mimetype,
"size": size,
"writable": is_file_writable(path),
"type": type,
}
)

async def write_content(self, content: Union[SaveContent, Dict]) -> None:
if not isinstance(content, SaveContent):
content = SaveContent(**content)
if content.format == "base64":
async with await open_file(content.path, "wb") as f:
content.content = cast(str, content.content)
content_bytes = content.content.encode("ascii")
await f.write(content_bytes)
else:
async with await open_file(content.path, "wt") as f:
if content.format == "json":
dict_content = cast(Dict, content.content)
if content.type == "notebook":
# see https://github.com/jupyterlab/jupyterlab/issues/11005
if (
"metadata" in dict_content
and "orig_nbformat" in dict_content["metadata"]
):
del dict_content["metadata"]["orig_nbformat"]
await f.write(json.dumps(dict_content, indent=2))
else:
async with create_task_group() as tg:
with CancelScope(shield=True) as scope:
# writing can never be cancelled, otherwise it would corrupt the file
tg.start_soon(self._write_content(content))

async def _write_content(self, content: Union[SaveContent, Dict]) -> None:
async with self.file_lock(content.path):
if not isinstance(content, SaveContent):
content = SaveContent(**content)
if content.format == "base64":
async with await open_file(content.path, "wb") as f:
content.content = cast(str, content.content)
await f.write(content.content)
content_bytes = content.content.encode("ascii")
await f.write(content_bytes)
else:
async with await open_file(content.path, "wt") as f:
if content.format == "json":
dict_content = cast(Dict, content.content)
if content.type == "notebook":
# see https://github.com/jupyterlab/jupyterlab/issues/11005
if (
"metadata" in dict_content
and "orig_nbformat" in dict_content["metadata"]
):
del dict_content["metadata"]["orig_nbformat"]
await f.write(json.dumps(dict_content, indent=2))
else:
content.content = cast(str, content.content)
await f.write(content.content)

@property
def file_id_manager(self):
Expand Down
21 changes: 7 additions & 14 deletions plugins/yjs/fps_yjs/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Dict
from uuid import uuid4

from anyio import TASK_STATUS_IGNORED, Event, Lock, create_task_group, sleep
from anyio import TASK_STATUS_IGNORED, Event, create_task_group, sleep
from anyio.abc import TaskGroup, TaskStatus
from fastapi import (
HTTPException,
Expand Down Expand Up @@ -166,7 +166,6 @@ class RoomManager:
cleaners: Dict[YRoom, Task]
last_modified: Dict[str, datetime]
websocket_server: JupyterWebsocketServer
lock: Lock
_task_group: TaskGroup

def __init__(self, contents: Contents, lifespan: Lifespan, task_group: TaskGroup):
Expand All @@ -179,7 +178,6 @@ def __init__(self, contents: Contents, lifespan: Lifespan, task_group: TaskGroup
self.cleaners = {} # a dictionary of room:task
self.last_modified = {} # a dictionary of file_id:last_modification_date
self.websocket_server = JupyterWebsocketServer(rooms_ready=False, auto_clean_rooms=False)
self.lock = Lock()

async def on_shutdown(self):
await self.lifespan.shutdown_request.wait()
Expand All @@ -203,8 +201,7 @@ async def serve(self, websocket: YWebsocket, permissions) -> None:
document = YDOCS.get(file_type, YFILE)(room.ydoc)
document.file_id = file_id
self.documents[websocket.path] = document
async with self.lock:
model = await self.contents.read_content(file_path, True, file_format)
model = await self.contents.read_content(file_path, True, file_format)
assert model.last_modified is not None
self.last_modified[file_id] = to_datetime(model.last_modified)
if not room.ready:
Expand Down Expand Up @@ -294,14 +291,12 @@ async def watch_file(self, file_format: str, file_id: str, document: YBaseDoc) -
await self.maybe_load_file(file_format, file_path, file_id)

async def maybe_load_file(self, file_format: str, file_path: str, file_id: str) -> None:
async with self.lock:
model = await self.contents.read_content(file_path, False)
model = await self.contents.read_content(file_path, False)
# do nothing if the file was saved by us
assert model.last_modified is not None
if self.last_modified[file_id] < to_datetime(model.last_modified):
# the file was not saved by us, update the shared document(s)
async with self.lock:
model = await self.contents.read_content(file_path, True, file_format)
model = await self.contents.read_content(file_path, True, file_format)
assert model.last_modified is not None
documents = [v for k, v in self.documents.items() if k.split(":", 2)[2] == file_id]
for document in documents:
Expand Down Expand Up @@ -339,8 +334,7 @@ async def maybe_save_document(
except Exception:
return
assert file_path is not None
async with self.lock:
model = await self.contents.read_content(file_path, True, file_format)
model = await self.contents.read_content(file_path, True, file_format)
assert model.last_modified is not None
if self.last_modified[file_id] < to_datetime(model.last_modified):
# file changed on disk, let's revert
Expand All @@ -357,9 +351,8 @@ async def maybe_save_document(
"path": file_path,
"type": file_type,
}
async with self.lock:
await self.contents.write_content(content)
model = await self.contents.read_content(file_path, False)
await self.contents.write_content(content)
model = await self.contents.read_content(file_path, False)
assert model.last_modified is not None
self.last_modified[file_id] = to_datetime(model.last_modified)
document.dirty = False
Expand Down

0 comments on commit 01eb0ce

Please sign in to comment.