diff --git a/milatools/cli/code.py b/milatools/cli/code.py new file mode 100644 index 00000000..2c9ef573 --- /dev/null +++ b/milatools/cli/code.py @@ -0,0 +1,213 @@ +from __future__ import annotations + +import asyncio +import shutil +from logging import getLogger as get_logger +from pathlib import PurePosixPath +from typing import Awaitable + +from milatools.cli import console +from milatools.cli.init_command import DRAC_CLUSTERS +from milatools.cli.utils import ( + CommandNotFoundError, + MilatoolsUserError, + currently_in_a_test, + internet_on_compute_nodes, +) +from milatools.utils.compute_node import ComputeNode, salloc, sbatch +from milatools.utils.disk_quota import check_disk_quota +from milatools.utils.local_v2 import LocalV2 +from milatools.utils.remote_v2 import RemoteV2 +from milatools.utils.vscode_utils import sync_vscode_extensions + +logger = get_logger(__name__) + + +async def code( + path: str, + command: str, + persist: bool, + job: int | None, + node: str | None, + alloc: list[str], + cluster: str = "mila", +) -> ComputeNode | int: + """Open a remote VSCode session on a compute node. + + Arguments: + path: Path to open on the remote machine + command: Command to use to start vscode (defaults to "code" or the value of \ + $MILATOOLS_CODE_COMMAND) + persist: Whether the server should persist or not after exiting the terminal. + job: ID of the job to connect to + node: Name of the node to connect to + alloc: Extra options to pass to slurm + """ + # Check that the `code` command is in the $PATH so that we can use just `code` as + # the command. + if not shutil.which(command): + raise CommandNotFoundError(command) + + if (job or node) and not persist: + logger.warning("Assuming persist=True since a job or node was specified.") + persist = True + + # Connect to the cluster's login node. + login_node = await RemoteV2.connect(cluster) + + relative_path: PurePosixPath | None = None + # Get $HOME because we have to give the full path to the folder to the code command. + home = PurePosixPath( + await login_node.get_output_async("echo $HOME", display=False, hide=True) + ) + if not path.startswith("/"): + relative_path = PurePosixPath(path) + path = str(home if path == "." else home / path) + elif (_path := PurePosixPath(path)).is_relative_to(home): + relative_path = _path.relative_to(home) + console.log( + f"Hint: you can use a path relative to your $HOME instead of an absolute path.\n" + f"For example, `mila code {path}` is the same as `mila code {relative_path}`.", + highlight=True, + markup=True, + ) + + try: + await check_disk_quota(login_node) + except MilatoolsUserError: + # Raise errors that are meant to be shown to the user (disk quota is reached). + raise + except Exception as exc: + logger.warning( + f"Unable to check the disk-quota on the {cluster} cluster: {exc}" + ) + + # NOTE: Perhaps we could eventually do this check dynamically, if the cluster is an + # unknown cluster? + sync_vscode_extensions_task = None + if not internet_on_compute_nodes(cluster): + # Sync the VsCode extensions from the local machine over to the target cluster. + console.log( + f"Installing VSCode extensions that are on the local machine on " + f"{cluster}.", + style="cyan", + ) + # todo: use the mila or the local machine as the reference for vscode + # extensions? + # TODO: If the remote is a cluster that doesn't yet have `vscode-server`, we + # could launch vscode at the same time (or before) syncing the vscode extensions? + sync_vscode_extensions_task = sync_vscode_extensions( + LocalV2(), + [login_node], + ) + + compute_node_task: Awaitable[ComputeNode] + if job or node: + if job and node: + logger.warning( + "Both job ID and node name were specified. Ignoring the node name and " + "only using the job id." + ) + job_id_or_node = job or node + assert job_id_or_node is not None + compute_node_task = ComputeNode.connect( + login_node=login_node, job_id_or_node_name=job_id_or_node + ) + else: + if cluster in DRAC_CLUSTERS and not any("--account" in flag for flag in alloc): + logger.warning( + "Warning: When using the DRAC clusters, you usually need to " + "specify the account to use when submitting a job. You can specify " + "this in the job resources with `--alloc`, like so: " + "`--alloc --account=`, for example:\n" + f"mila code some_path --cluster {cluster} --alloc " + f"--account=your-account-here" + ) + # Set the job name to `mila-code`. This should not be changed by the user + # ideally, so we can collect some simple stats about the use of `milatools` on + # the clusters. + if any(flag.split("=")[0] in ("-J", "--job-name") for flag in alloc): + raise MilatoolsUserError( + "The job name flag (--job-name or -J) should be left unset for now " + "because we use the job name to measure how many people use `mila " + "code` on the various clusters. We also make use of the job name when " + "the call to `salloc` is interrupted before we have a chance to know " + "the job id." + ) + job_name = "mila-code" + alloc = alloc + [f"--job-name={job_name}"] + + if persist: + compute_node_task = sbatch( + login_node, sbatch_flags=alloc, job_name=job_name + ) + else: + # NOTE: Here we actually need the job name to be known, so that we can + # scancel jobs if the call is interrupted. + compute_node_task = salloc( + login_node, salloc_flags=alloc, job_name=job_name + ) + + if sync_vscode_extensions_task is not None: + # Sync the vscode extensions at the same time as waiting for the job. + # Wait until all extensions are done syncing before launching vscode. + # If any of the tasks failed, we want to raise the exception. + # NOTE: Not using this at the moment because when interrupted, the job request + # isn't cancelled properly. + compute_node, _ = await asyncio.gather( + compute_node_task, + sync_vscode_extensions_task, + ) + else: + compute_node = await compute_node_task + + await launch_vscode_loop(command, compute_node, path) + + if not persist and not (job or node): + # Cancel the job if it was not persistent. + # (--job and --node are used to connect to persistent jobs) + await compute_node.close_async() + console.print(f"Ended session on '{compute_node.hostname}'") + return compute_node.job_id + + console.print("This allocation is persistent and is still active.") + console.print("To reconnect to this job, run the following:") + console.print( + f" mila code {relative_path or path} " + + (f"--cluster {cluster} " if cluster != "mila" else "") + + f"--job {compute_node.job_id}", + style="bold", + ) + console.print("To kill this allocation:") + console.print(f" ssh {cluster} scancel {compute_node.job_id}", style="bold") + return compute_node + + +async def launch_vscode_loop(code_command: str, compute_node: ComputeNode, path: str): + while True: + code_command_to_run = ( + code_command, + "--new-window", + "--wait", + "--remote", + f"ssh-remote+{compute_node.hostname}", + path, + ) + await LocalV2.run_async(code_command_to_run, display=True) + # TODO: BUG: This now requires two Ctrl+C's instead of one! + console.print( + "The editor was closed. Reopen it with or terminate the " + "process with (maybe twice)." + ) + if currently_in_a_test(): + # NOTE: This early exit kills the job when it is not persistent. + break + try: + input() + except KeyboardInterrupt: + break + except asyncio.CancelledError: + raise + except Exception as exc: + logger.error(f"Error while waiting for user input: {exc}") + break diff --git a/milatools/cli/commands.py b/milatools/cli/commands.py index 33296f81..d0873f2f 100644 --- a/milatools/cli/commands.py +++ b/milatools/cli/commands.py @@ -29,21 +29,11 @@ import questionary as qn import rich.logging -from typing_extensions import TypedDict +from typing_extensions import TypedDict, deprecated from milatools.cli import console -from milatools.utils.local_v1 import LocalV1 -from milatools.utils.remote_v1 import RemoteV1, SlurmRemote -from milatools.utils.remote_v2 import RemoteV2 -from milatools.utils.vscode_utils import ( - get_code_command, - # install_local_vscode_extensions_on_remote, - sync_vscode_extensions, - sync_vscode_extensions_with_hostnames, -) - -from ..__version__ import __version__ -from .init_command import ( +from milatools.cli.code import code +from milatools.cli.init_command import ( print_welcome_message, setup_keys_on_login_node, setup_passwordless_ssh_access, @@ -51,11 +41,10 @@ setup_vscode_settings, setup_windows_ssh_config_from_wsl, ) -from .profile import ensure_program, setup_profile -from .utils import ( +from milatools.cli.profile import ensure_program, setup_profile +from milatools.cli.utils import ( CLUSTERS, AllocationFlagsAction, - Cluster, CommandNotFoundError, MilatoolsUserError, SSHConnectionError, @@ -64,15 +53,24 @@ currently_in_a_test, get_fully_qualified_name, get_hostname_to_use_for_compute_node, - make_process, - no_internet_on_compute_nodes, randname, running_inside_WSL, with_control_file, ) +from milatools.utils.disk_quota import check_disk_quota_v1 +from milatools.utils.local_v1 import LocalV1 +from milatools.utils.local_v2 import LocalV2 +from milatools.utils.remote_v1 import ( + NodeNameAndJobidDict, + RemoteV1, + SlurmRemote, +) +from milatools.utils.vscode_utils import get_code_command, sync_vscode_extensions + +from ..__version__ import __version__ if typing.TYPE_CHECKING: - from typing_extensions import Unpack + from typing_extensions import Unpack # pragma: no cover logger = get_logger(__name__) @@ -89,6 +87,8 @@ def main(): try: mila() + except KeyboardInterrupt: + console.print("Exited by user.") except MilatoolsUserError as exc: # These are user errors and should not be reported print("ERROR:", exc, file=sys.stderr) @@ -132,9 +132,18 @@ def main(): def mila(): parser = ArgumentParser(prog="mila", description=__doc__, add_help=True) add_arguments(parser) + verbose, function, args_dict = parse_args(parser) setup_logging(verbose) - return function(**args_dict) + + if inspect.iscoroutinefunction(function): + try: + return asyncio.run(function(**args_dict)) + except KeyboardInterrupt: + console.log("Terminated by user.") + return + else: + return function(**args_dict) def add_arguments(parser: argparse.ArgumentParser): @@ -249,9 +258,13 @@ def add_arguments(parser: argparse.ArgumentParser): help="Node to connect to", metavar="NODE", ) + _add_allocation_options(code_parser) - code_parser.set_defaults(function=code) + if sys.platform == "win32": + code_parser.set_defaults(function=code_v1) + else: + code_parser.set_defaults(function=code) # ----- mila sync vscode-extensions ------ @@ -290,7 +303,7 @@ def add_arguments(parser: argparse.ArgumentParser): "extensions locally. Defaults to all the available SLURM clusters." ), ) - sync_vscode_parser.set_defaults(function=sync_vscode_extensions_with_hostnames) + sync_vscode_parser.set_defaults(function=sync_vscode_extensions) # ----- mila serve ------ @@ -440,13 +453,6 @@ def parse_args(parser: argparse.ArgumentParser) -> tuple[int, Callable, dict[str # replace SEARCH -> "search", REMOTE -> "remote", etc. args_dict = _convert_uppercase_keys_to_lowercase(args_dict) - if inspect.iscoroutinefunction(function): - try: - return asyncio.run(function(**args_dict)) - except KeyboardInterrupt: - console.log("Terminated by user.") - return - assert callable(function) return verbose, function, args_dict @@ -556,14 +562,20 @@ def forward( local_proc.kill() -def code( +@deprecated( + "Support for the `mila code` command is now deprecated on Windows machines, as it " + "does not support ssh keys with passphrases or clusters where 2FA is enabled. " + "Please consider switching to the Windows Subsystem for Linux (WSL) to run " + "`mila code`." +) +def code_v1( path: str, command: str, persist: bool, job: int | None, node: str | None, alloc: list[str], - cluster: Cluster = "mila", + cluster: str = "mila", ): """Open a remote VSCode session on a compute node. @@ -576,7 +588,13 @@ def code( node: Node to connect to alloc: Extra options to pass to slurm """ - here = LocalV1() + if command is None: + command = get_code_command() + command_path = shutil.which(command) + if not command_path: + raise CommandNotFoundError(command) + + here = LocalV2() remote = RemoteV1(cluster) if cluster != "mila" and job is None and node is None: @@ -590,46 +608,13 @@ def code( f"--account=your-account-here" ) - if command is None: - command = get_code_command() - try: - check_disk_quota(remote) + check_disk_quota_v1(remote) except MilatoolsUserError: raise except Exception as exc: logger.warning(f"Unable to check the disk-quota on the cluster: {exc}") - if sys.platform == "win32": - print( - "Syncing vscode extensions in the background isn't supported on " - "Windows. Skipping." - ) - elif no_internet_on_compute_nodes(cluster): - # Sync the VsCode extensions from the local machine over to the target cluster. - # TODO: Make this happen in the background (without overwriting the output). - run_in_the_background = False - print( - console.log( - f"[cyan]Installing VSCode extensions that are on the local machine on " - f"{cluster}" + (" in the background." if run_in_the_background else ".") - ) - ) - if run_in_the_background: - copy_vscode_extensions_process = make_process( - sync_vscode_extensions_with_hostnames, - # todo: use the mila cluster as the source for vscode extensions? Or - # `localhost`? - source="localhost", - destinations=[cluster], - ) - copy_vscode_extensions_process.start() - else: - sync_vscode_extensions( - LocalV1(), - [cluster], - ) - if node is None: cnode = _find_allocation( remote, @@ -648,51 +633,45 @@ def code( else: node_name = node proc = None - data = None + jobs_on_that_node = remote.get_output( + f"squeue --me --nodelist {node_name} -ho %A", display=True + ).splitlines() + if not jobs_on_that_node: + raise MilatoolsUserError( + f"No jobs are currently running on node {node_name}!" + ) + job_str = jobs_on_that_node[0] + job = int(job_str) + data: NodeNameAndJobidDict = {"node_name": node, "jobid": job_str} if not path.startswith("/"): # Get $HOME because we have to give the full path to code home = remote.home() path = home if path == "." else f"{home}/{path}" - command_path = shutil.which(command) - if not command_path: - raise CommandNotFoundError(command) - # NOTE: Since we have the config entries for the DRAC compute nodes, there is no # need to use the fully qualified hostname here. if cluster == "mila": node_name = get_hostname_to_use_for_compute_node(node_name) - # Try to detect if this is being run from within the Windows Subsystem for Linux. - # If so, then we run `code` through a powershell.exe command to open VSCode without - # issues. - inside_WSL = running_inside_WSL() + # Note: We can't possibly be running inside the WSL (otherwise code(v2) would be used). try: while True: - if inside_WSL: - here.run( - "powershell.exe", - "code", - "-nw", - "--remote", - f"ssh-remote+{node_name}", - path, - ) - else: - here.run( + here.run( + ( command_path, "-nw", "--remote", f"ssh-remote+{node_name}", path, - ) + ), + ) print( "The editor was closed. Reopen it with " " or terminate the process with " ) if currently_in_a_test(): - break + raise KeyboardInterrupt input() except KeyboardInterrupt: @@ -704,17 +683,22 @@ def code( if persist: print("This allocation is persistent and is still active.") print("To reconnect to this node:") - print( - T.bold( - f" mila code {path} " - + (f"--cluster={cluster} " if cluster != "mila" else "") - + f"--node {node_name}" - ) + console.print( + f" mila code {path} " + + (f"--cluster={cluster} " if cluster != "mila" else "") + + f"--node {node_name}", + style="bold", ) print("To kill this allocation:") assert data is not None - assert "jobid" in data - print(T.bold(f" ssh {cluster} scancel {data['jobid']}")) + if "jobid" in data: + console.print(f" ssh {cluster} scancel {data['jobid']}", style="bold") + else: + assert "node_name" in data + console.print( + f" ssh {cluster} scancel --me --nodelist {data['node_name']}", + style="bold", + ) def connect(identifier: str, port: int | None): @@ -1188,141 +1172,12 @@ def _standard_server( proc.kill() -def _parse_lfs_quota_output( - lfs_quota_output: str, -) -> tuple[tuple[float, float], tuple[int, int]]: - """Parses space and # of files (usage, limit) from the output of `lfs quota`.""" - lines = lfs_quota_output.splitlines() - - header_line: str | None = None - header_line_index: int | None = None - for index, line in enumerate(lines): - if ( - len(line_parts := line.strip().split()) == 9 - and line_parts[0].lower() == "filesystem" - ): - header_line = line - header_line_index = index - break - assert header_line - assert header_line_index is not None - - values_line_parts: list[str] = [] - # The next line may overflow to two (or maybe even more?) lines if the name of the - # $HOME dir is too long. - for content_line in lines[header_line_index + 1 :]: - additional_values = content_line.strip().split() - assert len(values_line_parts) < 9 - values_line_parts.extend(additional_values) - if len(values_line_parts) == 9: - break - - assert len(values_line_parts) == 9, values_line_parts - ( - _filesystem, - used_kbytes, - _quota_kbytes, - limit_kbytes, - _grace_kbytes, - files, - _quota_files, - limit_files, - _grace_files, - ) = values_line_parts - - used_gb = int(used_kbytes.strip()) / (1024**2) - max_gb = int(limit_kbytes.strip()) / (1024**2) - used_files = int(files.strip()) - max_files = int(limit_files.strip()) - return (used_gb, max_gb), (used_files, max_files) - - -def check_disk_quota(remote: RemoteV1 | RemoteV2) -> None: - cluster = remote.hostname - - # NOTE: This is what the output of the command looks like on the Mila cluster: - # - # Disk quotas for usr normandf (uid 1471600598): - # Filesystem kbytes quota limit grace files quota limit grace - # /home/mila/n/normandf - # 95747836 0 104857600 - 908722 0 1048576 - - # uid 1471600598 is using default block quota setting - # uid 1471600598 is using default file quota setting - - # Need to assert this, otherwise .get_output calls .run which would spawn a job! - assert not isinstance(remote, SlurmRemote) - if not remote.get_output("which lfs", hide=True): - logger.debug("Cluster doesn't have the lfs command. Skipping check.") - return - - console.log("Checking disk quota on $HOME...") - - home_disk_quota_output = remote.get_output("lfs quota -u $USER $HOME", hide=True) - if "not on a mounted Lustre filesystem" in home_disk_quota_output: - logger.debug("Cluster doesn't use lustre on $HOME filesystem. Skipping check.") - return - - (used_gb, max_gb), (used_files, max_files) = _parse_lfs_quota_output( - home_disk_quota_output - ) - - def get_colour(used: float, max: float) -> str: - return "red" if used >= max else "orange" if used / max > 0.7 else "green" - - disk_usage_style = get_colour(used_gb, max_gb) - num_files_style = get_colour(used_files, max_files) - from rich.text import Text - - console.log( - "Disk usage:", - Text(f"{used_gb:.2f} / {max_gb:.2f} GiB", style=disk_usage_style), - "and", - Text(f"{used_files} / {max_files} files", style=num_files_style), - markup=False, - ) - size_ratio = used_gb / max_gb - files_ratio = used_files / max_files - reason = ( - f"{used_gb:.1f} / {max_gb} GiB" - if size_ratio > files_ratio - else f"{used_files} / {max_files} files" - ) - - freeing_up_space_instructions = ( - "For example, temporary files (logs, checkpoints, etc.) can be moved to " - "$SCRATCH, while files that need to be stored for longer periods can be moved " - "to $ARCHIVE or to a shared project folder under /network/projects.\n" - "Visit https://docs.mila.quebec/Information.html#storage to learn more about " - "how to best make use of the different filesystems available on the cluster." - ) - - if used_gb >= max_gb or used_files >= max_files: - raise MilatoolsUserError( - T.red( - f"ERROR: Your disk quota on the $HOME filesystem is exceeded! " - f"({reason}).\n" - f"To fix this, login to the cluster with `ssh {cluster}` and free up " - f"some space, either by deleting files, or by moving them to a " - f"suitable filesystem.\n" + freeing_up_space_instructions - ) - ) - if max(size_ratio, files_ratio) > 0.9: - warning_message = ( - f"You are getting pretty close to your disk quota on the $HOME " - f"filesystem: ({reason})\n" - "Please consider freeing up some space in your $HOME folder, either by " - "deleting files, or by moving them to a more suitable filesystem.\n" - + freeing_up_space_instructions - ) - logger.warning(UserWarning(warning_message)) - - def _find_allocation( remote: RemoteV1, node: str | None, job: int | str | None, alloc: list[str], - cluster: Cluster = "mila", + cluster: str = "mila", job_name: str = "mila-tools", ): if (node is not None) + (job is not None) + bool(alloc) > 1: @@ -1336,8 +1191,9 @@ def _find_allocation( elif job is not None: node_name = remote.get_output(f"squeue --jobs {job} -ho %N") + node_hostname = get_hostname_to_use_for_compute_node(node_name, cluster=cluster) return RemoteV1( - node_name, connect_kwargs=cluster_to_connect_kwargs.get(cluster) + node_hostname, connect_kwargs=cluster_to_connect_kwargs.get(cluster) ) else: diff --git a/milatools/cli/utils.py b/milatools/cli/utils.py index 0fb8d14e..715598ba 100644 --- a/milatools/cli/utils.py +++ b/milatools/cli/utils.py @@ -86,17 +86,15 @@ def currently_in_a_test() -> bool: return "pytest" in sys.modules -def no_internet_on_compute_nodes( - cluster: Cluster, -) -> TypeGuard[ClusterWithoutInternetOnCNodes]: +def internet_on_compute_nodes(cluster: str) -> TypeGuard[ClusterWithInternetOnCNodes]: if cluster not in CLUSTERS: warnings.warn( UserWarning( - f"Unknown cluster {cluster}. Assuming that compute nodes do not have " - f"internet access on this cluster for now." + f"Unknown cluster {cluster}. Assuming that compute nodes of this " + f"cluster do NOT have access to the internet." ) ) - return cluster not in get_args(ClusterWithInternetOnCNodes) + return cluster in get_args(ClusterWithInternetOnCNodes) def randname(): diff --git a/milatools/utils/compute_node.py b/milatools/utils/compute_node.py index cb36f568..ffa5f09a 100644 --- a/milatools/utils/compute_node.py +++ b/milatools/utils/compute_node.py @@ -192,10 +192,11 @@ async def close_async(self): async def get_queued_milatools_job_ids( - login_node: RemoteV2, job_name="mila-code" + login_node: RemoteV2, job_name: str | None = "mila-code" ) -> set[int]: jobs = await login_node.get_output_async( - f"squeue --noheader --me --format=%A --name={job_name}" + "squeue --noheader --me --format=%A" + + (f" --name={job_name}" if job_name is not None else "") ) return set([int(job_id_str) for job_id_str in jobs.splitlines()]) @@ -226,11 +227,11 @@ async def cancel_new_jobs_on_interrupt(login_node: RemoteV2, job_name: str): try: yield except (KeyboardInterrupt, asyncio.CancelledError): + logger.warning("Interrupted before we were able to parse a job id!") jobs_after = login_node.get_output( f"squeue --noheader --me --format=%A --name={job_name}" ) jobs_after = list(map(int, stripped_lines_of(jobs_after))) - logger.warning("Interrupted before we were able to parse a job id!") # We were unable to get the job id, so we'll try to cancel only the newly # spawned jobs from this user that match the set name. new_jobs = list(set(jobs_after) - set(jobs_before)) @@ -284,14 +285,14 @@ async def salloc( # trying to go full-async so that the parsing of the job-id from stderr can # eventually be done at the same time as something else (while waiting for the # job to start) using things like `asyncio.gather` and `asyncio.wait_for`. - logger.debug(f"(local) $ {shlex.join(command)}") - console.log( - f"({login_node.hostname}) $ {salloc_command}", style="green", markup=False - ) async with cancel_new_jobs_on_interrupt(login_node, job_name): # NOTE: If stdin were not set to PIPE, then the terminal would actually be live # and run commands on the compute node! For instance if you were to do # `mila code .` and then write `salloc`, it would spawn a second job! + logger.debug(f"(localhost) $ {shlex.join(command)}") + console.log( + f"({login_node.hostname}) $ {salloc_command}", style="green", markup=False + ) salloc_subprocess = await asyncio.subprocess.create_subprocess_exec( *command, shell=False, @@ -314,9 +315,10 @@ async def salloc( console.log(f"Waiting for job {job_id} to start.", style="green") await wait_while_job_is_pending(login_node, job_id) except (KeyboardInterrupt, asyncio.CancelledError): - logger.debug("Killing the salloc subprocess following a KeyboardInterrupt.") - salloc_subprocess.terminate() + logger.warning("Interrupted while waiting for the job to start.") login_node.run(f"scancel {job_id}", display=True, hide=False) + logger.debug("Killing the salloc subprocess.") + salloc_subprocess.terminate() raise # Note: While there are potentially states between `PENDING` and `RUNNING`, here diff --git a/milatools/utils/disk_quota.py b/milatools/utils/disk_quota.py new file mode 100644 index 00000000..2c44de84 --- /dev/null +++ b/milatools/utils/disk_quota.py @@ -0,0 +1,159 @@ +from __future__ import annotations + +from logging import getLogger as get_logger + +from typing_extensions import deprecated + +from milatools.cli import console +from milatools.cli.utils import MilatoolsUserError, T +from milatools.utils.remote_v1 import RemoteV1, SlurmRemote +from milatools.utils.remote_v2 import RemoteV2 + +logger = get_logger(__name__) + + +async def check_disk_quota(remote: RemoteV2) -> None: + """Checks that the disk quota isn't exceeded on the remote $HOME filesystem.""" + # NOTE: This is what the output of the command looks like on the Mila cluster: + # + # Disk quotas for usr normandf (uid 1471600598): + # Filesystem kbytes quota limit grace files quota limit grace + # /home/mila/n/normandf + # 95747836 0 104857600 - 908722 0 1048576 - + # uid 1471600598 is using default block quota setting + # uid 1471600598 is using default file quota setting + + # Need to assert this, otherwise .get_output calls .run which would spawn a job! + if not (await remote.get_output_async("which lfs", display=False, hide=True)): + logger.debug("Cluster doesn't have the lfs command. Skipping check.") + return + console.log("Checking disk quota on $HOME...") + home_disk_quota_output = await remote.get_output_async( + "lfs quota -u $USER $HOME", display=False, hide=True + ) + _check_disk_quota_common_part(home_disk_quota_output, cluster=remote.hostname) + + +@deprecated("Deprecated: use `check_disk_quota` instead. ", category=None) +def check_disk_quota_v1(remote: RemoteV1 | RemoteV2) -> None: + """Checks that the user's disk quota isn't exceeded on the remote filesystem(s).""" + # Need to check for this, because SlurmRemote is a subclass of RemoteV1 and + # .get_output calls SlurmRemote.run which would spawn a job! + assert not isinstance(remote, SlurmRemote) + if not (remote.get_output("which lfs", display=False, hide=True)): + logger.debug("Cluster doesn't have the lfs command. Skipping check.") + return + console.log("Checking disk quota on $HOME...") + home_disk_quota_output = remote.get_output( + "lfs quota -u $USER $HOME", display=False, hide=True + ) + _check_disk_quota_common_part(home_disk_quota_output, cluster=remote.hostname) + + +def _parse_lfs_quota_output( + lfs_quota_output: str, +) -> tuple[tuple[float, float], tuple[int, int]]: + """Parses space and # of files (usage, limit) from the output of `lfs quota`.""" + lines = lfs_quota_output.splitlines() + + header_line: str | None = None + header_line_index: int | None = None + for index, line in enumerate(lines): + if ( + len(line_parts := line.strip().split()) == 9 + and line_parts[0].lower() == "filesystem" + ): + header_line = line + header_line_index = index + break + assert header_line + assert header_line_index is not None + + values_line_parts: list[str] = [] + # The next line may overflow to two (or maybe even more?) lines if the name of the + # $HOME dir is too long. + for content_line in lines[header_line_index + 1 :]: + additional_values = content_line.strip().split() + assert len(values_line_parts) < 9 + values_line_parts.extend(additional_values) + if len(values_line_parts) == 9: + break + + assert len(values_line_parts) == 9, values_line_parts + ( + _filesystem, + used_kbytes, + _quota_kbytes, + limit_kbytes, + _grace_kbytes, + files, + _quota_files, + limit_files, + _grace_files, + ) = values_line_parts + + used_gb = int(used_kbytes.strip()) / (1024**2) + max_gb = int(limit_kbytes.strip()) / (1024**2) + used_files = int(files.strip()) + max_files = int(limit_files.strip()) + return (used_gb, max_gb), (used_files, max_files) + + +def _check_disk_quota_common_part(home_disk_quota_output: str, cluster: str): + if "not on a mounted Lustre filesystem" in home_disk_quota_output: + logger.debug("Cluster doesn't use lustre on $HOME filesystem. Skipping check.") + return + + (used_gb, max_gb), (used_files, max_files) = _parse_lfs_quota_output( + home_disk_quota_output + ) + + def get_colour(used: float, max: float) -> str: + return "red" if used >= max else "orange" if used / max > 0.7 else "green" + + disk_usage_style = get_colour(used_gb, max_gb) + num_files_style = get_colour(used_files, max_files) + from rich.text import Text + + console.log( + "Disk usage:", + Text(f"{used_gb:.2f} / {max_gb:.2f} GiB", style=disk_usage_style), + "and", + Text(f"{used_files} / {max_files} files", style=num_files_style), + markup=False, + ) + size_ratio = used_gb / max_gb + files_ratio = used_files / max_files + reason = ( + f"{used_gb:.1f} / {max_gb} GiB" + if size_ratio > files_ratio + else f"{used_files} / {max_files} files" + ) + + freeing_up_space_instructions = ( + "For example, temporary files (logs, checkpoints, etc.) can be moved to " + "$SCRATCH, while files that need to be stored for longer periods can be moved " + "to $ARCHIVE or to a shared project folder under /network/projects.\n" + "Visit https://docs.mila.quebec/Information.html#storage to learn more about " + "how to best make use of the different filesystems available on the cluster." + ) + + if used_gb >= max_gb or used_files >= max_files: + raise MilatoolsUserError( + T.red( + f"ERROR: Your disk quota on the $HOME filesystem is exceeded! " + f"({reason}).\n" + f"To fix this, login to the cluster with `ssh {cluster}` and free up " + f"some space, either by deleting files, or by moving them to a " + f"suitable filesystem.\n" + freeing_up_space_instructions + ) + ) + if max(size_ratio, files_ratio) > 0.9: + warning_message = ( + f"You are getting pretty close to your disk quota on the $HOME " + f"filesystem: ({reason})\n" + "Please consider freeing up some space in your $HOME folder, either by " + "deleting files, or by moving them to a more suitable filesystem.\n" + + freeing_up_space_instructions + ) + logger.warning(UserWarning(warning_message)) diff --git a/milatools/utils/local_v2.py b/milatools/utils/local_v2.py index 4694f00e..2944c157 100644 --- a/milatools/utils/local_v2.py +++ b/milatools/utils/local_v2.py @@ -193,9 +193,17 @@ async def run_async( stdin=asyncio.subprocess.PIPE, shell=False, ) + if input: logger.debug(f"Sending {input=!r} to the subprocess' stdin.") - stdout, stderr = await proc.communicate(input.encode() if input else None) + try: + stdout, stderr = await proc.communicate(input.encode() if input else None) + except asyncio.CancelledError: + logger.debug(f"Got interrupted while calling {proc}.communicate({input=}).") + # This is a fix for ugly error trace on interrupt: https://bugs.python.org/issue43884 + if transport := getattr(proc, "_transport", None): + transport.close() # type: ignore + raise assert proc.returncode is not None if proc.returncode != 0: @@ -208,7 +216,7 @@ async def run_async( logger.debug(message) if not warn: if stderr: - logger.error(stderr) + logger.error(stderr.decode()) raise subprocess.CalledProcessError( returncode=proc.returncode, cmd=program_and_args, diff --git a/milatools/utils/parallel_progress.py b/milatools/utils/parallel_progress.py index 71e717d7..aff12597 100644 --- a/milatools/utils/parallel_progress.py +++ b/milatools/utils/parallel_progress.py @@ -1,11 +1,14 @@ from __future__ import annotations -import multiprocessing -import time -from concurrent.futures import Future, ThreadPoolExecutor +import asyncio +import functools from logging import getLogger as get_logger -from multiprocessing.managers import DictProxy -from typing import Iterable, Protocol, TypedDict, TypeVar +from typing import ( + Coroutine, + Protocol, + TypedDict, + TypeVar, +) from rich.progress import ( BarColumn, @@ -32,156 +35,203 @@ class ProgressDict(TypedDict): info: NotRequired[str] -class TaskFn(Protocol[OutT_co]): - """Protocol for a function that can be run in parallel and reports its progress. +class ReportProgressFn(Protocol): + """A function to be called inside a task to show information in the progress bar.""" + + def __call__(self, progress: int, total: int, info: str | None = None) -> None: + ... # pragma: no cover - The function should periodically set a dict containing info about it's progress in - the `progress_dict` at key `task_id`. For example: - ```python - def _example_task_fn(progress_dict: DictProxy[TaskID, ProgressDict], task_id: TaskID): - import random - import time - progress_dict[task_id] = {"progress": 0, "total": len_of_task, "info": "Starting."} +def report_progress( + progress: int, + total: int, + info: str | None = None, + *, + task_id: TaskID, + progress_dict: dict[TaskID, ProgressDict], +): + if info is not None: + progress_dict[task_id] = {"progress": progress, "total": total, "info": info} + else: + progress_dict[task_id] = {"progress": progress, "total": total} - len_of_task = random.randint(3, 20) # take some random length of time - for n in range(len_of_task): - time.sleep(1) # sleep for a bit to simulate work - progress_dict[task_id] = {"progress": n + 1, "total": len_of_task} - progress_dict[task_id] = {"progress": len_of_task, "total": len_of_task, "info": "Done."} - return f"Some result for task {task_id}." +class AsyncTaskFn(Protocol[OutT_co]): + """Protocol for a function that can be run in parallel and reports its progress. - for result in parallel_progress_bar([_example_task_fn, _example_task_fn]): - print(result) + The function can (should) periodically report info about it's progress by calling + the `report_progress` function. For example: """ def __call__( - self, task_progress_dict: DictProxy[TaskID, ProgressDict], task_id: TaskID - ) -> OutT_co: - ... + self, report_progress: ReportProgressFn + ) -> Coroutine[None, None, OutT_co]: + ... # pragma: no cover -def parallel_progress_bar( - task_fns: list[TaskFn[OutT_co]], +async def run_async_tasks_with_progress_bar( + async_task_fns: list[AsyncTaskFn[OutT_co]], task_descriptions: list[str] | None = None, overall_progress_task_description: str = "[green]All jobs progress:", - n_workers: int = 8, -) -> Iterable[OutT_co]: - """Adapted from https://www.deanmontgomery.com/2022/03/24/rich-progress-and- - multiprocessing/ - - TODO: Double-check that using a ThreadPoolExecutor here actually makes sense and - that the calls over SSH can be done in parallel. + _show_elapsed_time: bool = True, +) -> list[OutT_co]: + """Run a sequence of async tasks in "parallel" and display a progress bar. + + Adapted from the example at: + + https://www.deanmontgomery.com/2022/03/24/rich-progress-and-multiprocessing/ + + NOTE: This differs from the usual progress bar: the results are returned as a list + (all at the same time) instead of one at a time. + + >>> import pytest, sys + >>> if sys.platform.startswith('win'): + ... pytest.skip("This doctest doesn't work properly on Windows.") + >>> async def example_task_fn(report_progress: ReportProgressFn, len_of_task: int): + ... import random + ... report_progress(progress=0, total=len_of_task, info="Starting.") + ... for n in range(len_of_task): + ... await asyncio.sleep(1) # sleep for a bit to simulate work + ... report_progress(progress=n + 1, total=len_of_task, info="working...") + ... report_progress(progress=len_of_task, total=len_of_task, info="Done.") + ... return f"Done after {len_of_task} seconds." + >>> import functools + >>> tasks = [functools.partial(example_task_fn, len_of_task=i) for i in range(1, 4)] + >>> import time + >>> start_time = time.time() + >>> results = asyncio.run(run_async_tasks_with_progress_bar(tasks, _show_elapsed_time=False)) + ✓ All jobs progress: 6/6 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 + ✓ Task 0 - Done. 1/1 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 + ✓ Task 1 - Done. 2/2 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 + ✓ Task 2 - Done. 3/3 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 + >>> results + ['Done after 1 seconds.', 'Done after 2 seconds.', 'Done after 3 seconds.'] + >>> f"Finished all tasks in {round(time.time() - start_time)} seconds." + 'Finished all tasks in 3 seconds.' """ if task_descriptions is None: - task_descriptions = [f"Task {i}" for i in range(len(task_fns))] - - assert task_fns - assert len(task_fns) == len(task_descriptions) - - futures: dict[TaskID, Future[OutT_co]] = {} - num_yielded_results: int = 0 - - # NOTE: Could also use a ProcessPoolExecutor here: - # executor = ProcessPoolExecutor(max_workers=n_workers) - executor = ThreadPoolExecutor( - max_workers=n_workers, thread_name_prefix="mila_sync_worker" - ) - manager = multiprocessing.Manager() - progress = Progress( + task_descriptions = [f"Task {i}" for i in range(len(async_task_fns))] + columns = [ SpinnerColumn(finished_text="[green]✓"), TextColumn("[progress.description]{task.description}"), MofNCompleteColumn(), BarColumn(bar_width=None), TaskProgressColumn(), + *([TimeElapsedColumn()] if _show_elapsed_time else []), TimeRemainingColumn(), - TimeElapsedColumn(), + ] + progress = Progress( + *columns, console=console, transient=False, refresh_per_second=10, expand=False, ) - with executor, manager, progress: - # We share some state between our main process and our worker - # functions - _progress_dict: DictProxy[TaskID, ProgressDict] = manager.dict() - - overall_progress_task = progress.add_task( - overall_progress_task_description, - visible=True, - start=True, + + _progress_dict: dict[TaskID, ProgressDict] = {} + tasks: dict[TaskID, asyncio.Task[OutT_co]] = {} + + overall_progress_task = progress.add_task( + overall_progress_task_description, + visible=True, + start=True, + ) + # iterate over the jobs we need to run + for task_description, async_task_fn in zip(task_descriptions, async_task_fns): + # NOTE: Could set visible=false so we don't have a lot of bars all at once. + task_id = progress.add_task( + description=task_description, + visible=False, + start=False, ) + report_progress_fn = functools.partial( + report_progress, task_id=task_id, progress_dict=_progress_dict + ) + coroutine = async_task_fn(report_progress=report_progress_fn) + + tasks[task_id] = asyncio.create_task(coroutine, name=task_description) + + update_pbar_task = asyncio.create_task( + update_progress_bar( + progress, + tasks=tasks, + task_descriptions=task_descriptions, + progress_dict=_progress_dict, + overall_progress_task=overall_progress_task, + ), + name=update_progress_bar.__name__, + ) + try: + with progress: + await asyncio.gather( + *[*tasks.values(), update_pbar_task], return_exceptions=True + ) + except (KeyboardInterrupt, asyncio.CancelledError) as err: + logger.warning(f"Received {type(err).__name__}, cancelling tasks.") + for task in tasks.values(): + task.cancel() + update_pbar_task.cancel() + raise + + return [task.result() for task in tasks.values()] + + +async def update_progress_bar( + progress: Progress, + tasks: dict[TaskID, asyncio.Task[OutT_co]], + progress_dict: dict[TaskID, ProgressDict], + task_descriptions: list[str], + overall_progress_task: TaskID, +): + assert len(task_descriptions) == len(tasks) + _started_task_ids: list[TaskID] = [] + while True: + total_progress = 0 + total_task_lengths = 0 + + for (task_id, task), task_description in zip(tasks.items(), task_descriptions): + if task_id not in progress_dict: + # No progress reported yet by the task function. + continue + + update_data = progress_dict[task_id] + task_progress = update_data["progress"] + task_total = update_data["total"] + + # Start the task in the progress bar when the first update is received. + # This allows us to have a nice per-task elapsed time instead of the + # same elapsed time in all tasks. + if task_id not in _started_task_ids and task_progress > 0: + # Note: calling `start_task` multiple times doesn't cause issues, + # but we're still doing this just to be explicit. + progress.start_task(task_id) + _started_task_ids.append(task_id) + + progress.update( + task_id=task_id, + completed=task_progress, + total=task_total, + description=task_description + + ( + " - Done." + if task.done() + else (f" - {info}" if (info := update_data.get("info")) else "") + ), + visible=True, + ) + + total_progress += task_progress + total_task_lengths += task_total - # iterate over the jobs we need to run - for task_name, task_fn in zip(task_descriptions, task_fns): - # NOTE: Could set visible=false so we don't have a lot of bars all at once. - task_id = progress.add_task( - description=task_name, visible=True, start=False + if total_progress or total_task_lengths: + progress.update( + task_id=overall_progress_task, + completed=total_progress, + total=total_task_lengths, + visible=True, ) - futures[task_id] = executor.submit(task_fn, _progress_dict, task_id) - - _started_task_ids: list[TaskID] = [] - - # monitor the progress: - while num_yielded_results < len(futures): - total_progress = 0 - total_task_lengths = 0 - - for (task_id, future), task_description in zip( - futures.items(), task_descriptions - ): - if task_id not in _progress_dict: - # No progress reported yet by the task function. - continue - - update_data = _progress_dict[task_id] - task_progress = update_data["progress"] - total = update_data["total"] - - # Start the task in the progress bar when the first update is received. - # This allows us to have a nice per-task elapsed time instead of the - # same elapsed time in all tasks. - if task_id not in _started_task_ids and task_progress > 0: - # Note: calling `start_task` multiple times doesn't cause issues, - # but we're still doing this just to be explicit. - progress.start_task(task_id) - _started_task_ids.append(task_id) - - # Update the progress bar for this task: - progress.update( - task_id=task_id, - completed=task_progress, - total=total, - description=task_description - + (f" - {info}" if (info := update_data.get("info")) else ""), - visible=True, - ) - total_progress += task_progress - total_task_lengths += total - - if total_progress or total_task_lengths: - progress.update( - task_id=overall_progress_task, - completed=total_progress, - total=total_task_lengths, - visible=True, - ) - - next_task_id_to_yield, next_future_to_resolve = list(futures.items())[ - num_yielded_results - ] - if next_future_to_resolve.done(): - logger.debug(f"Task {next_task_id_to_yield} is done, yielding result.") - yield next_future_to_resolve.result() - num_yielded_results += 1 - - try: - time.sleep(0.01) - except KeyboardInterrupt: - logger.info( - "Received keyboard interrupt, cancelling tasks that haven't started yet." - ) - for future in futures.values(): - future.cancel() - break + + if all(task.done() for task in tasks.values()): + break + + await asyncio.sleep(0.10) diff --git a/milatools/utils/vscode_utils.py b/milatools/utils/vscode_utils.py index bcd587b3..c63380cc 100644 --- a/milatools/utils/vscode_utils.py +++ b/milatools/utils/vscode_utils.py @@ -1,36 +1,33 @@ from __future__ import annotations +import asyncio import functools import os -import shlex import shutil import subprocess import sys import textwrap from logging import getLogger as get_logger from pathlib import Path -from typing import Literal, Sequence +from typing import Sequence from milatools.cli.utils import ( - CLUSTERS, + CommandNotFoundError, batched, stripped_lines_of, ) -from milatools.utils.local_v1 import LocalV1 +from milatools.utils.local_v2 import LocalV2 from milatools.utils.parallel_progress import ( - DictProxy, - ProgressDict, - TaskFn, - TaskID, - parallel_progress_bar, + AsyncTaskFn, + ReportProgressFn, + run_async_tasks_with_progress_bar, ) -from milatools.utils.remote_v1 import RemoteV1 from milatools.utils.remote_v2 import RemoteV2 logger = get_logger(__name__) -def running_inside_WSL() -> bool: +def _running_inside_WSL() -> bool: return sys.platform == "linux" and bool(shutil.which("powershell.exe")) @@ -46,7 +43,7 @@ def get_expected_vscode_settings_json_path() -> Path: / "User" / "settings.json" ) - if running_inside_WSL(): + if _running_inside_WSL(): # Need to get the Windows Home directory, not the WSL one! windows_username = subprocess.getoutput("powershell.exe '$env:UserName'") return Path( @@ -60,130 +57,111 @@ def get_code_command() -> str: return os.environ.get("MILATOOLS_CODE_COMMAND", "code") -def get_vscode_executable_path() -> str | None: - return shutil.which(get_code_command()) +def _get_local_vscode_executable_path(code_command: str | None = None) -> str: + if code_command is None: + code_command = get_code_command() + + code_command_path = shutil.which(code_command) + if not code_command_path: + raise CommandNotFoundError(code_command) + return code_command_path def vscode_installed() -> bool: - return bool(get_vscode_executable_path()) + try: + _ = _get_local_vscode_executable_path() + except CommandNotFoundError: + return False + return True -def sync_vscode_extensions_with_hostnames( - source: str, - destinations: list[str], -): - destinations = list(destinations) - if source in destinations: - if source == "mila" and destinations == CLUSTERS: - logger.info("Assuming you want to sync from mila to all DRAC/CC clusters.") - else: - logger.warning( - f"{source=} is also in the destinations to sync to. " f"Removing it." - ) - destinations.remove(source) +async def sync_vscode_extensions( + source: str | LocalV2 | RemoteV2, + destinations: Sequence[str | LocalV2 | RemoteV2], +) -> dict[str, list[str]]: + """Syncs vscode extensions between `source` all all the clusters in `dest`. - if len(set(destinations)) != len(destinations): - raise ValueError(f"{destinations=} contains duplicate hostnames!") + This creates an async task for each cluster in `dest` and displays a progress bar. + Returns the extensions that were installed on each cluster. + """ + if isinstance(source, str): + if source == "localhost": + source = LocalV2() + else: + source = await RemoteV2.connect(source) - source_obj = LocalV1() if source == "localhost" else RemoteV2(source) - return sync_vscode_extensions(source_obj, destinations) + destinations = _remove_source_from_destinations(source, destinations) + if not destinations: + logger.info("No destinations to sync extensions to!") + return {} -def sync_vscode_extensions( - source: str | LocalV1 | RemoteV2, - dest_clusters: Sequence[str | LocalV1 | RemoteV2], -): - """Syncs vscode extensions between `source` all all the clusters in `dest`. + source_extensions = await _get_vscode_extensions(source) - This spawns a thread for each cluster in `dest` and displays a parallel progress bar - for the syncing of vscode extensions to each cluster. - """ - if isinstance(source, LocalV1): - source_hostname = "localhost" - source_extensions = get_local_vscode_extensions() - elif isinstance(source, RemoteV2): - source_hostname = source.hostname - code_server_executable = find_code_server_executable( - source, remote_vscode_server_dir="~/.vscode-server" - ) - if not code_server_executable: - raise RuntimeError( - f"The vscode-server executable was not found on {source.hostname}." - ) - source_extensions = get_remote_vscode_extensions(source, code_server_executable) - else: - assert isinstance(source, str) - source_hostname = source - source = RemoteV2(source) - - task_hostnames: list[str] = [] - task_fns: list[TaskFn[ProgressDict]] = [] + tasks: list[AsyncTaskFn[list[str]]] = [] task_descriptions: list[str] = [] - for dest_remote in dest_clusters: - dest_hostname: str - - if dest_remote == "localhost": - dest_hostname = dest_remote # type: ignore - dest_remote = LocalV1() # pickleable - elif isinstance(dest_remote, LocalV1): - dest_hostname = "localhost" - dest_remote = dest_remote # again, pickleable - elif isinstance(dest_remote, RemoteV2): - dest_hostname = dest_remote.hostname - dest_remote = dest_remote # pickleable - elif isinstance(dest_remote, RemoteV1): - # We unfortunately can't pass this kind of object to another process or - # thread because it uses `fabric.Connection` which don't appear to be - # pickleable. This means we will have to re-connect in the subprocess. - dest_hostname = dest_remote.hostname - dest_remote = None - else: - assert isinstance(dest_remote, str) - # The dest_remote is a hostname. Try to connect to it with a reusable SSH - # control socket so we can get the 2FA prompts out of the way in advance. - # NOTE: We could fallback to using the `Remote` class with paramiko inside - # the subprocess if this doesn't work, but it would suck because it messes - # up the UI, and you need to press 1 in the terminal to get the 2FA prompt, - # which screws up the progress bars. - dest_hostname = dest_remote - dest_remote = RemoteV2(hostname=dest_hostname) - - task_hostnames.append(dest_hostname) - task_fns.append( + dest_hostnames = [ + dest if isinstance(dest, str) else dest.hostname for dest in destinations + ] + for dest_runner, dest_hostname in zip(destinations, dest_hostnames): + tasks.append( functools.partial( - install_vscode_extensions_task_function, - dest_hostname=dest_hostname, + _install_vscode_extensions_task_function, source_extensions=source_extensions, - remote=dest_remote, - source_name=source_hostname, + remote=dest_runner, + source_name=source.hostname, ) ) - task_descriptions.append(f"{source_hostname} -> {dest_hostname}") + task_descriptions.append(f"{source.hostname} -> {dest_hostname}") - results: dict[str, ProgressDict] = {} + results = await run_async_tasks_with_progress_bar( + async_task_fns=tasks, + task_descriptions=task_descriptions, + overall_progress_task_description="[green]Syncing vscode extensions:", + ) + return {hostname: result for hostname, result in zip(dest_hostnames, results)} + + +def _remove_source_from_destinations( + source: LocalV2 | RemoteV2, destinations: Sequence[str | LocalV2 | RemoteV2] +): + dest_hostnames = [ + dest if isinstance(dest, str) else dest.hostname for dest in destinations + ] + if source.hostname in dest_hostnames: + logger.debug(f"{source.hostname!r} is also in the destinations, removing it.") + destinations = list(destinations) + destinations.pop(dest_hostnames.index(source.hostname)) + + if len(set(dest_hostnames)) != len(dest_hostnames): + raise ValueError(f"{dest_hostnames=} contains duplicate hostnames!") + return destinations - for hostname, result in zip( - task_hostnames, - parallel_progress_bar( - task_fns=task_fns, - task_descriptions=task_descriptions, - overall_progress_task_description="[green]Syncing vscode extensions:", - ), - ): - results[hostname] = result - return results + +async def _get_vscode_extensions( + source: LocalV2 | RemoteV2, +) -> dict[str, str]: + if isinstance(source, LocalV2): + code_server_executable = _get_local_vscode_executable_path(code_command=None) + else: + code_server_executable = await _find_code_server_executable( + source, remote_vscode_server_dir="~/.vscode-server" + ) + if not code_server_executable: + raise RuntimeError( + f"The vscode-server executable was not found on {source.hostname}." + ) + return await _get_vscode_extensions_dict(source, code_server_executable) -def install_vscode_extensions_task_function( - task_progress_dict: DictProxy[TaskID, ProgressDict], - task_id: TaskID, - dest_hostname: str | Literal["localhost"], +async def _install_vscode_extensions_task_function( + report_progress: ReportProgressFn, source_extensions: dict[str, str], - remote: RemoteV2 | LocalV1 | None, + remote: str | RemoteV2 | LocalV2, source_name: str, verbose: bool = False, -) -> ProgressDict: +) -> list[str]: """Installs vscode extensions on the remote cluster. 1. Finds the `code-server` executable on the remote; @@ -192,38 +170,37 @@ def install_vscode_extensions_task_function( extensions on the source; 4. Install the extensions that are missing or out of date on the remote, updating the progress dict as it goes. + + + Returns the list of installed extensions, in the form 'extension_name@version'. """ + installed: list[str] = [] def _update_progress( progress: int, status: str, total: int = len(source_extensions) ): - # Show progress to the parent process by setting an item in the task progress - # dict. - progress_dict: ProgressDict = { - "progress": progress, - "total": total, - "info": textwrap.shorten(status, 50, placeholder="..."), - } - task_progress_dict[task_id] = progress_dict - return progress_dict - - if remote is None: - if dest_hostname == "localhost": - remote = LocalV1() + info = textwrap.shorten(status, 50, placeholder="...") + report_progress(progress=progress, total=total, info=info) + + dest_hostname = remote if isinstance(remote, str) else remote.hostname + + if isinstance(remote, str): + if remote == "localhost": + remote = LocalV2() else: _update_progress(0, "Connecting...") - remote = RemoteV2(dest_hostname) + remote = await RemoteV2.connect(remote) - if isinstance(remote, LocalV1): - assert dest_hostname == "localhost" - code_server_executable = get_vscode_executable_path() - assert code_server_executable - extensions_on_dest = get_local_vscode_extensions() + if isinstance(remote, LocalV2): + code_server_executable = _get_local_vscode_executable_path() + _update_progress(0, status="fetching installed extensions...") + extensions_on_dest = await _get_vscode_extensions_dict( + remote, code_server_executable + ) else: - dest_hostname = remote.hostname remote_vscode_server_dir = "~/.vscode-server" _update_progress(0, f"Looking for code-server in {remote_vscode_server_dir}") - code_server_executable = find_code_server_executable( + code_server_executable = await _find_code_server_executable( remote, remote_vscode_server_dir=remote_vscode_server_dir, ) @@ -232,20 +209,22 @@ def _update_progress( f"The vscode-server executable was not found on {remote.hostname}." f"Skipping syncing extensions to {remote.hostname}." ) - return _update_progress( + _update_progress( # IDEA: Use a progress of `-1` to signify an error, and use a "X" # instead of a checkmark? progress=0, total=0, status="code-server executable not found!", ) + return installed + _update_progress(0, status="fetching installed extensions...") - extensions_on_dest = get_remote_vscode_extensions( + extensions_on_dest = await _get_vscode_extensions_dict( remote, code_server_executable ) logger.debug(f"{len(source_extensions)=}, {len(extensions_on_dest)=}") - to_install = extensions_to_install( + to_install = _extensions_to_install( source_extensions, extensions_on_dest, source_name=source_name, @@ -266,92 +245,71 @@ def _update_progress( total=len(to_install), status=f"Installing {extension_name}", ) - result = install_vscode_extension( + extension = f"{extension_name}@{extension_version}" + result = await _install_vscode_extension( remote, - code_server_executable, - extension=f"{extension_name}@{extension_version}", + code_server_executable=code_server_executable, + extension=extension, verbose=verbose, ) - except KeyboardInterrupt: - return _update_progress( + if result.returncode != 0: + logger.debug( + f"Unable to install extension {extension} on {dest_hostname}: {result.stderr}" + ) + else: + installed.append(extension) + except (KeyboardInterrupt, asyncio.CancelledError): + _update_progress( progress=index, total=len(to_install), status="Interrupted.", ) + return installed - if result.returncode != 0: - logger.debug(f"{dest_hostname}: {result.stderr}") - - return _update_progress( + _update_progress( progress=len(to_install), total=len(to_install), status="Done.", ) + return installed -def install_vscode_extension( - remote: LocalV1 | RemoteV2, +async def _install_vscode_extension( + remote: LocalV2 | RemoteV2, code_server_executable: str, extension: str, verbose: bool = False, ): - command = ( - code_server_executable, - "--install-extension", - extension, + command = f"{code_server_executable} --install-extension {extension}" + result = await remote.run_async( + command, + display=verbose, + warn=True, + hide=not verbose, ) - if isinstance(remote, RemoteV2): - result = remote.run( - shlex.join(command), - display=verbose, - warn=True, - hide=not verbose, - ) - else: - result = remote.run( - *command, - capture_output=not verbose, - display_command=verbose, - ) if result.stdout: logger.debug(result.stdout) return result -def get_local_vscode_extensions() -> dict[str, str]: - output = subprocess.run( - ( - get_vscode_executable_path() or get_code_command(), - "--list-extensions", - "--show-versions", - ), - shell=False, - check=True, - capture_output=True, - text=True, - ).stdout.strip() - return parse_vscode_extensions_versions(stripped_lines_of(output)) - - -def get_remote_vscode_extensions( - remote: RemoteV1 | RemoteV2, - remote_code_server_executable: str, +async def _get_vscode_extensions_dict( + remote: RemoteV2 | LocalV2, + code_server_executable: str, ) -> dict[str, str]: """Returns the list of isntalled extensions and the path to the code-server executable.""" - remote_extensions = parse_vscode_extensions_versions( + return _parse_vscode_extensions_versions( stripped_lines_of( - remote.get_output( - f"{remote_code_server_executable} --list-extensions --show-versions", + await remote.get_output_async( + f"{code_server_executable} --list-extensions --show-versions", display=False, hide=True, ) ) ) - return remote_extensions -def extensions_to_install( +def _extensions_to_install( source_extensions: dict[str, str], dest_extensions: dict[str, str], source_name: str, @@ -386,8 +344,8 @@ def extensions_to_install( return extensions_to_install_on_dest -def find_code_server_executable( - remote: RemoteV1 | RemoteV2, remote_vscode_server_dir: str = "~/.vscode-server" +async def _find_code_server_executable( + remote: RemoteV2, remote_vscode_server_dir: str = "~/.vscode-server" ) -> str | None: """Find the most recent `code-server` executable on the remote. @@ -395,13 +353,14 @@ def find_code_server_executable( """ cluster = remote.hostname # TODO: When doing this for the first time on a remote cluster, this file might not - # be present until the vscode window has opened and installed the vscode server on the - # remote! Perhaps we should wait a little bit until it finishes installing somehow? + # be present until the vscode window has opened and installed the vscode server on + # the remote! Perhaps we should wait a little bit until it finishes installing + # somehow? find_code_server_executables_command = ( f"find {remote_vscode_server_dir} -name code-server -executable -type f" ) code_server_executables = stripped_lines_of( - remote.get_output( + await remote.get_output_async( find_code_server_executables_command, display=False, warn=True, @@ -412,6 +371,9 @@ def find_code_server_executable( logger.warning(f"Unable to find any code-server executables on {cluster}.") return None + # Now that we have the list of vscode-server executables, we get the version of + # each. + # Run a single fused command over SSH instead of one command for each executable. # Each executable outputs 3 lines: # ``` @@ -420,7 +382,7 @@ def find_code_server_executable( # x64 # ``` remote_version_command_output = stripped_lines_of( - remote.get_output( + await remote.get_output_async( find_code_server_executables_command + " -exec {} --version \\;", display=False, hide=True, @@ -448,6 +410,8 @@ def find_code_server_executable( f"{cluster}." ) # Use the most recent vscode-server executable. + # TODO: Should we instead use the one that is closest to the version of the local + # editor? most_recent_code_server_executable = max( code_server_executable_versions.keys(), key=code_server_executable_versions.__getitem__, @@ -455,7 +419,7 @@ def find_code_server_executable( return most_recent_code_server_executable -def parse_vscode_extensions_versions( +def _parse_vscode_extensions_versions( list_extensions_output_lines: list[str], ) -> dict[str, str]: extensions = [line for line in list_extensions_output_lines if "@" in line] diff --git a/tests/cli/test_commands.py b/tests/cli/test_commands.py index b833664a..46e7163d 100644 --- a/tests/cli/test_commands.py +++ b/tests/cli/test_commands.py @@ -8,7 +8,8 @@ import pytest from pytest_regressions.file_regression import FileRegressionFixture -from milatools.cli.commands import _parse_lfs_quota_output, main +from milatools.cli.commands import main +from milatools.utils.disk_quota import _parse_lfs_quota_output from .common import requires_no_s_flag diff --git a/tests/cli/test_utils.py b/tests/cli/test_utils.py index 4ed177b4..afa4c3ff 100644 --- a/tests/cli/test_utils.py +++ b/tests/cli/test_utils.py @@ -10,6 +10,7 @@ from milatools.cli.utils import ( get_fully_qualified_name, get_hostname_to_use_for_compute_node, + internet_on_compute_nodes, make_process, qn, randname, @@ -114,3 +115,28 @@ def test_make_process(): # run the syncing of vscode extensions in the background during `mila code`. assert not process.daemon assert not process.is_alive() + + +@pytest.mark.parametrize( + ("cluster", "expected"), + [ + ("mila", True), + ("narval", False), + ("beluga", False), + ("graham", False), + ("cedar", True), + ], +) +def test_internet_on_compute_nodes(cluster: str, expected: bool): + assert internet_on_compute_nodes(cluster) == expected + + +@pytest.mark.parametrize( + "cluster", + [ + "unknown_cluster", + ], +) +def test_internet_on_compute_nodes_unknown_cluster(cluster: str): + with pytest.warns(UserWarning, match="Unknown cluster"): + assert not internet_on_compute_nodes(cluster) diff --git a/tests/conftest.py b/tests/conftest.py index f74a8177..e08c5d56 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,10 +17,10 @@ import rich from fabric.connection import Connection -import milatools.cli +import milatools.cli.code import milatools.cli.commands -import milatools.utils import milatools.utils.compute_node +import milatools.utils.disk_quota import milatools.utils.local_v2 import milatools.utils.parallel_progress import milatools.utils.remote_v2 @@ -60,15 +60,19 @@ def use_wider_console_during_tests(monkeypatch: pytest.MonkeyPatch): test_console = rich.console.Console( record=True, width=200, log_time=False, log_path=False ) + monkeypatch.setattr(milatools.cli, "console", test_console) monkeypatch.setitem(globals(), "console", test_console) + for module in [ milatools.cli.commands, milatools.utils.compute_node, milatools.utils.local_v2, milatools.utils.parallel_progress, milatools.utils.remote_v2, + milatools.utils.disk_quota, test_parallel_progress, + milatools.cli.code, ]: # These modules import the console from milatools.cli before this runs, so we # need to patch them also. diff --git a/tests/integration/test_code.py b/tests/integration/test_code.py new file mode 100644 index 00000000..64ce63c5 --- /dev/null +++ b/tests/integration/test_code.py @@ -0,0 +1,455 @@ +from __future__ import annotations + +import asyncio +import contextlib +import datetime +import re +import shutil +import sys +from datetime import timedelta +from logging import getLogger as get_logger +from unittest.mock import AsyncMock, Mock + +import pytest +import pytest_asyncio +from pytest_regressions.file_regression import FileRegressionFixture + +from milatools.cli.code import code +from milatools.cli.commands import code_v1 +from milatools.cli.utils import ( + CommandNotFoundError, + MilatoolsUserError, + get_hostname_to_use_for_compute_node, + removesuffix, +) +from milatools.utils import disk_quota +from milatools.utils.compute_node import ( + ComputeNode, + get_queued_milatools_job_ids, + salloc, +) +from milatools.utils.disk_quota import check_disk_quota, check_disk_quota_v1 +from milatools.utils.remote_v1 import RemoteV1 +from milatools.utils.remote_v2 import RemoteV2 + +from ..conftest import job_name, launches_jobs +from .test_slurm_remote import get_recent_jobs_info_dicts + +logger = get_logger(__name__) + + +async def _get_job_info( + job_id: int, + login_node: RemoteV2, + fields: tuple[str, ...] = ("JobID", "JobName", "Node", "WorkDir", "State"), +) -> dict: + return dict( + zip( + fields, + ( + await login_node.get_output_async( + f"sacct --noheader --allocations --user=$USER --jobs {job_id} " + "--format=" + ",".join(f"{field}%40" for field in fields), + display=False, + hide=True, + ) + ) + .strip() + .split(), + ) + ) + + +@launches_jobs +@pytest.mark.slow +@pytest.mark.asyncio +@pytest.mark.parametrize("persist", [True, False], ids=["sbatch", "salloc"]) +@pytest.mark.parametrize( + job_name.__name__, + # Don't set the `--job-name` in the `allocation_flags` fixture + # (this is necessary for `mila code` to work properly). + [None], + ids=[""], + indirect=True, +) +async def test_code( + login_node_v2: RemoteV2, + persist: bool, + capsys: pytest.CaptureFixture, + allocation_flags: list[str], + file_regression: FileRegressionFixture, + slurm_account_on_cluster: str, +): + if login_node_v2.hostname == "localhost": + pytest.skip( + "TODO: This test doesn't yet work with the slurm cluster spun up in the GitHub CI." + ) + + home = await login_node_v2.get_output_async("echo $HOME") + scratch = await login_node_v2.get_output_async("echo $SCRATCH") + + start = datetime.datetime.now() - timedelta(minutes=5) + jobs_before = get_recent_jobs_info_dicts( + login_node_v2, since=datetime.datetime.now() - start + ) + jobs_before = { + int(job_info["JobID"]): job_info + for job_info in jobs_before + if job_info["JobName"] == "mila-code" + } + + relative_path = "bob" + + with contextlib.redirect_stderr(sys.stdout): + logger.info(f"{'sbatch' if persist else 'salloc'} flags: {allocation_flags}") + compute_node_or_job_id = await code( + path=relative_path, + command="echo", # replace the usual `code` with `echo` for testing. + # idea: Could probably also return the process ID of the `code` editor? + persist=persist, + job=None, + node=None, + alloc=allocation_flags, + cluster=login_node_v2.hostname, # type: ignore + ) + + # Get the output that was printed while running that command. + captured_output = capsys.readouterr().out + + node_hostname: str | None = None + if persist: + assert isinstance(compute_node_or_job_id, ComputeNode) + compute_node = compute_node_or_job_id + assert compute_node is not None + job_id = compute_node.job_id + node_hostname = compute_node.hostname + + else: + assert isinstance(compute_node_or_job_id, int) + job_id = compute_node_or_job_id + + await asyncio.sleep(5) # give a chance to sacct to update. + + job_info = await _get_job_info( + job_id=job_id, + login_node=login_node_v2, + fields=("JobID", "JobName", "Node", "WorkDir", "State"), + ) + if node_hostname is None: + node_hostname = get_hostname_to_use_for_compute_node( + job_info["Node"], cluster=login_node_v2.hostname + ) + assert node_hostname and node_hostname != "None" + + # Check that the workdir is the scratch directory (because we cd'ed to $SCRATCH + # before submitting the job) + workdir = job_info["WorkDir"] + assert workdir == scratch + try: + if persist: + # Job should still be running since we're using `persist` (that's the whole + # point.) + # NOTE: There's a fixture that scancel's all our jobs spawned during unit tests + # so there's no issue of lingering jobs on the cluster after the tests run/fail. + assert job_info["State"] == "RUNNING" + await compute_node.close_async() + else: + # NOTE: Job is actually in the `COMPLETED` state because we exited cleanly (by + # passing `exit\n` to the salloc subprocess.) + assert job_info["State"] == "COMPLETED" + finally: + login_node_v2.run(f"scancel {job_id}", display=True) + + def filter_captured_output(captured_output: str) -> str: + # Remove information that may vary between runs from the regression test files. + def filter_line(line: str) -> str: + if ( + regex := re.compile( + r"Disk usage: \d+\.\d+ / \d+\.\d+ GiB and \d+ / \d+ files" + ) + ).match(line): + # IDEA: Use regex to go from this: + # Disk usage: 66.56 / 100.00 GiB and 789192 / 1048576 files + # to this: + # Disk usage: X / LIMIT GiB and X / LIMIT files + line = regex.sub("Disk usage: X / LIMIT GiB and X / LIMIT files", line) + + # If the line ends with an elapsed time, replace it with something constant. + line = re.sub(r"\d+:\d+:\d+$", "H:MM:SS", line) + # In the progress bar for syncing vscode extensions, there might be one with + # N/N (which depends on how many extensions were missing). Replace it with a + # constant. + line = re.sub(r" \d+/\d+ ", " N/N ", line) + + return ( + line.rstrip() + .replace(str(job_id), "JOB_ID") + .replace(node_hostname, "COMPUTE_NODE") + .replace(home, "$HOME") + .replace( + "salloc: Pending job allocation JOB_ID", + "salloc: Granted job allocation JOB_ID", + ) + .replace( + f"--account={slurm_account_on_cluster}", "--account=SLURM_ACCOUNT" + ) + ) + + return "\n".join(filter_line(line) for line in captured_output.splitlines()) + + file_regression.check(filter_captured_output(captured_output)) + + +@pytest.mark.parametrize("use_v1", [False, True], ids=["code", "code_v1"]) +@pytest.mark.asyncio +async def test_code_without_code_command_in_path( + monkeypatch: pytest.MonkeyPatch, use_v1: bool +): + """Test the case where `mila code` is run without having vscode installed.""" + + def mock_which(command: str) -> str | None: + assert command == "code" # pretend like vscode isn't installed. + return None + + monkeypatch.setattr(shutil, shutil.which.__name__, Mock(side_effect=mock_which)) + if use_v1: + with pytest.raises(CommandNotFoundError): + code_v1( + path="bob", + command="code", + persist=False, + job=None, + node=None, + alloc=[], + cluster="bob", + ) + else: + with pytest.raises(CommandNotFoundError): + await code( + path="bob", + command="code", + persist=False, + job=None, + node=None, + alloc=[], + cluster="bob", + ) + + +@pytest_asyncio.fixture(scope="session") +async def existing_job( + cluster: str, + login_node_v2: RemoteV2, + allocation_flags: list[str], + job_name: str, +) -> ComputeNode: + """Gets a compute node connecting to a running job on the cluster. + + This avoids making an allocation if possible, by reusing an already-running job with + the name `job_name` if it exists. + """ + if cluster == "localhost": + pytest.skip( + "This test doesn't yet work with the slurm cluster spun up in the GitHub CI." + ) + + existing_test_jobs_on_cluster = await get_queued_milatools_job_ids( + login_node_v2, job_name=job_name + ) + # todo: filter to use only the ones that are expected to be up for a little while + # longer (e.g. 2-3 minutes) + for job_id in existing_test_jobs_on_cluster: + try: + # Note: Connecting to a compute node runs a command with `srun`, so it will + # raise an error if the job is no longer running. + compute_node = await ComputeNode.connect(login_node_v2, job_id) + except Exception as exc: + logger.debug(f"Unable to reuse job {job_id}: {exc}") + else: + logger.info( + f"Reusing existing test job with name {job_name} on the cluster: {job_id}" + ) + return compute_node + logger.info( + "Unable to find existing test jobs on the cluster. Allocating a new one." + ) + compute_node = await salloc( + login_node_v2, salloc_flags=allocation_flags, job_name=job_name + ) + return compute_node + + +@pytest.fixture +def doesnt_create_new_jobs_fixture(capsys: pytest.CaptureFixture): + yield + out, err = capsys.readouterr() + assert "Submitted batch job" not in out + assert "Submitted batch job" not in err + assert "salloc: Pending job allocation" not in out + assert "salloc: Pending job allocation" not in err + assert "salloc: Granted job allocation" not in out + assert "salloc: Granted job allocation" not in err + + +doesnt_create_new_jobs = pytest.mark.usefixtures( + doesnt_create_new_jobs_fixture.__name__ +) + + +@doesnt_create_new_jobs +@pytest.mark.parametrize("use_v1", [False, True], ids=["code", "code_v1"]) +@pytest.mark.parametrize( + ("use_node_name", "use_job_id"), + [(True, False), (False, True), (True, True)], + ids=["node", "job", "both"], +) +@pytest.mark.asyncio +async def test_code_with_existing_job( + cluster: str, + existing_job: ComputeNode, + use_job_id: bool, + use_node_name: bool, + use_v1: bool, + capsys: pytest.CaptureFixture, + monkeypatch: pytest.MonkeyPatch, +): + """Test using `mila code --job `""" + + path = "bob" + assert use_job_id or use_node_name + + job: int | None = existing_job.job_id if use_job_id else None + node: str | None = None + if use_node_name: + hostname = existing_job.hostname + # We actually need to pass `cn-a001` (node name) as --node, not the entire + # hostname! + node = removesuffix(hostname, ".server.mila.quebec") + + if not use_v1: + + async def _mock_close_async(): + return + + monkeypatch.setattr( + ComputeNode, + ComputeNode.close_async.__name__, + mock_close_async := AsyncMock( + spec=ComputeNode.close_async, side_effect=_mock_close_async + ), + ) + + def _mock_close(): + return + + monkeypatch.setattr( + ComputeNode, + ComputeNode.close.__name__, + mock_close := Mock(spec=ComputeNode.close, side_effect=_mock_close), + ) + + compute_node_or_job_id = await code( + path=path, + command="echo", # replace the usual `code` with `echo` for testing. + persist=True, # todo: Doesn't really make sense to pass --persist when using --job or --node! + job=job, + node=node, + alloc=[], + cluster=cluster, + ) + assert isinstance(compute_node_or_job_id, ComputeNode) + assert compute_node_or_job_id.job_id == existing_job.job_id + assert compute_node_or_job_id.hostname == existing_job.hostname + mock_close_async.assert_not_called() + mock_close.assert_not_called() + else: + node: str | None = None + if use_node_name: + hostname = existing_job.hostname + # We actually need to pass `cn-a001` (node name) as --node, not the entire + # hostname! + node = removesuffix(hostname, ".server.mila.quebec") + + code_v1( + path=path, + command="echo", # replace the usual `code` with `echo` for testing. + persist=True, # so this doesn't try to cancel the running job on exit. + job=job, + node=node, + alloc=[], + cluster=cluster, + ) + # BUG: here it prints that it ends the session, but it *doesn't* end the job. + # This is correct, but misleading. We should probably print something else. + ended_session_string = f"Ended session on {existing_job.hostname!r}" + output = capsys.readouterr().out + assert ended_session_string not in output + + +@doesnt_create_new_jobs +@pytest.mark.asyncio +@pytest.mark.parametrize("use_v1", [False, True], ids=["v2", "v1"]) +async def test_code_with_disk_quota_reached( + monkeypatch: pytest.MonkeyPatch, use_v1: bool +): + if use_v1: + from milatools.cli import commands + + # Makes the test slightly quicker to run. + monkeypatch.setattr(commands, RemoteV1.__name__, Mock(spec=RemoteV1)) + + def _mock_check_disk_quota_v1(remote: RemoteV1 | RemoteV2): + raise MilatoolsUserError( + "ERROR: Your disk quota on the $HOME filesystem is exceeded! " + ) + + mock_check_disk_quota = Mock( + spec=check_disk_quota_v1, side_effect=_mock_check_disk_quota_v1 + ) + monkeypatch.setattr( + disk_quota, check_disk_quota_v1.__name__, mock_check_disk_quota + ) + monkeypatch.setattr( + commands, check_disk_quota_v1.__name__, mock_check_disk_quota + ) + with pytest.raises(MilatoolsUserError): + code_v1( + path="bob", + command="echo", # replace the usual `code` with `echo` for testing. + persist=False, + job=None, + node="bobobo", # to avoid accidental sallocs. + alloc=[], + cluster="bob", + ) + mock_check_disk_quota.assert_called_once() + else: + from milatools.cli import code + + # Makes the test quicker to run by avoiding connecting to the cluster. + monkeypatch.setattr(code, RemoteV2.__name__, Mock(spec=RemoteV2)) + + async def _mock_check_disk_quota(remote: RemoteV1 | RemoteV2): + raise MilatoolsUserError( + "ERROR: Your disk quota on the $HOME filesystem is exceeded! " + ) + + mock_check_disk_quota = AsyncMock( + spec=check_disk_quota, side_effect=_mock_check_disk_quota + ) + monkeypatch.setattr( + disk_quota, check_disk_quota.__name__, mock_check_disk_quota + ) + + monkeypatch.setattr(code, check_disk_quota.__name__, mock_check_disk_quota) + with pytest.raises(MilatoolsUserError): + await code.code( + path="bob", + command="echo", # replace the usual `code` with `echo` for testing. + persist=False, + job=None, + node="bobobo", # to avoid accidental sallocs. + alloc=[], + cluster="bob", + ) + mock_check_disk_quota.assert_called_once() diff --git a/tests/integration/test_code/test_code_beluga__salloc_.txt b/tests/integration/test_code/test_code_beluga__salloc_.txt new file mode 100644 index 00000000..57ab9c77 --- /dev/null +++ b/tests/integration/test_code/test_code_beluga__salloc_.txt @@ -0,0 +1,13 @@ +Checking disk quota on $HOME... +Disk usage: X / LIMIT GiB and X / LIMIT files +Installing VSCode extensions that are on the local machine on beluga. +(beluga) $ cd $SCRATCH && salloc --wckey=milatools_test --account=SLURM_ACCOUNT --nodes=1 --ntasks=1 --cpus-per-task=1 --mem=1G --time=0:05:00 --oversubscribe --job-name=mila-code +salloc: Pending job allocation JOB_ID +Waiting for job JOB_ID to start. +✓ Syncing vscode extensions: N/N ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 H:MM:SS +✓ localhost -> beluga - Done. N/N ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 H:MM:SS +(localhost) $ echo --new-window --wait --remote ssh-remote+COMPUTE_NODE $HOME/bob +--new-window --wait --remote ssh-remote+COMPUTE_NODE $HOME/bob + +The editor was closed. Reopen it with or terminate the process with (maybe twice). +Ended session on 'COMPUTE_NODE' \ No newline at end of file diff --git a/tests/integration/test_code/test_code_beluga__sbatch_.txt b/tests/integration/test_code/test_code_beluga__sbatch_.txt new file mode 100644 index 00000000..e15dc160 --- /dev/null +++ b/tests/integration/test_code/test_code_beluga__sbatch_.txt @@ -0,0 +1,18 @@ +Checking disk quota on $HOME... +Disk usage: X / LIMIT GiB and X / LIMIT files +Installing VSCode extensions that are on the local machine on beluga. +(beluga) $ cd $SCRATCH && sbatch --parsable --wckey=milatools_test --account=SLURM_ACCOUNT --nodes=1 --ntasks=1 --cpus-per-task=1 --mem=1G --time=0:05:00 --oversubscribe --job-name=mila-code --wrap +'srun sleep 7d' +JOB_ID + +✓ Syncing vscode extensions: N/N ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 H:MM:SS +✓ localhost -> beluga - Done. N/N ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 H:MM:SS +(localhost) $ echo --new-window --wait --remote ssh-remote+COMPUTE_NODE $HOME/bob +--new-window --wait --remote ssh-remote+COMPUTE_NODE $HOME/bob + +The editor was closed. Reopen it with or terminate the process with (maybe twice). +This allocation is persistent and is still active. +To reconnect to this job, run the following: + mila code bob --cluster beluga --job JOB_ID +To kill this allocation: + ssh beluga scancel JOB_ID \ No newline at end of file diff --git a/tests/integration/test_code/test_code_cedar__salloc_.txt b/tests/integration/test_code/test_code_cedar__salloc_.txt new file mode 100644 index 00000000..2b013778 --- /dev/null +++ b/tests/integration/test_code/test_code_cedar__salloc_.txt @@ -0,0 +1,11 @@ +Checking disk quota on $HOME... +Disk usage: X / LIMIT GiB and X / LIMIT files +(cedar) $ cd $SCRATCH && salloc --wckey=milatools_test --account=SLURM_ACCOUNT --nodes=1 --ntasks=1 --cpus-per-task=1 --mem=1G --time=0:05:00 --oversubscribe --job-name=mila-code +salloc: NOTE: Your memory request of 1024M was likely submitted as 1G. Please note that Slurm interprets memory requests denominated in G as multiples of 1024M, not 1000M. +salloc: Pending job allocation JOB_ID +Waiting for job JOB_ID to start. +(localhost) $ echo --new-window --wait --remote ssh-remote+COMPUTE_NODE $HOME/bob +--new-window --wait --remote ssh-remote+COMPUTE_NODE $HOME/bob + +The editor was closed. Reopen it with or terminate the process with (maybe twice). +Ended session on 'COMPUTE_NODE' \ No newline at end of file diff --git a/tests/integration/test_code/test_code_cedar__sbatch_.txt b/tests/integration/test_code/test_code_cedar__sbatch_.txt new file mode 100644 index 00000000..98c7702e --- /dev/null +++ b/tests/integration/test_code/test_code_cedar__sbatch_.txt @@ -0,0 +1,17 @@ +Checking disk quota on $HOME... +Disk usage: X / LIMIT GiB and X / LIMIT files +(cedar) $ cd $SCRATCH && sbatch --parsable --wckey=milatools_test --account=SLURM_ACCOUNT --nodes=1 --ntasks=1 --cpus-per-task=1 --mem=1G --time=0:05:00 --oversubscribe --job-name=mila-code --wrap +'srun sleep 7d' +JOB_ID + +sbatch: NOTE: Your memory request of 1024M was likely submitted as 1G. Please note that Slurm interprets memory requests denominated in G as multiples of 1024M, not 1000M. + +(localhost) $ echo --new-window --wait --remote ssh-remote+COMPUTE_NODE $HOME/bob +--new-window --wait --remote ssh-remote+COMPUTE_NODE $HOME/bob + +The editor was closed. Reopen it with or terminate the process with (maybe twice). +This allocation is persistent and is still active. +To reconnect to this job, run the following: + mila code bob --cluster cedar --job JOB_ID +To kill this allocation: + ssh cedar scancel JOB_ID \ No newline at end of file diff --git a/tests/integration/test_code/test_code_graham__salloc_.txt b/tests/integration/test_code/test_code_graham__salloc_.txt new file mode 100644 index 00000000..3ed442f1 --- /dev/null +++ b/tests/integration/test_code/test_code_graham__salloc_.txt @@ -0,0 +1,13 @@ +Checking disk quota on $HOME... +Installing VSCode extensions that are on the local machine on graham. +(graham) $ cd $SCRATCH && salloc --wckey=milatools_test --account=SLURM_ACCOUNT --nodes=1 --ntasks=1 --cpus-per-task=1 --mem=1G --time=0:05:00 --oversubscribe --job-name=mila-code +salloc: NOTE: Your memory request of 1024M was likely submitted as 1G. Please note that Slurm interprets memory requests denominated in G as multiples of 1024M, not 1000M. +salloc: Pending job allocation JOB_ID +Waiting for job JOB_ID to start. +✓ Syncing vscode extensions: N/N ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 H:MM:SS +✓ localhost -> graham - Done. N/N ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 H:MM:SS +(localhost) $ echo --new-window --wait --remote ssh-remote+COMPUTE_NODE $HOME/bob +--new-window --wait --remote ssh-remote+COMPUTE_NODE $HOME/bob + +The editor was closed. Reopen it with or terminate the process with (maybe twice). +Ended session on 'COMPUTE_NODE' \ No newline at end of file diff --git a/tests/integration/test_code/test_code_graham__sbatch_.txt b/tests/integration/test_code/test_code_graham__sbatch_.txt new file mode 100644 index 00000000..050bf45d --- /dev/null +++ b/tests/integration/test_code/test_code_graham__sbatch_.txt @@ -0,0 +1,19 @@ +Checking disk quota on $HOME... +Installing VSCode extensions that are on the local machine on graham. +(graham) $ cd $SCRATCH && sbatch --parsable --wckey=milatools_test --account=SLURM_ACCOUNT --nodes=1 --ntasks=1 --cpus-per-task=1 --mem=1G --time=0:05:00 --oversubscribe --job-name=mila-code --wrap +'srun sleep 7d' +JOB_ID + +sbatch: NOTE: Your memory request of 1024M was likely submitted as 1G. Please note that Slurm interprets memory requests denominated in G as multiples of 1024M, not 1000M. + +✓ Syncing vscode extensions: N/N ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 H:MM:SS +✓ localhost -> graham - Done. N/N ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 H:MM:SS +(localhost) $ echo --new-window --wait --remote ssh-remote+COMPUTE_NODE $HOME/bob +--new-window --wait --remote ssh-remote+COMPUTE_NODE $HOME/bob + +The editor was closed. Reopen it with or terminate the process with (maybe twice). +This allocation is persistent and is still active. +To reconnect to this job, run the following: + mila code bob --cluster graham --job JOB_ID +To kill this allocation: + ssh graham scancel JOB_ID \ No newline at end of file diff --git a/tests/integration/test_code/test_code_mila__salloc_.txt b/tests/integration/test_code/test_code_mila__salloc_.txt new file mode 100644 index 00000000..14a6e463 --- /dev/null +++ b/tests/integration/test_code/test_code_mila__salloc_.txt @@ -0,0 +1,13 @@ +Checking disk quota on $HOME... +Disk usage: X / LIMIT GiB and X / LIMIT files +(mila) $ cd $SCRATCH && salloc --wckey=milatools_test --account=SLURM_ACCOUNT --nodes=1 --ntasks=1 --cpus-per-task=1 --mem=1G --time=0:05:00 --oversubscribe --job-name=mila-code +salloc: -------------------------------------------------------------------------------------------------- +salloc: # Using default long partition +salloc: -------------------------------------------------------------------------------------------------- +salloc: Granted job allocation JOB_ID +Waiting for job JOB_ID to start. +(localhost) $ echo --new-window --wait --remote ssh-remote+COMPUTE_NODE $HOME/bob +--new-window --wait --remote ssh-remote+COMPUTE_NODE $HOME/bob + +The editor was closed. Reopen it with or terminate the process with (maybe twice). +Ended session on 'COMPUTE_NODE' \ No newline at end of file diff --git a/tests/integration/test_code/test_code_mila__sbatch_.txt b/tests/integration/test_code/test_code_mila__sbatch_.txt new file mode 100644 index 00000000..249496de --- /dev/null +++ b/tests/integration/test_code/test_code_mila__sbatch_.txt @@ -0,0 +1,19 @@ +Checking disk quota on $HOME... +Disk usage: X / LIMIT GiB and X / LIMIT files +(mila) $ cd $SCRATCH && sbatch --parsable --wckey=milatools_test --account=SLURM_ACCOUNT --nodes=1 --ntasks=1 --cpus-per-task=1 --mem=1G --time=0:05:00 --oversubscribe --job-name=mila-code --wrap 'srun sleep +7d' +JOB_ID + +sbatch: -------------------------------------------------------------------------------------------------- +sbatch: # Using default long partition +sbatch: -------------------------------------------------------------------------------------------------- + +(localhost) $ echo --new-window --wait --remote ssh-remote+COMPUTE_NODE $HOME/bob +--new-window --wait --remote ssh-remote+COMPUTE_NODE $HOME/bob + +The editor was closed. Reopen it with or terminate the process with (maybe twice). +This allocation is persistent and is still active. +To reconnect to this job, run the following: + mila code bob --job JOB_ID +To kill this allocation: + ssh mila scancel JOB_ID \ No newline at end of file diff --git a/tests/integration/test_code/test_code_narval__salloc_.txt b/tests/integration/test_code/test_code_narval__salloc_.txt new file mode 100644 index 00000000..035f4abc --- /dev/null +++ b/tests/integration/test_code/test_code_narval__salloc_.txt @@ -0,0 +1,13 @@ +Checking disk quota on $HOME... +Disk usage: X / LIMIT GiB and X / LIMIT files +Installing VSCode extensions that are on the local machine on narval. +(narval) $ cd $SCRATCH && salloc --wckey=milatools_test --account=SLURM_ACCOUNT --nodes=1 --ntasks=1 --cpus-per-task=1 --mem=1G --time=0:05:00 --oversubscribe --job-name=mila-code +salloc: Pending job allocation JOB_ID +Waiting for job JOB_ID to start. +✓ Syncing vscode extensions: N/N ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 H:MM:SS +✓ localhost -> narval - Done. N/N ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 H:MM:SS +(localhost) $ echo --new-window --wait --remote ssh-remote+COMPUTE_NODE $HOME/bob +--new-window --wait --remote ssh-remote+COMPUTE_NODE $HOME/bob + +The editor was closed. Reopen it with or terminate the process with (maybe twice). +Ended session on 'COMPUTE_NODE' \ No newline at end of file diff --git a/tests/integration/test_code/test_code_narval__sbatch_.txt b/tests/integration/test_code/test_code_narval__sbatch_.txt new file mode 100644 index 00000000..6bb5e09e --- /dev/null +++ b/tests/integration/test_code/test_code_narval__sbatch_.txt @@ -0,0 +1,18 @@ +Checking disk quota on $HOME... +Disk usage: X / LIMIT GiB and X / LIMIT files +Installing VSCode extensions that are on the local machine on narval. +(narval) $ cd $SCRATCH && sbatch --parsable --wckey=milatools_test --account=SLURM_ACCOUNT --nodes=1 --ntasks=1 --cpus-per-task=1 --mem=1G --time=0:05:00 --oversubscribe --job-name=mila-code --wrap +'srun sleep 7d' +JOB_ID + +✓ Syncing vscode extensions: N/N ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 H:MM:SS +✓ localhost -> narval - Done. N/N ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 H:MM:SS +(localhost) $ echo --new-window --wait --remote ssh-remote+COMPUTE_NODE $HOME/bob +--new-window --wait --remote ssh-remote+COMPUTE_NODE $HOME/bob + +The editor was closed. Reopen it with or terminate the process with (maybe twice). +This allocation is persistent and is still active. +To reconnect to this job, run the following: + mila code bob --cluster narval --job JOB_ID +To kill this allocation: + ssh narval scancel JOB_ID \ No newline at end of file diff --git a/tests/integration/test_code_command.py b/tests/integration/test_code_v1.py similarity index 77% rename from tests/integration/test_code_command.py rename to tests/integration/test_code_v1.py index 21e83b2d..bafc7410 100644 --- a/tests/integration/test_code_command.py +++ b/tests/integration/test_code_v1.py @@ -1,15 +1,13 @@ from __future__ import annotations -import logging import re -import subprocess import time from datetime import timedelta from logging import getLogger as get_logger import pytest -from milatools.cli.commands import check_disk_quota, code +from milatools.cli.commands import code_v1 from milatools.cli.utils import get_hostname_to_use_for_compute_node from milatools.utils.remote_v1 import RemoteV1 from milatools.utils.remote_v2 import RemoteV2 @@ -20,30 +18,11 @@ logger = get_logger(__name__) -@pytest.mark.slow -def test_check_disk_quota( - login_node: RemoteV1 | RemoteV2, - capsys: pytest.LogCaptureFixture, - caplog: pytest.LogCaptureFixture, -): - if login_node.hostname.startswith("graham") or login_node.hostname == "localhost": - with pytest.raises(subprocess.CalledProcessError): - check_disk_quota(remote=login_node) - else: - with caplog.at_level(logging.DEBUG): - check_disk_quota(remote=login_node) - # TODO: Maybe figure out a way to actually test this, (not just by running it and - # expecting no errors). - # Check that it doesn't raise any errors. - # IF the quota is nearly met, then a warning is logged. - # IF the quota is met, then a `MilatoolsUserError` is logged. - - @pytest.mark.slow @launches_jobs @PARAMIKO_SSH_BANNER_BUG @pytest.mark.parametrize("persist", [True, False]) -def test_code( +def test_code_v1( login_node: RemoteV1 | RemoteV2, persist: bool, capsys: pytest.CaptureFixture, @@ -56,7 +35,7 @@ def test_code( home = login_node.run("echo $HOME", display=False, hide=True).stdout.strip() scratch = login_node.get_output("echo $SCRATCH") relative_path = "bob" - code( + code_v1( path=relative_path, command="echo", # replace the usual `code` with `echo` for testing. persist=persist, @@ -96,7 +75,7 @@ def test_code( node_hostname = get_hostname_to_use_for_compute_node( node, cluster=login_node.hostname ) - expected_line = f"(local) $ /usr/bin/echo -nw --remote ssh-remote+{node_hostname} {home}/{relative_path}" + expected_line = f"(localhost) $ /usr/bin/echo -nw --remote ssh-remote+{node_hostname} {home}/{relative_path}" assert any((expected_line in line) for line in captured_output.splitlines()), ( captured_output, expected_line, @@ -117,6 +96,11 @@ def test_code( # completed yet, or sacct doesn't show the change in status quick enough. # Relaxing it a bit for now. # assert "CANCELLED" in job_info["State"] - assert "CANCELLED" in job_info["State"] or job_info["State"] == "RUNNING" + assert "CANCELLED" in job_info["State"] or job_info["State"] in [ + "RUNNING", + # fixme: Not sure why this is the case, but the function is being + # deprecated anyway. + "COMPLETED", + ] finally: login_node.run(f"scancel {job_id}", display=True) diff --git a/tests/integration/test_slurm_remote.py b/tests/integration/test_slurm_remote.py index 7fd174e9..92d5c824 100644 --- a/tests/integration/test_slurm_remote.py +++ b/tests/integration/test_slurm_remote.py @@ -101,7 +101,9 @@ def fabric_connection_to_login_node(login_node: RemoteV1 | RemoteV2): @pytest.fixture def salloc_slurm_remote( - fabric_connection_to_login_node: fabric.Connection, allocation_flags: list[str] + fabric_connection_to_login_node: fabric.Connection, + allocation_flags: list[str], + job_name: str, ): """Fixture that creates a `SlurmRemote` that uses `salloc` (persist=False). @@ -111,18 +113,20 @@ def salloc_slurm_remote( """ return SlurmRemote( connection=fabric_connection_to_login_node, - alloc=allocation_flags, + alloc=allocation_flags + ["--job-name", job_name], ) @pytest.fixture def sbatch_slurm_remote( - fabric_connection_to_login_node: fabric.Connection, allocation_flags: list[str] + fabric_connection_to_login_node: fabric.Connection, + allocation_flags: list[str], + job_name: str, ): """Fixture that creates a `SlurmRemote` that uses `sbatch` (persist=True).""" return SlurmRemote( connection=fabric_connection_to_login_node, - alloc=allocation_flags, + alloc=allocation_flags + ["--job-name", job_name], persist=True, ) diff --git a/tests/integration/test_sync_command.py b/tests/integration/test_sync_command.py index ed27bea5..31f19b4a 100644 --- a/tests/integration/test_sync_command.py +++ b/tests/integration/test_sync_command.py @@ -1,19 +1,22 @@ from __future__ import annotations import importlib +import inspect +import subprocess from logging import getLogger as get_logger from typing import Callable -from unittest.mock import Mock +from unittest.mock import ANY, AsyncMock, Mock import pytest from typing_extensions import ParamSpec -from milatools.utils.local_v1 import LocalV1 +from milatools.utils import vscode_utils +from milatools.utils.local_v2 import LocalV2 from milatools.utils.remote_v2 import RemoteV2 from milatools.utils.vscode_utils import ( - extensions_to_install, - find_code_server_executable, - install_vscode_extensions_task_function, + _extensions_to_install, + _find_code_server_executable, + _install_vscode_extensions_task_function, sync_vscode_extensions, ) @@ -25,6 +28,7 @@ logger = get_logger(__name__) +@pytest.mark.slow @pytest.mark.parametrize( "source", [ @@ -39,10 +43,12 @@ "cluster", ], ) -def test_sync_vscode_extensions( +@pytest.mark.asyncio +async def test_sync_vscode_extensions( source: str, dest: str, cluster: str, + login_node_v2: RemoteV2, monkeypatch: pytest.MonkeyPatch, ): if source == "cluster": @@ -53,31 +59,56 @@ def test_sync_vscode_extensions( if source == dest: pytest.skip("Source and destination are the same.") - def _mock_and_patch( - wraps: Callable, - _mock_type: Callable[P, Mock] = Mock, - *mock_args: P.args, - **mock_kwargs: P.kwargs, - ): + def mock_and_patch(wraps: Callable, *mock_args, **mock_kwargs): + mock_kwargs = mock_kwargs.copy() mock_kwargs["wraps"] = wraps + _mock_type = AsyncMock if inspect.iscoroutinefunction(wraps) else Mock mock = _mock_type(*mock_args, **mock_kwargs) module = importlib.import_module(wraps.__module__) monkeypatch.setattr(module, wraps.__name__, mock) return mock - mock_task_function = _mock_and_patch(wraps=install_vscode_extensions_task_function) + mock_task_function = mock_and_patch( + wraps=_install_vscode_extensions_task_function, + ) + extension, version = "ms-python.python", "v2024.0.1" + # Make it so we only need to install this particular extension. - mock_extensions_to_install = _mock_and_patch( - wraps=extensions_to_install, return_value={"ms-python.python": "v2024.0.1"} + mock_extensions_to_install = mock_and_patch( + wraps=_extensions_to_install, + return_value={extension: version}, + ) + mock_find_code_server_executable = mock_and_patch( + wraps=_find_code_server_executable, + ) + from milatools.utils.vscode_utils import _install_vscode_extension + + mock_install_extension = AsyncMock( + spec=_install_vscode_extension, + return_value=subprocess.CompletedProcess( + args=["..."], + returncode=0, + stdout=f"Successfully installed {extension}@{version}", + ), + ) + monkeypatch.setattr( + vscode_utils, _install_vscode_extension.__name__, mock_install_extension ) - mock_find_code_server_executable = _mock_and_patch( - wraps=find_code_server_executable, + + # Avoid actually installing this (possibly oudated?) extension. + extensions_per_cluster = await sync_vscode_extensions( + source=LocalV2() if source == "localhost" else login_node_v2, + destinations=[dest], ) + assert extensions_per_cluster == {dest: [f"{extension}@{version}"]} - sync_vscode_extensions( - source=LocalV1() if source == "localhost" else RemoteV2(source), - dest_clusters=[dest], + mock_install_extension.assert_called_once_with( + LocalV2() if dest == "localhost" else login_node_v2, + code_server_executable=ANY, + extension=f"{extension}@{version}", + verbose=ANY, ) + mock_task_function.assert_called_once() mock_extensions_to_install.assert_called_once() if source == "localhost": diff --git a/tests/utils/test_compute_node.py b/tests/utils/test_compute_node.py index c9259324..52961bed 100644 --- a/tests/utils/test_compute_node.py +++ b/tests/utils/test_compute_node.py @@ -139,6 +139,8 @@ async def get_new_job_ids() -> set[int]: await asyncio.sleep(0.1) return new_job_ids + allocation_flags = allocation_flags + ["--job-name", job_name] + # Check that a job allocation was indeed created. # NOTE: Assuming that it takes more time for the job to be allocated than it takes for # the job to show up in `squeue`. diff --git a/tests/utils/test_disk_quota.py b/tests/utils/test_disk_quota.py new file mode 100644 index 00000000..7b792759 --- /dev/null +++ b/tests/utils/test_disk_quota.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +import inspect +import logging +import subprocess +from typing import Callable + +import pytest + +from milatools.utils.disk_quota import check_disk_quota, check_disk_quota_v1 +from milatools.utils.remote_v1 import RemoteV1 +from milatools.utils.remote_v2 import RemoteV2 + + +@pytest.mark.slow +@pytest.mark.asyncio +@pytest.mark.parametrize("check_disk_quota_fn", [check_disk_quota, check_disk_quota_v1]) +async def test_check_disk_quota( + login_node_v2: RemoteV2, + caplog: pytest.LogCaptureFixture, + check_disk_quota_fn: Callable[[RemoteV1 | RemoteV2], None], +): + # TODO: Figure out a way to actually test this, (not just by running it and + # expecting no errors). + # Check that it doesn't raise any errors. + # IF the quota is nearly met, then a warning is logged. + # IF the quota is met, then a `MilatoolsUserError` is logged. + async def _check_disk_quota(): + if inspect.iscoroutinefunction(check_disk_quota_fn): + await check_disk_quota_fn(login_node_v2) + else: + check_disk_quota_fn(login_node_v2) + + if ( + login_node_v2.hostname.startswith("graham") + or login_node_v2.hostname == "localhost" + ): + with pytest.raises(subprocess.CalledProcessError): + await _check_disk_quota() + + else: + with caplog.at_level(logging.DEBUG): + await _check_disk_quota() diff --git a/tests/utils/test_parallel_progress.py b/tests/utils/test_parallel_progress.py index 0e5e3245..78138de7 100644 --- a/tests/utils/test_parallel_progress.py +++ b/tests/utils/test_parallel_progress.py @@ -1,20 +1,20 @@ from __future__ import annotations +import asyncio import functools import time from logging import getLogger as get_logger from typing import TypeVar +import pytest from pytest_regressions.file_regression import FileRegressionFixture from milatools.cli import console from milatools.cli.utils import removesuffix from milatools.utils.parallel_progress import ( - DictProxy, - ProgressDict, - TaskFn, - TaskID, - parallel_progress_bar, + AsyncTaskFn, + ReportProgressFn, + run_async_tasks_with_progress_bar, ) from ..cli.common import xfails_on_windows @@ -24,28 +24,20 @@ OutT = TypeVar("OutT") -def _task_fn( - task_progress_dict: DictProxy[TaskID, ProgressDict], - task_id: TaskID, +async def _async_task_fn( + report_progress: ReportProgressFn, + task_id: int, task_length: int, result: OutT, ) -> OutT: - task_progress_dict[task_id] = { - "progress": 0, - "total": task_length, - "info": "Starting task.", - } + report_progress(0, task_length, "Starting task.") for n in range(task_length): - time.sleep(1.0) # sleep for a bit to simulate work + await asyncio.sleep(1.0) # sleep for a bit to simulate work logger.debug(f"Task {task_id} is {n+1}/{task_length} done.") - task_progress_dict[task_id] = {"progress": n + 1, "total": task_length} + report_progress(n + 1, task_length) - task_progress_dict[task_id] = { - "progress": task_length, - "total": task_length, - "info": "Done.", - } + report_progress(task_length, task_length, "Done.") return result @@ -54,34 +46,29 @@ def _task_fn( reason="Output is weird on windows? something to do with linebreaks perhaps.", strict=True, ) -def test_parallel_progress_bar(file_regression: FileRegressionFixture): +@pytest.mark.asyncio +async def test_async_progress_bar(file_regression: FileRegressionFixture): num_tasks = 4 task_length = 5 task_lengths = [task_length for _ in range(num_tasks)] task_results = [i for i in range(num_tasks)] - task_fns: list[TaskFn[int]] = [ - # pylance doesn't sees this as `Partial[int]` because it doesn't "save" the rest - # of the signature. Ignoring the type error here. - functools.partial(_task_fn, task_length=task_length, result=result) # type: ignore - for task_length, result in zip(task_lengths, task_results) + task_fns: list[AsyncTaskFn[int]] = [ + functools.partial( + _async_task_fn, task_id=i, task_length=task_length, result=result + ) + for i, (task_length, result) in enumerate(zip(task_lengths, task_results)) ] start_time = time.time() - - console.begin_capture() - - time_to_results: list[float] = [] - results: list[int] = [] - for result in parallel_progress_bar(task_fns, n_workers=num_tasks): - results.append(result) - time_to_result = time.time() - start_time - time_to_results.append(time_to_result) - + with console.capture() as capture: + # NOTE: the results are returned as a list (all at the same time). + results = await run_async_tasks_with_progress_bar( + task_fns, _show_elapsed_time=False + ) assert results == task_results - all_output = console.end_capture() - + all_output = capture.get() # Remove the elapsed column since its values can vary a little bit between runs. all_output_without_elapsed = "\n".join( removesuffix(line, last_part).rstrip() @@ -97,3 +84,67 @@ def test_parallel_progress_bar(file_regression: FileRegressionFixture): # All tasks sleep for `task_length` seconds, so the total time should still be # roughly `task_length` seconds. assert total_time_seconds < 2 * task_length + + +@pytest.mark.asyncio +async def test_interrupt_progress_bar(): + """Test the case where one of the tasks raises an exception.""" + num_tasks = 4 + task_length = 5 + task_lengths = [task_length for _ in range(num_tasks)] + task_results = [i for i in range(num_tasks)] + + task_fns: list[AsyncTaskFn] = [ + functools.partial( + _async_task_fn, task_id=i, task_length=task_length, result=result + ) + for i, (task_length, result) in enumerate(zip(task_lengths, task_results)) + ] + + # todo: seems not possible to raise KeyboardInterrupt, it seems to mess with + # pytest-asyncio. Would be good to test it though. + exception_type = asyncio.CancelledError + + async def _task_that_raises_an_exception( + report_progress: ReportProgressFn, + task_length: int, + ): + report_progress(0, task_length, "Starting task.") + # Raise an exception midway through the task. + await asyncio.sleep(task_length / 2) + report_progress( + task_length // 2, + task_length, + f"Done sleeping, about to raise a {exception_type.__name__}.", + ) + raise exception_type() + + task_that_raises_an_exception = functools.partial( + _task_that_raises_an_exception, task_length=task_length + ) + + results = None + with pytest.raises(exception_type): + results = await run_async_tasks_with_progress_bar( + task_fns + [task_that_raises_an_exception] + ) + # Results was not set. + assert results is None + + # Other test case: the interrupt is raised from "outside the progress bar". + # Check that the "outside" task raising an exception doesn't cancel the tasks in + # the "progress bar group". + + async def _raise_after(delay: int): + await asyncio.sleep(delay) + raise exception_type() + + results, exception = await asyncio.gather( + run_async_tasks_with_progress_bar(task_fns), + _raise_after(1), + return_exceptions=True, + ) + # The result from the progress bar should be there, and the exception from the other + # task is also there. + assert results == task_results + assert isinstance(exception, exception_type) diff --git a/tests/utils/test_parallel_progress/test_parallel_progress_bar.txt b/tests/utils/test_parallel_progress/test_async_progress_bar.txt similarity index 87% rename from tests/utils/test_parallel_progress/test_parallel_progress_bar.txt rename to tests/utils/test_parallel_progress/test_async_progress_bar.txt index cf817d7b..713aa400 100644 --- a/tests/utils/test_parallel_progress/test_parallel_progress_bar.txt +++ b/tests/utils/test_parallel_progress/test_async_progress_bar.txt @@ -1,5 +1,5 @@ -✓ All jobs progress: 20/20 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 -✓ Task 0 - Done. 5/5 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 -✓ Task 1 - Done. 5/5 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 -✓ Task 2 - Done. 5/5 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 -✓ Task 3 - Done. 5/5 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 \ No newline at end of file +✓ All jobs progress: 20/20 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% +✓ Task 0 - Done. 5/5 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% +✓ Task 1 - Done. 5/5 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% +✓ Task 2 - Done. 5/5 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% +✓ Task 3 - Done. 5/5 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% \ No newline at end of file diff --git a/tests/utils/test_vscode_utils.py b/tests/utils/test_vscode_utils.py index 53c05af6..1253b69c 100644 --- a/tests/utils/test_vscode_utils.py +++ b/tests/utils/test_vscode_utils.py @@ -1,44 +1,45 @@ from __future__ import annotations -import getpass -import multiprocessing +import functools import shutil import sys from logging import getLogger as get_logger -from multiprocessing.managers import DictProxy from pathlib import Path -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock import pytest - -from milatools.cli.utils import running_inside_WSL -from milatools.utils.local_v1 import LocalV1 -from milatools.utils.parallel_progress import ProgressDict +import pytest_asyncio + +from milatools.cli.utils import MilatoolsUserError, running_inside_WSL +from milatools.utils import vscode_utils +from milatools.utils.local_v2 import LocalV2 +from milatools.utils.parallel_progress import ( + ProgressDict, + report_progress, +) from milatools.utils.remote_v1 import RemoteV1 -from milatools.utils.remote_v2 import RemoteV2, UnsupportedPlatformError +from milatools.utils.remote_v2 import RemoteV2 from milatools.utils.vscode_utils import ( - extensions_to_install, - find_code_server_executable, + _extensions_to_install, + _find_code_server_executable, + _get_local_vscode_executable_path, + _get_vscode_extensions, + _get_vscode_extensions_dict, + _install_vscode_extension, + _install_vscode_extensions_task_function, get_code_command, get_expected_vscode_settings_json_path, - get_local_vscode_extensions, - get_remote_vscode_extensions, - get_vscode_executable_path, - install_vscode_extension, - install_vscode_extensions_task_function, sync_vscode_extensions, - sync_vscode_extensions_with_hostnames, vscode_installed, ) -from tests.integration.conftest import skip_if_not_already_logged_in from ..cli.common import ( in_github_CI, in_self_hosted_github_CI, requires_ssh_to_localhost, skip_if_on_github_cloud_CI, - xfails_on_windows, ) +from .test_remote_v2 import uses_remote_v2 logger = get_logger(__name__) @@ -69,10 +70,6 @@ def test_vscode_installed_with_env_var( strict=True, ) -uses_remote_v2 = xfails_on_windows( - raises=UnsupportedPlatformError, reason="Uses RemoteV2", strict=True -) - @requires_vscode def test_get_expected_vscode_settings_json_path(): @@ -91,62 +88,64 @@ def test_running_inside_WSL(): def test_get_vscode_executable_path(): - code = get_vscode_executable_path() if vscode_installed(): - assert code is not None and Path(code).exists() + code = _get_local_vscode_executable_path() + assert Path(code).exists() else: - assert code is None + with pytest.raises( + MilatoolsUserError, match="Command 'code' does not exist locally." + ): + _get_local_vscode_executable_path() @pytest.fixture def mock_find_code_server_executable(monkeypatch: pytest.MonkeyPatch): - """Makes it so we use the local `code` executable instead of `code-server`.""" + """Makes it so we use the local `code` executable instead of `code-server`. + + This makes it possible to treat the `code` executable on localhost just like the + `code-server` executable on remote machines because they have mostly the same CLI. + """ import milatools.utils.vscode_utils - mock_find_code_server_executable = Mock( - spec=find_code_server_executable, return_value=get_vscode_executable_path() + mock_find_code_server_executable = AsyncMock( + spec=_find_code_server_executable, + return_value=_get_local_vscode_executable_path(), ) monkeypatch.setattr( milatools.utils.vscode_utils, - find_code_server_executable.__name__, + _find_code_server_executable.__name__, mock_find_code_server_executable, ) return mock_find_code_server_executable -@xfails_on_windows(raises=UnsupportedPlatformError, reason="Uses RemoteV2", strict=True) +@uses_remote_v2 @requires_vscode @requires_ssh_to_localhost -def test_sync_vscode_extensions_in_parallel_with_hostnames( - monkeypatch: pytest.MonkeyPatch, +@pytest.mark.asyncio +async def test_sync_vscode_extensions( + mock_find_code_server_executable: Mock, monkeypatch: pytest.MonkeyPatch ): - import milatools.utils.vscode_utils - - # Make it so we use the local `code` executable instead of `code-server`. + # Skip the check that removes the source from the destinations. monkeypatch.setattr( - milatools.utils.vscode_utils, - find_code_server_executable.__name__, - Mock( - spec=find_code_server_executable, return_value=get_vscode_executable_path() - ), - ) - sync_vscode_extensions_with_hostnames( - # Make the destination slightly different so it actually gets wrapped as a - # `Remote(v2)` object. - "localhost", - destinations=[f"{getpass.getuser()}@localhost"], + vscode_utils, + vscode_utils._remove_source_from_destinations.__name__, + lambda source, destinations: destinations, ) - -@requires_vscode -@requires_ssh_to_localhost -def test_sync_vscode_extensions_in_parallel(): - results = sync_vscode_extensions(LocalV1(), dest_clusters=[LocalV1()]) - assert results == {"localhost": {"info": "Done.", "progress": 0, "total": 0}} + remote = await RemoteV2.connect("localhost") + results = await sync_vscode_extensions( + remote, + # Make the destination slightly different to avoid the duplicate hostname + # detection that happens in `sync_vscode_extensions`. + destinations=[remote], + ) + assert results == {"localhost": []} + mock_find_code_server_executable.assert_called() -@pytest.fixture -def vscode_extensions( +@pytest_asyncio.fixture +async def vscode_extensions( request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch ) -> tuple[dict[str, str], dict[str, str], dict[str, str]]: """Returns a dict of vscode extension names and versions to be installed locally. @@ -154,7 +153,7 @@ def vscode_extensions( Here we pretend like some local vscode extensions are missing by patching the function that returns the local extensions to return only part of its actual result. """ - all_extensions = get_local_vscode_extensions() + all_extensions = await _get_vscode_extensions(LocalV2()) installed_extensions = all_extensions.copy() num_missing_extensions = 3 @@ -167,13 +166,13 @@ def vscode_extensions( # `localhost` is the source, so it has all the extensions # the "remote" (just to localhost during tests) is missing some extensions - mock_remote_extensions = Mock( - spec=get_remote_vscode_extensions, - return_value=(installed_extensions, str(get_vscode_executable_path())), + mock_remote_extensions = AsyncMock( + spec=_get_vscode_extensions_dict, + return_value=(installed_extensions, str(_get_local_vscode_executable_path())), ) monkeypatch.setattr( milatools.utils.vscode_utils, - get_remote_vscode_extensions.__name__, + _get_vscode_extensions_dict.__name__, mock_remote_extensions, ) @@ -212,49 +211,53 @@ def _remote(hostname: str): @uses_remote_v2 @requires_ssh_to_localhost @requires_vscode -def test_install_vscode_extensions_task_function( +@pytest.mark.asyncio +async def test_install_vscode_extensions_task_function( installed_extensions: dict[str, str], missing_extensions: dict[str, str], mock_find_code_server_executable: Mock, ): - with multiprocessing.Manager() as manager: - from milatools.utils.parallel_progress import TaskID + from milatools.utils.parallel_progress import TaskID - logger.debug(f"{len(installed_extensions)=}, {len(missing_extensions)=}") - # Pretend like we don't already have these extensions locally. + logger.debug(f"{len(installed_extensions)=}, {len(missing_extensions)=}") + # Pretend like we don't already have these extensions locally. - task_progress_dict: DictProxy[TaskID, ProgressDict] = manager.dict() - - _fake_remote = _remote("localhost") - - result = install_vscode_extensions_task_function( - task_progress_dict=task_progress_dict, + task_progress_dict: dict[TaskID, ProgressDict] = {} + _fake_remote = await RemoteV2.connect("localhost") + result = await _install_vscode_extensions_task_function( + report_progress=functools.partial( + report_progress, + progress_dict=task_progress_dict, task_id=TaskID(0), - dest_hostname="fake_cluster", - source_extensions=missing_extensions, - remote=_fake_remote, - source_name="localhost", - ) - mock_find_code_server_executable.assert_called_once_with( - _fake_remote, remote_vscode_server_dir="~/.vscode-server" - ) - - assert result == { - "info": "Done.", - "progress": len(missing_extensions), - "total": len(missing_extensions), - } - assert task_progress_dict[TaskID(0)] == result + ), + source_extensions=missing_extensions, + remote=_fake_remote, + source_name="localhost", + ) + mock_find_code_server_executable.assert_called_once_with( + _fake_remote, remote_vscode_server_dir="~/.vscode-server" + ) + + assert result == [ + f"{ext_name}@{ext_version}" + for ext_name, ext_version in missing_extensions.items() + ] + assert task_progress_dict[TaskID(0)] == { + "info": "Done.", + "progress": len(missing_extensions), + "total": len(missing_extensions), + } @uses_remote_v2 @requires_ssh_to_localhost @requires_vscode -def test_install_vscode_extension(missing_extensions: dict[str, str]): +@pytest.mark.asyncio +async def test_install_vscode_extension(missing_extensions: dict[str, str]): extension_name, version = next(iter(missing_extensions.items())) - result = install_vscode_extension( - remote=_remote("localhost"), - code_server_executable=str(get_vscode_executable_path()), + result = await _install_vscode_extension( + remote=(await RemoteV2.connect("localhost")), + code_server_executable=str(_get_local_vscode_executable_path()), extension=f"{extension_name}@{version}", verbose=False, ) @@ -267,8 +270,10 @@ def test_install_vscode_extension(missing_extensions: dict[str, str]): @requires_ssh_to_localhost @requires_vscode -def test_get_local_vscode_extensions(): - local_extensions = get_local_vscode_extensions() +@pytest.mark.asyncio +async def test_get_local_vscode_extensions(): + local_extensions = await _get_vscode_extensions(LocalV2()) + assert local_extensions and all( isinstance(ext, str) and isinstance(version, str) for ext, version in local_extensions.items() @@ -278,19 +283,19 @@ def test_get_local_vscode_extensions(): @uses_remote_v2 @requires_ssh_to_localhost @requires_vscode -def test_get_remote_vscode_extensions(): +@pytest.mark.asyncio +async def test_get_remote_vscode_extensions(mock_find_code_server_executable): # We make it so this calls the local `code` command over SSH to localhost, # therefore the "remote" extensions are the same as the local extensions. - fake_remote = _remote("localhost") + fake_remote = await RemoteV2.connect("localhost") - local_vscode_executable = get_vscode_executable_path() + local_vscode_executable = _get_local_vscode_executable_path() assert local_vscode_executable is not None - fake_remote_extensions = get_remote_vscode_extensions( - fake_remote, remote_code_server_executable=local_vscode_executable + fake_remote_extensions = await _get_vscode_extensions_dict( + fake_remote, code_server_executable=local_vscode_executable ) - # Because of the mocking we did above, this should be true - assert fake_remote_extensions == get_local_vscode_extensions() + assert fake_remote_extensions == await _get_vscode_extensions(LocalV2()) @requires_vscode @@ -299,7 +304,7 @@ def test_extensions_to_install( installed_extensions: dict[str, str], missing_extensions: dict[str, str], ): - to_install = extensions_to_install( + to_install = _extensions_to_install( source_extensions=all_extensions, dest_extensions=installed_extensions, source_name="foo", @@ -308,40 +313,35 @@ def test_extensions_to_install( assert to_install == missing_extensions -@uses_remote_v2 +# TODO: This test assumes that the `code-server` executable already exists for you (the +# dev) on the slurm cluster used in tests.. This is not ideal! +# Perhaps we could remove `vscode-server` from one of the clusters before running the +# tests? However this sounds a bit dangerous. +@pytest.mark.slow +@skip_if_on_github_cloud_CI +@pytest.mark.asyncio @pytest.mark.parametrize( - ("cluster", "remote_vscode_server_dir", "should_exist"), + ("remote_vscode_server_dir", "should_exist"), [ - pytest.param( - "localhost", - "~/vscode", - False, - marks=[ - skip_if_on_github_cloud_CI, - requires_ssh_to_localhost, - requires_vscode, - ], - ), - pytest.param( - "mila", + ( "~/.vscode-server", - True, - marks=[ - skip_if_on_github_cloud_CI, - skip_if_not_already_logged_in("mila"), - ], + True, # todo: Replace this hard-coded value with something smarter. ), + ("~/.vscode-server-dir-that-doesnt-exist", False), ], ) -def test_find_code_server_executable( - cluster: str, remote_vscode_server_dir: str, should_exist: bool +async def test_find_code_server_executable( + login_node_v2: RemoteV2, remote_vscode_server_dir: str, should_exist: bool ): - code_server_exe_path = find_code_server_executable( - RemoteV2(cluster), remote_vscode_server_dir=remote_vscode_server_dir + # NOTE: The `find` command in $HOME takes a very long time to run! + code_server_exe_path = await _find_code_server_executable( + login_node_v2, + remote_vscode_server_dir=remote_vscode_server_dir, ) if not should_exist: assert code_server_exe_path is None else: - assert code_server_exe_path and code_server_exe_path.startswith( - code_server_exe_path - ) + assert code_server_exe_path + remote_home = await login_node_v2.get_output_async("echo $HOME") + expected_dir = remote_vscode_server_dir.replace("~", remote_home) + assert code_server_exe_path.startswith(expected_dir)