Skip to content

Commit

Permalink
Merge pull request #177 from pyiron/decouple_draw_and_status
Browse files Browse the repository at this point in the history
  • Loading branch information
liamhuber authored Apr 4, 2023
2 parents ae23b31 + c8700dc commit 7553b2d
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 71 deletions.
8 changes: 4 additions & 4 deletions ironflow/gui/workflows/boxes/node_interface/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def _input_field_list(self) -> list[list[widgets.Widget]]:
"val": "Serialization error -- please reconnect an input"
}
if inp.val is None:
inp.val = dtype_state["val"]
inp.update(dtype_state["val"])

try:
if dtype_state["batched"]:
Expand Down Expand Up @@ -185,8 +185,8 @@ def _input_field_list(self) -> list[list[widgets.Widget]]:
def _input_change_i(self, i_c) -> Callable:
def input_change(change: dict) -> None:
# Todo: Test this in exec mode
self.node.inputs[i_c].val = change["new"]
self.node.update(i_c)
self.node.inputs[i_c].update(change["new"])
# self.node.update(i_c)
self.screen.redraw_active_flow_canvas()

return input_change
Expand All @@ -212,7 +212,7 @@ def toggle_batching(change: dict) -> None:
def _input_reset_i(self, i_c, associated_input_field) -> Callable:
def input_reset(button: widgets.Button) -> None:
default = self.node.inputs[i_c].dtype.default
self.node.inputs[i_c].val = default
self.node.inputs[i_c].update(default)
InfoMsgs.write(
f"Value for {self.node.title}.{self.node.inputs[i_c].label_str} "
f"reset to {default}"
Expand Down
2 changes: 1 addition & 1 deletion ironflow/gui/workflows/canvas_widgets/ports.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def on_click(
def _current_color(self):
if self.highlighted:
color = self.layout.highlight_color
elif self.port.valid_val:
elif self.port.ready:
if self.selected:
color = self.layout.valid_selected_color
else:
Expand Down
6 changes: 5 additions & 1 deletion ironflow/model/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ def __getitem__(self, item):
def __iter__(self):
return self._filtered_port_list.__iter__()

def __len__(self):
return len(self._port_list)


class ValueFinder(PortFinder):
def __getattr__(self, key):
Expand Down Expand Up @@ -248,6 +251,7 @@ def setup_ports(self, inputs_data=None, outputs_data=None):
# widget in the front end which has probably overridden the
# Node.input() method
self.inputs[-1].val = deserialize(inp["val"])
self.inputs[-1].set_dtype_ok()

for out in outputs_data:
dtype = dtypes.DType.from_str(out["dtype"])(
Expand All @@ -262,7 +266,7 @@ def setup_ports(self, inputs_data=None, outputs_data=None):

@property
def all_input_is_valid(self):
return all(p.valid_val for p in self.inputs.ports)
return all(p.ready for p in self.inputs.ports)

def place_event(self):
# place_event() is executed *before* the connections are built
Expand Down
123 changes: 85 additions & 38 deletions ironflow/model/port.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import Optional, TYPE_CHECKING

from numpy import argwhere
from ryvencore.InfoMsgs import InfoMsgs
from ryvencore.NodePort import NodeInput as NodeInputCore, NodeOutput as NodeOutputCore
from ryvencore.NodePortBP import (
NodeOutputBP as NodeOutputBPCore,
Expand All @@ -24,66 +25,80 @@
from ironflow.model.node import Node


class TypeHaver:
"""
A parent class for the has-type classes to facilitate super calls, regardless of
the order these havers appear as mixins to other classes.
"""

@property
def valid_val(self):
try:
other_type_checks = super().valid_val
except AttributeError:
other_type_checks = True
return other_type_checks


class HasDType(TypeHaver):
class HasDType:
"""A mixin to add the valid value check property"""

@property
def valid_val(self):
return self._dtype_ok and super().valid_val

@property
def _dtype_ok(self):
def set_dtype_ok(self):
if self.dtype is not None:
if self.val is not None:
return self.dtype.valid_val(self.val)
self._dtype_ok = self.dtype.valid_val(self.val)
else:
return self.dtype.allow_none
self._dtype_ok = self.dtype.allow_none
else:
return True
self._dtype_ok = True

@property
def dtype_ok(self):
try:
return self._dtype_ok
except AttributeError:
self.set_dtype_ok()
return self._dtype_ok


class HasOType(TypeHaver):
class HasOType:
"""A mixin to add the valid value check to properties with an ontology type"""

@property
def valid_val(self):
return self._otype_ok and super().valid_val
def recalculate_otype_checks(self, ignore=None):
self.set_otype_ok()
if self.otype is not None:
# Along connections
for con in self.connections:
if isinstance(self, NodeInput):
other = con.out
else:
other = con.inp

@property
def _otype_ok(self):
if other != ignore and other.otype is not None:
other.recalculate_otype_checks(ignore=self)

# Across the node
if isinstance(self, NodeInput):
ports = self.node.outputs.ports
else:
ports = self.node.inputs.ports
if ignore not in ports:
for port in ports:
if port.otype is not None:
port.recalculate_otype_checks(ignore=self)

def set_otype_ok(self):
if self.otype is not None:
if isinstance(self, NodeInput):
input_tree = self.otype.get_source_tree(
additional_requirements=self.get_downstream_requirements()
)
return all(
self._otype_ok = all(
con.out.all_connections_found_in(input_tree)
for con in self.connections
if con.out.otype is not None
)
else:
return all(
self._otype_ok = all(
con.inp.workflow_tree_contains_connections_of(self)
for con in self.connections
if con.inp.otype is not None
)
else:
return True
self._otype_ok = True

@property
def otype_ok(self):
try:
return self._otype_ok
except AttributeError:
self.set_otype_ok()
return self._otype_ok

def _output_graph_is_represented_in_workflow_tree(self, output_port, input_tree):
try:
Expand Down Expand Up @@ -131,7 +146,13 @@ def get_downstream_requirements(self):
return list(set(downstream_requirements))


class NodeInput(NodeInputCore, HasDType, HasOType):
class HasTypes(HasOType, HasDType):
@property
def ready(self):
return self.dtype_ok and self.otype_ok


class NodeInput(NodeInputCore, HasTypes):
def __init__(
self,
node: Node,
Expand All @@ -157,14 +178,14 @@ def batch(self):
if self.dtype is not None and not self.dtype.batched:
self.dtype.batched = True
if len(self.connections) == 0:
self.val = [self.val]
self.update([self.val])
self._update_node()

def unbatch(self):
if self.dtype is not None and self.dtype.batched:
self.dtype.batched = False
if len(self.connections) == 0:
self.val = self.val[-1]
self.update(self.val[-1])
self._update_node()

def data(self) -> dict:
Expand All @@ -182,8 +203,30 @@ def workflow_tree_contains_connections_of(self, port: NodeOutput):
)
return self._output_graph_is_represented_in_workflow_tree(port, tree)

def update(self, data=None):
# super().update(data=data)
# We need to add the dtype update _between_ the val update and node update
if self.type_ == "data":
self.val = data # self.get_val()
InfoMsgs.write("Data in input set to", data)

self.set_dtype_ok()

self.node.update(inp=self.node.inputs.index(self))

class NodeOutput(NodeOutputCore, HasDType, HasOType):
def connected(self):
super().connected()
self.set_dtype_ok()
self.recalculate_otype_checks() # Note: Only need to call or one of input or
# output since Flow.add_connection calls .connected on both inp and out

def disconnected(self):
super().disconnected()
self.recalculate_otype_checks() # Note: Only need to call or one of input or
# output since Flow.add_connection calls .connected on both inp and out


class NodeOutput(NodeOutputCore, HasTypes):
def __init__(
self,
node,
Expand Down Expand Up @@ -216,6 +259,10 @@ def all_connections_found_in(self, tree):
"""
return self._output_graph_is_represented_in_workflow_tree(self, tree)

def set_val(self, val):
super().set_val(val)
self.set_dtype_ok()


class NodeInputBP(NodeInputBPCore):
def __init__(
Expand Down
15 changes: 9 additions & 6 deletions ironflow/nodes/pyiron/atomistics_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from matplotlib.axes import Axes
from nglview import NGLWidget
from pandas import DataFrame
from ryvencore.InfoMsgs import InfoMsgs

import pyiron_base
import pyiron_ontology
Expand Down Expand Up @@ -51,6 +50,7 @@
PortList,
)
from ironflow.nodes.std.special_nodes import DualNodeBase
from ryvencore.InfoMsgs import InfoMsgs

if TYPE_CHECKING:
from pyiron_base import HasGroups
Expand Down Expand Up @@ -797,22 +797,25 @@ def _update_potential_choices(self):
available_potentials = self._get_potentials()

if len(available_potentials) == 0:
self.inputs.ports.potential.val = None
self.inputs.ports.potential.update(None)
self.inputs.ports.potential.dtype.items = ["No valid potential"]
else:
if (
last_potential not in available_potentials
and len(self.inputs.ports.potential.connections) == 0
):
if self.inputs.ports.potential.dtype.batched:
self.inputs.ports.potential.val = available_potentials
self.inputs.ports.potential.update(available_potentials)
else:
self.inputs.ports.potential.val = available_potentials[0]
self.inputs.ports.potential.update(available_potentials[0])
self.inputs.ports.potential.dtype.items = available_potentials
self.inputs.ports.potential.set_dtype_ok()

def update_event(self, inp=-1):
if inp == 1 and self.inputs.ports.structure.valid_val:
self._update_potential_choices()
if inp == 1:
self.inputs.ports.structure.set_dtype_ok()
if self.inputs.ports.structure.ready:
self._update_potential_choices()
super().update_event(inp=inp)

def node_function(self, project, structure, potential, **kwargs) -> dict:
Expand Down
56 changes: 35 additions & 21 deletions tests/unit/model/test_port.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,20 @@
from ironflow.model.port import NodeInput, NodeOutput, NodeOutputBP


class DummyInputs:
def index(self, node):
pass


class DummyNode:
def update(self, inp):
pass

@property
def inputs(self):
return DummyInputs()


class TestPorts(TestCase):
def test_for_dtype(self):
"""The `dtype` attribute should always be present, although it may be None"""
Expand All @@ -17,28 +31,28 @@ def test_for_dtype(self):

def test_validity(self):
with self.subTest("Without dtype set"):
p0 = NodeInput(node=None, dtype=None)
p0.val = None
self.assertTrue(p0.valid_val)
p0.val = "foo"
self.assertTrue(p0.valid_val)
p0.val = 42
self.assertTrue(p0.valid_val)
p0 = NodeInput(node=DummyNode(), dtype=None)
p0.update(None)
self.assertTrue(p0.ready)
p0.update("foo")
self.assertTrue(p0.ready)
p0.update(42)
self.assertTrue(p0.ready)

with self.subTest("With allow none"):
p0 = NodeInput(node=None, dtype=String(allow_none=True))
p0.val = None
self.assertTrue(p0.valid_val)
p0.val = "foo"
self.assertTrue(p0.valid_val)
p0.val = 42
self.assertFalse(p0.valid_val, msg="Should be wrong type")
p0 = NodeInput(node=DummyNode(), dtype=String(allow_none=True))
p0.update(None)
self.assertTrue(p0.ready)
p0.update("foo")
self.assertTrue(p0.ready)
p0.update(42)
self.assertFalse(p0.ready, msg="Should be wrong type")

with self.subTest("With allow none"):
p0 = NodeInput(node=None, dtype=String(allow_none=False))
p0.val = None
self.assertFalse(p0.valid_val, msg="None should be disallowed")
p0.val = "foo"
self.assertTrue(p0.valid_val)
p0.val = 42
self.assertFalse(p0.valid_val, msg="Should be wrong type")
p0 = NodeInput(node=DummyNode(), dtype=String(allow_none=False))
p0.update(None)
self.assertFalse(p0.ready, msg="None should be disallowed")
p0.update("foo")
self.assertTrue(p0.ready)
p0.update(42)
self.assertFalse(p0.ready, msg="Should be wrong type")

0 comments on commit 7553b2d

Please sign in to comment.