diff --git a/payu/experiment.py b/payu/experiment.py index e0a06706..c66e8845 100644 --- a/payu/experiment.py +++ b/payu/experiment.py @@ -28,6 +28,7 @@ from payu import envmod from payu.fsops import mkdir_p, make_symlink, read_config, movetree from payu.fsops import list_archive_dirs +from payu.fsops import run_script_command from payu.schedulers.pbs import get_job_info, pbs_env_init, get_job_id from payu.models import index as model_index import payu.profilers @@ -887,44 +888,12 @@ def set_userscript_env_vars(self): } ) - def run_userscript(self, script_cmd): - # Setup environment variables with current run information + def run_userscript(self, script_cmd : str): + """Run a user defined script or subcommand at various stages of the + payu submissions""" self.set_userscript_env_vars() - - # First try to interpret the argument as a full command: - try: - sp.check_call(shlex.split(script_cmd)) - except EnvironmentError as exc: - # Now try to run the script explicitly - if exc.errno == errno.ENOENT: - cmd = os.path.join(self.control_path, script_cmd) - # Simplistic recursion check - assert os.path.isfile(cmd) - self.run_userscript(cmd) - - # If we get a "non-executable" error, then guess the type - elif exc.errno == errno.EACCES: - # TODO: Move outside - ext_cmd = {'.py': sys.executable, - '.sh': '/bin/bash', - '.csh': '/bin/tcsh'} - - _, f_ext = os.path.splitext(script_cmd) - shell_name = ext_cmd.get(f_ext) - if shell_name: - print('payu: warning: Assuming that {0} is a {1} script ' - 'based on the filename extension.' - ''.format(os.path.basename(script_cmd), - os.path.basename(shell_name))) - cmd = ' '.join([shell_name, script_cmd]) - self.run_userscript(cmd) - else: - # If we can't guess the shell, then abort - raise - except sp.CalledProcessError as exc: - # If the script runs but the output is bad, then warn the user - print('payu: warning: user script \'{0}\' failed (error {1}).' - ''.format(script_cmd, exc.returncode)) + run_script_command(script_cmd, + control_path=Path(self.control_path)) def sweep(self, hard_sweep=False): # TODO: Fix the IO race conditions! diff --git a/payu/fsops.py b/payu/fsops.py index cc1c09f7..8d480929 100644 --- a/payu/fsops.py +++ b/payu/fsops.py @@ -27,6 +27,11 @@ # Delete this once this bug in Lustre is fixed CHECK_LUSTRE_PATH_LEN = True +# File extensions to script interpreters +EXTENSION_TO_INTERPRETER = {'.py': sys.executable, + '.sh': '/bin/bash', + '.csh': '/bin/tcsh'} + def mkdir_p(path): """Create a new directory; ignore if it already exists.""" @@ -229,3 +234,91 @@ def list_archive_dirs(archive_path: Union[Path, str], dirs.sort(key=lambda d: int(d.lstrip(dir_type))) return dirs + + +def run_script_command(script_cmd: str, control_path: Path) -> None: + """Run a user defined script or command. + + Parameters + ---------- + script_cmd : string + String of user-script command defined in configuration file + control_path : Path + The control directory of the experiment + + Raises + ------ + RuntimeError + If there's was an error running the user-script + """ + try: + _run_script(script_cmd, control_path) + except Exception as e: + error_msg = f"User defined script/command failed to run: {script_cmd}" + raise RuntimeError(error_msg) from e + +def needs_subprocess_shell(command: str) -> bool: + """Check if command contains shell specific values. For example, file + redirections, pipes or logical operators. + + Parameters + ---------- + command: string + String of command to run in subprocess call + + Returns + ------- + bool + Returns True if command requires a subprocess shell, False otherwise + """ + shell_values = ['>', '<', '|', '&&', '$', '`'] + for value in shell_values: + if value in command: + return True + return False + +def _run_script(script_cmd: str, control_path: Path) -> None: + """Helper recursive function to attempt running a script command. + + Parameters + ---------- + script_cmd : string + The script command to attempt to run in subprocess call + control_path: Path + The control directory to use for resolving relative filepaths, if file + is not found + """ + # First try to interpret the argument as a full command: + try: + if needs_subprocess_shell(script_cmd): + subprocess.check_call(script_cmd, shell=True) + else: + subprocess.check_call(shlex.split(script_cmd)) + except EnvironmentError as e: + # Now try to run the script explicitly + if e.errno == errno.ENOENT: + # Check if script is a file in the control directory + cmd = control_path / script_cmd + if cmd.is_file(): + _run_script(str(cmd), control_path) + else: + raise + + # If we get a "non-executable" error, then guess the type + elif e.errno == errno.EACCES: + _, file_ext = os.path.splitext(script_cmd) + shell_name = EXTENSION_TO_INTERPRETER.get(file_ext, None) + if shell_name: + print('payu: warning: Assuming that {0} is a {1} ' + 'script based on the filename extension.' + ''.format(os.path.basename(script_cmd), + os.path.basename(shell_name))) + + cmd = f'{shell_name} {script_cmd}' + _run_script(cmd, control_path) + else: + raise + + # Otherwise re-raise the error + else: + raise diff --git a/test/test_payu.py b/test/test_payu.py index b1184fc5..5b4beac8 100644 --- a/test/test_payu.py +++ b/test/test_payu.py @@ -12,7 +12,7 @@ import payu.laboratory import payu.envmod -from .common import testdir, tmpdir, ctrldir, labdir, workdir +from .common import testdir, tmpdir, ctrldir, labdir, workdir, cd from .common import make_exe, make_inputs, make_restarts, make_all_files @@ -301,3 +301,116 @@ def test_list_archive_dirs(): # Clean up test archive shutil.rmtree(tmp_archive) shutil.rmtree(tmp_archive_2) + + +@pytest.fixture +def script_control_dir(): + # Create a temporary control directory + control_dir = tmpdir / 'script_control_dir' + control_dir.mkdir() + + yield control_dir + + # Tear down + shutil.rmtree(control_dir) + + +def test_run_userscript_python_script(script_control_dir: Path): + # Create a simple python script + python_script = script_control_dir / 'test_script.py' + with open(python_script, 'w') as f: + f.writelines([ + f"with open('{script_control_dir}/output.txt', 'w') as f:\n", + " f.write('Test Python user script')" + ]) + + # Test run userscript + payu.fsops.run_script_command('test_script.py', script_control_dir) + + # Check script output + with open((script_control_dir / 'output.txt'), 'r') as f: + assert f.read() == "Test Python user script" + + +def test_run_userscript_bash_script(script_control_dir: Path): + # Create a simple bash script + bash_script = script_control_dir / 'test_script.sh' + with open(bash_script, 'w') as f: + f.writelines([ + '#!/bin/bash\n', + 'echo -n "Test bash user script" > output.txt' + ]) + + # Test execute script + with cd(script_control_dir): + payu.fsops.run_script_command('./test_script.sh', script_control_dir) + + # Check script output + with open((script_control_dir / 'output.txt'), 'r') as f: + assert f.read() == "Test bash user script" + + +def test_userscript_unknown_extension(script_control_dir: Path): + # Create a text file + text_file = script_control_dir / 'test_txt.txt' + text_file.touch() + + # Test user script raises an error + with pytest.raises(RuntimeError): + payu.fsops.run_script_command(str(text_file), script_control_dir) + + +def test_userscript_non_existent_file(script_control_dir: Path): + # Test user script raises an error + with pytest.raises(RuntimeError): + payu.fsops.run_script_command('unknown_userscript.sh', + script_control_dir) + + +def test_userscript_non_existent_file(script_control_dir: Path): + # Test userscript raises an error + with pytest.raises(RuntimeError): + payu.fsops.run_script_command('unknown_userscript.sh', + script_control_dir) + + +def test_run_userscript_python_script_eror(script_control_dir): + # Create a python script that'll exit with an error + python_script = script_control_dir / 'test_script.py' + with open(python_script, 'w') as f: + f.write('raise ValueError("Test that script exits with error")') + + # Test userscript raises an error + with pytest.raises(RuntimeError): + payu.fsops.run_script_command('test_script.py', + script_control_dir) + + +def test_run_userscript_command(script_control_dir): + # Create a simple command + cmd = 'echo -n "some_data" > test.txt' + + # Test payu userscript + with cd(script_control_dir): + payu.fsops.run_script_command(cmd, script_control_dir) + + # Check userscript output + with open((script_control_dir / 'test.txt'), 'r') as f: + content = f.read() + assert content == "some_data" + + +@pytest.mark.parametrize("command, expected", [ + ('echo "Some Data" > test.txt', True), + ('ls -l | grep "test"', True), + ('cmd1 && cmd2', True), + ('cmd1 || cmd2', True), + ('echo "Data"', False), + ('ls -l', False), + ('some_python_script.py', False), + ('/bin/bash script.sh', False), + ('echo $PAYU_ENV_VALUE', True), + ('echo `date`', True) +]) +def test_needs_shell(command, expected): + assert payu.fsops.needs_subprocess_shell(command) == expected