From 0f23e0b9ce8c63697a285fcc84fd2ec36e8334ec Mon Sep 17 00:00:00 2001 From: David Brochart Date: Fri, 12 Aug 2022 11:40:35 +0200 Subject: [PATCH 1/4] Implement /api/me, store permissions in database --- plugins/auth/fps_auth/db.py | 1 + plugins/auth/fps_auth/models.py | 10 ++++++++- plugins/auth/fps_auth/routes.py | 39 ++++++++++++++++++++++++++++++++- 3 files changed, 48 insertions(+), 2 deletions(-) diff --git a/plugins/auth/fps_auth/db.py b/plugins/auth/fps_auth/db.py index b3f724eb..6e41565f 100644 --- a/plugins/auth/fps_auth/db.py +++ b/plugins/auth/fps_auth/db.py @@ -54,6 +54,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base): avatar = Column(String(length=32), nullable=True) workspace = Column(Text(), default="{}", nullable=False) settings = Column(Text(), default="{}", nullable=False) + permissions = Column(Text(), default="{}", nullable=False) oauth_accounts: List[OAuthAccount] = relationship("OAuthAccount", lazy="joined") diff --git a/plugins/auth/fps_auth/models.py b/plugins/auth/fps_auth/models.py index 79c40e50..c078ac9d 100644 --- a/plugins/auth/fps_auth/models.py +++ b/plugins/auth/fps_auth/models.py @@ -1,5 +1,5 @@ import uuid -from typing import Optional +from typing import Dict, List, Optional from fastapi_users import schemas from pydantic import BaseModel @@ -13,6 +13,14 @@ class JupyterUser(BaseModel): avatar: Optional[str] = None workspace: str = "{}" settings: str = "{}" + permissions: str = "{}" + + +class Permissions(BaseModel): + __root__: Dict[str, List[str]] + + def items(self): + return self.__root__.items() class UserRead(schemas.BaseUser[uuid.UUID], JupyterUser): diff --git a/plugins/auth/fps_auth/routes.py b/plugins/auth/fps_auth/routes.py index 48e2f340..223b0b42 100644 --- a/plugins/auth/fps_auth/routes.py +++ b/plugins/auth/fps_auth/routes.py @@ -1,6 +1,9 @@ +import json +from typing import Dict, List from uuid import uuid4 from fastapi import APIRouter, Depends +from fastapi.exceptions import HTTPException from fps.config import get_config # type: ignore from fps.hooks import register_router # type: ignore from fps.logging import get_configured_logger # type: ignore @@ -16,7 +19,7 @@ ) from .config import get_auth_config from .db import Session, User, UserDb, create_db_and_tables, secret -from .models import UserCreate, UserRead, UserUpdate +from .models import Permissions, UserCreate, UserRead, UserUpdate logger = get_configured_logger("auth") @@ -50,6 +53,7 @@ async def startup(): is_verified=True, workspace="{}", settings="{}", + permissions="{}", ) await user_db.create(global_user) @@ -70,6 +74,39 @@ async def get_users(user: UserRead = Depends(current_user)): return [usr.User for usr in users if usr.User.is_active] +@router.get("/api/me") +async def get_api_me( + permissions, + user: UserRead = Depends(current_user), +): + try: + permissions_to_check = Permissions.parse_obj(json.loads(permissions)) + except BaseException: + raise HTTPException( + 400, + detail='permissions should be a JSON dict of {{"resource": ["action",]}}, ' + f"got {permissions}", + ) + + user_permissions = json.loads(user.permissions) + checked_permissions: Dict[str, List[str]] = {} + for resource, actions in permissions_to_check.items(): + user_resource_permissions = user_permissions.get(resource) + if user_resource_permissions is None: + continue + allowed = checked_permissions[resource] = [] + for action in actions: + if action in user_resource_permissions: + allowed.append(action) + + keys = ["email", "name", "avatar", "anonymous", "username", "color"] + identity = {k: getattr(user, k) for k in keys} + return { + "identity": identity, + "permissions": checked_permissions, + } + + # redefine GET /me because we want our current_user dependency # it is first defined in users_router and so it wins over the one in fapi_users.get_users_router users_router = APIRouter() From 5a39ddceb12f984f3a44b376ff027658cadadf9a Mon Sep 17 00:00:00 2001 From: David Brochart Date: Wed, 17 Aug 2022 11:05:54 +0200 Subject: [PATCH 2/4] Add authorization --- .github/workflows/main.yml | 13 +-- README.md | 12 +- binder/jupyter_notebook_config.py | 8 +- plugins/auth/fps_auth/backends.py | 118 +++++++++++++------- plugins/auth/fps_auth/config.py | 11 +- plugins/auth/fps_auth/db.py | 25 +++-- plugins/auth/fps_auth/fixtures.py | 66 ++++++++++- plugins/auth/fps_auth/models.py | 27 ++--- plugins/auth/fps_auth/routes.py | 74 +++++++----- plugins/contents/fps_contents/routes.py | 16 +-- plugins/jupyterlab/fps_jupyterlab/routes.py | 8 +- plugins/kernels/fps_kernels/routes.py | 66 ++++------- plugins/lab/fps_lab/routes.py | 16 +-- plugins/nbconvert/fps_nbconvert/routes.py | 2 +- plugins/retrolab/fps_retrolab/routes.py | 10 +- plugins/terminals/fps_terminals/routes.py | 37 ++---- plugins/yjs/fps_yjs/routes.py | 30 +---- pytest.ini | 3 + setup.cfg | 1 + tests/conftest.py | 50 +-------- tests/test_auth.py | 22 +++- 21 files changed, 327 insertions(+), 288 deletions(-) create mode 100644 pytest.ini diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 6c75c2a0..a3830b17 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -38,19 +38,16 @@ jobs: - name: Install jupyverse run: | - pip install fps[uvicorn] - pip install . --no-deps + pip install ./plugins/jupyterlab + pip install ./plugins/login pip install ./plugins/auth pip install ./plugins/contents pip install ./plugins/kernels + pip install ./plugins/terminals + pip install ./plugins/lab pip install ./plugins/nbconvert pip install ./plugins/yjs - pip install ./plugins/lab - pip install ./plugins/jupyterlab - pip install "jupyter_ydoc >=0.1.16,<0.2.0" # FIXME: remove with next JupyterLab release - pip install "y-py >=0.5.4" - - pip install mypy pytest pytest-asyncio requests ipykernel + pip install .[test] - name: Check types run: | diff --git a/README.md b/README.md index 42f51c0f..0b745a26 100644 --- a/README.md +++ b/README.md @@ -40,8 +40,6 @@ When switching e.g. from the JupyterLab to the RetroLab front-end, you need to Clone this repository and install the needed plugins: ```bash -pip install fps[uvicorn] -pip install -e . --no-deps pip install -e plugins/jupyterlab pip install -e plugins/login pip install -e plugins/auth @@ -51,9 +49,9 @@ pip install -e plugins/terminals pip install -e plugins/lab pip install -e plugins/nbconvert pip install -e plugins/yjs +pip install -e .[test] # if you want RetroLab instead of JupyterLab: -# pip install -e . --no-deps # pip install -e plugins/retrolab # ... ``` @@ -63,7 +61,7 @@ pip install -e plugins/yjs ## Without authentication ```bash -jupyverse --open-browser --authenticator.mode=noauth +jupyverse --open-browser --auth.mode=noauth ``` This will open a browser at 127.0.0.1:8000 by default, and load the JupyterLab front-end. @@ -72,7 +70,7 @@ You have full access to the API, without restriction. ## With token authentication ```bash -jupyverse --open-browser --authenticator.mode=token +jupyverse --open-browser --auth.mode=token ``` This is the default mode, and it corresponds to @@ -81,7 +79,7 @@ This is the default mode, and it corresponds to ## With user authentication ```bash -jupyverse --open-browser --authenticator.mode=user +jupyverse --open-browser --auth.mode=user ``` We provide a JupyterLab extension for authentication, that you can install with: @@ -96,7 +94,7 @@ You can currently authenticate as an anonymous user, or ## With collaborative editing ```bash -jupyverse --open-browser --authenticator.collaborative +jupyverse --open-browser --auth.collaborative ``` This is especially interesting if you are "user-authenticated", since your will appear as the diff --git a/binder/jupyter_notebook_config.py b/binder/jupyter_notebook_config.py index f2f916bc..626f1ec1 100644 --- a/binder/jupyter_notebook_config.py +++ b/binder/jupyter_notebook_config.py @@ -2,8 +2,8 @@ [ "jupyverse", "--no-open-browser", - "--authenticator.mode=noauth", - "--authenticator.collaborative", + "--auth.mode=noauth", + "--auth.collaborative", "--RetroLab.enabled=false", "--Lab.base_url={base_url}jupyverse-jlab/", "--port={port}", @@ -16,8 +16,8 @@ [ "jupyverse", "--no-open-browser", - "--authenticator.mode=noauth", - "--authenticator.collaborative", + "--auth.mode=noauth", + "--auth.collaborative", "--JupyterLab.enabled=false", "--Lab.base_url={base_url}jupyverse-rlab/", "--port={port}", diff --git a/plugins/auth/fps_auth/backends.py b/plugins/auth/fps_auth/backends.py index 7b2bc24c..b8c4d6ca 100644 --- a/plugins/auth/fps_auth/backends.py +++ b/plugins/auth/fps_auth/backends.py @@ -2,7 +2,7 @@ from typing import Generic, Optional import httpx -from fastapi import Depends, HTTPException, Response, status +from fastapi import Depends, HTTPException, Response, WebSocket, status from fastapi_users import ( # type: ignore BaseUserManager, FastAPIUsers, @@ -95,7 +95,7 @@ async def on_after_register(self, user: User, request: Optional[Request] = None) anonymous=False, username=r["login"], color=None, - avatar=r["avatar_url"], + avatar_url=r["avatar_url"], is_active=True, ), ) @@ -135,39 +135,81 @@ async def create_guest(user_db, auth_config): return await user_db.create(guest) -async def current_user( - response: Response, - token: Optional[str] = None, - user: User = Depends( - fapi_users.current_user(optional=True, get_enabled_backends=get_enabled_backends) - ), - user_db=Depends(get_user_db), - auth_config=Depends(get_auth_config), -): - active_user = user - - if auth_config.collaborative: - if not active_user and auth_config.mode == "noauth": - active_user = await create_guest(user_db, auth_config) - await cookie_authentication.login(get_jwt_strategy(), active_user, response) - - elif not active_user and auth_config.mode == "token": - global_user = await user_db.get_by_email(auth_config.global_email) - if global_user and global_user.hashed_password == token: - active_user = await create_guest(user_db, auth_config) - await cookie_authentication.login(get_jwt_strategy(), active_user, response) - else: - if auth_config.mode == "token": - global_user = await user_db.get_by_email(auth_config.global_email) - if global_user and global_user.hashed_password == token: - active_user = global_user - await cookie_authentication.login(get_jwt_strategy(), active_user, response) - - if active_user: - return active_user - - elif auth_config.login_url: - raise RedirectException(auth_config.login_url) - - else: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) +def current_user(resource: Optional[str] = None): + async def _( + request: Request, + response: Response, + token: Optional[str] = None, + user: Optional[User] = Depends( + fapi_users.current_user(optional=True, get_enabled_backends=get_enabled_backends) + ), + user_db=Depends(get_user_db), + auth_config=Depends(get_auth_config), + ): + if auth_config.mode == "user": + # "user" authentication: check authorization + if user and resource: + # check if allowed to access the resource + permissions = user.permissions.get(resource, []) + if request.method in ("GET", "HEAD"): + if "read" not in permissions: + user = None + elif request.method in ("POST", "PUT", "PATCH", "DELETE"): + if "write" not in permissions: + user = None + else: + # "noauth" or "token" authentication + if auth_config.collaborative: + if not user and auth_config.mode == "noauth": + user = await create_guest(user_db, auth_config) + await cookie_authentication.login(get_jwt_strategy(), user, response) + + elif not user and auth_config.mode == "token": + global_user = await user_db.get_by_email(auth_config.global_email) + if global_user and global_user.hashed_password == token: + user = await create_guest(user_db, auth_config) + await cookie_authentication.login(get_jwt_strategy(), user, response) + else: + if auth_config.mode == "token": + global_user = await user_db.get_by_email(auth_config.global_email) + if global_user and global_user.hashed_password == token: + user = global_user + await cookie_authentication.login(get_jwt_strategy(), user, response) + + if user: + return user + + elif auth_config.login_url: + raise RedirectException(auth_config.login_url) + + else: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) + + return _ + + +def websocket_for_current_user(resource: str): + async def _( + websocket: WebSocket, + auth_config=Depends(get_auth_config), + user_manager: UserManager = Depends(get_user_manager), + ) -> Optional[WebSocket]: + accept_websocket = False + if auth_config.mode == "noauth": + accept_websocket = True + elif "fastapiusersauth" in websocket._cookies: + token = websocket._cookies["fastapiusersauth"] + user = await get_jwt_strategy().read_token(token, user_manager) + if user: + if auth_config.mode == "user": + if "execute" in user.permissions.get(resource, []): + accept_websocket = True + else: + accept_websocket = True + if accept_websocket: + return websocket + else: + await websocket.close(code=status.WS_1008_POLICY_VIOLATION) + return None + + return _ diff --git a/plugins/auth/fps_auth/config.py b/plugins/auth/fps_auth/config.py index fefd4fcc..1e42eedb 100644 --- a/plugins/auth/fps_auth/config.py +++ b/plugins/auth/fps_auth/config.py @@ -2,11 +2,11 @@ from uuid import uuid4 from fps.config import PluginModel, get_config # type: ignore -from fps.hooks import register_config, register_plugin_name # type: ignore -from pydantic import SecretStr +from fps.hooks import register_config # type: ignore +from pydantic import BaseSettings, SecretStr -class AuthConfig(PluginModel): +class AuthConfig(PluginModel, BaseSettings): client_id: str = "" client_secret: SecretStr = SecretStr("") redirect_uri: str = "" @@ -17,12 +17,15 @@ class AuthConfig(PluginModel): global_email: str = "guest@jupyter.com" cookie_secure: bool = False # FIXME: should default to True, and set to False for tests clear_users: bool = False + test: bool = False login_url: Optional[str] = None + class Config(PluginModel.Config): + env_prefix = "fps_auth_" + def get_auth_config(): return get_config(AuthConfig) c = register_config(AuthConfig) -n = register_plugin_name("authenticator") diff --git a/plugins/auth/fps_auth/db.py b/plugins/auth/fps_auth/db.py index 6e41565f..8e38aa90 100644 --- a/plugins/auth/fps_auth/db.py +++ b/plugins/auth/fps_auth/db.py @@ -9,7 +9,7 @@ SQLAlchemyUserDatabase, ) from fps.config import get_config # type: ignore -from sqlalchemy import Boolean, Column, String, Text # type: ignore +from sqlalchemy import JSON, Boolean, Column, String, Text # type: ignore from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine # type: ignore from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base # type: ignore from sqlalchemy.orm import relationship, sessionmaker # type: ignore @@ -20,8 +20,11 @@ jupyter_dir = Path.home() / ".local" / "share" / "jupyter" jupyter_dir.mkdir(parents=True, exist_ok=True) -secret_path = jupyter_dir / "jupyverse_secret" -userdb_path = jupyter_dir / "jupyverse_users.db" +name = "jupyverse" +if auth_config.test: + name += "_test" +secret_path = jupyter_dir / f"{name}_secret" +userdb_path = jupyter_dir / f"{name}_users.db" if auth_config.clear_users: if userdb_path.is_file(): @@ -30,11 +33,9 @@ secret_path.unlink() if not secret_path.is_file(): - with open(secret_path, "w") as f: - f.write(secrets.token_hex(32)) + secret_path.write_text(secrets.token_hex(32)) -with open(secret_path) as f: - secret = f.read() +secret = secret_path.read_text() DATABASE_URL = f"sqlite+aiosqlite:///{userdb_path}" @@ -48,13 +49,15 @@ class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base): class User(SQLAlchemyBaseUserTableUUID, Base): anonymous = Column(Boolean, default=True, nullable=False) email = Column(String(length=32), nullable=False, unique=True) - username = Column(String(length=32), nullable=True, unique=True) - name = Column(String(length=32), nullable=True) + username = Column(String(length=32), nullable=False, unique=True) + name = Column(String(length=32), default="") + display_name = Column(String(length=32), default="") + initials = Column(String(length=8), nullable=True) color = Column(String(length=32), nullable=True) - avatar = Column(String(length=32), nullable=True) + avatar_url = Column(String(length=32), nullable=True) workspace = Column(Text(), default="{}", nullable=False) settings = Column(Text(), default="{}", nullable=False) - permissions = Column(Text(), default="{}", nullable=False) + permissions = Column(JSON, default={}, nullable=False) oauth_accounts: List[OAuthAccount] = relationship("OAuthAccount", lazy="joined") diff --git a/plugins/auth/fps_auth/fixtures.py b/plugins/auth/fps_auth/fixtures.py index 279d2aa0..d86733f3 100644 --- a/plugins/auth/fps_auth/fixtures.py +++ b/plugins/auth/fps_auth/fixtures.py @@ -1,3 +1,5 @@ +from uuid import uuid4 + import pytest # type: ignore from fps_auth.config import AuthConfig, get_auth_config @@ -9,7 +11,7 @@ def auth_mode(): @pytest.fixture def auth_config(auth_mode): - yield AuthConfig.parse_obj({"mode": auth_mode}) + yield AuthConfig.parse_obj({"mode": auth_mode, "test": True}) @pytest.fixture @@ -18,3 +20,65 @@ async def override_get_config(): return auth_config app.dependency_overrides[get_auth_config] = override_get_config + + +@pytest.fixture() +def permissions(): + return {} + + +@pytest.fixture() +def authenticated_client(client, permissions): + # create a new user + username = uuid4().hex + # if logged in, log out + first_time = True + while True: + response = client.get("/api/me") + if response.status_code == 403: + break + assert first_time + response = client.post("/auth/logout") + assert response.status_code == 200 + first_time = False + + # register user + register_body = { + "email": username + "@example.com", + "password": username, + "username": username, + "permissions": permissions, + } + response = client.post("/auth/register", json=register_body) + # check that we cannot register if not logged in + assert response.status_code == 403 + # log in as admin + login_body = {"username": "admin@jupyter.com", "password": "jupyverse"} + response = client.post("/auth/login", data=login_body) + assert response.status_code == 200 + # register user + response = client.post("/auth/register", json=register_body) + assert response.status_code == 201 + + # FIXME: + # # log out + # response = client.post("/auth/logout") + # assert response.status_code == 200 + # # check that we can't get our identity, since we're not logged in + # response = client.get("/api/me") + # assert response.status_code == 403 + + # log in with registered user + login_body = {"username": username + "@example.com", "password": username} + response = client.post("/auth/login", data=login_body) + assert response.status_code == 200 + # we should now have a cookie + assert "fastapiusersauth" in client.cookies + # check our identity, since we're logged in + response = client.get("/api/me", json={"permissions": permissions}) + assert response.status_code == 200 + me = response.json() + assert me["identity"]["username"] == username + # check our permissions + assert me["permissions"] == permissions + yield client diff --git a/plugins/auth/fps_auth/models.py b/plugins/auth/fps_auth/models.py index c078ac9d..4354d008 100644 --- a/plugins/auth/fps_auth/models.py +++ b/plugins/auth/fps_auth/models.py @@ -5,33 +5,28 @@ from pydantic import BaseModel -class JupyterUser(BaseModel): +class Permissions(BaseModel): + permissions: Dict[str, List[str]] + + +class JupyterUser(Permissions): anonymous: bool = True username: str = "" - name: Optional[str] = None + name: str = "" + display_name: str = "" + initials: Optional[str] = None color: Optional[str] = None - avatar: Optional[str] = None + avatar_url: Optional[str] = None workspace: str = "{}" settings: str = "{}" - permissions: str = "{}" - - -class Permissions(BaseModel): - __root__: Dict[str, List[str]] - - def items(self): - return self.__root__.items() class UserRead(schemas.BaseUser[uuid.UUID], JupyterUser): pass -class UserCreate(schemas.BaseUserCreate): - anonymous: bool = True - username: Optional[str] = None - name: Optional[str] = None - color: Optional[str] = None +class UserCreate(schemas.BaseUserCreate, JupyterUser): + pass class UserUpdate(schemas.BaseUserUpdate, JupyterUser): diff --git a/plugins/auth/fps_auth/routes.py b/plugins/auth/fps_auth/routes.py index 223b0b42..04c52799 100644 --- a/plugins/auth/fps_auth/routes.py +++ b/plugins/auth/fps_auth/routes.py @@ -1,9 +1,8 @@ -import json -from typing import Dict, List +from typing import Dict, List, Optional from uuid import uuid4 from fastapi import APIRouter, Depends -from fastapi.exceptions import HTTPException +from fastapi_users.password import PasswordHelper from fps.config import get_config # type: ignore from fps.hooks import register_router # type: ignore from fps.logging import get_configured_logger # type: ignore @@ -36,11 +35,27 @@ async def startup(): async with UserDb() as user_db: - global_user = await user_db.get_by_email(auth_config.global_email) + if auth_config.test: + admin_user = await user_db.get_by_email("admin@jupyter.com") + if not admin_user: + admin_user = dict( + id=uuid4(), + anonymous=False, + email="admin@jupyter.com", + username="admin@jupyter.com", + hashed_password=PasswordHelper().hash("jupyverse"), + is_superuser=True, + is_active=True, + is_verified=True, + workspace="{}", + settings="{}", + permissions={"admin": ["read", "write"]}, + ) + await user_db.create(admin_user) + global_user = await user_db.get_by_email(auth_config.global_email) if global_user: await user_db.update(global_user, {"hashed_password": auth_config.token}) - else: global_user = dict( id=uuid4(), @@ -53,7 +68,7 @@ async def startup(): is_verified=True, workspace="{}", settings="{}", - permissions="{}", + permissions={}, ) await user_db.create(global_user) @@ -67,7 +82,7 @@ async def startup(): @router.get("/auth/users") -async def get_users(user: UserRead = Depends(current_user)): +async def get_users(user: UserRead = Depends(current_user("admin"))): async with Session() as session: statement = select(User) users = (await session.execute(statement)).unique().all() @@ -76,30 +91,23 @@ async def get_users(user: UserRead = Depends(current_user)): @router.get("/api/me") async def get_api_me( - permissions, - user: UserRead = Depends(current_user), + permissions_to_check: Optional[Permissions] = None, + user: UserRead = Depends(current_user()), ): - try: - permissions_to_check = Permissions.parse_obj(json.loads(permissions)) - except BaseException: - raise HTTPException( - 400, - detail='permissions should be a JSON dict of {{"resource": ["action",]}}, ' - f"got {permissions}", - ) - - user_permissions = json.loads(user.permissions) checked_permissions: Dict[str, List[str]] = {} - for resource, actions in permissions_to_check.items(): - user_resource_permissions = user_permissions.get(resource) - if user_resource_permissions is None: - continue - allowed = checked_permissions[resource] = [] - for action in actions: - if action in user_resource_permissions: - allowed.append(action) - - keys = ["email", "name", "avatar", "anonymous", "username", "color"] + if permissions_to_check is not None: + permissions = permissions_to_check.permissions + user_permissions = user.permissions + for resource, actions in permissions.items(): + user_resource_permissions = user_permissions.get(resource) + if user_resource_permissions is None: + continue + allowed = checked_permissions[resource] = [] + for action in actions: + if action in user_resource_permissions: + allowed.append(action) + + keys = ["username", "name", "display_name", "initials", "avatar_url", "color"] identity = {k: getattr(user, k) for k in keys} return { "identity": identity, @@ -113,7 +121,7 @@ async def get_api_me( @users_router.get("/me") -async def get_me(user: UserRead = Depends(current_user)): +async def get_me(user: UserRead = Depends(current_user("admin"))): return user @@ -121,7 +129,11 @@ async def get_me(user: UserRead = Depends(current_user)): # Cookie based auth login and logout r_cookie_auth = register_router(fapi_users.get_auth_router(cookie_authentication), prefix="/auth") -r_register = register_router(fapi_users.get_register_router(UserRead, UserCreate), prefix="/auth") +r_register = register_router( + fapi_users.get_register_router(UserRead, UserCreate), + prefix="/auth", + dependencies=[Depends(current_user("admin"))], +) r_user = register_router(users_router, prefix="/auth/user") # GitHub OAuth register router diff --git a/plugins/contents/fps_contents/routes.py b/plugins/contents/fps_contents/routes.py index 2f9e6077..5b111129 100644 --- a/plugins/contents/fps_contents/routes.py +++ b/plugins/contents/fps_contents/routes.py @@ -22,7 +22,7 @@ "/api/contents/{path:path}/checkpoints", status_code=201, ) -async def create_checkpoint(path, user: UserRead = Depends(current_user)): +async def create_checkpoint(path, user: UserRead = Depends(current_user("contents"))): src_path = Path(path) dst_path = Path(".ipynb_checkpoints") / f"{src_path.stem}-checkpoint{src_path.suffix}" try: @@ -42,7 +42,7 @@ async def create_checkpoint(path, user: UserRead = Depends(current_user)): async def create_content( path: Optional[str], request: Request, - user: UserRead = Depends(current_user), + user: UserRead = Depends(current_user("contents")), ): create_content = CreateContent(**(await request.json())) content_path = Path(create_content.path) @@ -75,13 +75,13 @@ async def create_content( @router.get("/api/contents") async def get_root_content( content: int, - user: UserRead = Depends(current_user), + user: UserRead = Depends(current_user("contents")), ): return await read_content("", bool(content)) @router.get("/api/contents/{path:path}/checkpoints") -async def get_checkpoint(path, user: UserRead = Depends(current_user)): +async def get_checkpoint(path, user: UserRead = Depends(current_user("contents"))): src_path = Path(path) dst_path = Path(".ipynb_checkpoints") / f"{src_path.stem}-checkpoint{src_path.suffix}" if not dst_path.exists(): @@ -94,7 +94,7 @@ async def get_checkpoint(path, user: UserRead = Depends(current_user)): async def get_content( path: str, content: int = 0, - user: UserRead = Depends(current_user), + user: UserRead = Depends(current_user("contents")), ): return await read_content(path, bool(content)) @@ -104,7 +104,7 @@ async def save_content( path, request: Request, response: Response, - user: UserRead = Depends(current_user), + user: UserRead = Depends(current_user("contents")), ): content = SaveContent(**(await request.json())) try: @@ -120,7 +120,7 @@ async def save_content( ) async def delete_content( path, - user: UserRead = Depends(current_user), + user: UserRead = Depends(current_user("contents")), ): p = Path(path) if p.exists(): @@ -135,7 +135,7 @@ async def delete_content( async def rename_content( path, request: Request, - user: UserRead = Depends(current_user), + user: UserRead = Depends(current_user("contents")), ): rename_content = RenameContent(**(await request.json())) Path(path).rename(rename_content.path) diff --git a/plugins/jupyterlab/fps_jupyterlab/routes.py b/plugins/jupyterlab/fps_jupyterlab/routes.py index e1816b8f..a1aac188 100644 --- a/plugins/jupyterlab/fps_jupyterlab/routes.py +++ b/plugins/jupyterlab/fps_jupyterlab/routes.py @@ -37,7 +37,7 @@ @router.get("/lab") async def get_lab( - user: UserRead = Depends(current_user), + user: UserRead = Depends(current_user()), lab_config=Depends(get_lab_config), auth_config=Depends(get_auth_config), ): @@ -56,7 +56,7 @@ async def load_workspace( @router.get("/lab/api/workspaces/{name}") -async def get_workspace_data(user: UserRead = Depends(current_user)): +async def get_workspace_data(user: UserRead = Depends(current_user())): if user: return json.loads(user.workspace) return {} @@ -68,7 +68,7 @@ async def get_workspace_data(user: UserRead = Depends(current_user)): ) async def set_workspace( request: Request, - user: UserRead = Depends(current_user), + user: UserRead = Depends(current_user()), user_db=Depends(get_user_db), ): await user_db.update(user, {"workspace": await request.body()}) @@ -78,7 +78,7 @@ async def set_workspace( @router.get("/lab/workspaces/{name}", response_class=HTMLResponse) async def get_workspace( name, - user: UserRead = Depends(current_user), + user: UserRead = Depends(current_user()), lab_config=Depends(get_lab_config), auth_config=Depends(get_auth_config), ): diff --git a/plugins/kernels/fps_kernels/routes.py b/plugins/kernels/fps_kernels/routes.py index 58c1f952..c4b23190 100644 --- a/plugins/kernels/fps_kernels/routes.py +++ b/plugins/kernels/fps_kernels/routes.py @@ -4,16 +4,11 @@ import uuid from http import HTTPStatus -from fastapi import APIRouter, Depends, Response, WebSocket, status +from fastapi import APIRouter, Depends, Response from fastapi.responses import FileResponse from fps.hooks import register_router # type: ignore -from fps_auth.backends import ( # type: ignore - UserManager, - current_user, - get_jwt_strategy, - get_user_manager, -) -from fps_auth.config import get_auth_config # type: ignore +from fps_auth.backends import current_user # type: ignore +from fps_auth.backends import websocket_for_current_user # type: ignore from fps_auth.models import UserRead # type: ignore from fps_lab.config import get_lab_config # type: ignore from fps_yjs.routes import YDocWebSocketHandler # type: ignore @@ -42,7 +37,7 @@ async def stop_kernels(): @router.get("/api/kernelspecs") async def get_kernelspecs( - lab_config=Depends(get_lab_config), user: UserRead = Depends(current_user) + lab_config=Depends(get_lab_config), user: UserRead = Depends(current_user("kernelspecs")) ): for path in (prefix_dir / "share" / "jupyter" / "kernels").glob("*/kernel.json"): with open(path) as f: @@ -61,13 +56,13 @@ async def get_kernelspecs( async def get_kernelspec( kernel_name, file_name, - user: UserRead = Depends(current_user), + user: UserRead = Depends(current_user()), ): return FileResponse(prefix_dir / "share" / "jupyter" / "kernels" / kernel_name / file_name) @router.get("/api/kernels") -async def get_kernels(user: UserRead = Depends(current_user)): +async def get_kernels(user: UserRead = Depends(current_user("kernels"))): results = [] for kernel_id, kernel in kernels.items(): results.append( @@ -85,7 +80,7 @@ async def get_kernels(user: UserRead = Depends(current_user)): @router.delete("/api/sessions/{session_id}", status_code=204) async def delete_session( session_id: str, - user: UserRead = Depends(current_user), + user: UserRead = Depends(current_user("sessions")), ): kernel_id = sessions[session_id]["kernel"]["id"] kernel_server = kernels[kernel_id]["server"] @@ -98,7 +93,7 @@ async def delete_session( @router.patch("/api/sessions/{session_id}") async def rename_session( request: Request, - user: UserRead = Depends(current_user), + user: UserRead = Depends(current_user("sessions")), ): rename_session = await request.json() session_id = rename_session.pop("id") @@ -108,7 +103,7 @@ async def rename_session( @router.get("/api/sessions") -async def get_sessions(user: UserRead = Depends(current_user)): +async def get_sessions(user: UserRead = Depends(current_user("sessions"))): for session in sessions.values(): kernel_id = session["kernel"]["id"] kernel_server = kernels[kernel_id]["server"] @@ -124,7 +119,7 @@ async def get_sessions(user: UserRead = Depends(current_user)): ) async def create_session( request: Request, - user: UserRead = Depends(current_user), + user: UserRead = Depends(current_user("sessions")), ): create_session = await request.json() kernel_name = create_session["kernel"]["name"] @@ -158,7 +153,7 @@ async def create_session( @router.post("/api/kernels/{kernel_id}/restart") async def restart_kernel( kernel_id, - user: UserRead = Depends(current_user), + user: UserRead = Depends(current_user("kernels")), ): if kernel_id in kernels: kernel = kernels[kernel_id] @@ -177,7 +172,7 @@ async def restart_kernel( async def execute_cell( request: Request, kernel_id, - user: UserRead = Depends(current_user), + user: UserRead = Depends(current_user("kernels")), ): r = await request.json() execution = Execution(**r) @@ -206,7 +201,7 @@ async def execute_cell( @router.get("/api/kernels/{kernel_id}") async def get_kernel( kernel_id, - user: UserRead = Depends(current_user), + user: UserRead = Depends(current_user("kernels")), ): if kernel_id in kernels: kernel = kernels[kernel_id] @@ -222,33 +217,20 @@ async def get_kernel( @router.websocket("/api/kernels/{kernel_id}/channels") async def kernel_channels( - websocket: WebSocket, kernel_id, session_id, - auth_config=Depends(get_auth_config), - user_manager: UserManager = Depends(get_user_manager), + websocket=Depends(websocket_for_current_user("kernels")), ): - accept_websocket = False - if auth_config.mode == "noauth": - accept_websocket = True - elif "fastapiusersauth" in websocket._cookies: - token = websocket._cookies["fastapiusersauth"] - user = await get_jwt_strategy().read_token(token, user_manager) - if user: - accept_websocket = True - if accept_websocket: - subprotocol = ( - "v1.kernel.websocket.jupyter.org" - if "v1.kernel.websocket.jupyter.org" in websocket["subprotocols"] - else None - ) - await websocket.accept(subprotocol=subprotocol) - accepted_websocket = AcceptedWebSocket(websocket, subprotocol) - if kernel_id in kernels: - kernel_server = kernels[kernel_id]["server"] - await kernel_server.serve(accepted_websocket, session_id) - else: - await websocket.close(code=status.WS_1008_POLICY_VIOLATION) + subprotocol = ( + "v1.kernel.websocket.jupyter.org" + if "v1.kernel.websocket.jupyter.org" in websocket["subprotocols"] + else None + ) + await websocket.accept(subprotocol=subprotocol) + accepted_websocket = AcceptedWebSocket(websocket, subprotocol) + if kernel_id in kernels: + kernel_server = kernels[kernel_id]["server"] + await kernel_server.serve(accepted_websocket, session_id) r = register_router(router) diff --git a/plugins/lab/fps_lab/routes.py b/plugins/lab/fps_lab/routes.py index 369a1f7a..4f0cec43 100644 --- a/plugins/lab/fps_lab/routes.py +++ b/plugins/lab/fps_lab/routes.py @@ -60,7 +60,7 @@ def init_router(router, redirect_after_root): async def get_root( response: Response, lab_config=Depends(get_lab_config), - user: UserRead = Depends(current_user), + user: UserRead = Depends(current_user()), ): # auto redirect response.status_code = status.HTTP_302_FOUND @@ -77,7 +77,7 @@ async def get_mathjax(rest_of_path): ) @router.get("/lab/api/listings/@jupyterlab/extensionmanager-extension/listings.json") - async def get_listings(user: UserRead = Depends(current_user)): + async def get_listings(user: UserRead = Depends(current_user())): return { "blocked_extensions_uris": [], "allowed_extensions_uris": [], @@ -88,12 +88,12 @@ async def get_listings(user: UserRead = Depends(current_user)): @router.get("/lab/api/translations/") async def get_translations_( lab_config=Depends(get_lab_config), - user: UserRead = Depends(current_user), + user: UserRead = Depends(current_user()), ): return RedirectResponse(f"{lab_config.base_url}lab/api/translations") @router.get("/lab/api/translations") - async def get_translations(user: UserRead = Depends(current_user)): + async def get_translations(user: UserRead = Depends(current_user())): locale = Locale.parse("en") data = { "en": { @@ -112,7 +112,7 @@ async def get_translations(user: UserRead = Depends(current_user)): @router.get("/lab/api/translations/{language}") async def get_translation( language, - user: UserRead = Depends(current_user), + user: UserRead = Depends(current_user()), ): global LOCALE if language == "en": @@ -138,7 +138,7 @@ async def get_setting( name0, name1, name2, - user: UserRead = Depends(current_user), + user: UserRead = Depends(current_user()), ): with open(jlab_dir / "static" / "package.json") as f: package = json.load(f) @@ -173,7 +173,7 @@ async def change_setting( request: Request, name0, name1, - user: UserRead = Depends(current_user), + user: UserRead = Depends(current_user()), user_db=Depends(get_user_db), ): settings = json.loads(user.settings) @@ -182,7 +182,7 @@ async def change_setting( return Response(status_code=HTTPStatus.NO_CONTENT.value) @router.get("/lab/api/settings") - async def get_settings(user: UserRead = Depends(current_user)): + async def get_settings(user: UserRead = Depends(current_user())): with open(jlab_dir / "static" / "package.json") as f: package = json.load(f) if user: diff --git a/plugins/nbconvert/fps_nbconvert/routes.py b/plugins/nbconvert/fps_nbconvert/routes.py index 1bb787a4..cb640e64 100644 --- a/plugins/nbconvert/fps_nbconvert/routes.py +++ b/plugins/nbconvert/fps_nbconvert/routes.py @@ -24,7 +24,7 @@ async def get_nbconvert_document( format: str, path: str, download: bool, - user: UserRead = Depends(current_user), + user: UserRead = Depends(current_user("nbconvert")), ): exporter = nbconvert.exporters.get_exporter(format) if download: diff --git a/plugins/retrolab/fps_retrolab/routes.py b/plugins/retrolab/fps_retrolab/routes.py index 552f953d..06ec57d8 100644 --- a/plugins/retrolab/fps_retrolab/routes.py +++ b/plugins/retrolab/fps_retrolab/routes.py @@ -44,7 +44,7 @@ @router.get("/retro/tree", response_class=HTMLResponse) async def get_tree( - user: UserRead = Depends(current_user), + user: UserRead = Depends(current_user()), lab_config=Depends(get_lab_config), auth_config=Depends(get_auth_config), ): @@ -54,7 +54,7 @@ async def get_tree( @router.get("/retro/notebooks/{path:path}", response_class=HTMLResponse) async def get_notebook( path, - user: UserRead = Depends(current_user), + user: UserRead = Depends(current_user()), lab_config=Depends(get_lab_config), auth_config=Depends(get_auth_config), ): @@ -64,7 +64,7 @@ async def get_notebook( @router.get("/retro/edit/{path:path}", response_class=HTMLResponse) async def edit_file( path, - user: UserRead = Depends(current_user), + user: UserRead = Depends(current_user()), lab_config=Depends(get_lab_config), auth_config=Depends(get_auth_config), ): @@ -74,7 +74,7 @@ async def edit_file( @router.get("/retro/consoles/{path:path}", response_class=HTMLResponse) async def get_console( path, - user: UserRead = Depends(current_user), + user: UserRead = Depends(current_user()), lab_config=Depends(get_lab_config), auth_config=Depends(get_auth_config), ): @@ -84,7 +84,7 @@ async def get_console( @router.get("/retro/terminals/{name}", response_class=HTMLResponse) async def get_terminal( name: str, - user: UserRead = Depends(current_user), + user: UserRead = Depends(current_user()), lab_config=Depends(get_lab_config), auth_config=Depends(get_auth_config), ): diff --git a/plugins/terminals/fps_terminals/routes.py b/plugins/terminals/fps_terminals/routes.py index 2f8322e5..845305f8 100644 --- a/plugins/terminals/fps_terminals/routes.py +++ b/plugins/terminals/fps_terminals/routes.py @@ -3,11 +3,9 @@ from http import HTTPStatus from typing import Any, Dict -from fastapi import APIRouter, Depends, Response, WebSocket, status +from fastapi import APIRouter, Depends, Response from fps.hooks import register_router # type: ignore -from fps_auth.backends import current_user, get_jwt_strategy # type: ignore -from fps_auth.config import get_auth_config # type: ignore -from fps_auth.db import get_user_db # type: ignore +from fps_auth.backends import current_user, websocket_for_current_user # type: ignore from fps_auth.models import UserRead # type: ignore from .models import Terminal @@ -29,7 +27,7 @@ async def get_terminals(): @router.post("/api/terminals") async def create_terminal( - user: UserRead = Depends(current_user), + user: UserRead = Depends(current_user("terminals")), ): name = str(len(TERMINALS) + 1) terminal = Terminal( @@ -46,7 +44,7 @@ async def create_terminal( @router.delete("/api/terminals/{name}", status_code=204) async def delete_terminal( name: str, - user: UserRead = Depends(current_user), + user: UserRead = Depends(current_user("terminals")), ): for websocket in TERMINALS[name]["server"].websockets: TERMINALS[name]["server"].quit(websocket) @@ -56,28 +54,15 @@ async def delete_terminal( @router.websocket("/terminals/websocket/{name}") async def terminal_websocket( - websocket: WebSocket, name, - auth_config=Depends(get_auth_config), - user_db=Depends(get_user_db), + websocket=Depends(websocket_for_current_user("terminals")), ): - accept_websocket = False - if auth_config.mode == "noauth": - accept_websocket = True - else: - cookie = websocket._cookies["fastapiusersauth"] - user = await get_jwt_strategy().read_token(cookie, user_db) - if user: - accept_websocket = True - if accept_websocket: - await websocket.accept() - await TERMINALS[name]["server"].serve(websocket) - if name in TERMINALS: - TERMINALS[name]["server"].quit(websocket) - if not TERMINALS[name]["server"].websockets: - del TERMINALS[name] - else: - await websocket.close(code=status.WS_1008_POLICY_VIOLATION) + await websocket.accept() + await TERMINALS[name]["server"].serve(websocket) + if name in TERMINALS: + TERMINALS[name]["server"].quit(websocket) + if not TERMINALS[name]["server"].websockets: + del TERMINALS[name] r = register_router(router) diff --git a/plugins/yjs/fps_yjs/routes.py b/plugins/yjs/fps_yjs/routes.py index 06137aeb..6fcb2de9 100644 --- a/plugins/yjs/fps_yjs/routes.py +++ b/plugins/yjs/fps_yjs/routes.py @@ -5,14 +5,9 @@ from typing import Optional, Set, Tuple import fastapi -from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect, status +from fastapi import APIRouter, Depends, WebSocketDisconnect from fps.hooks import register_router # type: ignore -from fps_auth.backends import ( # type: ignore - UserManager, - get_jwt_strategy, - get_user_manager, -) -from fps_auth.config import get_auth_config # type: ignore +from fps_auth.backends import websocket_for_current_user # type: ignore from fps_contents.routes import read_content, write_content # type: ignore from jupyter_ydoc import ydocs as YDOCS # type: ignore from ypy_websocket.websocket_server import WebsocketServer, YRoom # type: ignore @@ -44,25 +39,12 @@ def to_datetime(iso_date: str) -> datetime: @router.websocket("/api/yjs/{path:path}") async def websocket_endpoint( - websocket: WebSocket, path, - auth_config=Depends(get_auth_config), - user_manager: UserManager = Depends(get_user_manager), + websocket=Depends(websocket_for_current_user("yjs")), ): - accept_websocket = False - if auth_config.mode == "noauth": - accept_websocket = True - elif "fastapiusersauth" in websocket._cookies: - token = websocket._cookies["fastapiusersauth"] - user = await get_jwt_strategy().read_token(token, user_manager) - if user: - accept_websocket = True - if accept_websocket: - await websocket.accept() - socket = YDocWebSocketHandler(WebsocketAdapter(websocket, path), path) - await socket.serve() - else: - await websocket.close(code=status.WS_1008_POLICY_VIOLATION) + await websocket.accept() + socket = YDocWebSocketHandler(WebsocketAdapter(websocket, path), path) + await socket.serve() class WebsocketAdapter: diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 00000000..8218a9e9 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +env = + FPS_AUTH_TEST=True diff --git a/setup.cfg b/setup.cfg index 76e99bd3..7dc6953f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -42,6 +42,7 @@ test = mypy pytest pytest-asyncio + pytest-env requests websockets ipykernel diff --git a/tests/conftest.py b/tests/conftest.py index cf61f0e3..ceb8de3d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,7 +3,6 @@ import subprocess import time from pathlib import Path -from uuid import uuid4 import pytest @@ -18,51 +17,6 @@ def cwd(): return Path(__file__).parent.parent -@pytest.fixture() -def authenticated_client(client): - # create a new user - username = uuid4().hex - # if logged in, log out - first_time = True - while True: - response = client.get("/auth/user/me") - if response.status_code == 401: - break - assert first_time - response = client.post("/auth/logout") - first_time = False - - # register user - register_body = { - "email": username + "@example.com", - "password": username, - "username": username, - } - response = client.post("/auth/register", json=register_body) - assert response.status_code == 201 - # check that we can't list users yet, since we're not logged in - response = client.get("/auth/users") - assert response.status_code == 401 - # log in with registered user - login_body = {"username": username + "@example.com", "password": username} - assert "fastapiusersauth" not in client.cookies - response = client.post("/auth/login", data=login_body) - assert response.status_code == 200 - # we should now have a cookie - assert "fastapiusersauth" in client.cookies - # check that we can list users now, since we are logged in - response = client.get("/auth/users") - assert response.status_code == 200 - users = response.json() - assert username in [user["username"] for user in users] - # who am I? - response = client.get("/auth/user/me") - assert response.status_code != 401 - me = response.json() - assert me["username"] == username - yield client - - def get_open_port(): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.bind(("", 0)) @@ -79,8 +33,8 @@ def start_jupyverse(auth_mode, clear_users, cwd, capfd): command_list = [ "jupyverse", "--no-open-browser", - f"--authenticator.mode={auth_mode}", - "--authenticator.clear_users=" + str(clear_users).lower(), + f"--auth.mode={auth_mode}", + "--auth.clear_users=" + str(clear_users).lower(), f"--port={port}", ] print(" ".join(command_list)) diff --git a/tests/test_auth.py b/tests/test_auth.py index da0b9626..e1af73a9 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -25,7 +25,7 @@ def test_root_auth(auth_mode, client): expected = 200 content_type = "text/html; charset=utf-8" elif auth_mode in ["token", "user"]: - expected = 401 + expected = 403 content_type = "application/json" assert response.status_code == expected @@ -42,8 +42,26 @@ def test_no_auth(client): def test_token_auth(client): # no token provided, should not work response = client.get("/") - assert response.status_code == 401 + assert response.status_code == 403 # token provided, should work auth_config = get_auth_config() response = client.get(f"/?token={auth_config.token}") assert response.status_code == 200 + + +@pytest.mark.parametrize("auth_mode", ("user",)) +@pytest.mark.parametrize( + "permissions", + ( + {}, + {"admin": ["read"], "foo": ["bar", "baz"]}, + ), +) +def test_permissions(authenticated_client, permissions): + response = authenticated_client.get("/auth/user/me") + if "admin" in permissions.keys(): + # we have the permissions + assert response.status_code == 200 + else: + # we don't have the permissions + assert response.status_code == 403 From db3df02d013045e000879be9b6748ea33e5d5f01 Mon Sep 17 00:00:00 2001 From: David Brochart Date: Wed, 17 Aug 2022 19:00:31 +0200 Subject: [PATCH 3/4] Cleaned up programmatic user creation --- plugins/auth/fps_auth/backends.py | 2 +- plugins/auth/fps_auth/db.py | 6 +- plugins/auth/fps_auth/fixtures.py | 13 ++-- plugins/auth/fps_auth/routes.py | 100 ++++++++++++++++++------------ 4 files changed, 69 insertions(+), 52 deletions(-) diff --git a/plugins/auth/fps_auth/backends.py b/plugins/auth/fps_auth/backends.py index b8c4d6ca..fdfa07b8 100644 --- a/plugins/auth/fps_auth/backends.py +++ b/plugins/auth/fps_auth/backends.py @@ -101,7 +101,7 @@ async def on_after_register(self, user: User, request: Optional[Request] = None) ) -def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db)): +async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db)): yield UserManager(user_db) diff --git a/plugins/auth/fps_auth/db.py b/plugins/auth/fps_auth/db.py index 8e38aa90..1ab57851 100644 --- a/plugins/auth/fps_auth/db.py +++ b/plugins/auth/fps_auth/db.py @@ -62,7 +62,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base): engine = create_async_engine(DATABASE_URL) -Session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) +async_session_maker = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) async def create_db_and_tables(): @@ -71,7 +71,7 @@ async def create_db_and_tables(): async def get_async_session() -> AsyncGenerator[AsyncSession, None]: - async with Session() as session: + async with async_session_maker() as session: yield session @@ -81,7 +81,7 @@ async def get_user_db(session: AsyncSession = Depends(get_async_session)): class UserDb: async def __aenter__(self): - self.session = Session() + self.session = async_session_maker() session = await self.session.__aenter__() return SQLAlchemyUserDatabase(session, User, OAuthAccount) diff --git a/plugins/auth/fps_auth/fixtures.py b/plugins/auth/fps_auth/fixtures.py index d86733f3..9ad982f9 100644 --- a/plugins/auth/fps_auth/fixtures.py +++ b/plugins/auth/fps_auth/fixtures.py @@ -60,13 +60,12 @@ def authenticated_client(client, permissions): response = client.post("/auth/register", json=register_body) assert response.status_code == 201 - # FIXME: - # # log out - # response = client.post("/auth/logout") - # assert response.status_code == 200 - # # check that we can't get our identity, since we're not logged in - # response = client.get("/api/me") - # assert response.status_code == 403 + # log out + response = client.post("/auth/logout") + assert response.status_code == 200 + # check that we can't get our identity, since we're not logged in + response = client.get("/api/me") + assert response.status_code == 403 # log in with registered user login_body = {"username": username + "@example.com", "password": username} diff --git a/plugins/auth/fps_auth/routes.py b/plugins/auth/fps_auth/routes.py index 04c52799..6a67e376 100644 --- a/plugins/auth/fps_auth/routes.py +++ b/plugins/auth/fps_auth/routes.py @@ -1,8 +1,8 @@ +import contextlib from typing import Dict, List, Optional -from uuid import uuid4 from fastapi import APIRouter, Depends -from fastapi_users.password import PasswordHelper +from fastapi_users.exceptions import UserAlreadyExists from fps.config import get_config # type: ignore from fps.hooks import register_router # type: ignore from fps.logging import get_configured_logger # type: ignore @@ -13,11 +13,20 @@ cookie_authentication, current_user, fapi_users, + get_user_manager, github_authentication, github_cookie_authentication, ) from .config import get_auth_config -from .db import Session, User, UserDb, create_db_and_tables, secret +from .db import ( + User, + UserDb, + async_session_maker, + create_db_and_tables, + get_async_session, + get_user_db, + secret, +) from .models import Permissions, UserCreate, UserRead, UserUpdate logger = get_configured_logger("auth") @@ -27,50 +36,59 @@ router = APIRouter() +get_async_session_context = contextlib.asynccontextmanager(get_async_session) +get_user_db_context = contextlib.asynccontextmanager(get_user_db) +get_user_manager_context = contextlib.asynccontextmanager(get_user_manager) + + +async def create_user( + username: str, + email: str, + password: str, + is_superuser: bool = False, + permissions: Dict[str, List[str]] = {}, +): + async with get_async_session_context() as session: + async with get_user_db_context(session) as user_db: + async with get_user_manager_context(user_db) as user_manager: + await user_manager.create( + UserCreate( + username=username, + email=email, + password=password, + is_superuser=is_superuser, + permissions=permissions, + ) + ) + + @router.on_event("startup") async def startup(): await create_db_and_tables() auth_config = get_auth_config() - async with UserDb() as user_db: - - if auth_config.test: - admin_user = await user_db.get_by_email("admin@jupyter.com") - if not admin_user: - admin_user = dict( - id=uuid4(), - anonymous=False, - email="admin@jupyter.com", - username="admin@jupyter.com", - hashed_password=PasswordHelper().hash("jupyverse"), - is_superuser=True, - is_active=True, - is_verified=True, - workspace="{}", - settings="{}", - permissions={"admin": ["read", "write"]}, - ) - await user_db.create(admin_user) - - global_user = await user_db.get_by_email(auth_config.global_email) - if global_user: - await user_db.update(global_user, {"hashed_password": auth_config.token}) - else: - global_user = dict( - id=uuid4(), - anonymous=True, - email=auth_config.global_email, - username=auth_config.global_email, - hashed_password=auth_config.token, - is_superuser=False, - is_active=False, - is_verified=True, - workspace="{}", - settings="{}", - permissions={}, + if auth_config.test: + try: + await create_user( + username="admin@jupyter.com", + email="admin@jupyter.com", + password="jupyverse", + permissions={"admin": ["read", "write"]}, ) - await user_db.create(global_user) + except UserAlreadyExists: + pass + + try: + await create_user( + username=auth_config.global_email, + email=auth_config.global_email, + password=auth_config.token, + ) + except UserAlreadyExists: + async with UserDb() as user_db: + global_user = await user_db.get_by_email(auth_config.global_email) + await user_db.update(global_user, {"hashed_password": auth_config.token}) if auth_config.mode == "token": logger.info("") @@ -83,7 +101,7 @@ async def startup(): @router.get("/auth/users") async def get_users(user: UserRead = Depends(current_user("admin"))): - async with Session() as session: + async with async_session_maker() as session: statement = select(User) users = (await session.execute(statement)).unique().all() return [usr.User for usr in users if usr.User.is_active] From 4d5c2cfab6847b4b9a3312b3c6850d19d946e277 Mon Sep 17 00:00:00 2001 From: David Brochart Date: Thu, 18 Aug 2022 10:53:18 +0200 Subject: [PATCH 4/4] Store token in username --- plugins/auth/fps_auth/backends.py | 22 +++++++------- plugins/auth/fps_auth/db.py | 10 ------- plugins/auth/fps_auth/routes.py | 49 +++++++++++++++++-------------- 3 files changed, 38 insertions(+), 43 deletions(-) diff --git a/plugins/auth/fps_auth/backends.py b/plugins/auth/fps_auth/backends.py index fdfa07b8..a7f51d17 100644 --- a/plugins/auth/fps_auth/backends.py +++ b/plugins/auth/fps_auth/backends.py @@ -24,6 +24,7 @@ from .config import get_auth_config from .db import User, get_user_db, secret +from .models import UserCreate logger = get_configured_logger("auth") @@ -118,21 +119,20 @@ async def get_enabled_backends(auth_config=Depends(get_auth_config)): ) -async def create_guest(user_db, auth_config): +async def create_guest(user_manager, auth_config): # workspace and settings are copied from global user # but this is a new user - global_user = await user_db.get_by_email(auth_config.global_email) + global_user = await user_manager.get_by_email(auth_config.global_email) user_id = str(uuid.uuid4()) guest = dict( - id=user_id, anonymous=True, email=f"{user_id}@jupyter.com", username=f"{user_id}@jupyter.com", - hashed_password="", + password="", workspace=global_user.workspace, settings=global_user.settings, ) - return await user_db.create(guest) + return await user_manager.create(UserCreate(**guest)) def current_user(resource: Optional[str] = None): @@ -143,7 +143,7 @@ async def _( user: Optional[User] = Depends( fapi_users.current_user(optional=True, get_enabled_backends=get_enabled_backends) ), - user_db=Depends(get_user_db), + user_manager: UserManager = Depends(get_user_manager), auth_config=Depends(get_auth_config), ): if auth_config.mode == "user": @@ -161,18 +161,18 @@ async def _( # "noauth" or "token" authentication if auth_config.collaborative: if not user and auth_config.mode == "noauth": - user = await create_guest(user_db, auth_config) + user = await create_guest(user_manager, auth_config) await cookie_authentication.login(get_jwt_strategy(), user, response) elif not user and auth_config.mode == "token": - global_user = await user_db.get_by_email(auth_config.global_email) + global_user = await user_manager.get_by_email(auth_config.global_email) if global_user and global_user.hashed_password == token: - user = await create_guest(user_db, auth_config) + user = await create_guest(user_manager, auth_config) await cookie_authentication.login(get_jwt_strategy(), user, response) else: if auth_config.mode == "token": - global_user = await user_db.get_by_email(auth_config.global_email) - if global_user and global_user.hashed_password == token: + global_user = await user_manager.get_by_email(auth_config.global_email) + if global_user and global_user.username == token: user = global_user await cookie_authentication.login(get_jwt_strategy(), user, response) diff --git a/plugins/auth/fps_auth/db.py b/plugins/auth/fps_auth/db.py index 1ab57851..bd7b5865 100644 --- a/plugins/auth/fps_auth/db.py +++ b/plugins/auth/fps_auth/db.py @@ -77,13 +77,3 @@ async def get_async_session() -> AsyncGenerator[AsyncSession, None]: async def get_user_db(session: AsyncSession = Depends(get_async_session)): yield SQLAlchemyUserDatabase(session, User, OAuthAccount) - - -class UserDb: - async def __aenter__(self): - self.session = async_session_maker() - session = await self.session.__aenter__() - return SQLAlchemyUserDatabase(session, User, OAuthAccount) - - async def __aexit__(self, exc_type, exc_value, exc_tb): - return await self.session.__aexit__(exc_type, exc_value, exc_tb) diff --git a/plugins/auth/fps_auth/routes.py b/plugins/auth/fps_auth/routes.py index 6a67e376..f3b1d49a 100644 --- a/plugins/auth/fps_auth/routes.py +++ b/plugins/auth/fps_auth/routes.py @@ -20,7 +20,6 @@ from .config import get_auth_config from .db import ( User, - UserDb, async_session_maker, create_db_and_tables, get_async_session, @@ -41,25 +40,27 @@ get_user_manager_context = contextlib.asynccontextmanager(get_user_manager) -async def create_user( - username: str, - email: str, - password: str, - is_superuser: bool = False, - permissions: Dict[str, List[str]] = {}, -): +@contextlib.asynccontextmanager +async def _get_user_manager(): async with get_async_session_context() as session: async with get_user_db_context(session) as user_db: async with get_user_manager_context(user_db) as user_manager: - await user_manager.create( - UserCreate( - username=username, - email=email, - password=password, - is_superuser=is_superuser, - permissions=permissions, - ) - ) + yield user_manager + + +async def create_user(**kwargs): + async with _get_user_manager() as user_manager: + await user_manager.create(UserCreate(**kwargs)) + + +async def update_user(user, **kwargs): + async with _get_user_manager() as user_manager: + await user_manager.update(UserUpdate(**kwargs), user) + + +async def get_user_by_email(user_email): + async with _get_user_manager() as user_manager: + return await user_manager.get_by_email(user_email) @router.on_event("startup") @@ -81,14 +82,18 @@ async def startup(): try: await create_user( - username=auth_config.global_email, + username=auth_config.token, email=auth_config.global_email, - password=auth_config.token, + password="", + permissions={}, ) except UserAlreadyExists: - async with UserDb() as user_db: - global_user = await user_db.get_by_email(auth_config.global_email) - await user_db.update(global_user, {"hashed_password": auth_config.token}) + global_user = await get_user_by_email(auth_config.global_email) + await update_user( + global_user, + username=auth_config.token, + permissions={}, + ) if auth_config.mode == "token": logger.info("")