Skip to content

Commit 5a3d3e7

Browse files
authored
Add an option to disable using the distributed scheduler from the diagnostic script (#3787)
1 parent c84868e commit 5a3d3e7

File tree

3 files changed

+52
-10
lines changed

3 files changed

+52
-10
lines changed

esmvaltool/diag_scripts/shared/_base.py

+15-10
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def __init__(self, cfg):
181181
if not os.path.exists(self._log_file):
182182
self.table = {}
183183
else:
184-
with open(self._log_file, 'r') as file:
184+
with open(self._log_file, 'r', encoding='utf-8') as file:
185185
self.table = yaml.safe_load(file)
186186

187187
def log(self, filename, record):
@@ -212,8 +212,8 @@ def log(self, filename, record):
212212
if isinstance(filename, Path):
213213
filename = str(filename)
214214
if filename in self.table:
215-
raise KeyError(
216-
"Provenance record for {} already exists.".format(filename))
215+
msg = f"Provenance record for {filename} already exists."
216+
raise KeyError(msg)
217217

218218
self.table[filename] = record
219219

@@ -222,7 +222,7 @@ def _save(self):
222222
dirname = os.path.dirname(self._log_file)
223223
if not os.path.exists(dirname):
224224
os.makedirs(dirname)
225-
with open(self._log_file, 'w') as file:
225+
with open(self._log_file, 'w', encoding='utf-8') as file:
226226
yaml.safe_dump(self.table, file)
227227

228228
def __enter__(self):
@@ -253,9 +253,8 @@ def select_metadata(metadata, **attributes):
253253
"""
254254
selection = []
255255
for attribs in metadata:
256-
if all(a in attribs and (
257-
attribs[a] == attributes[a] or attributes[a] == '*')
258-
for a in attributes):
256+
if all(a in attribs and v in (attribs[a], '*')
257+
for a, v in attributes.items()):
259258
selection.append(attribs)
260259
return selection
261260

@@ -424,7 +423,7 @@ def get_cfg(filename=None):
424423
"""Read diagnostic script configuration from settings.yml."""
425424
if filename is None:
426425
filename = sys.argv[1]
427-
with open(filename) as file:
426+
with open(filename, encoding='utf-8') as file:
428427
cfg = yaml.safe_load(file)
429428
return cfg
430429

@@ -441,7 +440,7 @@ def _get_input_data_files(cfg):
441440

442441
input_files = {}
443442
for filename in metadata_files:
444-
with open(filename) as file:
443+
with open(filename, encoding='utf-8') as file:
445444
metadata = yaml.safe_load(file)
446445
input_files.update(metadata)
447446

@@ -469,6 +468,10 @@ def main(cfg):
469468
with run_diagnostic() as cfg:
470469
main(cfg)
471470
471+
To prevent the diagnostic script from using the Dask Distributed scheduler,
472+
set ``no_distributed: true`` in the diagnostic script definition in the
473+
recipe or in the resulting settings.yml file.
474+
472475
The `cfg` dict passed to `main` contains the script configuration that
473476
can be used with the other functions in this module.
474477
"""
@@ -568,7 +571,9 @@ def main(cfg):
568571
logger.info("Removing %s from previous run.", provenance_file)
569572
os.remove(provenance_file)
570573

571-
if not args.no_distributed and 'scheduler_address' in cfg:
574+
use_distributed = not (args.no_distributed
575+
or cfg.get('no_distributed', False))
576+
if use_distributed and 'scheduler_address' in cfg:
572577
try:
573578
client = distributed.Client(cfg['scheduler_address'])
574579
except OSError as exc:

esmvaltool/recipes/recipe_eady_growth_rate.yml

+3
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ diagnostics:
4949
scripts:
5050
annual_eady_growth_rate:
5151
script: primavera/eady_growth_rate/eady_growth_rate.py
52+
no_distributed: true
5253
time_statistic: 'annual_mean'
5354

5455

@@ -63,6 +64,7 @@ diagnostics:
6364
scripts:
6465
summer_eady_growth_rate:
6566
script: primavera/eady_growth_rate/eady_growth_rate.py
67+
no_distributed: true
6668
time_statistic: 'seasonal_mean'
6769

6870
winter_egr:
@@ -76,5 +78,6 @@ diagnostics:
7678
scripts:
7779
winter_eady_growth_rate:
7880
script: primavera/eady_growth_rate/eady_growth_rate.py
81+
no_distributed: true
7982
time_statistic: 'seasonal_mean'
8083
plot_levels: [70000]

tests/unit/diag_scripts/shared/test_base.py

+34
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,40 @@ def test_run_diagnostic(tmp_path, monkeypatch):
367367
assert 'example_setting' in cfg
368368

369369

370+
@pytest.mark.parametrize("no_distributed", [False, True])
371+
def test_run_diagnostic_configures_dask(
372+
tmp_path,
373+
monkeypatch,
374+
mocker,
375+
no_distributed,
376+
):
377+
378+
settings = create_settings(tmp_path)
379+
scheduler_address = "tcp://127.0.0.1:38789"
380+
settings["scheduler_address"] = scheduler_address
381+
if no_distributed:
382+
settings["no_distributed"] = True
383+
settings_file = write_settings(settings)
384+
385+
monkeypatch.setattr(sys, 'argv', ['', settings_file])
386+
387+
# Create files created by ESMValCore
388+
for filename in ('log.txt', 'profile.bin', 'resource_usage.txt'):
389+
file = Path(settings['run_dir']) / filename
390+
file.touch()
391+
392+
mocker.patch.object(shared._base.distributed, "Client")
393+
394+
with shared.run_diagnostic() as cfg:
395+
assert 'example_setting' in cfg
396+
397+
if no_distributed:
398+
shared._base.distributed.Client.assert_not_called()
399+
else:
400+
shared._base.distributed.Client.assert_called_once_with(
401+
scheduler_address)
402+
403+
370404
@pytest.mark.parametrize('flag', ['-l', '--log-level'])
371405
def test_run_diagnostic_log_level(tmp_path, monkeypatch, flag):
372406
"""Test if setting the log level from the command line works."""

0 commit comments

Comments
 (0)