diff --git a/marimo/_server/api/endpoints/assets.py b/marimo/_server/api/endpoints/assets.py index 2da2f24b7b1..d5d1a65f599 100644 --- a/marimo/_server/api/endpoints/assets.py +++ b/marimo/_server/api/endpoints/assets.py @@ -2,7 +2,6 @@ from __future__ import annotations import mimetypes -import os import re from pathlib import Path from typing import TYPE_CHECKING @@ -33,7 +32,7 @@ router = APIRouter() # Root directory for static assets -root = os.path.realpath(str(import_files("marimo").joinpath("_static"))) +root = Path(import_files("marimo").joinpath("_static")).resolve() config = ( get_default_config_manager(current_path=None) @@ -41,12 +40,27 @@ .get("server", {}) ) +assets_dir = root / "assets" +follow_symlinks = config.get("follow_symlink", False) + +if not follow_symlinks and assets_dir.is_symlink(): + LOGGER.error( + "Assets directory is a symlink but follow_symlink=false.\n" + "To fix this:\n" + "1. Run 'marimo config show' to see your current config\n" + "2. Add 'follow_symlink = true' under the [server] section in your config\n" + "3. Restart marimo\n\n" + "Example config:\n" + "[server]\n" + "follow_symlink = true" + ) + try: router.mount( "/assets", app=StaticFiles( - directory=os.path.join(root, "assets"), - follow_symlink=config.get("follow_symlink", False), + directory=assets_dir, + follow_symlink=follow_symlinks, ), name="assets", ) @@ -60,15 +74,14 @@ @requires("read", redirect="auth:login_page") async def index(request: Request) -> HTMLResponse: app_state = AppState(request) - index_html = os.path.join(root, "index.html") + index_html = root / "index.html" file_key = ( app_state.query_params(FILE_QUERY_PARAM_KEY) or app_state.session_manager.file_router.get_unique_file_key() ) - with open(index_html, "r") as f: # noqa: ASYNC101 ASYNC230 - html = f.read() + html = index_html.read_text() if not file_key: # We don't know which file to use, so we need to render a homepage @@ -215,7 +228,7 @@ async def public_files_service_worker(request: Request) -> Response: async def serve_public_file(request: Request) -> Response: """Serve files from the notebook's directory under /public/""" app_state = AppState(request) - filepath = request.path_params["filepath"] + filepath = str(request.path_params["filepath"]) # Get notebook ID from header notebook_id = request.headers.get("X-Notebook-Id") if notebook_id: @@ -233,7 +246,7 @@ async def serve_public_file(request: Request) -> Response: except ValueError: return Response(status_code=403, content="Access denied") - if file_path.is_file() and not os.path.islink(str(file_path)): + if file_path.is_file() and not file_path.is_symlink(): return FileResponse(file_path) raise HTTPException(status_code=404, detail="File not found") @@ -242,8 +255,8 @@ async def serve_public_file(request: Request) -> Response: # Catch all for serving static files @router.get("/{path:path}") async def serve_static(request: Request) -> FileResponse: - path = request.path_params["path"] + path = str(request.path_params["path"]) if any(re.match(pattern, path) for pattern in STATIC_FILES): - return FileResponse(os.path.join(root, path)) + return FileResponse(root / path) raise HTTPException(status_code=404, detail="Not Found")