Skip to content

Commit

Permalink
FIX: namespaced outputs in BaseRestartWorkChain (#4961)
Browse files Browse the repository at this point in the history
The `BaseRestartWorkChain` did not return an `output_namespace` of
its `_process_class` as described in #4623. It happened because in its
`results` method, only the output keys are obtained from the call to
`node.get_outgoing` (checked and returned by the parent WorkChain).

This was changed for a call to `exposed_outputs`, which instead returns
the whole nested namespace. The `out_many` method is not used here
in order to make a post-check for ports that allows to keep the original
exit code check and report.

Cherry-pick: e1abe0a
  • Loading branch information
unkcpz authored and sphuber committed Aug 8, 2021
1 parent 2683d71 commit e02933c
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 3 deletions.
6 changes: 4 additions & 2 deletions aiida/engine/processes/workchains/restart.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,11 +303,13 @@ def results(self) -> Optional['ExitCode']:

self.report(f'work chain completed after {self.ctx.iteration} iterations')

exposed_outputs = self.exposed_outputs(node, self.process_class)

for name, port in self.spec().outputs.items():

try:
output = node.get_outgoing(link_label_filter=name).one().node
except ValueError:
output = exposed_outputs[name]
except KeyError:
if port.required:
self.report(f"required output '{name}' was not an output of {self.ctx.process_name}<{node.pk}>")
else:
Expand Down
47 changes: 46 additions & 1 deletion tests/engine/processes/workchains/test_restart.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# pylint: disable=invalid-name,no-self-use,no-member
import pytest

from aiida import engine
from aiida import engine, orm
from aiida.engine.processes.workchains.awaitable import Awaitable


Expand Down Expand Up @@ -146,3 +146,48 @@ def mock_submit(_, process_class, **kwargs):
assert isinstance(result, engine.ToContext)
assert isinstance(result['children'], Awaitable)
assert process.node.get_extra(SomeWorkChain._considered_handlers_extra) == [[]] # pylint: disable=protected-access


class OutputNamespaceWorkChain(engine.WorkChain):
"""A WorkChain has namespaced output"""

@classmethod
def define(cls, spec):
super().define(spec)
spec.output_namespace('sub', valid_type=orm.Int, dynamic=True)
spec.outline(cls.finalize)

def finalize(self):
self.out('sub.result', orm.Int(1).store())


class CustomBRWorkChain(engine.BaseRestartWorkChain):
"""`BaseRestartWorkChain` of `OutputNamespaceWorkChain`"""

_process_class = OutputNamespaceWorkChain

@classmethod
def define(cls, spec):
super().define(spec)
spec.expose_outputs(cls._process_class)
spec.output('extra', valid_type=orm.Int)

spec.outline(
cls.setup,
engine.while_(cls.should_run_process)(
cls.run_process,
cls.inspect_process,
),
cls.results,
)

def setup(self):
super().setup()
self.ctx.inputs = {}


@pytest.mark.requires_rmq
def test_results():
res, node = engine.launch.run_get_node(CustomBRWorkChain)
assert res['sub'].result.value == 1
assert node.exit_status == 11

0 comments on commit e02933c

Please sign in to comment.