Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Task messaging ssh #4119

Merged
merged 12 commits into from
Mar 23, 2021
3 changes: 3 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ setting `flow.cylc[scheduler]allow implicit tasks` to `True`.

### Enhancements

[#4119](https://github.com/cylc/cylc-flow/pull/4119) - Reimplement ssh task
communications.

[#4115](https://github.com/cylc/cylc-flow/pull/4115) - Raise an error when
invalid sort keys are provided clients.

Expand Down
4 changes: 3 additions & 1 deletion cylc/flow/cfgspec/globalcfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@
''')
Conf('suite definition directory', VDR.V_STRING)
Conf('communication method',
VDR.V_STRING, 'zmq', options=['zmq', 'poll'], desc='''
VDR.V_STRING, 'zmq', options=['zmq', 'poll', 'ssh'], desc='''
The means by which task progress messages are reported back to
the running suite.

Expand All @@ -365,6 +365,8 @@
Direct client-server TCP communication via network ports
poll
The suite polls for the status of tasks (no task messaging)
ssh
Use non-interactive ssh for task communications
''')
# TODO ensure that it is possible to over-ride the following three
# settings in suite config.
Expand Down
3 changes: 3 additions & 0 deletions cylc/flow/job_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,10 @@ def _write_suite_environment(self, handle, job_conf, run_d):
'\n export CYLC_SUITE_UUID="%s"' % job_conf['uuid_str'])

def _write_task_environment(self, handle, job_conf):
comm_meth = job_conf['platform']['communication method']

handle.write("\n\n # CYLC TASK ENVIRONMENT:")
handle.write(f"\n export CYLC_TASK_COMMS_METHOD={comm_meth}")
handle.write('\n export CYLC_TASK_JOB="%s"' % job_conf['job_d'])
handle.write(
'\n export CYLC_TASK_NAMESPACE_HIERARCHY="%s"' %
Expand Down
5 changes: 4 additions & 1 deletion cylc/flow/network/authorisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,14 @@ def _authorise(self, *args, user='?', meta=None, **kwargs):
meta = {}
host = meta.get('host', '?')
prog = meta.get('prog', '?')
comms_method = meta.get('comms_method', '?')

# Hardcoded, for new - but much of this functionality can be
# removed more swingingly.
LOG.info(
'[client-command] %s %s@%s:%s', fcn.__name__, user, host, prog)
'[client-command] %s %s://%s@%s:%s',
fcn.__name__, comms_method, user, host, prog
)
return fcn(self, *args, **kwargs)

return _authorise
Expand Down
22 changes: 17 additions & 5 deletions cylc/flow/network/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
get_location,
ZMQSocketBase
)
from cylc.flow.network.client_factory import CommsMeth
from cylc.flow.network.server import PB_METHOD_MAP
from cylc.flow.suite_files import detect_old_contact_file

Expand Down Expand Up @@ -223,15 +224,26 @@ def serial_request(self, command, args=None, timeout=None):
self.loop.run_until_complete(task)
return task.result()

@staticmethod
def get_header() -> dict:
def get_header(self) -> dict:
"""Return "header" data to attach to each request for traceability.

Returns:
dict: dictionary with the header information, such as
program and hostname.
"""
cmd = sys.argv[0]

host = socket.gethostname()
# Identify communication method
comms_method = os.getenv("CLIENT_COMMS_METH", default=CommsMeth.ZMQ)
if (self.host and
(comms_method == CommsMeth.ZMQ) and
(socket.gethostbyname(
self.host) == socket.gethostbyname(socket.gethostname()))):
comms_method = CommsMeth.LOCAL
if len(sys.argv) > 1:
cmd = sys.argv[1]
else:
cmd = sys.argv[0]

cylc_executable_location = which("cylc")
if cylc_executable_location:
Expand All @@ -243,11 +255,11 @@ def get_header() -> dict:

if cmd.startswith(cylc_bin_dir):
cmd = cmd.replace(cylc_bin_dir, '')

return {
'meta': {
'prog': cmd,
'host': socket.gethostname()
'host': host,
'comms_method': comms_method,
}
}

Expand Down
53 changes: 53 additions & 0 deletions cylc/flow/network/client_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# THIS FILE IS PART OF THE CYLC SUITE ENGINE.
# Copyright (C) NIWA & British Crown (Met Office) & Contributors.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import os


