diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 9800cb0a..b6c28497 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -33,7 +33,7 @@ jobs: python -m pip install --upgrade pip pip install poetry - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} cache: poetry @@ -59,9 +59,6 @@ jobs: - name: Test with pytest run: poetry run pytest --cov=milatools --cov-report=xml --cov-append - - name: Test with pytest (with -s flag) - run: poetry run pytest --cov=milatools --cov-report=xml --cov-append -s - - name: Upload coverage reports to Codecov uses: codecov/codecov-action@v3 with: @@ -96,7 +93,7 @@ jobs: options: --health-cmd="mysqladmin ping" --health-interval=10s --health-timeout=5s --health-retries=3 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 # NOTE: Replacing this with our customized version of # - uses: koesterlab/setup-slurm-action@v1 @@ -120,7 +117,7 @@ jobs: ssh -o 'StrictHostKeyChecking no' localhost id - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v3 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} @@ -131,7 +128,7 @@ jobs: poetry install --with=dev - name: Launch integration tests - run: poetry run pytest tests/cli/test_slurm_remote.py --cov=milatools --cov-report=xml --cov-append -s -vvv --log-level=DEBUG + run: poetry run pytest tests/integration --cov=milatools --cov-report=xml --cov-append -s -vvv --log-level=DEBUG timeout-minutes: 3 env: SLURM_CLUSTER: localhost diff --git a/milatools/cli/commands.py b/milatools/cli/commands.py index 2ae14286..1a612a93 100644 --- a/milatools/cli/commands.py +++ b/milatools/cli/commands.py @@ -51,6 +51,8 @@ MilatoolsUserError, SSHConnectionError, T, + cluster_to_connect_kwargs, + currently_in_a_test, get_fully_qualified_hostname_of_compute_node, get_fully_qualified_name, make_process, @@ -492,7 +494,7 @@ def code( persist: bool, job: str | None, node: str | None, - alloc: Sequence[str], + alloc: list[str], cluster: Cluster = "mila", ): """Open a remote VSCode session on a compute node. @@ -525,13 +527,12 @@ def code( if command is None: command = os.environ.get("MILATOOLS_CODE_COMMAND", "code") - if remote.hostname != "graham": # graham doesn't use lustre for $HOME - try: - check_disk_quota(remote) - except MilatoolsUserError: - raise - except Exception as exc: - logger.warning(f"Unable to check the disk-quota on the cluster: {exc}") + try: + check_disk_quota(remote) + except MilatoolsUserError: + raise + except Exception as exc: + logger.warning(f"Unable to check the disk-quota on the cluster: {exc}") vscode_extensions_folder = Path.home() / ".vscode/extensions" if vscode_extensions_folder.exists() and no_internet_on_compute_nodes(cluster): @@ -618,6 +619,8 @@ def code( "The editor was closed. Reopen it with " " or terminate the process with " ) + if currently_in_a_test(): + break input() except KeyboardInterrupt: @@ -714,7 +717,7 @@ def serve_list(purge: bool): class StandardServerArgs(TypedDict): - alloc: Sequence[str] + alloc: list[str] """Extra options to pass to slurm.""" job: str | None @@ -922,7 +925,7 @@ def _standard_server( name: str | None, node: str | None, job: str | None, - alloc: Sequence[str], + alloc: list[str], port_pattern=None, token_pattern=None, ): @@ -1074,54 +1077,83 @@ def _standard_server( proc.kill() -def _get_disk_quota_usage( - remote: Remote, print_command_output: bool = True +def _parse_lfs_quota_output( + lfs_quota_output: str, ) -> tuple[tuple[float, float], tuple[int, int]]: - """Checks the disk quota on the $HOME filesystem on the mila cluster. - - Returns whether the quota is exceeded, in terms of storage space or number of files. - """ + """Parses space and # of files (usage, limit) from the output of `lfs quota`.""" + lines = lfs_quota_output.splitlines() + + header_line: str | None = None + header_line_index: int | None = None + for index, line in enumerate(lines): + if ( + len(line_parts := line.strip().split()) == 9 + and line_parts[0].lower() == "filesystem" + ): + header_line = line + header_line_index = index + break + assert header_line + assert header_line_index is not None + + values_line_parts: list[str] = [] + # The next line may overflow to two (or maybe even more?) lines if the name of the + # $HOME dir is too long. + for content_line in lines[header_line_index + 1 :]: + additional_values = content_line.strip().split() + assert len(values_line_parts) < 9 + values_line_parts.extend(additional_values) + if len(values_line_parts) == 9: + break - # NOTE: This is what the output of the command looks like on the Mila cluster: - # - # $ lfs quota -u $USER /home/mila - # Disk quotas for usr normandf (uid 1471600598): - # Filesystem kbytes quota limit grace files quota limit grace - # /home/mila 101440844 0 104857600 - 936140 0 1048576 - - # uid 1471600598 is using default block quota setting - # uid 1471600598 is using default file quota setting - # - home_disk_quota_output = remote.get_output( - "lfs quota -u $USER $HOME", hide=not print_command_output - ) - lines = home_disk_quota_output.splitlines() + assert len(values_line_parts) == 9, values_line_parts ( _filesystem, used_kbytes, - _quota1, + _quota_kbytes, limit_kbytes, - _grace1, + _grace_kbytes, files, - _quota2, + _quota_files, limit_files, - _grace2, - ) = lines[2].strip().split() + _grace_files, + ) = values_line_parts - used_gb = float(int(used_kbytes.strip()) / (1024) ** 2) - max_gb = float(int(limit_kbytes.strip()) / (1024) ** 2) + used_gb = int(used_kbytes.strip()) / (1024**2) + max_gb = int(limit_kbytes.strip()) / (1024**2) used_files = int(files.strip()) max_files = int(limit_files.strip()) return (used_gb, max_gb), (used_files, max_files) def check_disk_quota(remote: Remote) -> None: - cluster = ( - "mila" # todo: if we run this on CC, then we should use `diskusage_report` - ) - # todo: Check the disk-quota of other filesystems if needed. - filesystem = "$HOME" + cluster = remote.hostname + + # NOTE: This is what the output of the command looks like on the Mila cluster: + # + # Disk quotas for usr normandf (uid 1471600598): + # Filesystem kbytes quota limit grace files quota limit grace + # /home/mila/n/normandf + # 95747836 0 104857600 - 908722 0 1048576 - + # uid 1471600598 is using default block quota setting + # uid 1471600598 is using default file quota setting + + # Need to assert this, otherwise .get_output calls .run which would spawn a job! + assert not isinstance(remote, SlurmRemote) + if not remote.get_output("which lfs", hide=True): + logger.debug("Cluster doesn't have the lfs command. Skipping check.") + return + logger.debug("Checking disk quota on $HOME...") - (used_gb, max_gb), (used_files, max_files) = _get_disk_quota_usage(remote) + + home_disk_quota_output = remote.get_output("lfs quota -u $USER $HOME", hide=True) + if "not on a mounted Lustre filesystem" in home_disk_quota_output: + logger.debug("Cluster doesn't use lustre on $HOME filesystem. Skipping check.") + return + + (used_gb, max_gb), (used_files, max_files) = _parse_lfs_quota_output( + home_disk_quota_output + ) logger.debug( f"Disk usage: {used_gb:.1f} / {max_gb} GiB and {used_files} / {max_files} files" ) @@ -1144,7 +1176,7 @@ def check_disk_quota(remote: Remote) -> None: if used_gb >= max_gb or used_files >= max_files: raise MilatoolsUserError( T.red( - f"ERROR: Your disk quota on the {filesystem} filesystem is exceeded! " + f"ERROR: Your disk quota on the $HOME filesystem is exceeded! " f"({reason}).\n" f"To fix this, login to the cluster with `ssh {cluster}` and free up " f"some space, either by deleting files, or by moving them to a " @@ -1153,24 +1185,20 @@ def check_disk_quota(remote: Remote) -> None: ) if max(size_ratio, files_ratio) > 0.9: warning_message = ( - f"WARNING: You are getting pretty close to your disk quota on the $HOME " + f"You are getting pretty close to your disk quota on the $HOME " f"filesystem: ({reason})\n" "Please consider freeing up some space in your $HOME folder, either by " "deleting files, or by moving them to a more suitable filesystem.\n" + freeing_up_space_instructions ) - # TODO: Perhaps we could use the logger or the warnings package instead of just - # printing? - # logger.warning(UserWarning(warning_message)) - # warnings.warn(UserWarning(T.yellow(warning_message))) - print(UserWarning(T.yellow(warning_message))) + logger.warning(UserWarning(warning_message)) def _find_allocation( remote: Remote, node: str | None, job: str | None, - alloc: Sequence[str], + alloc: list[str], cluster: Cluster = "mila", job_name: str = "mila-tools", ): @@ -1179,11 +1207,11 @@ def _find_allocation( if node is not None: node_name = get_fully_qualified_hostname_of_compute_node(node, cluster=cluster) - return Remote(node_name) + return Remote(node_name, connect_kwargs=cluster_to_connect_kwargs.get(cluster)) elif job is not None: node_name = remote.get_output(f"squeue --jobs {job} -ho %N") - return Remote(node_name) + return Remote(node_name, connect_kwargs=cluster_to_connect_kwargs.get(cluster)) else: alloc = ["-J", job_name, *alloc] diff --git a/milatools/cli/local.py b/milatools/cli/local.py index a84630e9..5ece7ccd 100644 --- a/milatools/cli/local.py +++ b/milatools/cli/local.py @@ -9,7 +9,7 @@ import paramiko.ssh_exception from typing_extensions import deprecated -from .utils import CommandNotFoundError, T, shjoin +from .utils import CommandNotFoundError, T, cluster_to_connect_kwargs, shjoin logger = get_logger(__name__) @@ -76,7 +76,13 @@ def display(split_command: list[str] | tuple[str, ...] | str) -> None: def check_passwordless(host: str) -> bool: try: - with fabric.Connection(host) as connection: + connect_kwargs_for_host = {"allow_agent": False} + if host in cluster_to_connect_kwargs: + connect_kwargs_for_host.update(cluster_to_connect_kwargs[host]) + with fabric.Connection( + host, + connect_kwargs=connect_kwargs_for_host, + ) as connection: results: fabric.runners.Result = connection.run( "echo OK", in_stream=False, diff --git a/milatools/cli/remote.py b/milatools/cli/remote.py index f80c395d..534cb5d9 100644 --- a/milatools/cli/remote.py +++ b/milatools/cli/remote.py @@ -17,7 +17,15 @@ from fabric import Connection from typing_extensions import Self, TypedDict, deprecated -from .utils import SSHConnectionError, T, control_file_var, here, shjoin +from .utils import ( + DRAC_CLUSTERS, + SSHConnectionError, + T, + cluster_to_connect_kwargs, + control_file_var, + here, + shjoin, +) batch_template = """#!/bin/bash #SBATCH --output={output_file} @@ -107,11 +115,17 @@ def __init__( connection: fabric.Connection | None = None, transforms: Sequence[Callable[[str], str]] = (), keepalive: int = 60, + connect_kwargs: dict[str, str] | None = None, ): self.hostname = hostname try: if connection is None: - connection = Connection(hostname) + _connect_kwargs = cluster_to_connect_kwargs.get(hostname) + if connect_kwargs is not None: + _connect_kwargs = (_connect_kwargs or {}).copy() + _connect_kwargs.update(connect_kwargs) + + connection = Connection(hostname, connect_kwargs=_connect_kwargs) if keepalive: connection.open() # NOTE: this transport gets mocked in tests, so we use a "soft" @@ -379,7 +393,7 @@ def puttext(self, text: str, dest: str) -> None: self.put(f.name, dest) def home(self) -> str: - return self.get_output("echo $HOME", hide=True) + return self.simple_run("echo $HOME").stdout.strip() def persist(self): # TODO: I don't really understand why this is here. @@ -427,7 +441,7 @@ class SlurmRemote(Remote): def __init__( self, connection: fabric.Connection, - alloc: Sequence[str], + alloc: list[str], transforms: Sequence[Callable[[str], str]] = (), persist: bool = False, hostname: str = "->", @@ -448,18 +462,23 @@ def srun_transform(self, cmd: str) -> str: def srun_transform_persist(self, cmd: str) -> str: tag = time.time_ns() - batch_file = f".milatools/batch/batch-{tag}.sh" - output_file = f".milatools/batch/out-{tag}.txt" + home = self.home() + + batch_file = str(Path(home) / f".milatools/batch/batch-{tag}.sh") + output_file = str(Path(home) / f".milatools/batch/out-{tag}.txt") batch = batch_template.format( command=cmd, output_file=output_file, control_file=control_file_var.get(), ) - # NOTE: We need to move to $SCRATCH before we run `sbatch`. - self.puttext(batch, batch_file) + self.puttext(text=batch, dest=batch_file) self.simple_run(f"chmod +x {batch_file}") - cmd = shjoin(["sbatch", *self.alloc, f"~/{batch_file}"]) - return f"cd $SCRATCH && {cmd}; touch {output_file}; tail -n +1 -f {output_file}" + cmd = shjoin(["sbatch", *self.alloc, batch_file]) + + # NOTE: We need to cd to $SCRATCH before we run `sbatch` on DRAC clusters. + if self.connection.host in DRAC_CLUSTERS: + cmd = f"cd $SCRATCH && {cmd}" + return f"{cmd}; touch {output_file}; tail -n +1 -f {output_file}" def with_transforms( self, *transforms: Callable[[str], str], persist: bool | None = None @@ -507,11 +526,13 @@ def ensure_allocation( }, login_node_runner else: remote = Remote(hostname=self.hostname, connection=self.connection) + command = shjoin(["salloc", *self.alloc]) # NOTE: On some DRAC clusters, it's required to first cd to $SCRATCH or - # /projects - # before submitting a job. + # /projects before submitting a job. + if self.connection.host in DRAC_CLUSTERS: + command = f"cd $SCRATCH && {command}" proc, results = remote.extract( - "cd $SCRATCH && " + shjoin(["salloc", *self.alloc]), + command, patterns={"node_name": "salloc: Nodes ([^ ]+) are ready for job"}, ) diff --git a/milatools/cli/utils.py b/milatools/cli/utils.py index 4df3140b..fab395eb 100644 --- a/milatools/cli/utils.py +++ b/milatools/cli/utils.py @@ -63,6 +63,17 @@ ) DRAC_CLUSTERS: list[Cluster] = [c for c in CLUSTERS if c != "mila"] +cluster_to_connect_kwargs: dict[str, dict[str, Any]] = { + "mila": { + "banner_timeout": 60, + } +} +"""The `connect_kwargs` dict to be passed to `fabric.Connection` for each cluster. + +NOTE: These are passed down to `paramiko.SSHClient.connect`. See that method for all +the possible values. +""" + def no_internet_on_compute_nodes( cluster: Cluster, @@ -149,8 +160,12 @@ def __str__(self): + "\n\t" + "-Retry connecting with mila" + "\n\t" - + f"-Try to exclude the node with -x {self.node_hostname} " - "parameter" + + f"-Try to exclude the node with -x {self.node_hostname} parameter\n" + + "\n" + + "If you reach out for help, you might want to also include this detailed error message:\n" + + "\n```\n" + + str(self.error) + + "\n```\n" ) @@ -290,3 +305,8 @@ def make_process( # Tiny wrapper around the `multiprocessing.Process` init to detect if the args and # kwargs don't match the target signature using typing instead of at runtime. return multiprocessing.Process(target=target, daemon=True, args=args, kwargs=kwargs) + + +def currently_in_a_test() -> bool: + """Returns True during unit tests (pytest) and False during normal execution.""" + return "pytest" in sys.modules diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/cli/common.py b/tests/cli/common.py index 458544b7..d220924b 100644 --- a/tests/cli/common.py +++ b/tests/cli/common.py @@ -47,14 +47,6 @@ ) -REQUIRES_S_FLAG_REASON = ( - "Seems to require reading from stdin? Works with the -s flag, but other " - "tests might not." -) -requires_s_flag = pytest.mark.skipif( - "-s" not in sys.argv, - reason=REQUIRES_S_FLAG_REASON, -) requires_no_s_flag = pytest.mark.skipif( "-s" in sys.argv, reason="Passing pytest's -s flag makes this test fail.", diff --git a/tests/cli/test_commands.py b/tests/cli/test_commands.py index 3b35c085..98e3a40a 100644 --- a/tests/cli/test_commands.py +++ b/tests/cli/test_commands.py @@ -1,11 +1,14 @@ +from __future__ import annotations + import contextlib import io import shlex +import textwrap import pytest from pytest_regressions.file_regression import FileRegressionFixture -from milatools.cli.commands import main +from milatools.cli.commands import _parse_lfs_quota_output, main from .common import requires_no_s_flag @@ -81,19 +84,6 @@ def test_invalid_command_output( file_regression.check(_convert_argparse_output_to_pre_py311_format(buf.getvalue())) -# TODO: Perhaps we could use something like this so we can run all tests locally, but -# skip the ones that need to actually connect to the cluster when running on GitHub -# Actions. -# def dont_run_on_github(*args): -# return pytest.param( -# *args, -# marks=pytest.mark.skipif( -# "GITHUB_ACTIONS" in os.environ, -# reason="We don't run this test on GitHub Actions.", -# ), -# ) - - @pytest.mark.parametrize( "command", ["mila docs conda", "mila intranet", "mila intranet idt"] ) @@ -111,3 +101,53 @@ def test_check_command_output( main() output: str = buf.getvalue() file_regression.check(_convert_argparse_output_to_pre_py311_format(output)) + + +used_kbytes = 95764232 +limit_kbytes = 104857600 +used_files = 908504 +limit_files = 1048576 + + +def _kb_to_gb(kb: int) -> float: + return kb / (1024**2) + + +@pytest.mark.parametrize( + ("output", "expected"), + [ + ( + textwrap.dedent( + f"""\ + Disk quotas for usr normandf (uid 1471600598): + Filesystem kbytes quota limit grace files quota limit grace + /home/mila/n/normandf + {used_kbytes} 0 {limit_kbytes} - {used_files} 0 {limit_files} - + uid 1471600598 is using default block quota setting + uid 1471600598 is using default file quota setting + """ + ), + ( + (_kb_to_gb(used_kbytes), _kb_to_gb(limit_kbytes)), + (used_files, limit_files), + ), + ), + ( + textwrap.dedent( + f"""\ + Disk quotas for usr normandf (uid 3098083): + Filesystem kbytes quota limit grace files quota limit grace + /home/normandf {used_kbytes} {limit_kbytes} {limit_kbytes} - {used_files} {limit_files} {limit_files} - + """ + ), + ( + (_kb_to_gb(used_kbytes), _kb_to_gb(limit_kbytes)), + (used_files, limit_files), + ), + ), + ], +) +def test_parse_lfs_quota_output( + output, expected: tuple[tuple[float, float], tuple[int, int]] +): + assert _parse_lfs_quota_output(output) == expected diff --git a/tests/cli/test_init_command.py b/tests/cli/test_init_command.py index 7724edde..594db06e 100644 --- a/tests/cli/test_init_command.py +++ b/tests/cli/test_init_command.py @@ -36,7 +36,12 @@ setup_windows_ssh_config_from_wsl, ) from milatools.cli.local import Local, check_passwordless -from milatools.cli.utils import SSHConfig, T, running_inside_WSL +from milatools.cli.utils import ( + SSHConfig, + T, + cluster_to_connect_kwargs, + running_inside_WSL, +) from .common import ( in_github_CI, @@ -1313,7 +1318,7 @@ def cluster(request: pytest.FixtureRequest) -> str: def authorized_keys_backup(cluster: str): """Fixture used to backup the authorized_keys file on the remote and restore it after tests.""" - connect_kwargs = {} + connect_kwargs = cluster_to_connect_kwargs.get(cluster, {}) backup_authorized_keys_path = "~/.ssh/authorized_keys.backup" if not check_passwordless(cluster): if in_github_CI: @@ -1360,7 +1365,9 @@ def test_setup_passwordless_ssh_access_to_real_cluster( f"Temporarily removing the ~/.ssh/authorized_keys file on {cluster} " f"(backed up at {cluster}:{authorized_keys_backup})" ) - fabric.Connection(cluster).run( + fabric.Connection( + cluster, connect_kwargs=cluster_to_connect_kwargs.get(cluster) + ).run( "rm ~/.ssh/authorized_keys", echo=True, echo_format=T.bold_cyan(f"({cluster})" + " $ {command}"), diff --git a/tests/cli/test_local.py b/tests/cli/test_local.py index 783f74de..b598d11c 100644 --- a/tests/cli/test_local.py +++ b/tests/cli/test_local.py @@ -115,6 +115,8 @@ def test_popen( [ ("localhost", passwordless_ssh_connection_to_localhost_is_setup), ("blablabob@localhost", False), + skip_param_if_on_github_ci("mila", True), + skip_param_if_on_github_ci("bobobobobobo@mila", False), skip_param_if_on_github_ci("narval", True), skip_param_if_on_github_ci("blablabob@narval", False), skip_param_if_on_github_ci("beluga", True), @@ -124,4 +126,6 @@ def test_popen( ], ) def test_check_passwordless(hostname: str, expected: bool): + # TODO: Maybe also test how `check_passwordless` behaves when using a key with a + # passphrase. assert check_passwordless(hostname) == expected diff --git a/tests/cli/test_remote.py b/tests/cli/test_remote.py index 9c2cf0df..ccad5b3b 100644 --- a/tests/cli/test_remote.py +++ b/tests/cli/test_remote.py @@ -2,7 +2,6 @@ from __future__ import annotations import shutil -import sys import time import typing import unittest @@ -22,9 +21,9 @@ SlurmRemote, get_first_node_name, ) -from milatools.cli.utils import T, shjoin +from milatools.cli.utils import T, cluster_to_connect_kwargs, shjoin -from .common import function_call_string, requires_s_flag +from .common import function_call_string @pytest.mark.parametrize("keepalive", [0, 123]) @@ -41,7 +40,9 @@ def test_init( r = Remote(host, keepalive=keepalive) # The Remote should have created a Connection instance (which happens to be # the mock_connection we made above). - MockConnection.assert_called_once_with(host) + MockConnection.assert_called_once_with( + host, connect_kwargs=cluster_to_connect_kwargs.get(host) + ) assert r.connection is mock_connection # The connection's Transport is opened if a non-zero value is passed for `keepalive` @@ -137,7 +138,6 @@ def test_remote_transform_methods( mock_connection.run.assert_called_once() transformed_command = mock_connection.run.mock_calls[0][1][0] - # "#Connection({mock_connection.host!r}), regression_file_text = f"""\ After creating a Remote like so: @@ -412,14 +412,8 @@ def test_home(remote: Remote): assert home_dir == str(Path.home()) -@requires_s_flag -def test_persist(remote: Remote, capsys: pytest.CaptureFixture): - _persisted_remote = remote.persist() - assert ( - "Warning: --persist does not work with --node or --job" - in capsys.readouterr().out - ) - assert _persisted_remote is remote +def test_persist(remote: Remote): + assert remote.persist() is remote def test_ensure_allocation(remote: Remote): @@ -504,24 +498,26 @@ def test_srun_transform_persist( f"remote.{method_call_string}", "```", "", - "created the following files:", + "created the following files (with abs path to the home directory " + "replaced with '$HOME' for tests):", "\n".join( "\n\n".join( [ f"- {str(new_file).replace(str(Path.home()), '~')}:", "", "```", - new_file.read_text(), + new_file.read_text().replace(str(Path.home()), "$HOME"), "```", ] ) for new_file in new_files ), "", - "and produced the following command as output:", + "and produced the following command as output (with the absolute " + "path to the home directory replaced with '$HOME' for tests):", "", "```bash", - output_command, + output_command.replace(str(Path.home()), "$HOME"), "```", "", ] @@ -623,7 +619,11 @@ def test_ensure_allocation_without_persist(self, mock_connection: Connection): alloc = ["--time=00:01:00"] remote = SlurmRemote(mock_connection, alloc=alloc, transforms=(), persist=False) node = "bob-123" - expected_command = f"cd $SCRATCH && salloc {shjoin(alloc)}" + expected_command = ( + f"cd $SCRATCH && salloc {shjoin(alloc)}" + if mock_connection.host == "mila" + else f"salloc {shjoin(alloc)}" + ) def write_stuff( command: str, diff --git a/tests/cli/test_remote/test_srun_transform_persist_localhost_.md b/tests/cli/test_remote/test_srun_transform_persist_localhost_.md index 1475fe6e..a4445240 100644 --- a/tests/cli/test_remote/test_srun_transform_persist_localhost_.md +++ b/tests/cli/test_remote/test_srun_transform_persist_localhost_.md @@ -14,7 +14,7 @@ Calling this: remote.srun_transform_persist('bob') ``` -created the following files: +created the following files (with abs path to the home directory replaced with '$HOME' for tests): - ~/.milatools/batch/batch-1234567890.sh: @@ -22,7 +22,7 @@ created the following files: ``` #!/bin/bash -#SBATCH --output=.milatools/batch/out-1234567890.txt +#SBATCH --output=$HOME/.milatools/batch/out-1234567890.txt #SBATCH --ntasks=1 echo jobid = $SLURM_JOB_ID >> /dev/null @@ -32,8 +32,8 @@ bob ``` -and produced the following command as output: +and produced the following command as output (with the absolute path to the home directory replaced with '$HOME' for tests): ```bash -cd $SCRATCH && sbatch --time=00:01:00 '~/.milatools/batch/batch-1234567890.sh'; touch .milatools/batch/out-1234567890.txt; tail -n +1 -f .milatools/batch/out-1234567890.txt +sbatch --time=00:01:00 $HOME/.milatools/batch/batch-1234567890.sh; touch $HOME/.milatools/batch/out-1234567890.txt; tail -n +1 -f $HOME/.milatools/batch/out-1234567890.txt ``` diff --git a/tests/cli/test_utils.py b/tests/cli/test_utils.py index c36dba01..ebfcf0af 100644 --- a/tests/cli/test_utils.py +++ b/tests/cli/test_utils.py @@ -41,7 +41,7 @@ def test_hostname(): @pytest.mark.parametrize( - ("cluster", "node", "expected"), + ("cluster_name", "node", "expected"), [ ("mila", "cn-a001", "cn-a001.server.mila.quebec"), ] @@ -78,10 +78,12 @@ def test_hostname(): # ], ) def test_get_fully_qualified_hostname_of_compute_node( - cluster: str, node: str, expected: str + cluster_name: str, node: str, expected: str ): assert ( - get_fully_qualified_hostname_of_compute_node(node_name=node, cluster=cluster) + get_fully_qualified_hostname_of_compute_node( + node_name=node, cluster=cluster_name + ) == expected ) diff --git a/tests/cli/conftest.py b/tests/conftest.py similarity index 66% rename from tests/cli/conftest.py rename to tests/conftest.py index d7c411a4..0312524c 100644 --- a/tests/cli/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,5 @@ from __future__ import annotations -import os from collections.abc import Generator from unittest.mock import Mock @@ -9,29 +8,8 @@ from fabric.connection import Connection from milatools.cli.remote import Remote - -from .common import REQUIRES_S_FLAG_REASON - -in_github_ci = "PLATFORM" in os.environ - - -@pytest.fixture(autouse=in_github_ci) -def skip_if_s_flag_passed_during_ci_run_and_test_doesnt_require_it( - request: pytest.FixtureRequest, pytestconfig: pytest.Config -): - capture_value = pytestconfig.getoption("-s") - assert capture_value in ["no", "fd"] - s_flag_set = capture_value == "no" - test_requires_s_flag = any( - mark.name == "skipif" - and mark.kwargs.get("reason", "") == REQUIRES_S_FLAG_REASON - for mark in request.node.iter_markers() - ) - if s_flag_set and not test_requires_s_flag: - # NOTE: WE only run the tests that require -s when -s is passed, because - # otherwise we get very weird errors related to closed file descriptors! - pytest.skip(reason="Running with the -s flag and this test doesn't require it.") - +from milatools.cli.utils import cluster_to_connect_kwargs +from tests.integration.conftest import SLURM_CLUSTER passwordless_ssh_connection_to_localhost_is_setup = False @@ -114,11 +92,6 @@ def mock_connection( This Mock is used to check how the connection is used by `Remote` and `SlurmRemote`. """ mock_connection: Mock = MockConnection.return_value - # mock_connection.configure_mock( - # # Modify the repr so they show up nicely in the regression files and with - # # consistent/reproducible names. - # __repr__=lambda _: f"Connection({repr(host)})", - # ) return mock_connection @@ -126,3 +99,43 @@ def mock_connection( def remote(mock_connection: Connection): assert isinstance(mock_connection.host, str) return Remote(hostname=mock_connection.host, connection=mock_connection) + + +@pytest.fixture(scope="function") +def login_node(cluster: str) -> Remote: + """Fixture that gives a Remote connected to the login node of a slurm cluster. + + NOTE: Making this a function-scoped fixture because the Connection object of the + Remote seems to be passed (and reused?) when creating the `SlurmRemote` object. + + We want to avoid that, because `SlurmRemote` creates jobs when it runs commands. + We also don't want to accidentally end up with `login_node` that runs commands on + compute nodes because a previous test kept the same connection object while doing + salloc (just in case that were to happen). + """ + + return Remote( + cluster, + connection=Connection( + cluster, connect_kwargs=cluster_to_connect_kwargs.get(cluster) + ), + ) + + +@pytest.fixture(scope="session", params=[SLURM_CLUSTER]) +def cluster(request: pytest.FixtureRequest) -> str: + """Fixture that gives the hostname of the slurm cluster to use for tests. + + NOTE: The `cluster` can also be parametrized indirectly by tests, for example: + + ```python + @pytest.mark.parametrize("cluster", ["mila", "some_cluster"], indirect=True) + def test_something(remote: Remote): + ... # here the remote is connected to the cluster specified above! + ``` + """ + slurm_cluster_hostname = request.param + + if not slurm_cluster_hostname: + pytest.skip("Requires ssh access to a SLURM cluster.") + return slurm_cluster_hostname diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py new file mode 100644 index 00000000..c66ed709 --- /dev/null +++ b/tests/integration/conftest.py @@ -0,0 +1,112 @@ +from __future__ import annotations + +import datetime +import functools +import os +import time +from logging import getLogger as get_logger + +import pytest + +from milatools.cli.remote import Remote +from tests.cli.common import in_github_CI + +logger = get_logger(__name__) +JOB_NAME = "milatools_test" +WCKEY = "milatools_test" + +SLURM_CLUSTER = os.environ.get("SLURM_CLUSTER", "mila" if not in_github_CI else None) +"""The name of the slurm cluster to use for tests. + +When running the tests on a dev machine, this defaults to the Mila cluster. Set to +`None` when running on the github CI. +""" + +MAX_JOB_DURATION = datetime.timedelta(seconds=10) + +hangs_in_github_CI = pytest.mark.skipif( + SLURM_CLUSTER == "localhost", + reason=( + "TODO: Hangs in the GitHub CI (probably because it runs salloc or sbatch on a " + "cluster with only `localhost` as a 'compute' node?)" + ), +) + + +@pytest.fixture(scope="session", autouse=True) +def cancel_all_milatools_jobs_before_and_after_tests(cluster: str): + # Note: need to recreate this because login_node is a function-scoped fixture. + login_node = Remote(cluster) + logger.info( + f"Cancelling milatools test jobs on {cluster} before running integration tests." + ) + login_node.run(f"scancel -u $USER --wckey={WCKEY}") + time.sleep(1) + yield + logger.info( + f"Cancelling milatools test jobs on {cluster} after running integration tests." + ) + login_node.run(f"scancel -u $USER --wckey={WCKEY}") + time.sleep(1) + # Display the output of squeue just to be sure that the jobs were cancelled. + logger.info(f"Checking that all jobs have been cancelked on {cluster}...") + login_node._run("squeue --me", echo=True, in_stream=False) + + +@functools.lru_cache +def get_slurm_account(cluster: str) -> str: + """Gets the SLURM account of the user using sacctmgr on the slurm cluster. + + When there are multiple accounts, this selects the first account, alphabetically. + + On DRAC cluster, this uses the `def` allocations instead of `rrg`, and when + the rest of the accounts are the same up to a '_cpu' or '_gpu' suffix, it uses + '_cpu'. + + For example: + + ```text + def-someprofessor_cpu <-- this one is used. + def-someprofessor_gpu + rrg-someprofessor_cpu + rrg-someprofessor_gpu + ``` + """ + logger.info( + f"Fetching the list of SLURM accounts available on the {cluster} cluster." + ) + result = Remote(cluster).run( + "sacctmgr --noheader show associations where user=$USER format=Account%50" + ) + accounts = [line.strip() for line in result.stdout.splitlines()] + assert accounts + logger.info(f"Accounts on the slurm cluster {cluster}: {accounts}") + account = sorted(accounts)[0] + logger.info(f"Using account {account} to launch jobs in tests.") + return account + + +@pytest.fixture() +def allocation_flags(cluster: str, request: pytest.FixtureRequest) -> list[str]: + # note: thanks to lru_cache, this is only making one ssh connection per cluster. + account = get_slurm_account(cluster) + allocation_options = { + "job-name": JOB_NAME, + "wckey": WCKEY, + "account": account, + "nodes": 1, + "ntasks": 1, + "cpus-per-task": 1, + "mem": "1G", + "time": MAX_JOB_DURATION, + "oversubscribe": None, # allow multiple such jobs to share resources. + } + overrides = getattr(request, "param", {}) + assert isinstance(overrides, dict) + if overrides: + print(f"Overriding allocation options with {overrides}") + allocation_options.update(overrides) + return [ + f"--{key}={value}" if value is not None else f"--{key}" + for key, value in allocation_options.items() + ] diff --git a/tests/integration/test_code_command.py b/tests/integration/test_code_command.py new file mode 100644 index 00000000..1fa039d4 --- /dev/null +++ b/tests/integration/test_code_command.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +import logging +import re +import time +from datetime import timedelta +from logging import getLogger as get_logger + +import pytest + +from milatools.cli.commands import check_disk_quota, code +from milatools.cli.remote import Remote +from milatools.cli.utils import get_fully_qualified_hostname_of_compute_node + +from ..cli.common import in_github_CI, skip_param_if_on_github_ci +from .conftest import SLURM_CLUSTER, hangs_in_github_CI +from .test_slurm_remote import get_recent_jobs_info_dicts + +logger = get_logger(__name__) + + +@pytest.mark.parametrize( + "cluster", + [ + skip_param_if_on_github_ci("mila"), + skip_param_if_on_github_ci("narval"), + skip_param_if_on_github_ci("beluga"), + skip_param_if_on_github_ci("cedar"), + skip_param_if_on_github_ci("graham"), + skip_param_if_on_github_ci("niagara"), + ], + indirect=True, +) +def test_check_disk_quota( + login_node: Remote, + capsys: pytest.LogCaptureFixture, + caplog: pytest.LogCaptureFixture, +): # noqa: F811 + with caplog.at_level(logging.DEBUG): + check_disk_quota(remote=login_node) + # Check that it doesn't raise any errors. + # IF the quota is nearly met, then a warning is logged. + # IF the quota is met, then a `MilatoolsUserError` is logged. + + +@pytest.mark.parametrize( + "cluster", + [ + pytest.param( + "localhost", + marks=[ + pytest.mark.skipif( + not (in_github_CI and SLURM_CLUSTER == "localhost"), + reason=( + "Only runs in the GitHub CI when localhost is a slurm cluster." + ), + ), + # todo: remove this mark once we're able to do sbatch and salloc in the + # GitHub CI. + hangs_in_github_CI, + ], + ), + skip_param_if_on_github_ci("mila"), + skip_param_if_on_github_ci("narval"), + skip_param_if_on_github_ci("beluga"), + skip_param_if_on_github_ci("cedar"), + skip_param_if_on_github_ci("graham"), + skip_param_if_on_github_ci("niagara"), + ], + indirect=True, +) +@pytest.mark.parametrize("persist", [True, False]) +def test_code( + login_node: Remote, + persist: bool, + capsys: pytest.CaptureFixture, + allocation_flags: list[str], +): + home = login_node.home() + scratch = login_node.get_output("echo $SCRATCH") + relative_path = "bob" + code( + path=relative_path, + command="echo", # replace the usual `code` with `echo` for testing. + persist=persist, + job=None, + node=None, + alloc=allocation_flags, + cluster=login_node.hostname, # type: ignore + ) + + # Get the output that was printed while running that command. + # We expect our fake vscode command (with 'code' replaced with 'echo') to have been + # executed. + captured_output: str = capsys.readouterr().out + + # Get the job id from the output just so we can more easily check the command output + # with sacct below. + if persist: + m = re.search(r"Submitted batch job ([0-9]+)", captured_output) + assert m + job_id = int(m.groups()[0]) + else: + m = re.search(r"salloc: Granted job allocation ([0-9]+)", captured_output) + assert m + job_id = int(m.groups()[0]) + + time.sleep(5) # give a chance to sacct to update. + recent_jobs = get_recent_jobs_info_dicts( + since=timedelta(minutes=5), + login_node=login_node, + fields=("JobID", "JobName", "Node", "WorkDir", "State"), + ) + job_id_to_job_info = {int(job_info["JobID"]): job_info for job_info in recent_jobs} + assert job_id in job_id_to_job_info, (job_id, job_id_to_job_info) + job_info = job_id_to_job_info[job_id] + + node = job_info["Node"] + node_hostname = get_fully_qualified_hostname_of_compute_node( + node, cluster=login_node.hostname + ) + expected_line = f"(local) $ /usr/bin/echo -nw --remote ssh-remote+{node_hostname} {home}/{relative_path}" + assert any((expected_line in line) for line in captured_output.splitlines()), ( + captured_output, + expected_line, + ) + + # Check that on the DRAC clusters, the workdir is the scratch directory (because we + # cd'ed to $SCRATCH before submitting the job) + workdir = job_info["WorkDir"] + if login_node.hostname == "mila": + assert workdir == home + else: + assert workdir == scratch + + if persist: + # Job should still be running since we're using `persist` (that's the whole + # point.) + assert job_info["State"] == "RUNNING" + else: + # Job should have been cancelled by us after the `echo` process finished. + # NOTE: This check is a bit flaky, perhaps our `scancel` command hasn't + # completed yet, or sacct doesn't show the change in status quick enough. + # Relaxing it a bit for now. + # assert "CANCELLED" in job_info["State"] + assert "CANCELLED" in job_info["State"] or job_info["State"] == "RUNNING" diff --git a/tests/cli/test_slurm_remote.py b/tests/integration/test_slurm_remote.py similarity index 69% rename from tests/cli/test_slurm_remote.py rename to tests/integration/test_slurm_remote.py index abb55977..761230de 100644 --- a/tests/cli/test_slurm_remote.py +++ b/tests/integration/test_slurm_remote.py @@ -7,8 +7,6 @@ from __future__ import annotations import datetime -import functools -import os import time from logging import getLogger as get_logger @@ -16,13 +14,12 @@ import pytest from milatools.cli.remote import Remote, SlurmRemote +from milatools.cli.utils import CLUSTERS + +from .conftest import JOB_NAME, MAX_JOB_DURATION, SLURM_CLUSTER, hangs_in_github_CI logger = get_logger(__name__) -SLURM_CLUSTER = os.environ.get("SLURM_CLUSTER") -JOB_NAME = "milatools_test" -WCKEY = "milatools_test" -MAX_JOB_DURATION = datetime.timedelta(seconds=10) # BUG: pytest-timeout seems to cause issues with paramiko threads.. # pytestmark = pytest.mark.timeout(60) @@ -35,28 +32,6 @@ reason="Requires ssh access to a SLURM cluster.", ) -# TODO: Import the value from `milatools.cli.utils` once the other PR adds it. -CLUSTERS = ["mila", "narval", "cedar", "beluga", "graham"] - - -@pytest.fixture(scope="session", params=[SLURM_CLUSTER]) -def cluster(request: pytest.FixtureRequest) -> str: - """Fixture that gives the hostname of the slurm cluster to use for tests. - - NOTE: The `cluster` can also be parametrized indirectly by tests, for example: - - ```python - @pytest.mark.parametrize("cluster", ["mila", "some_cluster"], indirect=True) - def test_something(remote: Remote): - ... # here the remote is connected to the cluster specified above! - ``` - """ - slurm_cluster_hostname = request.param - - if not slurm_cluster_hostname: - pytest.skip("Requires ssh access to a SLURM cluster.") - return slurm_cluster_hostname - def can_run_on_all_clusters(): """Makes a given test run on all the clusters in `CLUSTERS`, *for real*! @@ -67,63 +42,15 @@ def can_run_on_all_clusters(): return pytest.mark.parametrize("cluster", CLUSTERS, indirect=True) -@pytest.fixture() -def login_node(cluster: str) -> Remote: - """Fixture that gives a Remote connected to the login node of the slurm cluster. - - NOTE: Making this a function-scoped fixture because the Connection object is of the - Remote is used when creating the SlurmRemotes. - """ - return Remote(cluster) - - -@pytest.fixture(scope="module", autouse=True) -def cancel_all_milatools_jobs_before_and_after_tests(cluster: str): - # Note: need to recreate this because login_node is a function-scoped fixture. - login_node = Remote(cluster) - login_node.run(f"scancel -u $USER --wckey={WCKEY}") - time.sleep(1) - yield - login_node.run(f"scancel -u $USER --wckey={WCKEY}") - time.sleep(1) - # Display the output of squeue just to be sure that the jobs were cancelled. - login_node._run("squeue --me", echo=True, in_stream=False) - - -@functools.lru_cache -def get_slurm_account(cluster: str) -> str: - """Gets the SLURM account of the user using sacctmgr on the slurm cluster. - - When there are multiple accounts, this selects the first account, alphabetically. - - On DRAC cluster, this uses the `def` allocations instead of `rrg`, and when - the rest of the accounts are the same up to a '_cpu' or '_gpu' suffix, it uses - '_cpu'. - - For example: - - ```text - def-someprofessor_cpu <-- this one is used. - def-someprofessor_gpu - rrg-someprofessor_cpu - rrg-someprofessor_gpu - ``` - """ - # note: recreating the Connection here because this will be called for every test - # and we use functools.cache to cache the result, so the input has to be a simpler - # value like a string. - result = fabric.Connection(cluster).run( - "sacctmgr --noheader show associations where user=$USER format=Account%50", - echo=True, - in_stream=False, - ) - assert isinstance(result, fabric.runners.Result) - accounts: list[str] = [line.strip() for line in result.stdout.splitlines()] - assert accounts - logger.info(f"Accounts on the slurm cluster {cluster}: {accounts}") - account = sorted(accounts)[0] - logger.info(f"Using account {account} to launch jobs in tests.") - return account +def get_recent_jobs_info_dicts( + login_node: Remote, + since=datetime.timedelta(minutes=5), + fields=("JobID", "JobName", "Node", "State"), +) -> list[dict[str, str]]: + return [ + dict(zip(fields, line)) + for line in get_recent_jobs_info(login_node, since=since, fields=fields) + ] def get_recent_jobs_info( @@ -150,36 +77,8 @@ def sleep_so_sacct_can_update(): time.sleep(_SACCT_UPDATE_DELAY.total_seconds()) -@pytest.fixture() -def allocation_flags(cluster: str, request: pytest.FixtureRequest): - # note: thanks to lru_cache, this is only making one ssh connection per cluster. - account = get_slurm_account(cluster) - allocation_options = { - "job-name": JOB_NAME, - "wckey": WCKEY, - "account": account, - "nodes": 1, - "ntasks": 1, - "cpus-per-task": 1, - "mem": "1G", - "time": MAX_JOB_DURATION, - "oversubscribe": None, # allow multiple such jobs to share resources. - } - overrides = getattr(request, "param", {}) - assert isinstance(overrides, dict) - if overrides: - print(f"Overriding allocation options with {overrides}") - allocation_options.update(overrides) - return " ".join( - [ - f"--{key}={value}" if value is not None else f"--{key}" - for key, value in allocation_options.items() - ] - ) - - @requires_access_to_slurm_cluster -def test_cluster_setup(login_node: Remote, allocation_flags: str): +def test_cluster_setup(login_node: Remote, allocation_flags: list[str]): """Sanity Checks for the SLURM cluster of the CI: checks that `srun` works. NOTE: This is more-so a test to check that the slurm cluster used in the GitHub CI @@ -188,7 +87,7 @@ def test_cluster_setup(login_node: Remote, allocation_flags: str): job_id, compute_node = ( login_node.get_output( - f"srun {allocation_flags} bash -c 'echo $SLURM_JOB_ID $SLURMD_NODENAME'" + f"srun {' '.join(allocation_flags)} bash -c 'echo $SLURM_JOB_ID $SLURMD_NODENAME'" ) .strip() .split() @@ -205,7 +104,7 @@ def test_cluster_setup(login_node: Remote, allocation_flags: str): @pytest.fixture -def salloc_slurm_remote(login_node: Remote, allocation_flags: str): +def salloc_slurm_remote(login_node: Remote, allocation_flags: list[str]): """Fixture that creates a `SlurmRemote` that uses `salloc` (persist=False). The SlurmRemote is essentially just a Remote with an added `ensure_allocation` @@ -214,16 +113,16 @@ def salloc_slurm_remote(login_node: Remote, allocation_flags: str): """ return SlurmRemote( connection=login_node.connection, - alloc=allocation_flags.split(), + alloc=allocation_flags, ) @pytest.fixture -def sbatch_slurm_remote(login_node: Remote, allocation_flags: str): +def sbatch_slurm_remote(login_node: Remote, allocation_flags: list[str]): """Fixture that creates a `SlurmRemote` that uses `sbatch` (persist=True).""" return SlurmRemote( connection=login_node.connection, - alloc=allocation_flags.split(), + alloc=allocation_flags, persist=True, ) @@ -273,11 +172,6 @@ def test_run( assert (job_id, JOB_NAME, compute_node) in sacct_output -hangs_in_github_CI = pytest.mark.skipif( - SLURM_CLUSTER == "localhost", reason="BUG: Hangs in the GitHub CI.." -) - - @hangs_in_github_CI @requires_access_to_slurm_cluster def test_ensure_allocation(