From 2a703fa2160d1b02afab677a67dd4268e1fffbb5 Mon Sep 17 00:00:00 2001 From: Alex Tercete Date: Tue, 3 Dec 2024 15:36:41 +0000 Subject: [PATCH] driver/sshdriver: read from tools SSHDriver now supports the `tools` key in the config to customize the path to `ssh`, `scp`, `sshfs` and `rsync`. It falls back to loading them from the PATH, so there's no change in behaviour if the tools aren't specified. Signed-off-by: Alex Tercete --- labgrid/driver/sshdriver.py | 40 +++++++++++++++++++++++------------ man/labgrid-device-config.5 | 16 ++++++++++++++ man/labgrid-device-config.rst | 16 ++++++++++++++ tests/test_sshdriver.py | 22 +++++++++++++++++++ 4 files changed, 80 insertions(+), 14 deletions(-) diff --git a/labgrid/driver/sshdriver.py b/labgrid/driver/sshdriver.py index b9efe7c06..1d0635abe 100644 --- a/labgrid/driver/sshdriver.py +++ b/labgrid/driver/sshdriver.py @@ -41,6 +41,15 @@ class SSHDriver(CommandMixin, Driver, CommandProtocol, FileTransferProtocol): def __attrs_post_init__(self): super().__attrs_post_init__() self._keepalive = None + self._ssh = self._get_tool("ssh") + self._scp = self._get_tool("scp") + self._sshfs = self._get_tool("sshfs") + self._rsync = self._get_tool("rsync") + + def _get_tool(self, name): + if self.target.env: + return self.target.env.config.get_tool(name) + return name def _get_username(self): """Get the username from this class or from NetworkService""" @@ -105,7 +114,7 @@ def _start_own_master_once(self, timeout): self.tmpdir, f'control-{self.networkservice.address}' ) - args = ["ssh", "-f", *self.ssh_prefix, "-x", "-o", f"ConnectTimeout={timeout}", + args = [self._ssh, "-f", *self.ssh_prefix, "-x", "-o", f"ConnectTimeout={timeout}", "-o", "ControlPersist=300", "-o", "UserKnownHostsFile=/dev/null", "-o", "StrictHostKeyChecking=no", "-o", "ServerAliveInterval=15", "-MN", "-S", control.replace('%', '%%'), "-p", @@ -203,7 +212,7 @@ def _run(self, cmd, codec="utf-8", decodeerrors="strict", timeout=None): if not self._check_keepalive(): raise ExecutionError("Keepalive no longer running") - complete_cmd = ["ssh", "-x", *self.ssh_prefix, + complete_cmd = [self._ssh, "-x", *self.ssh_prefix, "-p", str(self.networkservice.port), "-l", self._get_username(), self.networkservice.address ] + cmd.split(" ") @@ -238,7 +247,7 @@ def interact(self, cmd=None): if not self._check_keepalive(): raise ExecutionError("Keepalive no longer running") - complete_cmd = ["ssh", "-x", *self.ssh_prefix, + complete_cmd = [self._ssh, "-x", *self.ssh_prefix, "-t", self.networkservice.address ] @@ -252,7 +261,7 @@ def interact(self, cmd=None): @contextlib.contextmanager def _forward(self, forward): - cmd = ["ssh", *self.ssh_prefix, + cmd = [self._ssh, *self.ssh_prefix, "-O", "forward", forward, self.networkservice.address ] @@ -261,7 +270,7 @@ def _forward(self, forward): try: yield finally: - cmd = ["ssh", *self.ssh_prefix, + cmd = [self._ssh, *self.ssh_prefix, "-O", "cancel", forward, self.networkservice.address ] @@ -361,7 +370,8 @@ def scp(self, *, src, dst): if dst.startswith(':'): dst = '_' + dst - complete_cmd = ["scp", + complete_cmd = [self._scp, + "-S", self._ssh, "-F", "none", "-o", f"ControlPath={self.control.replace('%', '%%')}", src, dst, @@ -391,12 +401,12 @@ def rsync(self, *, src, dst, extra=[]): if dst.startswith(':'): dst = '_' + dst - ssh_cmd = ["ssh", + ssh_cmd = [self._ssh, "-F", "none", "-o", f"ControlPath={self.control.replace('%', '%%')}", ] - complete_cmd = ["rsync", + complete_cmd = [self._rsync, "-v", f"--rsh={' '.join(ssh_cmd)}", "-rlpt", # --recursive --links --perms --times @@ -417,7 +427,7 @@ def sshfs(self, *, path, mountpoint): if not self._check_keepalive(): raise ExecutionError("Keepalive no longer running") - complete_cmd = ["sshfs", + complete_cmd = [self._sshfs, "-F", "none", "-f", "-o", f"ControlPath={self.control.replace('%', '%%')}", @@ -445,7 +455,7 @@ def get_status(self): @cached_property def _ssh_version(self): - version = subprocess.run(["ssh", "-V"], capture_output=True, text=True) + version = subprocess.run([self._ssh, "-V"], capture_output=True, text=True) version = re.match(r"^OpenSSH_(\d+)\.(\d+)", version.stderr) return tuple(int(x) for x in version.groups()) @@ -472,7 +482,8 @@ def _scp_supports_explicit_scp_mode(self): @step(args=['filename', 'remotepath']) def put(self, filename, remotepath=''): transfer_cmd = [ - "scp", + self._scp, + "-S", self._ssh, *self.ssh_prefix, "-P", str(self.networkservice.port), "-r", @@ -502,7 +513,8 @@ def put(self, filename, remotepath=''): @step(args=['filename', 'destination']) def get(self, filename, destination="."): transfer_cmd = [ - "scp", + self._scp, + "-S", self._ssh, *self.ssh_prefix, "-P", str(self.networkservice.port), "-r", @@ -530,7 +542,7 @@ def get(self, filename, destination="."): def _cleanup_own_master(self): """Exit the controlmaster and delete the tmpdir""" - complete_cmd = f"ssh -x -o ControlPath={self.control.replace('%', '%%')} -O exit -p {self.networkservice.port} -l {self._get_username()} {self.networkservice.address}".split(' ') # pylint: disable=line-too-long + complete_cmd = f"{self._ssh} -x -o ControlPath={self.control.replace('%', '%%')} -O exit -p {self.networkservice.port} -l {self._get_username()} {self.networkservice.address}".split(' ') # pylint: disable=line-too-long res = subprocess.call( complete_cmd, stdin=subprocess.DEVNULL, @@ -547,7 +559,7 @@ def _cleanup_own_master(self): def _start_keepalive(self): """Starts a keepalive connection via the own or external master.""" - args = ["ssh", *self.ssh_prefix, self.networkservice.address, "cat"] + args = [self._ssh, *self.ssh_prefix, self.networkservice.address, "cat"] assert self._keepalive is None self._keepalive = subprocess.Popen( diff --git a/man/labgrid-device-config.5 b/man/labgrid-device-config.5 index fd9e05c2e..d58aa2a3b 100644 --- a/man/labgrid-device-config.5 +++ b/man/labgrid-device-config.5 @@ -132,6 +132,14 @@ See: .TP +.B \fBrsync\fP +Path to the rsync binary, used by the SSHDriver. +See: +.TP +.B \fBscp\fP +Path to the scp binary, used by the SSHDriver. +See: +.TP .B \fBsd\-mux\-ctrl\fP Path to the sd\-mux\-ctrl binary, used by the USBSDWireDriver. See: @@ -140,6 +148,14 @@ See: Path to the sispmctl binary, used by the SiSPMPowerDriver. See: .TP +.B \fBssh\fP +Path to the ssh binary, used by the SSHDriver. +See: +.TP +.B \fBsshfs\fP +Path to the sshfs binary, used by the SSHDriver. +See: +.TP .B \fBuhubctl\fP Path to the uhubctl binary, used by the USBPowerDriver. See: diff --git a/man/labgrid-device-config.rst b/man/labgrid-device-config.rst index ba0156830..d8a79947d 100644 --- a/man/labgrid-device-config.rst +++ b/man/labgrid-device-config.rst @@ -131,6 +131,14 @@ TOOLS KEYS Path to the rk-usb-loader binary, used by the RKUSBDriver. See: https://git.pengutronix.de/cgit/barebox/tree/scripts/rk-usb-loader.c +``rsync`` + Path to the rsync binary, used by the SSHDriver. + See: https://github.com/rsyncproject/rsync + +``scp`` + Path to the scp binary, used by the SSHDriver. + See: https://github.com/openssh/openssh-portable + ``sd-mux-ctrl`` Path to the sd-mux-ctrl binary, used by the USBSDWireDriver. See: https://git.tizen.org/cgit/tools/testlab/sd-mux/ @@ -139,6 +147,14 @@ TOOLS KEYS Path to the sispmctl binary, used by the SiSPMPowerDriver. See: https://sispmctl.sourceforge.net/ +``ssh`` + Path to the ssh binary, used by the SSHDriver. + See: https://github.com/openssh/openssh-portable + +``sshfs`` + Path to the sshfs binary, used by the SSHDriver. + See: https://github.com/libfuse/sshfs + ``uhubctl`` Path to the uhubctl binary, used by the USBPowerDriver. See: https://github.com/mvp/uhubctl diff --git a/tests/test_sshdriver.py b/tests/test_sshdriver.py index c11cdabd6..4c233a834 100644 --- a/tests/test_sshdriver.py +++ b/tests/test_sshdriver.py @@ -1,6 +1,7 @@ import pytest import socket +from labgrid import Environment from labgrid.driver import SSHDriver, ExecutionError from labgrid.exceptions import NoResourceFoundError from labgrid.resource import NetworkService @@ -58,6 +59,27 @@ def test_run_check_raise(target, ssh_driver_mocked_and_activated, mocker): assert res == (['error'], [], 1) target.deactivate(s) +def test_default_tools(target): + NetworkService(target, "service", "1.2.3.4", "root") + s = SSHDriver(target, "ssh") + assert [s._ssh, s._scp, s._sshfs, s._rsync] == ["ssh", "scp", "sshfs", "rsync"] + +def test_custom_tools(target, tmpdir): + p = tmpdir.join("config.yaml") + p.write( + """ + tools: + ssh: "/path/to/ssh" + scp: "/path/to/scp" + sshfs: "/path/to/sshfs" + rsync: "/path/to/rsync" + """ + ) + target.env = Environment(str(p)) + NetworkService(target, "service", "1.2.3.4", "root") + s = SSHDriver(target, "ssh") + assert [s._ssh, s._scp, s._sshfs, s._rsync] == [f"/path/to/{t}" for t in ("ssh", "scp", "sshfs", "rsync")] + @pytest.fixture(scope='function') def ssh_localhost(target, pytestconfig): name = pytestconfig.getoption("--ssh-username")