Skip to content

Commit

Permalink
Add contents' file lock
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Apr 17, 2024
1 parent 8c5136e commit 481466b
Show file tree
Hide file tree
Showing 4 changed files with 216 additions and 108 deletions.
49 changes: 48 additions & 1 deletion jupyverse_api/jupyverse_api/contents/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Union, cast

from anyio import Event
from fastapi import APIRouter, Depends, Request, Response

from jupyverse_api import Router
Expand Down Expand Up @@ -36,9 +39,13 @@ def unwatch(self, path: str, watcher):


class Contents(Router, ABC):
file_lock: FileLock

def __init__(self, app: App, auth: Auth):
super().__init__(app=app)

self.file_lock = FileLock()

router = APIRouter()

@router.post(
Expand Down Expand Up @@ -194,3 +201,43 @@ async def rename_content(
user: User,
) -> Content:
...


class FileLock:
"""FileLock ensures that no file operation is done concurrently on the same file,
in order to prevent reading while writing or writing while reading (the same file).
Currently, this also prevents concurrent reading of the same file or directory,
which could be allowed.
"""
_locks: dict[Path, Event]

def __init__(self):
self._locks = {}

def __call__(self, path: Path | str):
if isinstance(path, str):
path = Path(path)
path = cast(Path, path)
return _FileLock(path, self._locks)


class _FileLock:
_path: Path
_locks: dict[Path, Event]
_lock: Event

def __init__(self, path: Path, locks: dict[Path, Event]):
self._path = path
self._locks = locks

async def __aenter__(self):
while True:
if self._path in self._locks:
await self._locks[self._path].wait()
else:
break
self._locks[self._path] = self._lock = Event()

async def __aexit__(self, exc_type, exc_value, exc_tb):
self._lock.set()
del self._locks[self._path]
60 changes: 60 additions & 0 deletions jupyverse_api/tests/test_contents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import pytest
from anyio import create_task_group, sleep
from jupyverse_api.contents import FileLock

pytestmark = pytest.mark.anyio


async def do_op(operation, file_lock, operations):
op, path = operation
async with file_lock(path):
operations.append(operation + ["start"])
await sleep(0.1)
operations.append(operation + ["done"])


async def test_file_lock():
file_lock = FileLock()

# test concurrent accesses to the same file
path = "path/to/file"
operations = []
async with create_task_group() as tg:
tg.start_soon(do_op, ["write0", path], file_lock, operations)
await sleep(0.01)
tg.start_soon(do_op, ["write1", path], file_lock, operations)
await sleep(0.01)
tg.start_soon(do_op, ["read0", path], file_lock, operations)

assert operations == [
["write0", path, "start"],
["write0", path, "done"],
["write1", path, "start"],
["write1", path, "done"],
["read0", path, "start"],
["read0", path, "done"],
]

# test concurrent accesses to different files
path0 = "path/to/file0"
path1 = "path/to/file1"
operations = []
async with create_task_group() as tg:
tg.start_soon(do_op, ["write0", path0], file_lock, operations)
await sleep(0.01)
tg.start_soon(do_op, ["write1", path1], file_lock, operations)
await sleep(0.01)
tg.start_soon(do_op, ["read0", path0], file_lock, operations)
await sleep(0.01)
tg.start_soon(do_op, ["read1", path1], file_lock, operations)

assert operations == [
["write0", path0, "start"],
["write1", path1, "start"],
["write0", path0, "done"],
["read0", path0, "start"],
["write1", path1, "done"],
["read1", path1, "start"],
["read0", path0, "done"],
["read1", path1, "done"],
]
194 changes: 101 additions & 93 deletions plugins/contents/fps_contents/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pathlib import Path
from typing import Dict, List, Optional, Union, cast

from anyio import open_file
from anyio import CancelScope, create_task_group, open_file
from fastapi import HTTPException, Response
from starlette.requests import Request

Expand Down Expand Up @@ -145,107 +145,115 @@ async def rename_content(
async def read_content(
self, path: Union[str, Path], get_content: bool, file_format: Optional[str] = None
) -> Content:
if isinstance(path, str):
path = Path(path)
content: Optional[Union[str, Dict, List[Dict]]] = None
if get_content:
async with self.file_lock(path):
if isinstance(path, str):
path = Path(path)
content: Optional[Union[str, Dict, List[Dict]]] = None
if get_content:
if path.is_dir():
content = [
(await self.read_content(subpath, get_content=False)).model_dump()
for subpath in path.iterdir()
if not subpath.name.startswith(".")
]
elif path.is_file() or path.is_symlink():
try:
async with await open_file(path, mode="rb") as f:
content_bytes = await f.read()
if file_format == "base64":
content = base64.b64encode(content_bytes).decode("ascii")
elif file_format == "json":
content = json.loads(content_bytes)
else:
content = content_bytes.decode()
except Exception:
raise HTTPException(status_code=404, detail="Item not found")
format: Optional[str] = None
if path.is_dir():
content = [
(await self.read_content(subpath, get_content=False)).model_dump()
for subpath in path.iterdir()
if not subpath.name.startswith(".")
]
elif path.is_file() or path.is_symlink():
try:
async with await open_file(path, mode="rb") as f:
content_bytes = await f.read()
if file_format == "base64":
content = base64.b64encode(content_bytes).decode("ascii")
elif file_format == "json":
content = json.loads(content_bytes)
else:
content = content_bytes.decode()
except Exception:
raise HTTPException(status_code=404, detail="Item not found")
format: Optional[str] = None
if path.is_dir():
size = None
type = "directory"
format = "json"
mimetype = None
elif path.is_file() or path.is_symlink():
size = get_file_size(path)
if path.suffix == ".ipynb":
type = "notebook"
format = None
size = None
type = "directory"
format = "json"
mimetype = None
if content is not None:
nb: dict
if file_format == "json":
content = cast(Dict, content)
nb = content
else:
content = cast(str, content)
nb = json.loads(content)
for cell in nb["cells"]:
if "metadata" not in cell:
cell["metadata"] = {}
cell["metadata"].update({"trusted": False})
if cell["cell_type"] == "code":
cell_source = cell["source"]
if not isinstance(cell_source, str):
cell["source"] = "".join(cell_source)
if file_format != "json":
content = json.dumps(nb)
elif path.suffix == ".json":
type = "json"
format = "text"
mimetype = "application/json"
elif path.is_file() or path.is_symlink():
size = get_file_size(path)
if path.suffix == ".ipynb":
type = "notebook"
format = None
mimetype = None
if content is not None:
nb: dict
if file_format == "json":
content = cast(Dict, content)
nb = content
else:
content = cast(str, content)
nb = json.loads(content)
for cell in nb["cells"]:
if "metadata" not in cell:
cell["metadata"] = {}
cell["metadata"].update({"trusted": False})
if cell["cell_type"] == "code":
cell_source = cell["source"]
if not isinstance(cell_source, str):
cell["source"] = "".join(cell_source)
if file_format != "json":
content = json.dumps(nb)
elif path.suffix == ".json":
type = "json"
format = "text"
mimetype = "application/json"
else:
type = "file"
format = None
mimetype = "text/plain"
else:
type = "file"
format = None
mimetype = "text/plain"
else:
raise HTTPException(status_code=404, detail="Item not found")
raise HTTPException(status_code=404, detail="Item not found")

return Content(
**{
"name": path.name,
"path": path.as_posix(),
"last_modified": get_file_modification_time(path),
"created": get_file_creation_time(path),
"content": content,
"format": format,
"mimetype": mimetype,
"size": size,
"writable": is_file_writable(path),
"type": type,
}
)
return Content(
**{
"name": path.name,
"path": path.as_posix(),
"last_modified": get_file_modification_time(path),
"created": get_file_creation_time(path),
"content": content,
"format": format,
"mimetype": mimetype,
"size": size,
"writable": is_file_writable(path),
"type": type,
}
)

async def write_content(self, content: Union[SaveContent, Dict]) -> None:
async with create_task_group() as tg:
with CancelScope(shield=True):
# writing can never be cancelled, otherwise it would corrupt the file
tg.start_soon(self._write_content, content)

async def _write_content(self, content: Union[SaveContent, Dict]) -> None:
if not isinstance(content, SaveContent):
content = SaveContent(**content)
if content.format == "base64":
async with await open_file(content.path, "wb") as f:
content.content = cast(str, content.content)
content_bytes = content.content.encode("ascii")
await f.write(content_bytes)
else:
async with await open_file(content.path, "wt") as f:
if content.format == "json":
dict_content = cast(Dict, content.content)
if content.type == "notebook":
# see https://github.com/jupyterlab/jupyterlab/issues/11005
if (
"metadata" in dict_content
and "orig_nbformat" in dict_content["metadata"]
):
del dict_content["metadata"]["orig_nbformat"]
await f.write(json.dumps(dict_content, indent=2))
else:
async with self.file_lock(content.path):
if content.format == "base64":
async with await open_file(content.path, "wb") as f:
content.content = cast(str, content.content)
await f.write(content.content)
content_bytes = content.content.encode("ascii")
await f.write(content_bytes)
else:
async with await open_file(content.path, "wt") as f:
if content.format == "json":
dict_content = cast(Dict, content.content)
if content.type == "notebook":
# see https://github.com/jupyterlab/jupyterlab/issues/11005
if (
"metadata" in dict_content
and "orig_nbformat" in dict_content["metadata"]
):
del dict_content["metadata"]["orig_nbformat"]
await f.write(json.dumps(dict_content, indent=2))
else:
content.content = cast(str, content.content)
await f.write(content.content)

@property
def file_id_manager(self):
Expand Down
Loading

0 comments on commit 481466b

Please sign in to comment.