Skip to content

Commit

Permalink
(only) add type hints (#75)
Browse files Browse the repository at this point in the history
* (only) add type hints

Signed-off-by: Fabrice Normandin <[email protected]>

* Adapted tests that were changed by explicit args

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix some other tests that had explicit args added

Signed-off-by: Fabrice Normandin <[email protected]>

---------

Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice authored Dec 7, 2023
1 parent cd4a359 commit 4fc25c4
Show file tree
Hide file tree
Showing 8 changed files with 490 additions and 183 deletions.
10 changes: 6 additions & 4 deletions milatools/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ def main():
# These are user errors and should not be reported
print("ERROR:", exc, file=sys.stderr)
except SSHConnectionError as err:
# These are errors coming from paramiko's failure to connect to the host
# These are errors coming from paramiko's failure to connect to the
# host
print("ERROR:", f"{err}", file=sys.stderr)
except Exception:
print(T.red(traceback.format_exc()), file=sys.stderr)
Expand Down Expand Up @@ -633,6 +634,7 @@ def code(
print("To reconnect to this node:")
print(T.bold(f" mila code {path} --node {node_name}"))
print("To kill this allocation:")
assert "jobid" in data
print(T.bold(f" ssh mila scancel {data['jobid']}"))


Expand Down Expand Up @@ -914,8 +916,8 @@ def _standard_server(
path: str | None,
*,
program: str,
installers,
command,
installers: dict[str, str],
command: str,
profile: str | None,
persist: bool,
port: int | None,
Expand Down Expand Up @@ -1168,7 +1170,7 @@ def check_disk_quota(remote: Remote) -> None:


def _find_allocation(
remote,
remote: Remote,
node: str | None,
job: str | None,
alloc: Sequence[str],
Expand Down
80 changes: 44 additions & 36 deletions milatools/cli/local.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,73 @@
from __future__ import annotations

import shlex
import subprocess
from subprocess import CompletedProcess
from typing import IO, Any

from typing_extensions import deprecated

from .utils import CommandNotFoundError, T, shjoin


class Local:
def display(self, args):
def display(self, args: list[str] | tuple[str, ...]) -> None:
print(T.bold_green("(local) $ ", shjoin(args)))

def silent_get(self, *args, **kwargs):
return subprocess.check_output(
args,
universal_newlines=True,
**kwargs,
)
def silent_get(self, *cmd: str) -> str:
return subprocess.check_output(cmd, universal_newlines=True)

def get(self, *args, **kwargs):
self.display(args)
return subprocess.check_output(
args,
universal_newlines=True,
**kwargs,
)
@deprecated("This isn't used and will probably be removed. Don't start using it.")
def get(self, *cmd: str) -> str:
self.display(cmd)
return subprocess.check_output(cmd, universal_newlines=True)

def run(self, *args, **kwargs):
self.display(args)
def run(
self,
*cmd: str,
stdout: int | IO[Any] | None = None,
stderr: int | IO[Any] | None = None,
capture_output: bool = False,
) -> CompletedProcess[str]:
self.display(cmd)
try:
return subprocess.run(
args,
cmd,
stdout=stdout,
stderr=stderr,
capture_output=capture_output,
universal_newlines=True,
**kwargs,
)
except FileNotFoundError as e:
if e.filename == args[0]:
raise CommandNotFoundError(e.filename)
else:
raise
if e.filename == cmd[0]:
raise CommandNotFoundError(e.filename) from e
raise

def popen(self, *args, **kwargs):
self.display(args)
def popen(
self,
*cmd: str,
stdout: int | IO[Any] | None = None,
stderr: int | IO[Any] | None = None,
) -> subprocess.Popen:
self.display(cmd)
return subprocess.Popen(
args,
universal_newlines=True,
**kwargs,
cmd, stdout=stdout, stderr=stderr, universal_newlines=True
)

def check_passwordless(self, host):
def check_passwordless(self, host: str) -> bool:
results = self.run(
"ssh",
"-oPreferredAuthentications=publickey",
host,
"echo OK",
*shlex.split(f"ssh -oPreferredAuthentications=publickey {host} 'echo OK'"),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
if results.returncode != 0:
if "Permission denied" in results.stderr:
return False
else:
print(results.stdout)
print(results.stderr)
exit(f"Failed to connect to {host}, could not understand error")
print(results.stdout)
print(results.stderr)
exit(f"Failed to connect to {host}, could not understand error")
# TODO: Perhaps we could actually check the output of the command here!
# elif "OK" in results.stdout:
else:
print("# OK")
return True
Loading

0 comments on commit 4fc25c4

Please sign in to comment.