diff --git a/CHANGELOG.md b/CHANGELOG.md index 61d1599498..b74a8cc93a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ * Now `pip install` uses `--no-cache` (\#1980). * API * Deprecate the `verbose` query parameter in `GET /api/v2/task/collect/{state_id}/` (\#1980). + * Add `project_dir` attribute to `UserSettings` (\#1990). + * Set a default for `DatasetV2.zarr_dir` (\#1990). * Combine the `args_schema_parallel` and `args_schema_non_parallel` query parameters in `GET /api/v2/task/` into a single parameter `args_schema` (\#1998). # 2.7.1 diff --git a/fractal_server/app/models/user_settings.py b/fractal_server/app/models/user_settings.py index acff62a8ca..c779d94df6 100644 --- a/fractal_server/app/models/user_settings.py +++ b/fractal_server/app/models/user_settings.py @@ -36,3 +36,4 @@ class UserSettings(SQLModel, table=True): ssh_jobs_dir: Optional[str] = None slurm_user: Optional[str] = None cache_dir: Optional[str] = None + project_dir: Optional[str] = None diff --git a/fractal_server/app/routes/api/v2/dataset.py b/fractal_server/app/routes/api/v2/dataset.py index fb10699842..d47d69bd90 100644 --- a/fractal_server/app/routes/api/v2/dataset.py +++ b/fractal_server/app/routes/api/v2/dataset.py @@ -22,6 +22,8 @@ from ._aux_functions import _get_submitted_jobs_statement from fractal_server.app.models import UserOAuth from fractal_server.app.routes.auth import current_active_user +from fractal_server.string_tools import sanitize_string +from fractal_server.urls import normalize_url router = APIRouter() @@ -40,14 +42,45 @@ async def create_dataset( """ Add new dataset to current project """ - await _get_project_check_owner( + project = await _get_project_check_owner( project_id=project_id, user_id=user.id, db=db ) - db_dataset = DatasetV2(project_id=project_id, **dataset.dict()) - db.add(db_dataset) - await db.commit() - await db.refresh(db_dataset) - await db.close() + + if dataset.zarr_dir is None: + + if user.settings.project_dir is None: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=( + "Both 'dataset.zarr_dir' and 'user.settings.project_dir' " + "are null" + ), + ) + + db_dataset = DatasetV2( + project_id=project_id, + zarr_dir="__PLACEHOLDER__", + **dataset.dict(exclude={"zarr_dir"}), + ) + db.add(db_dataset) + await db.commit() + await db.refresh(db_dataset) + path = ( + f"{user.settings.project_dir}/fractal/" + f"{project_id}_{sanitize_string(project.name)}/" + f"{db_dataset.id}_{sanitize_string(db_dataset.name)}" + ) + normalized_path = normalize_url(path) + db_dataset.zarr_dir = normalized_path + + db.add(db_dataset) + await db.commit() + await db.refresh(db_dataset) + else: + db_dataset = DatasetV2(project_id=project_id, **dataset.dict()) + db.add(db_dataset) + await db.commit() + await db.refresh(db_dataset) return db_dataset diff --git a/fractal_server/app/schemas/user_settings.py b/fractal_server/app/schemas/user_settings.py index edd3686bb2..0158cb7d1c 100644 --- a/fractal_server/app/schemas/user_settings.py +++ b/fractal_server/app/schemas/user_settings.py @@ -19,6 +19,10 @@ class UserSettingsRead(BaseModel): + """ + Schema reserved for superusers + """ + id: int ssh_host: Optional[str] = None ssh_username: Optional[str] = None @@ -28,6 +32,7 @@ class UserSettingsRead(BaseModel): slurm_user: Optional[str] = None slurm_accounts: list[str] cache_dir: Optional[str] = None + project_dir: Optional[str] = None class UserSettingsReadStrict(BaseModel): @@ -35,9 +40,14 @@ class UserSettingsReadStrict(BaseModel): slurm_accounts: list[str] cache_dir: Optional[str] = None ssh_username: Optional[str] = None + project_dir: Optional[str] = None class UserSettingsUpdate(BaseModel, extra=Extra.forbid): + """ + Schema reserved for superusers + """ + ssh_host: Optional[str] = None ssh_username: Optional[str] = None ssh_private_key_path: Optional[str] = None @@ -46,6 +56,7 @@ class UserSettingsUpdate(BaseModel, extra=Extra.forbid): slurm_user: Optional[str] = None slurm_accounts: Optional[list[StrictStr]] = None cache_dir: Optional[str] = None + project_dir: Optional[str] = None _ssh_host = validator("ssh_host", allow_reuse=True)( valstr("ssh_host", accept_none=True) @@ -83,6 +94,13 @@ def cache_dir_validator(cls, value): validate_cmd(value) return val_absolute_path("cache_dir")(value) + @validator("project_dir") + def project_dir_validator(cls, value): + if value is None: + return None + validate_cmd(value) + return val_absolute_path("project_dir")(value) + class UserSettingsUpdateStrict(BaseModel, extra=Extra.forbid): slurm_accounts: Optional[list[StrictStr]] = None diff --git a/fractal_server/app/schemas/v2/dataset.py b/fractal_server/app/schemas/v2/dataset.py index 3bb318446b..41678e8336 100644 --- a/fractal_server/app/schemas/v2/dataset.py +++ b/fractal_server/app/schemas/v2/dataset.py @@ -33,14 +33,16 @@ class DatasetCreateV2(BaseModel, extra=Extra.forbid): name: str - zarr_dir: str + zarr_dir: Optional[str] = None filters: Filters = Field(default_factory=Filters) # Validators @validator("zarr_dir") def normalize_zarr_dir(cls, v: str) -> str: - return normalize_url(v) + if v is not None: + return normalize_url(v) + return v _name = validator("name", allow_reuse=True)(valstr("name")) @@ -95,7 +97,7 @@ class DatasetImportV2(BaseModel, extra=Extra.forbid): name: str zarr_dir: str - images: list[SingleImage] = Field(default_factory=[]) + images: list[SingleImage] = Field(default_factory=list) filters: Filters = Field(default_factory=Filters) # Validators diff --git a/fractal_server/migrations/versions/19eca0dd47a9_user_settings_project_dir.py b/fractal_server/migrations/versions/19eca0dd47a9_user_settings_project_dir.py new file mode 100644 index 0000000000..652039a45f --- /dev/null +++ b/fractal_server/migrations/versions/19eca0dd47a9_user_settings_project_dir.py @@ -0,0 +1,39 @@ +"""user settings project dir + +Revision ID: 19eca0dd47a9 +Revises: 8e8f227a3e36 +Create Date: 2024-10-30 14:34:28.219355 + +""" +import sqlalchemy as sa +import sqlmodel +from alembic import op + + +# revision identifiers, used by Alembic. +revision = "19eca0dd47a9" +down_revision = "8e8f227a3e36" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("user_settings", schema=None) as batch_op: + batch_op.add_column( + sa.Column( + "project_dir", + sqlmodel.sql.sqltypes.AutoString(), + nullable=True, + ) + ) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("user_settings", schema=None) as batch_op: + batch_op.drop_column("project_dir") + + # ### end Alembic commands ### diff --git a/tests/no_version/test_api_user_groups.py b/tests/no_version/test_api_user_groups.py index b891d2d43f..8a8fb864a7 100644 --- a/tests/no_version/test_api_user_groups.py +++ b/tests/no_version/test_api_user_groups.py @@ -331,6 +331,7 @@ async def test_patch_user_settings_bulk( slurm_user="test01", slurm_accounts=[], cache_dir=None, + project_dir=None, ) == user.settings.dict(exclude={"id"}) # remove user4 from default user group @@ -351,6 +352,7 @@ async def test_patch_user_settings_bulk( # missing `slurm_user` slurm_accounts=["foo", "bar"], cache_dir="/tmp/cache", + project_dir="/foo", ) res = await registered_superuser_client.patch( f"{PREFIX}/group/{default_user_group.id}/user-settings/", json=patch @@ -373,4 +375,28 @@ async def test_patch_user_settings_bulk( slurm_user="test01", slurm_accounts=[], cache_dir=None, + project_dir=None, ) == user4.settings.dict(exclude={"id"}) + + res = await registered_superuser_client.patch( + f"{PREFIX}/group/{default_user_group.id}/user-settings/", + json=dict(project_dir="not/an/absolute/path"), + ) + assert res.status_code == 422 + + # `None` is a valid `project_dir` + res = await registered_superuser_client.patch( + f"{PREFIX}/group/{default_user_group.id}/user-settings/", + json=dict(project_dir="/fancy/dir"), + ) + assert res.status_code == 200 + for user in [user1, user2, user3]: + await db.refresh(user) + assert user.settings.project_dir == "/fancy/dir" + res = await registered_superuser_client.patch( + f"{PREFIX}/group/{default_user_group.id}/user-settings/", + json=dict(project_dir=None), + ) + for user in [user1, user2, user3]: + await db.refresh(user) + assert user.settings.project_dir is None diff --git a/tests/no_version/test_db.py b/tests/no_version/test_db.py index 0382b832de..7262a9bb9d 100644 --- a/tests/no_version/test_db.py +++ b/tests/no_version/test_db.py @@ -1,5 +1,6 @@ import pytest from devtools import debug +from sqlmodel import delete from sqlmodel import select from fractal_server.app.db import DB @@ -114,3 +115,55 @@ def test_DB_class_sync(): assert DB._engine_sync assert DB._sync_session_maker + + +async def test_reusing_id(db): + """ + Tests different database behaviors with incremental IDs. + + https://github.com/fractal-analytics-platform/fractal-server/issues/1991 + """ + + num_users = 10 + + # Create `num_users` new users + user_list = [ + UserOAuth(email=f"{x}@y.z", hashed_password="xxx") + for x in range(num_users) + ] + + for user in user_list: + db.add(user) + await db.commit() + for user in user_list: + await db.refresh(user) + + # Extract list of IDs + user_id_list = [user.id for user in user_list] + + # Remove all users + await db.execute(delete(UserOAuth)) + res = await db.execute(select(UserOAuth)) + users = res.scalars().unique().all() + assert len(users) == 0 + + # Create `num_users + 10` new users + new_user_list = [ + UserOAuth(email=f"{x}@y.z", hashed_password="xxx") + for x in range(num_users + 10) + ] + for user in new_user_list: + db.add(user) + await db.commit() + for user in new_user_list: + await db.refresh(user) + + # Extract list of IDs + new_user_id_list = [user.id for user in new_user_list] + + if DB_ENGINE == "sqlite": + # Assert IDs lists are overlapping + assert not set(new_user_id_list).isdisjoint(set(user_id_list)) + else: + # Assert IDs lists are disjoined + assert set(new_user_id_list).isdisjoint(set(user_id_list)) diff --git a/tests/v2/01_schemas/test_schemas_dataset.py b/tests/v2/01_schemas/test_schemas_dataset.py index 379ff5c378..87f32df041 100644 --- a/tests/v2/01_schemas/test_schemas_dataset.py +++ b/tests/v2/01_schemas/test_schemas_dataset.py @@ -27,6 +27,8 @@ async def test_schemas_dataset_v2(): DatasetCreateV2( name="name", zarr_dir="/zarr", filters={"types": {"a": "b"}} ) + # Test zarr_dir=None is valid + DatasetCreateV2(name="name", zarr_dir=None) dataset_create = DatasetCreateV2( name="name", @@ -35,6 +37,9 @@ async def test_schemas_dataset_v2(): ) assert dataset_create.zarr_dir == normalize_url(dataset_create.zarr_dir) + with pytest.raises(ValidationError): + DatasetImportV2(name="name", zarr_dir=None) + dataset_import = DatasetImportV2( name="name", filters={"attributes": {"x": 10}}, diff --git a/tests/v2/03_api/test_api_dataset.py b/tests/v2/03_api/test_api_dataset.py index c49b418640..2e135391af 100644 --- a/tests/v2/03_api/test_api_dataset.py +++ b/tests/v2/03_api/test_api_dataset.py @@ -8,6 +8,8 @@ ) from fractal_server.app.schemas.v2 import JobStatusTypeV2 from fractal_server.images import SingleImage +from fractal_server.string_tools import sanitize_string +from fractal_server.urls import normalize_url PREFIX = "api/v2" @@ -194,33 +196,26 @@ async def test_get_user_datasets( assert len(ds["history"]) == 0 -async def test_post_dataset(client, MockCurrentUser): - async with MockCurrentUser(): - # CREATE A PROJECT - res = await client.post( - f"{PREFIX}/project/", - json=dict(name="test project"), - ) - assert res.status_code == 201 - project = res.json() - project_id = project["id"] +async def test_post_dataset(client, MockCurrentUser, project_factory_v2): + async with MockCurrentUser() as user: + prj = await project_factory_v2(user) # ADD DATASET payload = dict(name="new dataset", zarr_dir="/tmp/zarr") res = await client.post( - f"{PREFIX}/project/{project_id}/dataset/", + f"{PREFIX}/project/{prj.id}/dataset/", json=payload, ) debug(res.json()) assert res.status_code == 201 dataset = res.json() assert dataset["name"] == payload["name"] - assert dataset["project_id"] == project_id + assert dataset["project_id"] == prj.id # EDIT DATASET payload1 = dict(name="new dataset name") res = await client.patch( - f"{PREFIX}/project/{project_id}/dataset/{dataset['id']}/", + f"{PREFIX}/project/{prj.id}/dataset/{dataset['id']}/", json=payload1, ) patched_dataset = res.json() @@ -229,6 +224,28 @@ async def test_post_dataset(client, MockCurrentUser): for k, v in payload1.items(): assert patched_dataset[k] == v + # Test POST dataset without zarr_dir + async with MockCurrentUser( + user_settings_dict={"project_dir": "/some/dir"} + ) as user: + prj = await project_factory_v2(user) + res = await client.post( + f"{PREFIX}/project/{prj.id}/dataset/", json=dict(name="DSName") + ) + assert res.json()["zarr_dir"] == normalize_url( + f"{user.settings.project_dir}/fractal/" + f"{prj.id}_{sanitize_string(prj.name)}/" + f"{res.json()['id']}_{sanitize_string(res.json()['name'])}" + ) + assert res.status_code == 201 + + async with MockCurrentUser() as user: + prj = await project_factory_v2(user) + res = await client.post( + f"{PREFIX}/project/{prj.id}/dataset/", json=dict(name="DSName2") + ) + assert res.status_code == 422 + async def test_delete_dataset( client, MockCurrentUser, project_factory_v2, dataset_factory_v2