diff --git a/src/app/api/crud.py b/src/app/api/crud.py index be5514ce..a143fced 100644 --- a/src/app/api/crud.py +++ b/src/app/api/crud.py @@ -15,18 +15,18 @@ async def get(entry_id: int, table: Table) -> Dict[str, Any]: return await database.fetch_one(query=query) -async def fetch_all(table: Table, query_filters: Optional[List[Tuple[str, Any]]] = None) -> List[Dict[str, Any]]: +async def fetch_all(table: Table, query_filters: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]: query = table.select() - if isinstance(query_filters, list): - for query_filter in query_filters: - query = query.where(getattr(table.c, query_filter[0]) == query_filter[1]) + if isinstance(query_filters, dict): + for query_filter_key, query_filter_value in query_filters.items(): + query = query.where(getattr(table.c, query_filter_key) == query_filter_value) return await database.fetch_all(query=query) -async def fetch_one(table: Table, query_filters: List[Tuple[str, Any]]) -> Dict[str, Any]: +async def fetch_one(table: Table, query_filters: Dict[str, Any]) -> Dict[str, Any]: query = table.select() - for query_filter in query_filters: - query = query.where(getattr(table.c, query_filter[0]) == query_filter[1]) + for query_filter_key, query_filter_value in query_filters.items(): + query = query.where(getattr(table.c, query_filter_key) == query_filter_value) return await database.fetch_one(query=query) diff --git a/src/app/api/deps.py b/src/app/api/deps.py index 6666e2af..f69e31c0 100644 --- a/src/app/api/deps.py +++ b/src/app/api/deps.py @@ -29,7 +29,7 @@ def unauthorized_exception(detail: str, authenticate_value: str) -> HTTPExceptio async def get_current_access(security_scopes: SecurityScopes, token: str = Depends(reusable_oauth2)) -> AccessRead: - """Dependency to use as fastapi.security.Security with scopes. + """ Dependency to use as fastapi.security.Security with scopes. >>> @app.get("/users/me") >>> async def read_users_me(current_user: User = Security(get_current_access, scopes=["me"])): @@ -63,8 +63,7 @@ async def get_current_access(security_scopes: SecurityScopes, token: str = Depen async def get_current_user(access: AccessRead = Depends(get_current_access)) -> UserRead: - user = await crud.fetch_one(users, [('access_id', access.id)]) - + user = await crud.fetch_one(users, {'access_id': access.id}) if user is None: raise HTTPException(status_code=400, detail="Permission denied") @@ -72,7 +71,7 @@ async def get_current_user(access: AccessRead = Depends(get_current_access)) -> async def get_current_device(access: AccessRead = Depends(get_current_access)) -> DeviceOut: - device = await crud.fetch_one(devices, [('access_id', access.id)]) + device = await crud.fetch_one(devices, {'access_id': access.id}) if device is None: raise HTTPException(status_code=400, detail="Permission denied") diff --git a/src/app/api/routes/accesses.py b/src/app/api/routes/accesses.py index d554a51d..cc463e6d 100644 --- a/src/app/api/routes/accesses.py +++ b/src/app/api/routes/accesses.py @@ -10,7 +10,7 @@ async def post_access(login: str, password: str, scopes: str) -> AccessRead: # Check that the login does not already exist - if await crud.fetch_one(accesses, [('login', login)]) is not None: + if await crud.fetch_one(accesses, {'login': login}) is not None: raise HTTPException( status_code=400, detail=f"An entry with login='{login}' already exists.", diff --git a/src/app/api/routes/devices.py b/src/app/api/routes/devices.py index 79519891..24da4724 100644 --- a/src/app/api/routes/devices.py +++ b/src/app/api/routes/devices.py @@ -40,7 +40,7 @@ async def delete_device(device_id: int = Path(..., gt=0), _=Security(get_current @router.get("/my-devices", response_model=List[DeviceOut]) async def fetch_my_devices(me: UserRead = Security(get_current_user, scopes=["admin", "me"])): - return await crud.fetch_all(devices, [("owner_id", me.id)]) + return await crud.fetch_all(devices, {"owner_id": me.id}) @router.put("/heartbeat", response_model=DeviceOut) @@ -57,7 +57,7 @@ async def update_device_location( user: UserRead = Security(get_current_user, scopes=["admin", "me"]) ): # Check that device is accessible to this user - entry = await crud.fetch_one(devices, [("id", device_id), ("owner_id", user.id)]) + entry = await crud.fetch_one(devices, {"id": device_id, "owner_id": user.id}) if entry is None: raise HTTPException( status_code=400, diff --git a/src/app/api/routes/login.py b/src/app/api/routes/login.py index cd0d170c..4a150a61 100644 --- a/src/app/api/routes/login.py +++ b/src/app/api/routes/login.py @@ -13,7 +13,7 @@ @router.post("/access-token", response_model=Token) async def create_access_token(form_data: OAuth2PasswordRequestForm = Depends()): # Verify credentials - entry = await crud.fetch_one(accesses, [('login', form_data.username)]) + entry = await crud.fetch_one(accesses, {'login': form_data.username}) if entry is None or not await security.verify_password(form_data.password, entry['hashed_password']): raise HTTPException(status_code=400, detail="Invalid credentials") # create access token using user user_id/user_scopes diff --git a/src/tests/conftest.py b/src/tests/conftest.py index 62e21cf9..e96b77d1 100644 --- a/src/tests/conftest.py +++ b/src/tests/conftest.py @@ -28,14 +28,14 @@ async def mock_fetch_all(table, query_filters=None): return table response = [] for entry in table: - if all(entry[k] == v for k, v in query_filters): + if all(entry[k] == v for k, v in query_filters.items()): response.append(entry) return response async def mock_fetch_one(table, query_filters=None): for entry in table: - if all(entry[k] == v for k, v in query_filters): + if all(entry[k] == v for k, v in query_filters.items()): return entry return None diff --git a/src/tests/test_deps.py b/src/tests/test_deps.py new file mode 100644 index 00000000..3ffd86a4 --- /dev/null +++ b/src/tests/test_deps.py @@ -0,0 +1,50 @@ +import pytest + +from app.api import crud, deps +from app.api.schemas import AccessRead, UserRead, DeviceOut +from copy import deepcopy + +USER_TABLE = [ + {"id": 1, "login": "first_user", "access_id": 1, "created_at": "2020-10-13T08:18:45.447773"}, + {"id": 99, "login": "connected_user", "access_id": 2, "created_at": "2020-11-13T08:18:45.447773"}, +] + +DEVICE_TABLE = [ + {"id": 1, "login": "first_device", "owner_id": 1, "access_id": 1, "specs": "v0.1", "elevation": None, "lat": None, + "lon": None, "yaw": None, "pitch": None, "last_ping": None, "created_at": "2020-10-13T08:18:45.447773"}, + {"id": 2, "login": "second_device", "owner_id": 99, "access_id": 2, "specs": "v0.1", "elevation": None, "lat": None, + "lon": None, "yaw": None, "pitch": None, "last_ping": None, "created_at": "2020-10-13T08:18:45.447773"}, + {"id": 99, "login": "connected_device", "owner_id": 1, "access_id": 3, "specs": "raspberry", "elevation": None, + "lat": None, "lon": None, "yaw": None, "pitch": None, "last_ping": None, + "created_at": "2020-10-13T08:18:45.447773"}, +] + + +def _patch_session(monkeypatch, mock_user_table=None, mock_device_table=None): + # DB patching + if mock_user_table is not None: + monkeypatch.setattr(deps, "users", mock_user_table) + if mock_device_table is not None: + monkeypatch.setattr(deps, "devices", mock_device_table) + # Sterilize all DB interactions through CRUD override + monkeypatch.setattr(crud, "fetch_one", pytest.mock_fetch_one) + + +@pytest.mark.asyncio +async def test_get_current_user(test_app, monkeypatch): + + mock_user_table = deepcopy(USER_TABLE) + _patch_session(monkeypatch, mock_user_table, None) + + response = await deps.get_current_user(AccessRead(id=1, login="JohnDoe", scopes="me")) + assert response == UserRead(**mock_user_table[0]) + + +@pytest.mark.asyncio +async def test_get_current_device(test_app, monkeypatch): + + mock_device_table = deepcopy(DEVICE_TABLE) + _patch_session(monkeypatch, None, mock_device_table) + + response = await deps.get_current_device(AccessRead(id=1, login="JohnDoe", scopes="me")) + assert response == DeviceOut(**mock_device_table[0])