Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type hinting throughout our classes #7

Merged
merged 1 commit into from
Oct 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 28 additions & 20 deletions ryven/ironflow/CanvasObject.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
from __future__ import annotations

from ipycanvas import Canvas, hold_canvas
import ipywidgets as widgets
import numpy as np
from IPython.display import display

from .NodeWidget import CanvasLayout, NodeWidget, PortWidget, ButtonNodeWidget
from .NodeWidget import CanvasLayout, NodeWidget, PortWidget, BaseCanvasWidget, ButtonNodeWidget
from .NodeWidgets import NodeWidgets
from .has_session import HasSession

from typing import TYPE_CHECKING, Optional, Union, List
if TYPE_CHECKING:
from Gui import GUI
from ryven.NENV import Node
Number = Union[int, float]

__author__ = "Joerg Neugebauer"
__copyright__ = (
"Copyright 2020, Max-Planck-Institut für Eisenforschung GmbH - "
Expand All @@ -24,7 +32,7 @@


class CanvasObject(HasSession):
def __init__(self, gui=None, width=2000, height=1000):
def __init__(self, gui: Optional[GUI] = None, width: int = 2000, height: int = 1000):
self.gui = gui
super().__init__(self.gui.session)
self._width, self._height = width, height
Expand Down Expand Up @@ -61,7 +69,7 @@ def __init__(self, gui=None, width=2000, height=1000):

self._object_to_gui_dict = {}

def draw_connection(self, port_1, port_2):
def draw_connection(self, port_1: int, port_2: int) -> None:
# i_out, i_in = path
# out = self.objects_to_draw[i_out]
# inp = self.objects_to_draw[i_in]
Expand All @@ -75,20 +83,20 @@ def draw_connection(self, port_1, port_2):
canvas.line_to(inp.x, inp.y)
canvas.stroke()

def _built_object_to_gui_dict(self):
def _built_object_to_gui_dict(self) -> None:
self._object_to_gui_dict = {}
for n in self.objects_to_draw:
self._object_to_gui_dict[n.node] = n
for p in n.objects_to_draw:
if hasattr(p, "port"):
self._object_to_gui_dict[p.port] = p

def canvas_restart(self):
def canvas_restart(self) -> None:
self._canvas.clear()
self._canvas.fill_style = self._col_background
self._canvas.fill_rect(0, 0, self._width, self._height)

def handle_keyboard_event(self, key, shift_key, ctrl_key, meta_key):
def handle_keyboard_event(self, key: str, shift_key, ctrl_key, meta_key) -> None:
if key == "Delete":
self.delete_selected()
elif key == "m":
Expand All @@ -99,7 +107,7 @@ def handle_keyboard_event(self, key, shift_key, ctrl_key, meta_key):
elif key == "n":
self.gui.mode_dropdown.value = mode_none

def set_connection(self, ind_node):
def set_connection(self, ind_node: int) -> None:
if self._connection_in is None:
self._connection_in = ind_node
else:
Expand All @@ -114,11 +122,11 @@ def set_connection(self, ind_node):
self._connection_in = None
self.deselect_all()

def deselect_all(self):
def deselect_all(self) -> None:
[o.set_selected(False) for o in self.objects_to_draw if o.selected]
self.redraw()

def handle_mouse_down(self, x, y):
def handle_mouse_down(self, x: Number, y: Number):
sel_object = self.get_element_at_xy(x, y)
self._selected_object = sel_object
if sel_object is not None:
Expand All @@ -145,30 +153,30 @@ def handle_mouse_down(self, x, y):
self._y0_mouse = y
self.redraw()

def _handle_node_select(self, sel_object):
def _handle_node_select(self, sel_object: NodeWidget) -> None:
self._node_widget = NodeWidgets(sel_object.node, self.gui).draw()
with self.gui.out_status:
self.gui.out_status.clear_output()
display(self._node_widget)
display(self._node_widget) # PyCharm nit is invalid, display takes *args is why it claims to want a tuple

def _handle_port_select(self, sel_object):
def _handle_port_select(self, sel_object: PortWidget) -> None:
if self._last_selected_port is None:
self._last_selected_port = sel_object.port
else:
self.flow.connect_nodes(self._last_selected_port, sel_object.port)
self._last_selected_port = None
self.deselect_all()

def get_element_at_xy(self, x_in, y_in):
def get_element_at_xy(self, x_in: Number, y_in: Number) -> Union[BaseCanvasWidget, None]:
for o in self.objects_to_draw:
if o.is_selected(x_in, y_in):
return o.get_element_at_xy(x_in, y_in)
return None

def get_selected_objects(self):
def get_selected_objects(self) -> List[BaseCanvasWidget]:
return [o for o in self.objects_to_draw if o.selected]

def handle_mouse_move(self, x, y):
def handle_mouse_move(self, x: Number, y: Number) -> None:
if self.gui.mode_dropdown.value == mode_move:
# dx = x - self._x0_mouse
# dy = y - self._y0_mouse
Expand All @@ -180,15 +188,15 @@ def handle_mouse_move(self, x, y):
[o.set_x_y(x, y) for o in self.objects_to_draw if o.selected]
self.redraw()

def redraw(self):
def redraw(self) -> None:
self.canvas_restart()
with hold_canvas(self._canvas):
self.canvas_restart()
[o.draw() for o in self.objects_to_draw]
for c in self.flow.connections:
self.draw_connection(c.inp, c.out)

def load_node(self, x, y, node):
def load_node(self, x: Number, y: Number, node: Node) -> NodeWidget:
# print ('node: ', node.identifier, node.GLOBAL_ID)

layout = CanvasLayout(
Expand All @@ -213,21 +221,21 @@ def load_node(self, x, y, node):
self.objects_to_draw.append(s)
return s

def add_node(self, x, y, node):
def add_node(self, x: Number, y: Number, node: Node):
n = self.flow.create_node(node)
print("node: ", n.identifier, n.GLOBAL_ID)
self.load_node(x, y, n)

self.redraw()

def delete_selected(self):
def delete_selected(self) -> None:
for o in self.objects_to_draw:
if o.selected:
self.objects_to_draw.remove(o)
self._remove_node_from_flow(o.node)
self.redraw()

def _remove_node_from_flow(self, node):
def _remove_node_from_flow(self, node: Node) -> None:
for c in self.flow.connections[::-1]: # Reverse to make sure we traverse whole thing even if we delete
# TODO: Can we be more efficient than looping over all nodes?
if (c.inp.node == node) or (c.out.node == node):
Expand Down
30 changes: 17 additions & 13 deletions ryven/ironflow/Gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import ryven.NENV as NENV
from pathlib import Path

from typing import Optional, Dict

__author__ = "Joerg Neugebauer"
__copyright__ = (
"Copyright 2020, Max-Planck-Institut für Eisenforschung GmbH - "
Expand All @@ -35,7 +37,7 @@


class GUI(HasSession):
def __init__(self, script_title="test", session=None): # , onto_dic=onto_dic):
def __init__(self, script_title: str = "test", session: Optional[rc.Session] = None): # , onto_dic=onto_dic):
super().__init__(session=rc.Session() if session is None else session)
self._script_title = script_title
self.session.create_script(title=self.script_title)
Expand All @@ -62,13 +64,13 @@ def __init__(self, script_title="test", session=None): # , onto_dic=onto_dic):
def script_title(self) -> str:
return self._script_title

def save(self, file_path):
def save(self, file_path: str) -> None:
data = self.serialize()

with open(file_path, "w") as f:
f.write(json.dumps(data, indent=4))

def serialize(self):
def serialize(self) -> Dict:
data = self.session.serialize()
i_script = 0
all_data = data["scripts"][i_script]["flow"]["nodes"]
Expand All @@ -77,13 +79,13 @@ def serialize(self):
all_data[i]["pos y"] = node_widget.y
return data

def load(self, file_path):
def load(self, file_path: str) -> None:
with open(file_path, "r") as f:
data = json.loads(f.read())

self.load_from_data(data)

def load_from_data(self, data):
def load_from_data(self, data: Dict) -> None:
i_script = 0
self.session.delete_script(self.script)
self.session.load(data)
Expand All @@ -105,14 +107,14 @@ def load_from_data(self, data):
self.out_plot.clear_output()
self.out_log.clear_output()

def _print(self, text):
def _print(self, text: str) -> None:
with self.out_log:
self.gui.out_log.clear_output()

print(text)

@debug_view.capture(clear_output=True)
def draw(self):
def draw(self) -> widgets.VBox:
self.out_plot = widgets.Output(
layout={"width": "50%", "border": "1px solid black"}
)
Expand Down Expand Up @@ -198,20 +200,22 @@ def draw(self):
]
)

def on_file_save(self, change):
# Type hinting for unused `change` argument in callbacks taken from ipywidgets docs:
# https://ipywidgets.readthedocs.io/en/latest/examples/Widget%20Events.html#Traitlet-events
def on_file_save(self, change: Dict) -> None:
self.save(f"{self.script_title}.json")

def on_file_load(self, change):
def on_file_load(self, change: Dict) -> None:
self.load(f"{self.script_title}.json")

def on_delete_node(self, change):
def on_delete_node(self, change: Dict) -> None:
self.canvas_widget.delete_selected()

def on_value_change(self, change):
def on_value_change(self, change: Dict) -> None:
self.node_selector.options = self._nodes_dict[self.modules_dropdown.value].keys()

def on_nodes_change(self, change):
def on_nodes_change(self, change: Dict) -> None:
self._selected_node = self._nodes_dict[self.modules_dropdown.value][self.node_selector.value]

def on_alg_mode_change(self, change):
def on_alg_mode_change(self, change: Dict) -> None:
self.canvas_widget.script.flow.set_algorithm_mode(self.alg_mode_dropdown.value)
Loading