class CommsMeth():
"""String literals used for identifying communication methods"""

SSH = 'ssh'
ZMQ = 'zmq'
LOCAL = 'local' # used for local commands


def get_comms_method():
""""Return Communication Method from environment variable, default zmq"""

return os.getenv('CYLC_TASK_COMMS_METHOD', CommsMeth.ZMQ)


def get_runtime_client(comms_method, workflow, timeout=None):
"""Return client for the provided communication method.

Args:
comm_method: communication method
workflow: workflow name
"""

if comms_method == CommsMeth.SSH:
from cylc.flow.network.ssh_client import SuiteRuntimeClient
else:
from cylc.flow.network.client import SuiteRuntimeClient
return SuiteRuntimeClient(workflow, timeout=timeout)


def get_client(workflow, timeout=None):
"""Get communication method and return correct SuiteRuntimeClient"""

comms_method = get_comms_method()
return get_runtime_client(comms_method, workflow, timeout=timeout)
102 changes: 102 additions & 0 deletions cylc/flow/network/ssh_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# THIS FILE IS PART OF THE CYLC SUITE ENGINE.
# Copyright (C) NIWA & British Crown (Met Office) & Contributors.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import json
import os
from typing import Union

from cylc.flow.exceptions import ClientError
from cylc.flow.network.client_factory import CommsMeth
from cylc.flow.network import get_location
from cylc.flow.remote import _remote_cylc_cmd
from cylc.flow.suite_files import load_contact_file, ContactFileFields


class SuiteRuntimeClient():
"""Client to the workflow server communication using ssh.

Determines host from the contact file unless provided.

Args:
suite (str):
Name of the suite to connect to.
timeout (float):
Set the default timeout in seconds.
host (str):
The host where the flow is running if known.
"""
def __init__(
self,
suite: str,
host: str = None,
timeout: Union[float, str] = None
):
self.suite = suite

if not host:
self.host, _, _ = get_location(suite)

def send_request(self, command, args=None, timeout=None):
"""Send a request, using ssh.

Determines ssh_cmd, cylc_path and login_shell settings from the contact
file.

Converts message to JSON and sends this to stdin. Executes the Cylc
command, then deserialises the output.

Use ``__call__`` to call this method.

Args:
command (str): The name of the endpoint to call.
args (dict): Arguments to pass to the endpoint function.
timeout (float): Override the default timeout (seconds).
Raises:
ClientError: Coverall, on error from function call
Returns:
object: Deserialized output from function called.
"""
# Set environment variable to determine the communication for use on
# the scheduler
os.environ["CLIENT_COMMS_METH"] = CommsMeth.SSH
cmd = ["client"]
if timeout:
cmd += [f'comms_timeout={timeout}']
cmd += [self.suite, command]
contact = load_contact_file(self.suite)
ssh_cmd = contact[ContactFileFields.SCHEDULER_SSH_COMMAND]
login_shell = contact[ContactFileFields.SCHEDULER_USE_LOGIN_SHELL]
cylc_path = contact[ContactFileFields.SCHEDULER_CYLC_PATH]
cylc_path = None if cylc_path == 'None' else cylc_path
if not args:
args = {}
message = json.dumps(args)
proc = _remote_cylc_cmd(
cmd,
host=self.host,
stdin_str=message,
ssh_cmd=ssh_cmd,
remote_cylc_path=cylc_path,
ssh_login_shell=login_shell,
capture_process=True)

out, err = (f.decode() for f in proc.communicate())
return_code = proc.wait()
if return_code:
raise ClientError(err, f"return-code={return_code}")
return json.loads(out)

