Skip to content

Commit

Permalink
Refactor run userscripts
Browse files Browse the repository at this point in the history
- Raise RuntimeErrors when userscripts exit with an error (previously it was just a warning)
- Move run userscript code to payu.fsops out of Experiment class so it's easier to test
- Use shell=True for subprocess commands needing a shell, e.g. file redirections, pipes
- Add tests for running userscripts
  • Loading branch information
jo-basevi committed Jul 1, 2024
1 parent 8a4a1e2 commit ceabfbd
Show file tree
Hide file tree
Showing 3 changed files with 213 additions and 38 deletions.
43 changes: 6 additions & 37 deletions payu/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from payu import envmod
from payu.fsops import mkdir_p, make_symlink, read_config, movetree
from payu.fsops import list_archive_dirs
from payu.fsops import run_script_command
from payu.schedulers.pbs import get_job_info, pbs_env_init, get_job_id
from payu.models import index as model_index
import payu.profilers
Expand Down Expand Up @@ -887,44 +888,12 @@ def set_userscript_env_vars(self):
}
)

def run_userscript(self, script_cmd):
# Setup environment variables with current run information
def run_userscript(self, script_cmd : str):
"""Run a user defined script or subcommand at various stages of the
payu submissions"""
self.set_userscript_env_vars()

# First try to interpret the argument as a full command:
try:
sp.check_call(shlex.split(script_cmd))
except EnvironmentError as exc:
# Now try to run the script explicitly
if exc.errno == errno.ENOENT:
cmd = os.path.join(self.control_path, script_cmd)
# Simplistic recursion check
assert os.path.isfile(cmd)
self.run_userscript(cmd)

# If we get a "non-executable" error, then guess the type
elif exc.errno == errno.EACCES:
# TODO: Move outside
ext_cmd = {'.py': sys.executable,
'.sh': '/bin/bash',
'.csh': '/bin/tcsh'}

_, f_ext = os.path.splitext(script_cmd)
shell_name = ext_cmd.get(f_ext)
if shell_name:
print('payu: warning: Assuming that {0} is a {1} script '
'based on the filename extension.'
''.format(os.path.basename(script_cmd),
os.path.basename(shell_name)))
cmd = ' '.join([shell_name, script_cmd])
self.run_userscript(cmd)
else:
# If we can't guess the shell, then abort
raise
except sp.CalledProcessError as exc:
# If the script runs but the output is bad, then warn the user
print('payu: warning: user script \'{0}\' failed (error {1}).'
''.format(script_cmd, exc.returncode))
run_script_command(script_cmd,
control_path=Path(self.control_path))

def sweep(self, hard_sweep=False):
# TODO: Fix the IO race conditions!
Expand Down
93 changes: 93 additions & 0 deletions payu/fsops.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@
# Delete this once this bug in Lustre is fixed
CHECK_LUSTRE_PATH_LEN = True

# File extensions to script interpreters
EXTENSION_TO_INTERPRETER = {'.py': sys.executable,
'.sh': '/bin/bash',
'.csh': '/bin/tcsh'}


