diff --git a/docs/developer/getting-started.rst b/docs/developer/getting-started.rst index 3c495dfd..c9bfad30 100644 --- a/docs/developer/getting-started.rst +++ b/docs/developer/getting-started.rst @@ -16,7 +16,7 @@ Development dependencies are specified in ``requirements/dev.txt`` and can be in Additionally, building the documentation requires `pandoc `_ which is not on PyPI and needs to be installed through other means. (E.g. with your OS package manager.) -If you want to run tests against a real backend or SSH server, you also need ``docker-compose``. +If you want to run tests against a real backend or SFTP server, you also need ``docker-compose``. See `Testing <./testing.rst>`_ for what this is good for and why. Install the package @@ -56,7 +56,7 @@ Running tests python -m pytest -n - Or to run tests against a real backend and SSH server (see setup above) + Or to run tests against a real backend and SFTP server (see setup above) .. code-block:: sh diff --git a/docs/reference/index.rst b/docs/reference/index.rst index 9bf94b52..6012dceb 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -30,7 +30,6 @@ File transfer transfer.link.LinkFileTransfer transfer.sftp.SFTPFileTransfer - transfer.ssh.SSHFileTransfer Auxiliary classes ~~~~~~~~~~~~~~~~~ diff --git a/docs/release-notes.rst b/docs/release-notes.rst index f1d1b242..aba3c878 100644 --- a/docs/release-notes.rst +++ b/docs/release-notes.rst @@ -48,6 +48,7 @@ Breaking changes ~~~~~~~~~~~~~~~~ * **Dropped support for Pydantic v1.** +* Removed deprecated ``transfer.ssh.SSHFileTransfer`` in favor of :class:`transfer.sftp.SFTPFileTransfer`. Bugfixes ~~~~~~~~ @@ -250,7 +251,7 @@ Breaking changes * A number of attributes of Dataset are now read only. * ``Dataset.new`` was removed, use the regular ``__init__`` method instead. * ``File.provide_locally`` was removed in favor of :meth:`Client.download_files`. -* ``ESSTestFileTransfer`` was renamed to :class:`transfer.ssh.SSHFileTransfer`. +* ``ESSTestFileTransfer`` was renamed to ``transfer.ssh.SSHFileTransfer``. Bugfixes ~~~~~~~~ diff --git a/docs/user-guide/installation.rst b/docs/user-guide/installation.rst index 180a113f..eed26c4f 100644 --- a/docs/user-guide/installation.rst +++ b/docs/user-guide/installation.rst @@ -7,9 +7,9 @@ Installation .. code-block:: sh - pip install scitacean[ssh] + pip install scitacean[sftp] - If :class:`scitacean.transfer.ssh.SSHFileTransfer` is not required, the ``ssh`` extra can be omitted. + If :class:`scitacean.transfer.sftp.SFTPFileTransfer` is not required, the ``sftp`` extra can be omitted. .. tab-item:: conda diff --git a/docs/user-guide/uploading.ipynb b/docs/user-guide/uploading.ipynb index 4df67aac..f336baeb 100644 --- a/docs/user-guide/uploading.ipynb +++ b/docs/user-guide/uploading.ipynb @@ -280,7 +280,7 @@ "\n", "All files associated with a dataset will be uploaded to the same folder.\n", "This folder may be at the path we specify when making the dataset, i.e. `dset.source_folder`.\n", - "However, the folder is ultimately determined by the file transfer (in this case `SSHFileTransfer`) and it may choose to override the `source_folder` that we set.\n", + "However, the folder is ultimately determined by the file transfer (in this case SFTPFileTransfer`) and it may choose to override the `source_folder` that we set.\n", "In this example, since we don't tell the file transfer otherwise, it respects `dset.source_folder` and uploads the files to that location.\n", "See the [File transfer](../reference/index.rst#file-transfer) reference for information how to control this behavior.\n", "The reason for this is that facilities may have a specific structure on their file server and Scitacean's file transfers can be used to enforce that.\n", diff --git a/pyproject.toml b/pyproject.toml index 05ee53c3..27a65d7c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,6 @@ dynamic = ["version"] "Source" = "https://github.com/SciCatProject/scitacean" [project.optional-dependencies] -ssh = ["fabric"] sftp = ["paramiko"] test = ["filelock", "hypothesis", "pyyaml"] diff --git a/src/scitacean/testing/ssh/__init__.py b/src/scitacean/testing/ssh/__init__.py deleted file mode 100644 index 7624220d..00000000 --- a/src/scitacean/testing/ssh/__init__.py +++ /dev/null @@ -1,120 +0,0 @@ -# SPDX-License-Identifier: BSD-3-Clause -# Copyright (c) 2022 Scitacean contributors (https://github.com/SciCatProject/scitacean) -"""Helpers for running tests with an SSH server. - -This subpackage is primarily meant for testing -:class:`scitacean.testing.ssh.SSHFileTransfer`. -But it can also be used to test downstream code that uses the SSH file transfer. - -The `pytest `_ fixtures in this package manage an SSH server -running in a docker container on the local machine. -They, therefore, require docker to be installed and running. - -Use the :func:`scitacean.testing.ssh.fixtures.ssh_fileserver` -fixture to manage the server and use -:func:`scitacean.testing.ssh.fixtures.ssh_access` -to get all required access parameters. -See below for examples. - -Attention ---------- -The fixtures support `pytest-xdist `_ -but only if all workers run on the local machine (the default). - -It may still happen that tests fail due to the complexity of synchronizing start up -and shut down of the SSH server between workers. - -See Also --------- -`Testing <../../user-guide/testing.ipynb>`_ user guide. - -Examples --------- -In order to test the SSH file transfer directly, use the provided fixtures to -open a connection manually. -Here, requesting the ``require_ssh_fileserver`` fixture ensures that the server -is running during the test, or that the test gets skipped if SSH tests are disabled. -Passing the ``connect`` argument as shown ensures that the file transfer -connects to the test server with the correct parameters. - -.. code-block:: python - - from scitacean.transfer.ssh import SSHFileTransfer - - def test_ssh_upload( - ssh_access, - ssh_connect_with_username_password, - require_ssh_fileserver, - ssh_data_dir, - ): - ssh = SSHFileTransfer(host=ssh_access.host, port=ssh_access.port) - ds = Dataset(...) - with ssh.connect_for_upload( - dataset=ds, - connect=ssh_connect_with_username_password - ) as connection: - # do upload - # assert that the file has been copied to ssh_data_dir - -Testing the SSH transfer together with a client requires some additional setup. -See ``test_client_with_ssh`` in -`ssh_test.py `_ -for an example. - -Implementation notes --------------------- -When the server fixture is first used, it initializes the server using these steps: - -1. Create a temporary directory with contents:: - - tmpdir - ├ docker-compose.yaml - ├ .env (specifies paths for docker volumes) - ├ counter (number of workers currently using the server) - ├ counter.lock (file lock) - └ data (storage of files) - └ seed (populated from scitacean/testing/ssh/ssh_server_seed) - -2. Start docker. -3. Make data writable by the user in docker. - This changes the ownership of data on the host to root (on some machines). - -The docker container and its volumes are removed at the end of the tests. -The fixture also tries to remove the temporary directory. -This can fail as the owner of its contents (in particular data) -may have been changed to root. -So cleanup can fail and leave the directory behind. - -Use the seed directory (``ssh_data_dir/"seed"``) to test downloads. -Corresponds to ``/data/seed`` on the server. - -Use the base data directory (``ssh_data_dir``) to test uploads. -Corresponds to ``/data`` on the server. - -The counter and counter.lock files are used to synchronize starting and stopping -of the docker container between processes. -This is required when ``pytest-xdist`` is used. -Otherwise, those files will not be present. -""" - -from ._pytest_helpers import add_pytest_option, skip_if_not_ssh, ssh_enabled -from ._ssh import ( - IgnorePolicy, - SSHAccess, - SSHUser, - configure, - local_access, - wait_until_ssh_server_is_live, -) - -__all__ = [ - "add_pytest_option", - "configure", - "local_access", - "ssh_enabled", - "skip_if_not_ssh", - "wait_until_ssh_server_is_live", - "IgnorePolicy", - "SSHAccess", - "SSHUser", -] diff --git a/src/scitacean/testing/ssh/_pytest_helpers.py b/src/scitacean/testing/ssh/_pytest_helpers.py deleted file mode 100644 index 8922e51e..00000000 --- a/src/scitacean/testing/ssh/_pytest_helpers.py +++ /dev/null @@ -1,44 +0,0 @@ -# SPDX-License-Identifier: BSD-3-Clause -# Copyright (c) 2024 SciCat Project (https://github.com/SciCatProject/scitacean) - -from typing import Optional - -import pytest - -_COMMAND_LINE_OPTION: Optional[str] = None - - -def add_pytest_option(parser: pytest.Parser, option: str = "--ssh-tests") -> None: - """Add a command-line option to pytest to toggle SSH tests. - - Parameters - ---------- - parser: - Pytest's command-line argument parser. - option: - Name of the command-line option. - """ - parser.addoption( - option, - action="store_true", - default=False, - help="Select whether to run tests with an SSH fileserver", - ) - global _COMMAND_LINE_OPTION - _COMMAND_LINE_OPTION = option - - -def skip_if_not_ssh(request: pytest.FixtureRequest) -> None: - """Mark the current test to be skipped if SSH tests are disabled.""" - if not ssh_enabled(request): - pytest.skip( - "Tests against an SSH file server are disabled, " - f"use {_COMMAND_LINE_OPTION} to enable them" - ) - - -def ssh_enabled(request: pytest.FixtureRequest) -> bool: - """Return True if SSH tests are enabled.""" - return _COMMAND_LINE_OPTION is not None and bool( - request.config.getoption(_COMMAND_LINE_OPTION) - ) diff --git a/src/scitacean/testing/ssh/_ssh.py b/src/scitacean/testing/ssh/_ssh.py deleted file mode 100644 index 3e6aa580..00000000 --- a/src/scitacean/testing/ssh/_ssh.py +++ /dev/null @@ -1,154 +0,0 @@ -# SPDX-License-Identifier: BSD-3-Clause -# Copyright (c) 2024 SciCat Project (https://github.com/SciCatProject/scitacean) -import importlib.resources -import time -from dataclasses import dataclass -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, Iterable, Tuple, Union - -import paramiko -import yaml - - -@dataclass -class SSHUser: - username: str - password: str - - -@dataclass -class SSHAccess: - host: str - port: int - user: SSHUser - - -def _read_yaml(filename: str) -> Any: - if hasattr(importlib.resources, "files"): - # Use new API added in Python 3.9 - return yaml.safe_load( - importlib.resources.files("scitacean.testing.ssh") - .joinpath(filename) - .read_text() - ) - # Old API, deprecated as of Python 3.11 - return yaml.safe_load( - importlib.resources.read_text("scitacean.testing.ssh", filename) - ) - - -@lru_cache(maxsize=1) -def _docker_compose_file() -> Dict[str, Any]: - return _read_yaml("docker-compose-ssh-server.yaml") # type: ignore[no-any-return] - - -def _seed_files() -> Iterable[Tuple[str, str]]: - if hasattr(importlib.resources, "files"): - # Use new API added in Python 3.9 - yield from ( - (file.name, file.read_text()) - for file in importlib.resources.files("scitacean.testing.ssh") - .joinpath("ssh_server_seed") - .iterdir() - ) - else: - # Old API, deprecated as of Python 3.11 - with importlib.resources.path( - "scitacean.testing.ssh", "ssh_server_seed" - ) as seed_dir: - for path in seed_dir.iterdir(): - yield path.name, path.read_text() - - -def local_access() -> SSHAccess: - config = _docker_compose_file() - service = config["services"]["scitacean-test-ssh-server"] - env = {k: v for k, v in map(lambda s: s.split("="), service["environment"])} - return SSHAccess( - host="localhost", - port=service["ports"][0].split(":")[0], - user=SSHUser( - username=env["USER_NAME"], - password=env["USER_PASSWORD"], - ), - ) - - -def _copy_seed(target_seed_dir: Path) -> None: - for name, content in _seed_files(): - target_seed_dir.joinpath(name).write_text(content) - - -def configure(target_dir: Union[Path, str]) -> Path: - """Generate a config file for docker compose and copy seed data.""" - target_dir = Path(target_dir) - target_seed_dir = target_dir / "data" / "seed" - target_seed_dir.mkdir(parents=True) - _copy_seed(target_seed_dir) - - config_target = target_dir / "docker-compose.yaml" - config_target.write_text(yaml.dump(_docker_compose_file())) - - target_dir.joinpath(".env").write_text( - f"""DATA_DIR={target_dir / 'data'} -SEED_DIR={target_seed_dir}""" - ) - - return config_target - - -def _can_connect(ssh_access: SSHAccess) -> bool: - try: - _make_client(ssh_access) - except paramiko.SSHException: - return False - return True - - -def wait_until_ssh_server_is_live( - ssh_access: SSHAccess, max_time: float, n_tries: int -) -> None: - # The container takes a while to be fully live. - for _ in range(n_tries): - if _can_connect(ssh_access): - return - time.sleep(max_time / n_tries) - if not _can_connect(ssh_access): - raise RuntimeError("Cannot connect to SSH server") - - -def cleanup_data_dir( - ssh_access: SSHAccess, ssh_connect_with_username_password: Any -) -> None: - # Delete all directories created by tests. - # These are owned by root on the host and cannot be deleted by Python's tempfile. - connection = ssh_connect_with_username_password( - host=ssh_access.host, port=ssh_access.port - ) - connection.run( - "find /data -not -path '/data' -not -path '/data/seed' | xargs rm -rf", - hide=True, - in_stream=False, - ) - - -def _make_client(ssh_access: SSHAccess) -> paramiko.SSHClient: - client = paramiko.SSHClient() - client.set_missing_host_key_policy(IgnorePolicy()) - client.connect( - hostname=ssh_access.host, - port=ssh_access.port, - username=ssh_access.user.username, - password=ssh_access.user.password, - allow_agent=False, - look_for_keys=False, - ) - return client - - -# Every time we create a container, it gets a new host key. -# So simply accept any host keys. -class IgnorePolicy(paramiko.MissingHostKeyPolicy): - def missing_host_key(self, client: Any, hostname: Any, key: Any) -> None: - return diff --git a/src/scitacean/testing/ssh/docker-compose-ssh-server.yaml b/src/scitacean/testing/ssh/docker-compose-ssh-server.yaml deleted file mode 100644 index 0ae83ba2..00000000 --- a/src/scitacean/testing/ssh/docker-compose-ssh-server.yaml +++ /dev/null @@ -1,24 +0,0 @@ -version: "2.1" -services: - scitacean-test-ssh-server: - image: linuxserver/openssh-server:latest - container_name: scitacean-test-ssh - hostname: scitacean-test-ssh-server - environment: - - PUID=1000 - - PGID=1000 - - TZ=CET # Not UTC on purpose to test timezone detection -# - PUBLIC_KEY=yourpublickey #optional -# - PUBLIC_KEY_FILE=/path/to/file #optional -# - PUBLIC_KEY_DIR=/path/to/directory/containing/_only_/pubkeys #optional -# - PUBLIC_KEY_URL=https://github.com/username.keys #optional - - SUDO_ACCESS=false - - PASSWORD_ACCESS=true - - USER_NAME=the-scitacean - - USER_PASSWORD=sup3r-str0ng - volumes: # configured in Python - - ${DATA_DIR}:/data - - ${SEED_DIR}:/data/seed - ports: - - "2222:2222" - restart: unless-stopped diff --git a/src/scitacean/testing/ssh/fixtures.py b/src/scitacean/testing/ssh/fixtures.py deleted file mode 100644 index b46cd1d1..00000000 --- a/src/scitacean/testing/ssh/fixtures.py +++ /dev/null @@ -1,226 +0,0 @@ -# SPDX-License-Identifier: BSD-3-Clause -# Copyright (c) 2024 SciCat Project (https://github.com/SciCatProject/scitacean) -# mypy: disable-error-code="no-untyped-def" -"""Pytest fixtures to manage and access a local SSH server.""" - -import logging -from pathlib import Path -from typing import Callable, Generator, Optional - -import fabric -import fabric.config -import pytest - -from ..._internal import docker -from .._pytest_helpers import init_work_dir, root_tmp_dir -from ._pytest_helpers import skip_if_not_ssh, ssh_enabled -from ._ssh import ( - IgnorePolicy, - SSHAccess, - configure, - local_access, - wait_until_ssh_server_is_live, -) - - -@pytest.fixture(scope="session") -def ssh_access(request: pytest.FixtureRequest) -> SSHAccess: - """Fixture that returns SSH access parameters. - - Returns - ------- - : - A URL and user to connect to the testing SSH server. - The user has access to all initial files registered in the - database and permissions to create new files. - """ - skip_if_not_ssh(request) - return local_access() - - -@pytest.fixture(scope="session") -def ssh_base_dir( - request: pytest.FixtureRequest, tmp_path_factory: pytest.TempPathFactory -) -> Optional[Path]: - """Fixture that returns the base working directory for the SSH server setup. - - Returns - ------- - : - A path to a directory on the host machine. - The directory gets populated by the - :func:`scitacean.testing.ssh.fixtures.ssh_fileserver` fixture. - It contains the docker configuration and volumes. - - Returns ``None`` if SSH tests are disabled - """ - if not ssh_enabled(request): - return None - return root_tmp_dir(request, tmp_path_factory) / "scitacean-ssh" - - -@pytest.fixture(scope="session") -def ssh_data_dir(ssh_base_dir: Optional[Path]) -> Optional[Path]: - """Fixture that returns the data directory for the SSH server setup. - - Returns - ------- - : - A path to a directory on the host machine. - The directory is mounted as ``/data`` on the server. - - Returns ``None`` if SSH tests are disabled - """ - if ssh_base_dir is None: - return None - return ssh_base_dir / "data" - - -@pytest.fixture() -def require_ssh_fileserver(request, ssh_fileserver) -> None: - """Fixture to declare that a test needs a local SSH server. - - Like :func:`scitacean.testing.ssh.ssh_fileserver` - but this skips the test if SSH tests are disabled. - """ - skip_if_not_ssh(request) - - -@pytest.fixture(scope="session") -def ssh_fileserver( - request: pytest.FixtureRequest, - ssh_access: SSHAccess, - ssh_base_dir: Optional[Path], - ssh_data_dir: Optional[Path], - ssh_connect_with_username_password, -) -> Generator[bool, None, None]: - """Fixture to declare that a test needs a local SSH server. - - If SSH tests are enabled, this fixture configures and starts an SSH server - in a docker container the first time a test requests it. - The server and container will be stopped and removed at the end of the test session. - - Does nothing if the SSH tests are disabled. - - Returns - ------- - : - True if SSH tests are enabled and False otherwise. - """ - if ssh_base_dir is None: - yield False - return - - target_dir, counter = init_work_dir(request, ssh_base_dir, name=None) - - try: - with counter.increment() as count: - if count == 1: - _ssh_docker_up(target_dir, ssh_access) - elif not _ssh_server_is_running(): - raise RuntimeError("Expected SSH server to be running") - yield True - finally: - with counter.decrement() as count: - if count == 0: - _ssh_docker_down(target_dir) - - -@pytest.fixture(scope="session") -def ssh_connection_config() -> fabric.config.Config: - """Fixture that returns the configuration for fabric.Connection for tests. - - Can be used to open SSH connections if ``SSHFileTransfer`` is not enough. - """ - config = fabric.config.Config() - config["load_ssh_configs"] = False - config["connect_kwargs"] = { - "allow_agent": False, - "look_for_keys": False, - } - return config - - -@pytest.fixture(scope="session") -def ssh_connect_with_username_password( - ssh_access: SSHAccess, ssh_connection_config: fabric.config.Config -) -> Callable[..., fabric.Connection]: - """Fixture that returns a function to create a connection to the testing SSH server. - - Uses username+password and rejects any other authentication attempt. - - Returns - ------- - : - A function to pass as the ``connect`` argument to - :meth:`scitacean.transfer.ssh.SSHFileTransfer.connect_for_download` - or :meth:`scitacean.transfer.ssh.SSHFileTransfer.connect_for_upload`. - - Examples - -------- - Explicitly connect to the - - .. code-block:: python - - def test_ssh(ssh_access, ssh_connect_with_username_password, ssh_fileserver): - ssh = SSHFileTransfer(host=ssh_access.host, port=ssh_access.port) - with ssh.connect_for_download( - connect=ssh_connect_with_username_password - ) as connection: - # use connection - """ - - def connect(host: str, port: int, **kwargs): - if kwargs: - raise ValueError( - "ssh_connect_with_username_password must only be" - f" used without extra arguments. Got {kwargs=}" - ) - connection = fabric.Connection( - host=host, - port=port, - user=ssh_access.user.username, - config=ssh_connection_config, - connect_kwargs={ - "password": ssh_access.user.password, - **ssh_connection_config.connect_kwargs, - }, - ) - connection.client.set_missing_host_key_policy(IgnorePolicy()) - return connection - - return connect - - -def _ssh_docker_up(target_dir: Path, ssh_access: SSHAccess) -> None: - if _ssh_server_is_running(): - raise RuntimeError("SSH docker container is already running") - docker_compose_file = target_dir / "docker-compose.yaml" - log = logging.getLogger("scitacean.testing") - log.info("Starting docker container with SSH server from %s", docker_compose_file) - configure(target_dir) - docker.docker_compose_up(docker_compose_file) - log.info("Waiting for SSH docker to become accessible") - wait_until_ssh_server_is_live(ssh_access=ssh_access, max_time=20, n_tries=20) - log.info("Successfully connected to SSH server") - # Give the user write access. - docker.docker_compose_run( - docker_compose_file, "scitacean-test-ssh-server", "chown", "1000:1000", "/data" - ) - - -def _ssh_docker_down(target_dir: Path) -> None: - # Check if container is running because the fixture can call this function - # if there was an exception in _ssh_docker_up. - # In that case, there is nothing to tear down. - if _ssh_server_is_running(): - docker_compose_file = target_dir / "docker-compose.yaml" - log = logging.getLogger("scitacean.testing") - log.info( - "Stopping docker container with SSH server from %s", docker_compose_file - ) - docker.docker_compose_down(docker_compose_file) - - -def _ssh_server_is_running() -> bool: - return docker.container_is_running("scitacean-test-ssh") diff --git a/src/scitacean/testing/ssh/ssh_server_seed/table.csv b/src/scitacean/testing/ssh/ssh_server_seed/table.csv deleted file mode 100644 index 76891270..00000000 --- a/src/scitacean/testing/ssh/ssh_server_seed/table.csv +++ /dev/null @@ -1,2 +0,0 @@ -7,2 -5,2 diff --git a/src/scitacean/testing/ssh/ssh_server_seed/text.txt b/src/scitacean/testing/ssh/ssh_server_seed/text.txt deleted file mode 100644 index 499ea0ea..00000000 --- a/src/scitacean/testing/ssh/ssh_server_seed/text.txt +++ /dev/null @@ -1 +0,0 @@ -This is some text for testing. diff --git a/src/scitacean/transfer/ssh.py b/src/scitacean/transfer/ssh.py deleted file mode 100644 index 90701b0b..00000000 --- a/src/scitacean/transfer/ssh.py +++ /dev/null @@ -1,503 +0,0 @@ -# SPDX-License-Identifier: BSD-3-Clause -# Copyright (c) 2024 SciCat Project (https://github.com/SciCatProject/scitacean) -# ruff: noqa: D100, D101, D102, D103 - -import os -import warnings -from contextlib import contextmanager -from datetime import datetime, timedelta -from getpass import getpass -from pathlib import Path -from typing import Any, Callable, Iterator, List, Optional, Tuple, Union - -from dateutil.tz import tzoffset - -# Note that invoke and paramiko are dependencies of fabric. -from fabric import Connection -from invoke.exceptions import UnexpectedExit -from paramiko import SFTPClient -from paramiko.ssh_exception import AuthenticationException, PasswordRequiredException - -from ..dataset import Dataset -from ..error import FileUploadError -from ..file import File -from ..filesystem import RemotePath -from ..logging import get_logger -from ..warning import VisibleDeprecationWarning -from .util import source_folder_for - -warnings.warn( - "SSHFileTransfer is deprecated and scheduled for removal in Scitacean v23.11.0." - "Use SFTPFileTransfer instead.", - VisibleDeprecationWarning, - stacklevel=0, -) - - -class SSHDownloadConnection: - def __init__(self, *, connection: Connection) -> None: - self._connection = connection - - def download_files(self, *, remote: List[RemotePath], local: List[Path]) -> None: - """Download files from the given remote path.""" - for r, l in zip(remote, local): - self.download_file(remote=r, local=l) - - def download_file(self, *, remote: RemotePath, local: Path) -> None: - get_logger().info( - "Downloading file %s from host %s to %s", - remote, - self._connection.host, - local, - ) - self._connection.get(remote=remote.posix, local=os.fspath(local)) - - -class SSHUploadConnection: - def __init__(self, *, connection: Connection, source_folder: RemotePath) -> None: - self._connection = connection - self._source_folder = source_folder - self._remote_timezone = self._get_remote_timezone() - - @property - def _sftp(self) -> SFTPClient: - return self._connection.sftp() # type: ignore[no-any-return] - - @property - def source_folder(self) -> RemotePath: - return self._source_folder - - def remote_path(self, filename: Union[str, RemotePath]) -> RemotePath: - return self.source_folder / filename - - def _make_source_folder(self) -> None: - try: - self._connection.run( - f"mkdir -p {self.source_folder.posix}", hide=True, in_stream=False - ) - except OSError as exc: - raise FileUploadError( - f"Failed to create source folder {self.source_folder}: {exc.args}" - ) from None - - def upload_files(self, *files: File) -> List[File]: - """Upload files to the remote folder.""" - self._make_source_folder() - uploaded = [] - try: - for file in files: - up, exc = self._upload_file(file) - uploaded.append(up) # need to add this file in order to revert it - if exc is not None: - raise exc - except Exception: - self.revert_upload(*uploaded) - raise - return uploaded - - def _upload_file(self, file: File) -> Tuple[File, Optional[Exception]]: - if file.local_path is None: - raise ValueError( - f"Cannot upload file to {file.remote_path}, " - "the file has no local path" - ) - remote_path = self.remote_path(file.remote_path) - get_logger().info( - "Uploading file %s to %s on host %s", - file.local_path, - remote_path, - self._connection.host, - ) - st = self._sftp.put( - remotepath=remote_path.posix, localpath=os.fspath(file.local_path) - ) - if (exc := self._validate_upload(file)) is not None: - return file, exc - creation_time = ( - datetime.fromtimestamp(st.st_mtime, tz=self._remote_timezone) - if st.st_mtime - else None - ) - return ( - file.uploaded( - remote_gid=str(st.st_gid), - remote_uid=str(st.st_uid), - remote_creation_time=creation_time, - remote_perm=str(st.st_mode), - remote_size=st.st_size, - ), - None, - ) - - def _validate_upload(self, file: File) -> Optional[Exception]: - if (checksum := self._compute_checksum(file)) is None: - return None - if checksum != file.checksum(): - return FileUploadError( - f"Upload of file {file.remote_path} failed: " - f"Checksum of uploaded file ({checksum}) does not " - f"match checksum of local file ({file.checksum()}) " - f"using algorithm {file.checksum_algorithm}" - ) - return None - - def _compute_checksum(self, file: File) -> Optional[str]: - if (hash_exe := _coreutils_checksum_for(file)) is None: - return None - try: - res = self._connection.run( - f"{hash_exe} {self.remote_path(file.remote_path).posix}", - hide=True, - in_stream=False, - ) - except UnexpectedExit as exc: - if exc.result.return_code == 127: - get_logger().warning( - "Cannot validate checksum of uploaded file %s because checksum " - "algorithm '%s' is not implemented on the server.", - file.remote_path, - file.checksum_algorithm, - ) - return None - raise - return res.stdout.split(" ", 1)[0] # type: ignore[no-any-return] - - def _get_remote_timezone(self) -> tzoffset: - cmd = 'date +"%Z|%::z"' - try: - tz_str = self._connection.run( - cmd, hide=True, in_stream=False - ).stdout.strip() - except UnexpectedExit as exc: - raise FileUploadError( - f"Failed to get timezone of fileserver: {exc.args}" - ) from None - tz = _parse_remote_timezone(tz_str) - get_logger().info("Detected timezone of fileserver: %s", tz) - return tz - - def revert_upload(self, *files: File) -> None: - """Remove uploaded files from the remote folder.""" - for file in files: - self._revert_upload_single(remote=file.remote_path, local=file.local_path) - - if _folder_is_empty(self._connection, self.source_folder): - try: - get_logger().info( - "Removing empty remote directory %s on host %s", - self.source_folder, - self._connection.host, - ) - self._sftp.rmdir(self.source_folder.posix) - except UnexpectedExit as exc: - get_logger().warning( - "Failed to remove empty remote directory %s on host:\n%s", - self.source_folder, - self._connection.host, - exc.result, - ) - - def _revert_upload_single( - self, *, remote: RemotePath, local: Optional[Path] - ) -> None: - remote_path = self.remote_path(remote) - get_logger().info( - "Reverting upload of file %s to %s on host %s", - local, - remote_path, - self._connection.host, - ) - - try: - self._sftp.remove(remote_path.posix) - except UnexpectedExit as exc: - get_logger().warning( - "Error reverting file %s:\n%s", remote_path, exc.result - ) - return - - -class SSHFileTransfer: - """Upload / download files using SSH. - - Configuration & Authentication - ------------------------------ - The file transfer connects to the server at the address given - as the ``host`` constructor argument. - This may be - - - a full url such as ``some.fileserver.edu``, - - an IP address like ``127.0.0.1``, - - or a host defined in the user's openSSH config file. - - The file transfer can authenticate using username+password. - It will ask for those on the command line. - However, it is **highly recommended to set up a key and use an SSH agent!** - This increases security as Scitacean no longer has to handle credentials itself. - And it is required for automated programs where a user cannot enter credentials - on a command line. - - Upload folder - ------------- - The file transfer can take an optional ``source_folder`` as a constructor argument. - If it is given, ``SSHFileTransfer`` uploads all files to it and ignores the - source folder set in the dataset. - If it is not given, ``SSHFileTransfer`` uses the dataset's source folder. - - The source folder argument to ``SSHFileTransfer`` may be a Python format string. - In that case, all format fields are replaced by the corresponding fields - of the dataset. - All non-ASCII characters and most special ASCII characters are replaced. - This should avoid broken paths from essentially random contents in datasets. - - Examples - -------- - Given - - .. code-block:: python - - dset = Dataset( - type="raw", - name="my-dataset", - source_folder="/dataset/source", - ) - - This uploads to ``/dataset/source``: - - .. code-block:: python - - file_transfer = SSHFileTransfer(host="fileserver") - - This uploads to ``/transfer/folder``: - - .. code-block:: python - - file_transfer = SSHFileTransfer(host="fileserver", - source_folder="transfer/folder") - - This uploads to ``/transfer/my-dataset``: - (Note that ``{name}`` is replaced by ``dset.name``.) - - .. code-block:: python - - file_transfer = SSHFileTransfer(host="fileserver", - source_folder="transfer/{name}") - - A useful approach is to include a unique ID in the source folder, for example - ``"/some/base/folder/{uid}"``, to avoid clashes between different datasets. - Scitacean will fill in the ``"{uid}"`` placeholder with a new UUID4. - - .. deprecated:: 23.08.0 - """ - - def __init__( - self, - *, - host: str, - port: Optional[int] = None, - source_folder: Optional[Union[str, RemotePath]] = None, - ) -> None: - """Construct a new SSH file transfer. - - Parameters - ---------- - host: - URL or name of the server to connect to. - port: - Port of the server. - source_folder: - Upload files to this folder if set. - Otherwise, upload to the dataset's source_folder. - Ignored when downloading files. - """ - self._host = host - self._port = port - self._source_folder_pattern = ( - RemotePath(source_folder) - if isinstance(source_folder, str) - else source_folder - ) - - def source_folder_for(self, dataset: Dataset) -> RemotePath: - """Return the source folder used for the given dataset.""" - return source_folder_for(dataset, self._source_folder_pattern) - - @contextmanager - def connect_for_download( - self, connect: Optional[Callable[..., Connection]] = None - ) -> Iterator[SSHDownloadConnection]: - """Create a connection for downloads, use as a context manager. - - Parameters - ---------- - connect: - A function that creates and returns a :class:`fabric.connection.Connection` - object. - Will first be called with only ``host`` and ``port``. - If this fails (by raising - :class:`paramiko.ssh_exception.AuthenticationException`), the function is - called with ``host``, ``port``, and, optionally, ``user`` and - ``connection_kwargs`` depending on the authentication method. - Raising :class:`paramiko.ssh_exception.AuthenticationException` in the 2nd - call or any other exception in the 1st signals failure of - ``connect_for_download``. - """ - con = _connect(self._host, self._port, connect=connect) - try: - yield SSHDownloadConnection(connection=con) - finally: - con.close() - - @contextmanager - def connect_for_upload( - self, dataset: Dataset, connect: Optional[Callable[..., Connection]] = None - ) -> Iterator[SSHUploadConnection]: - """Create a connection for uploads, use as a context manager. - - Parameters - ---------- - dataset: - The connection will be used to upload files of this dataset. - Used to determine the target folder. - connect: - A function that creates and returns a :class:`fabric.connection.Connection` - object. - Will first be called with only ``host`` and ``port``. - If this fails (by raising - :class:`paramiko.ssh_exception.AuthenticationException`), the function is - called with ``host``, ``port``, and, optionally, ``user`` and - ``connection_kwargs`` depending on the authentication method. - Raising :class:`paramiko.ssh_exception.AuthenticationException` in the 2nd - call or any other exception in the 1st signals failure of - ``connect_for_upload``. - """ - source_folder = self.source_folder_for(dataset) - con = _connect(self._host, self._port, connect=connect) - try: - yield SSHUploadConnection( - connection=con, - source_folder=source_folder, - ) - finally: - con.close() - - -def _ask_for_key_passphrase() -> str: - return getpass("The private key is encrypted, enter passphrase: ") - - -def _ask_for_credentials(host: str) -> Tuple[str, str]: - print(f"You need to authenticate to access {host}") # noqa: T201 - username = input("Username: ") - password = getpass("Password: ") - return username, password - - -def _generic_connect( - host: str, - port: Optional[int], - connect: Optional[Callable[..., Connection]], - **kwargs: Any, -) -> Connection: - if connect is None: - con = Connection(host=host, port=port, **kwargs) - else: - con = connect(host=host, port=port, **kwargs) - con.open() - return con - - -def _unauthenticated_connect( - host: str, port: Optional[int], connect: Optional[Callable[..., Connection]] -) -> Connection: - return _generic_connect(host=host, port=port, connect=connect) - - -def _authenticated_connect( - host: str, - port: Optional[int], - connect: Optional[Callable[..., Connection]], - exc: AuthenticationException, -) -> Connection: - # TODO fail fast if output going to file - if isinstance(exc, PasswordRequiredException) and "encrypted" in exc.args[0]: - # TODO does not work anymore, exception is always AuthenticationException - return _generic_connect( - host=host, - port=port, - connect=connect, - connect_kwargs={"passphrase": _ask_for_key_passphrase()}, - ) - else: - username, password = _ask_for_credentials(host) - return _generic_connect( - host=host, - port=port, - connect=connect, - user=username, - connect_kwargs={"password": password}, - ) - - -def _connect( - host: str, port: Optional[int], connect: Optional[Callable[..., Connection]] -) -> Connection: - try: - try: - return _unauthenticated_connect(host, port, connect) - except AuthenticationException as exc: - return _authenticated_connect(host, port, connect, exc) - except Exception as exc: - # We pass secrets as arguments to functions called in this block, and those - # can be leaked through exception handlers. So catch all exceptions - # and strip the backtrace up to this point to hide those secrets. - raise type(exc)(exc.args) from None - except BaseException as exc: - raise type(exc)(exc.args) from None - - -def _folder_is_empty(con: Connection, path: RemotePath) -> bool: - try: - ls: str = con.run(f"ls {path.posix}", hide=True, in_stream=False).stdout - return ls == "" - except UnexpectedExit: - return False # no further processing is needed in this case - - -def _coreutils_checksum_for(file: File) -> Optional[str]: - # blake2s is not supported because `b2sum -l 256` produces a different digest - # and I don't know why. - supported = { - "md5": "md5sum -b", - "sha256": "sha256sum -b", - "sha384": "sha384sum -b", - "sha512": "sha512sum -b", - "blake2b": "b2sum -l 512 -b", - } - algorithm = file.checksum_algorithm - if algorithm == "blake2s" or algorithm not in supported: - get_logger().warning( - "Cannot validate checksum of uploaded file %s because checksum algorithm " - "'%s' is not supported by scitacean for remote files.", - file.remote_path, - file.checksum_algorithm, - ) - return None - return supported[algorithm] - - -# Using `date +"%Z"` returns a timezone abbreviation like CET or EST. -# dateutil.tz.gettz can parse this abbreviation and return a tzinfo object. -# However, on Windows, it returns `None` if the string refers to the local timezone. -# This is indistinguishable from an unrecognised timezone, -# where gettz also returns `None`. -# To avoid this, use an explicit offset obtained from `date +"%::z"`. -# The timezone name is only used for debugging and not interpreted by -# dateutil or datetime. -def _parse_remote_timezone(tz_str: str) -> tzoffset: - # tz_str is expected to be of the form - # |:: - # as produced by `date +"%Z|%::z"` - name, offset = tz_str.split("|") - hours, minutes, seconds = map(int, offset.split(":")) - return tzoffset(name, timedelta(hours=hours, minutes=minutes, seconds=seconds)) diff --git a/tests/conftest.py b/tests/conftest.py index 31c6e91b..31493bb1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,12 +6,10 @@ from scitacean.testing.backend import add_pytest_option as add_backend_option from scitacean.testing.sftp import add_pytest_option as add_sftp_option -from scitacean.testing.ssh import add_pytest_option as add_ssh_option pytest_plugins = ( "scitacean.testing.backend.fixtures", "scitacean.testing.sftp.fixtures", - "scitacean.testing.ssh.fixtures", ) @@ -30,4 +28,3 @@ def pytest_addoption(parser: pytest.Parser) -> None: add_backend_option(parser) add_sftp_option(parser) - add_ssh_option(parser) diff --git a/tests/transfer/ssh_test.py b/tests/transfer/ssh_test.py deleted file mode 100644 index 07b00d33..00000000 --- a/tests/transfer/ssh_test.py +++ /dev/null @@ -1,402 +0,0 @@ -# SPDX-License-Identifier: BSD-3-Clause -# Copyright (c) 2022 Scitacean contributors (https://github.com/SciCatProject/scitacean) -# mypy: disable-error-code="no-untyped-def, return-value, arg-type, union-attr" - -import dataclasses -import tempfile -from contextlib import contextmanager -from datetime import datetime, timedelta, timezone -from pathlib import Path -from typing import Callable, Iterator, Optional - -import fabric -import paramiko -import pytest -from fabric import Connection - -from scitacean import Dataset, File, FileUploadError, RemotePath -from scitacean.testing.client import FakeClient -from scitacean.testing.ssh import IgnorePolicy, skip_if_not_ssh -from scitacean.transfer.ssh import ( - SSHDownloadConnection, - SSHFileTransfer, - SSHUploadConnection, -) - - -@pytest.fixture(scope="session", autouse=True) -def server(request, ssh_fileserver): - skip_if_not_ssh(request) - - -def test_download_one_file(ssh_access, ssh_connect_with_username_password, tmp_path): - ssh = SSHFileTransfer(host=ssh_access.host, port=ssh_access.port) - with ssh.connect_for_download(connect=ssh_connect_with_username_password) as con: - con.download_files( - remote=[RemotePath("/data/seed/text.txt")], local=[tmp_path / "text.txt"] - ) - with open(tmp_path / "text.txt", "r") as f: - assert f.read() == "This is some text for testing.\n" - - -def test_download_two_files(ssh_access, ssh_connect_with_username_password, tmp_path): - ssh = SSHFileTransfer(host=ssh_access.host, port=ssh_access.port) - with ssh.connect_for_download(connect=ssh_connect_with_username_password) as con: - con.download_files( - remote=[ - RemotePath("/data/seed/table.csv"), - RemotePath("/data/seed/text.txt"), - ], - local=[tmp_path / "local-table.csv", tmp_path / "text.txt"], - ) - with open(tmp_path / "local-table.csv", "r") as f: - assert f.read() == "7,2\n5,2\n" - with open(tmp_path / "text.txt", "r") as f: - assert f.read() == "This is some text for testing.\n" - - -def test_upload_one_file_source_folder_in_dataset( - ssh_access, ssh_connect_with_username_password, tmp_path, ssh_data_dir -): - ds = Dataset(type="raw", source_folder=RemotePath("/data/upload")) - with open(tmp_path / "file0.txt", "w") as f: - f.write("File to test upload123") - - ssh = SSHFileTransfer(host=ssh_access.host, port=ssh_access.port) - with ssh.connect_for_upload( - dataset=ds, connect=ssh_connect_with_username_password - ) as con: - assert con.source_folder == RemotePath("/data/upload") - con.upload_files( - File.from_local(path=tmp_path / "file0.txt", remote_path="upload_0.txt") - ) - - with open(ssh_data_dir / "upload" / "upload_0.txt", "r") as f: - assert f.read() == "File to test upload123" - - -def test_upload_one_file_source_folder_in_transfer( - ssh_access, ssh_connect_with_username_password, tmp_path, ssh_data_dir -): - ds = Dataset(type="raw", owner="librarian") - with open(tmp_path / "file1.txt", "w") as f: - f.write("File no. 2") - - ssh = SSHFileTransfer( - host=ssh_access.host, - port=ssh_access.port, - source_folder="/data/upload/{owner}", - ) - with ssh.connect_for_upload( - dataset=ds, connect=ssh_connect_with_username_password - ) as con: - assert con.source_folder == RemotePath("/data/upload/librarian") - con.upload_files( - File.from_local( - path=tmp_path / "file1.txt", remote_path=RemotePath("upload_1.txt") - ) - ) - - with open(ssh_data_dir / "upload" / "librarian" / "upload_1.txt", "r") as f: - assert f.read() == "File no. 2" - - -def test_upload_two_files( - ssh_access, ssh_connect_with_username_password, tmp_path, ssh_data_dir -): - ds = Dataset(type="raw", source_folder=RemotePath("/data/upload2")) - with open(tmp_path / "file2.1.md", "w") as f: - f.write("First part of file 2") - with open(tmp_path / "file2.2.md", "w") as f: - f.write("Second part of file 2") - - ssh = SSHFileTransfer(host=ssh_access.host, port=ssh_access.port) - with ssh.connect_for_upload( - dataset=ds, connect=ssh_connect_with_username_password - ) as con: - assert con.source_folder == RemotePath("/data/upload2") - con.upload_files( - File.from_local(path=tmp_path / "file2.1.md", base_path=tmp_path), - File.from_local(path=tmp_path / "file2.2.md", base_path=tmp_path), - ) - - with open(ssh_data_dir / "upload2" / "file2.1.md", "r") as f: - assert f.read() == "First part of file 2" - with open(ssh_data_dir / "upload2" / "file2.2.md", "r") as f: - assert f.read() == "Second part of file 2" - - -def test_revert_all_uploaded_files_single( - ssh_access, ssh_connect_with_username_password, tmp_path, ssh_data_dir -): - ds = Dataset(type="raw", source_folder=RemotePath("/data/revert-all-test-1")) - with open(tmp_path / "file3.txt", "w") as f: - f.write("File that should get reverted") - - ssh = SSHFileTransfer(host=ssh_access.host, port=ssh_access.port) - with ssh.connect_for_upload( - dataset=ds, connect=ssh_connect_with_username_password - ) as con: - file = File.from_local(path=tmp_path / "file3.txt", base_path=tmp_path) - con.upload_files(file) - con.revert_upload(file) - - assert "revert-all-test-1" not in (p.name for p in ssh_data_dir.iterdir()) - - -def test_revert_all_uploaded_files_two( - ssh_access, ssh_connect_with_username_password, tmp_path, ssh_data_dir -): - ds = Dataset(type="raw", source_folder=RemotePath("/data/revert-all-test-2")) - with open(tmp_path / "file3.1.txt", "w") as f: - f.write("File that should get reverted 1") - with open(tmp_path / "file3.2.txt", "w") as f: - f.write("File that should get reverted 2") - - ssh = SSHFileTransfer(host=ssh_access.host, port=ssh_access.port) - with ssh.connect_for_upload( - dataset=ds, connect=ssh_connect_with_username_password - ) as con: - file1 = File.from_local(path=tmp_path / "file3.1.txt", base_path=tmp_path) - file2 = File.from_local(path=tmp_path / "file3.2.txt", base_path=tmp_path) - con.upload_files(file1, file2) - con.revert_upload(file1, file2) - - assert "revert-all-test-2" not in (p.name for p in ssh_data_dir.iterdir()) - - -def test_revert_one_uploaded_file( - ssh_access, ssh_connect_with_username_password, tmp_path, ssh_data_dir -): - ds = Dataset(type="raw", source_folder=RemotePath("/data/revert-test")) - with open(tmp_path / "file4.txt", "w") as f: - f.write("File that should get reverted") - with open(tmp_path / "file5.txt", "w") as f: - f.write("File that should be kept") - - ssh = SSHFileTransfer(host=ssh_access.host, port=ssh_access.port) - with ssh.connect_for_upload( - dataset=ds, connect=ssh_connect_with_username_password - ) as con: - file4 = File.from_local(path=tmp_path / "file4.txt", base_path=tmp_path) - file5 = File.from_local(path=tmp_path / "file5.txt", base_path=tmp_path) - con.upload_files(file4, file5) - con.revert_upload(file4) - - assert "file4.txt" not in (p.name for p in (ssh_data_dir / "revert-test").iterdir()) - with open(ssh_data_dir / "revert-test" / "file5.txt", "r") as f: - assert f.read() == "File that should be kept" - - -def test_stat_uploaded_file( - ssh_access, ssh_connect_with_username_password, tmp_path, ssh_data_dir -): - ds = Dataset(type="raw", source_folder=RemotePath("/data/upload6")) - with open(tmp_path / "file6.txt", "w") as f: - f.write("File to test upload no 6") - - ssh = SSHFileTransfer(host=ssh_access.host, port=ssh_access.port) - with ssh.connect_for_upload( - dataset=ds, connect=ssh_connect_with_username_password - ) as con: - [uploaded] = con.upload_files( - File.from_local(path=tmp_path / "file6.txt", remote_path="upload_6.txt") - ) - - st = (ssh_data_dir / "upload6" / "upload_6.txt").stat() - assert uploaded.size == st.st_size - - # Set in docker-compose - assert uploaded.remote_uid == "1000" - assert uploaded.remote_gid == "1000" - - uploaded = dataclasses.replace(uploaded, local_path=None) - assert datetime.now(tz=timezone.utc) - uploaded.creation_time < timedelta(seconds=5) - - -class CorruptingSFTP(paramiko.SFTPClient): - """Appends bytes to uploaded files to simulate a broken transfer.""" - - def put(self, localpath, remotepath, callback=None, confirm=True): - with open(localpath, "r") as f: - content = f.read() - with tempfile.TemporaryDirectory() as tempdir: - corrupted_path = Path(tempdir) / "corrupted" - with open(corrupted_path, "w") as f: - f.write(content + "\nevil bytes") - super().put(str(corrupted_path), remotepath, callback, confirm) - - -class CorruptingTransfer(paramiko.Transport): - """Uses CorruptingSFTP to simulate a broken connection.""" - - def open_sftp_client(self) -> paramiko.SFTPClient: - return CorruptingSFTP.from_transport(self) - - -@pytest.fixture() -def ssh_corrupting_connect(ssh_access, ssh_connection_config): - def connect(host: str, port: int, **kwargs): - if kwargs: - raise ValueError( - "connect_with_username_password must only be" - f" used without extra arguments. Got {kwargs=}" - ) - connection = fabric.Connection( - host=host, - port=port, - user=ssh_access.user.username, - config=ssh_connection_config, - connect_kwargs={ - "password": ssh_access.user.password, - "transport_factory": CorruptingTransfer, - **ssh_connection_config.connect_kwargs, - }, - ) - connection.client.set_missing_host_key_policy(IgnorePolicy()) - return connection - - return connect - - -def test_upload_file_detects_checksum_mismatch( - ssh_access, ssh_corrupting_connect, tmp_path, ssh_data_dir -): - ds = Dataset( - type="raw", - source_folder=RemotePath("/data/upload7"), - checksum_algorithm="blake2b", - ) - with open(tmp_path / "file7.txt", "w") as f: - f.write("File to test upload no 7") - - ssh = SSHFileTransfer(host=ssh_access.host, port=ssh_access.port) - with ssh.connect_for_upload(dataset=ds, connect=ssh_corrupting_connect) as con: - with pytest.raises(FileUploadError): - con.upload_files( - dataclasses.replace( - File.from_local( - path=tmp_path / "file7.txt", - remote_path=RemotePath("upload_7.txt"), - ), - checksum_algorithm="blake2b", - ) - ) - - assert "upload7" not in (p.name for p in ssh_data_dir.iterdir()) - - -class RaisingSFTP(paramiko.SFTPClient): - def put(self, localpath, remotepath, callback=None, confirm=True): - raise RuntimeError("Upload disabled") - - -class RaisingTransfer(paramiko.Transport): - def open_sftp_client(self) -> paramiko.SFTPClient: - return RaisingSFTP.from_transport(self) - - -@pytest.fixture() -def ssh_raising_connect(ssh_access, ssh_connection_config): - def connect(host: str, port: int, **kwargs): - if kwargs: - raise ValueError( - "connect_with_username_password must only be" - f" used without extra arguments. Got {kwargs=}" - ) - connection = fabric.Connection( - host=host, - port=port, - user=ssh_access.user.username, - config=ssh_connection_config, - connect_kwargs={ - "password": ssh_access.user.password, - "transport_factory": RaisingTransfer, - **ssh_connection_config.connect_kwargs, - }, - ) - connection.client.set_missing_host_key_policy(IgnorePolicy()) - return connection - - return connect - - -def test_upload_file_reverts_if_upload_fails( - ssh_access, ssh_raising_connect, tmp_path, ssh_data_dir -): - ds = Dataset(type="raw", source_folder=RemotePath("/data/upload8")) - with open(tmp_path / "file8.txt", "w") as f: - f.write("File to test upload no 8") - - ssh = SSHFileTransfer(host=ssh_access.host, port=ssh_access.port) - with ssh.connect_for_upload(dataset=ds, connect=ssh_raising_connect) as con: - with pytest.raises(RuntimeError): - con.upload_files( - File.from_local( - path=tmp_path / "file8.txt", - remote_path=RemotePath("upload_8.txt"), - ) - ) - - assert "upload8" not in (p.name for p in ssh_data_dir.iterdir()) - - -class SSHTestFileTransfer(SSHFileTransfer): - def __init__(self, connect, **kwargs): - super().__init__(**kwargs) - self.connect = connect - - @contextmanager - def connect_for_download( - self, connect: Optional[Callable[..., Connection]] = None - ) -> Iterator[SSHDownloadConnection]: - connect = connect if connect is not None else self.connect - with super().connect_for_download(connect=connect) as connection: - yield connection - - @contextmanager - def connect_for_upload( - self, dataset: Dataset, connect: Optional[Callable[..., Connection]] = None - ) -> Iterator[SSHUploadConnection]: - connect = connect if connect is not None else self.connect - with super().connect_for_upload(dataset=dataset, connect=connect) as connection: - yield connection - - -# This test is referenced in the docs. -def test_client_with_ssh( - require_ssh_fileserver, - ssh_access, - ssh_connect_with_username_password, - ssh_data_dir, - tmp_path, -): - tmp_path.joinpath("file1.txt").write_text("File contents") - - client = FakeClient.without_login( - url="", - file_transfer=SSHTestFileTransfer( - connect=ssh_connect_with_username_password, - host=ssh_access.host, - port=ssh_access.port, - ), - ) - ds = Dataset( - access_groups=["group1"], - contact_email="p.stibbons@uu.am", - creation_location="UU", - creation_time=datetime(2023, 6, 23, 10, 0, 0), - owner="PonderStibbons", - owner_group="uu", - principal_investigator="MustrumRidcully", - source_folder="/data", - type="raw", - ) - ds.add_local_files(tmp_path / "file1.txt", base_path=tmp_path) - finalized = client.upload_new_dataset_now(ds) - - downloaded = client.get_dataset(finalized.pid) - downloaded = client.download_files(downloaded, target=tmp_path / "download") - - assert ssh_data_dir.joinpath("file1.txt").read_text() == "File contents" - assert downloaded.files[0].local_path.read_text() == "File contents"