__call__ = send_request
3 changes: 2 additions & 1 deletion cylc/flow/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,8 @@ def _construct_ssh_cmd(

for envvar in [
'CYLC_CONF_PATH',
'CYLC_COVERAGE'
'CYLC_COVERAGE',
'CLIENT_COMMS_METH'
]:
if envvar in os.environ:
command.append(
Expand Down
10 changes: 7 additions & 3 deletions cylc/flow/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,14 +1039,18 @@ def _configure_contact(self):
process_str,
fields.PUBLISH_PORT:
str(self.publisher.port),
fields.SSH_USE_LOGIN_SHELL:
str(get_platform()['use login shell']),
fields.SUITE_RUN_DIR_ON_SUITE_HOST:
self.suite_run_dir,
fields.UUID:
self.uuid_str.value,
fields.VERSION:
CYLC_VERSION
CYLC_VERSION,
fields.SCHEDULER_SSH_COMMAND:
str(get_platform()['ssh command']),
fields.SCHEDULER_CYLC_PATH:
str(get_platform()['cylc path']),
fields.SCHEDULER_USE_LOGIN_SHELL:
str(get_platform()['use login shell'])
}
# fmt: on
suite_files.dump_contact_file(self.suite, contact_data)
Expand Down
4 changes: 2 additions & 2 deletions cylc/flow/scripts/broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
from cylc.flow.broadcast_report import (
get_broadcast_bad_options_report, get_broadcast_change_report)
from cylc.flow.cfgspec.suite import SPEC, upg
from cylc.flow.network.client import SuiteRuntimeClient
from cylc.flow.network.client_factory import get_client
from cylc.flow.parsec.config import ParsecConfig
from cylc.flow.parsec.validate import cylc_config_validate

Expand Down Expand Up @@ -296,7 +296,7 @@ def get_option_parser():
def main(_, options, suite):
"""Implement cylc broadcast."""
suite = os.path.normpath(suite)
pclient = SuiteRuntimeClient(suite, timeout=options.comms_timeout)
pclient = get_client(suite, timeout=options.comms_timeout)

mutation_kwargs = {
'request_string': MUTATION,
Expand Down
1 change: 0 additions & 1 deletion cylc/flow/scripts/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

import json
import sys

from google.protobuf.json_format import MessageToDict

from cylc.flow.option_parsers import CylcOptionParser as COP
Expand Down
4 changes: 2 additions & 2 deletions cylc/flow/scripts/dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@
# Display the state of all tasks in a particular cycle point:
$ cylc dump -t SUITE | grep 2010082406"""

from cylc.flow.network.client_factory import get_client
import sys
import json

from graphene.utils.str_converters import to_snake_case

from cylc.flow.exceptions import CylcError
from cylc.flow.option_parsers import CylcOptionParser as COP
from cylc.flow.network.client import SuiteRuntimeClient
from cylc.flow.terminal import cli_function

TASK_SUMMARY_FRAGMENT = '''
Expand Down Expand Up @@ -167,7 +167,7 @@ def get_option_parser():

@cli_function(get_option_parser)
def main(_, options, suite):
pclient = SuiteRuntimeClient(suite, timeout=options.comms_timeout)
pclient = get_client(suite, timeout=options.comms_timeout)

if options.sort_by_cycle:
sort_args = {'keys': ['cyclePoint', 'name']}
Expand Down
5 changes: 2 additions & 3 deletions cylc/flow/scripts/ext_trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@

from cylc.flow import LOG
from cylc.flow.exceptions import CylcError, ClientError
from cylc.flow.network.client_factory import get_client
from cylc.flow.option_parsers import CylcOptionParser as COP
from cylc.flow.network.client import SuiteRuntimeClient
from cylc.flow.terminal import cli_function


Expand Down Expand Up @@ -92,8 +92,7 @@ def get_option_parser():
def main(parser, options, suite, event_msg, event_id):
suite = os.path.normpath(suite)
LOG.info('Send to suite %s: "%s" (%s)', suite, event_msg, event_id)

pclient = SuiteRuntimeClient(suite, timeout=options.comms_timeout)
pclient = get_client(suite, timeout=options.comms_timeout)

max_n_tries = int(options.max_n_tries)
retry_intvl_secs = float(options.retry_intvl_secs)
Expand Down
4 changes: 2 additions & 2 deletions cylc/flow/scripts/get_suite_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
To find the version you've invoked at the command line see "cylc version".
"""

from cylc.flow.network.client_factory import get_client
from cylc.flow.option_parsers import CylcOptionParser as COP
from cylc.flow.network.client import SuiteRuntimeClient
from cylc.flow.terminal import cli_function

QUERY = '''
Expand All @@ -47,7 +47,7 @@ def get_option_parser():

@cli_function(get_option_parser)
def main(parser, options, suite):
pclient = SuiteRuntimeClient(suite, timeout=options.comms_timeout)
pclient = get_client(suite, timeout=options.comms_timeout)

query_kwargs = {
'request_string': QUERY,
Expand Down
Loading