def mkdir_p(path):
"""Create a new directory; ignore if it already exists."""
Expand Down Expand Up @@ -229,3 +234,91 @@ def list_archive_dirs(archive_path: Union[Path, str],

dirs.sort(key=lambda d: int(d.lstrip(dir_type)))
return dirs


def run_script_command(script_cmd: str, control_path: Path) -> None:
"""Run a user defined script or command.
Parameters
----------
script_cmd : string
String of user-script command defined in configuration file
control_path : Path
The control directory of the experiment
Raises
------
RuntimeError
If there's was an error running the user-script
"""
try:
_run_script(script_cmd, control_path)
except Exception as e:
error_msg = f"User defined script/command failed to run: {script_cmd}"
raise RuntimeError(error_msg) from e

def needs_subprocess_shell(command: str) -> bool:
"""Check if command contains shell specific values. For example, file
redirections, pipes or logical operators.
Parameters
----------
command: string
String of command to run in subprocess call
Returns
-------
bool
Returns True if command requires a subprocess shell, False otherwise
"""
shell_values = ['>', '<', '|', '&&', '$', '`']
for value in shell_values:
if value in command:
return True
return False

def _run_script(script_cmd: str, control_path: Path) -> None:
"""Helper recursive function to attempt running a script command.
Parameters
----------
script_cmd : string
The script command to attempt to run in subprocess call
control_path: Path
The control directory to use for resolving relative filepaths, if file
is not found
"""
# First try to interpret the argument as a full command:
try:
if needs_subprocess_shell(script_cmd):
subprocess.check_call(script_cmd, shell=True)
else:
subprocess.check_call(shlex.split(script_cmd))
except EnvironmentError as e:
# Now try to run the script explicitly
if e.errno == errno.ENOENT:
# Check if script is a file in the control directory
cmd = control_path / script_cmd
if cmd.is_file():
_run_script(str(cmd), control_path)
else:
raise

# If we get a "non-executable" error, then guess the type
elif e.errno == errno.EACCES:
_, file_ext = os.path.splitext(script_cmd)
shell_name = EXTENSION_TO_INTERPRETER.get(file_ext, None)
if shell_name:
print('payu: warning: Assuming that {0} is a {1} '
'script based on the filename extension.'
''.format(os.path.basename(script_cmd),
os.path.basename(shell_name)))

cmd = f'{shell_name} {script_cmd}'
_run_script(cmd, control_path)
else:
raise

# Otherwise re-raise the error
else:
raise
115 changes: 114 additions & 1 deletion test/test_payu.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import payu.laboratory
import payu.envmod

from .common import testdir, tmpdir, ctrldir, labdir, workdir
from .common import testdir, tmpdir, ctrldir, labdir, workdir, cd
from .common import make_exe, make_inputs, make_restarts, make_all_files


Expand Down Expand Up @@ -301,3 +301,116 @@ def test_list_archive_dirs():
# Clean up test archive
shutil.rmtree(tmp_archive)
shutil.rmtree(tmp_archive_2)


@pytest.fixture
def script_control_dir():
# Create a temporary control directory
control_dir = tmpdir / 'script_control_dir'
control_dir.mkdir()

yield control_dir

# Tear down
shutil.rmtree(control_dir)


def test_run_userscript_python_script(script_control_dir: Path):
# Create a simple python script
python_script = script_control_dir / 'test_script.py'
with open(python_script, 'w') as f:
f.writelines([
f"with open('{script_control_dir}/output.txt', 'w') as f:\n",
" f.write('Test Python user script')"
])

# Test run userscript
payu.fsops.run_script_command('test_script.py', script_control_dir)

# Check script output
with open((script_control_dir / 'output.txt'), 'r') as f:
assert f.read() == "Test Python user script"


def test_run_userscript_bash_script(script_control_dir: Path):
# Create a simple bash script
bash_script = script_control_dir / 'test_script.sh'
with open(bash_script, 'w') as f:
f.writelines([
'#!/bin/bash\n',
'echo -n "Test bash user script" > output.txt'
])

# Test execute script
with cd(script_control_dir):
payu.fsops.run_script_command('./test_script.sh', script_control_dir)

# Check script output
with open((script_control_dir / 'output.txt'), 'r') as f:
assert f.read() == "Test bash user script"


def test_userscript_unknown_extension(script_control_dir: Path):
# Create a text file
text_file = script_control_dir / 'test_txt.txt'
text_file.touch()

# Test user script raises an error
with pytest.raises(RuntimeError):
payu.fsops.run_script_command(str(text_file), script_control_dir)


def test_userscript_non_existent_file(script_control_dir: Path):
# Test user script raises an error
with pytest.raises(RuntimeError):
payu.fsops.run_script_command('unknown_userscript.sh',
script_control_dir)


def test_userscript_non_existent_file(script_control_dir: Path):
# Test userscript raises an error
with pytest.raises(RuntimeError):
payu.fsops.run_script_command('unknown_userscript.sh',
script_control_dir)


def test_run_userscript_python_script_eror(script_control_dir):
# Create a python script that'll exit with an error
python_script = script_control_dir / 'test_script.py'
with open(python_script, 'w') as f:
f.write('raise ValueError("Test that script exits with error")')

# Test userscript raises an error
with pytest.raises(RuntimeError):
payu.fsops.run_script_command('test_script.py',
script_control_dir)


def test_run_userscript_command(script_control_dir):
# Create a simple command
cmd = 'echo -n "some_data" > test.txt'

# Test payu userscript
with cd(script_control_dir):
payu.fsops.run_script_command(cmd, script_control_dir)

# Check userscript output
with open((script_control_dir / 'test.txt'), 'r') as f:
content = f.read()
assert content == "some_data"


@pytest.mark.parametrize("command, expected", [
('echo "Some Data" > test.txt', True),
('ls -l | grep "test"', True),
('cmd1 && cmd2', True),
('cmd1 || cmd2', True),
('echo "Data"', False),
('ls -l', False),
('some_python_script.py', False),
('/bin/bash script.sh', False),
('echo $PAYU_ENV_VALUE', True),
('echo `date`', True)
])
def test_needs_shell(command, expected):
assert payu.fsops.needs_subprocess_shell(command) == expected

0 comments on commit ceabfbd

Please sign in to comment.