From 3174472b1692cd743463cc3c87bf2da0b4b1032f Mon Sep 17 00:00:00 2001 From: RunDevelopment Date: Sun, 23 Jul 2023 14:59:57 +0200 Subject: [PATCH 1/3] Added naming convention check --- .github/workflows/test-backend.yml | 11 ++ backend/src/api.py | 61 ++++++---- .../src/{type_checking.py => node_check.py} | 114 +++++++++++++----- package.json | 2 +- 4 files changed, 137 insertions(+), 51 deletions(-) rename backend/src/{type_checking.py => node_check.py} (69%) diff --git a/.github/workflows/test-backend.yml b/.github/workflows/test-backend.yml index f3e4ca88f..28a502a0d 100644 --- a/.github/workflows/test-backend.yml +++ b/.github/workflows/test-backend.yml @@ -46,6 +46,17 @@ jobs: env: TYPE_CHECK_LEVEL: 'error' + backend-name-check-test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: '3.9' + - run: python ./backend/src/run.py --close-after-start + env: + NAME_CHECK_LEVEL: 'error' + backend-bootstrap: runs-on: ubuntu-latest strategy: diff --git a/backend/src/api.py b/backend/src/api.py index a84f45785..d6e99969a 100644 --- a/backend/src/api.py +++ b/backend/src/api.py @@ -9,15 +9,17 @@ from base_types import InputId, OutputId from custom_types import NodeType, RunFn +from node_check import ( + NAME_CHECK_LEVEL, + TYPE_CHECK_LEVEL, + CheckFailedError, + CheckLevel, + check_naming_conventions, + check_schema_types, +) from nodes.base_input import BaseInput from nodes.base_output import BaseOutput from nodes.group import Group, GroupId, NestedGroup, NestedIdGroup -from type_checking import ( - TypeCheckLevel, - TypeMismatchError, - get_type_check_level, - typeValidateSchema, -) KB = 1024**1 MB = 1024**2 @@ -121,22 +123,35 @@ def register( if isinstance(see_also, str): see_also = [see_also] + def run_check(level: CheckLevel, run: Callable[[bool], None]): + if level == CheckLevel.NONE: + return + + try: + run(level == CheckLevel.FIX) + except CheckFailedError as e: + full_error_message = f"Error in {schema_id}: {e}" + if level == CheckLevel.ERROR: + # pylint: disable=raise-missing-from + raise CheckFailedError(full_error_message) + logger.warning(full_error_message) + def inner_wrapper(wrapped_func: T) -> T: p_inputs, group_layout = _process_inputs(inputs) p_outputs = _process_outputs(outputs) - TYPE_CHECK_LEVEL = get_type_check_level() - - if TYPE_CHECK_LEVEL != TypeCheckLevel.NONE: - try: - typeValidateSchema(wrapped_func, node_type, p_inputs, p_outputs) - except TypeMismatchError as e: - full_error_message = f"Error in {schema_id}: {e}" - if TYPE_CHECK_LEVEL == TypeCheckLevel.WARN: - logger.warning(full_error_message) - elif TYPE_CHECK_LEVEL == TypeCheckLevel.ERROR: - # pylint: disable=raise-missing-from - raise TypeMismatchError(full_error_message) + run_check( + TYPE_CHECK_LEVEL, + lambda _: check_schema_types( + wrapped_func, node_type, p_inputs, p_outputs + ), + ) + run_check( + NAME_CHECK_LEVEL, + lambda fix: check_naming_conventions( + wrapped_func, node_type, name, fix + ), + ) if decorators is not None: for decorator in decorators: @@ -268,7 +283,7 @@ def add(self, package: Package) -> Package: def load_nodes(self, current_file: str): import_errors: List[ImportError] = [] - type_errors: List[TypeMismatchError] = [] + failed_checks: List[CheckFailedError] = [] for package in list(self.packages.values()): for file_path in _iter_py_files(os.path.dirname(package.where)): @@ -285,12 +300,12 @@ def load_nodes(self, current_file: str): logger.warning(f"Failed to load {module} ({file_path}): {e}") except ValueError as e: logger.warning(f"Failed to load {module} ({file_path}): {e}") - except TypeMismatchError as e: + except CheckFailedError as e: logger.error(e) - type_errors.append(e) + failed_checks.append(e) - if len(type_errors) > 0: - raise RuntimeError(f"Type errors occurred in {len(type_errors)} node(s)") + if len(failed_checks) > 0: + raise RuntimeError(f"Checks failed in {len(failed_checks)} node(s)") logger.info(import_errors) self._refresh_nodes() diff --git a/backend/src/type_checking.py b/backend/src/node_check.py similarity index 69% rename from backend/src/type_checking.py rename to backend/src/node_check.py index f8b0a045c..27944427e 100644 --- a/backend/src/type_checking.py +++ b/backend/src/node_check.py @@ -3,6 +3,7 @@ import ast import inspect import os +import pathlib from enum import Enum from typing import Any, Callable, Dict, List, NewType, Set, Union, cast, get_args @@ -13,28 +14,42 @@ _Ty = NewType("_Ty", object) -class TypeMismatchError(Exception): +class CheckFailedError(Exception): pass -# Enum for type check level -class TypeCheckLevel(Enum): +class CheckLevel(Enum): NONE = "none" WARN = "warn" + FIX = "fix" ERROR = "error" + @staticmethod + def parse(s: str) -> CheckLevel: + s = s.strip().lower() + if s == CheckLevel.NONE.value: + return CheckLevel.NONE + elif s == CheckLevel.WARN.value: + return CheckLevel.WARN + elif s == CheckLevel.FIX.value: + return CheckLevel.FIX + elif s == CheckLevel.ERROR.value: + return CheckLevel.ERROR + else: + raise ValueError(f"Invalid check level: {s}") + + +def _get_check_level(name: str, default: CheckLevel) -> CheckLevel: + try: + s = os.environ.get(name, default.value) + return CheckLevel.parse(s) + except: + return default -# If it's stupid but it works, it's not stupid -def get_type_check_level() -> TypeCheckLevel: - type_check_level = os.environ.get("TYPE_CHECK_LEVEL", TypeCheckLevel.NONE.value) - if type_check_level.lower() == TypeCheckLevel.NONE.value: - return TypeCheckLevel.NONE - elif type_check_level.lower() == TypeCheckLevel.WARN.value: - return TypeCheckLevel.WARN - elif type_check_level.lower() == TypeCheckLevel.ERROR.value: - return TypeCheckLevel.ERROR - else: - return TypeCheckLevel.NONE + +CHECK_LEVEL = _get_check_level("CHECK_LEVEL", CheckLevel.NONE) +NAME_CHECK_LEVEL = _get_check_level("NAME_CHECK_LEVEL", CHECK_LEVEL) +TYPE_CHECK_LEVEL = _get_check_level("TYPE_CHECK_LEVEL", CHECK_LEVEL) class TypeTransformer(ast.NodeTransformer): @@ -120,7 +135,7 @@ def get_type_annotations(fn: Callable) -> Dict[str, _Ty]: def validate_return_type(return_type: _Ty, outputs: list[BaseOutput]): if len(outputs) == 0: if return_type is not None: # type: ignore - raise TypeMismatchError( + raise CheckFailedError( f"Return type should be 'None' because there are no outputs" ) elif len(outputs) == 1: @@ -128,18 +143,18 @@ def validate_return_type(return_type: _Ty, outputs: list[BaseOutput]): if o.associated_type is not None and not is_subset_of( return_type, o.associated_type ): - raise TypeMismatchError( + raise CheckFailedError( f"Return type '{return_type}' must be a subset of '{o.associated_type}'" ) else: if not str(return_type).startswith("typing.Tuple["): - raise TypeMismatchError( + raise CheckFailedError( f"Return type '{return_type}' must be a tuple because there are multiple outputs" ) return_args = get_args(return_type) if len(return_args) != len(outputs): - raise TypeMismatchError( + raise CheckFailedError( f"Return type '{return_type}' must have the same number of arguments as there are outputs" ) @@ -147,12 +162,12 @@ def validate_return_type(return_type: _Ty, outputs: list[BaseOutput]): if o.associated_type is not None and not is_subset_of( return_arg, o.associated_type ): - raise TypeMismatchError( + raise CheckFailedError( f"Return type of {o.label} '{return_arg}' must be a subset of '{o.associated_type}'" ) -def typeValidateSchema( +def check_schema_types( wrapped_func: Callable, node_type: NodeType, inputs: list[BaseInput], @@ -173,7 +188,7 @@ def typeValidateSchema( arg_spec = inspect.getfullargspec(wrapped_func) for arg in arg_spec.args: if not arg in ann: - raise TypeMismatchError(f"Missing type annotation for '{arg}'") + raise CheckFailedError(f"Missing type annotation for '{arg}'") if node_type == "iteratorHelper": # iterator helpers have inputs that do not describe the arguments of the function, so we can't check them @@ -184,13 +199,13 @@ def typeValidateSchema( context = [*ann.keys()][-1] context_type = ann.pop(context) if str(context_type) != "": - raise TypeMismatchError( + raise CheckFailedError( f"Last argument of an iterator must be an IteratorContext, not '{context_type}'" ) if arg_spec.varargs is not None: if not arg_spec.varargs in ann: - raise TypeMismatchError(f"Missing type annotation for '{arg_spec.varargs}'") + raise CheckFailedError(f"Missing type annotation for '{arg_spec.varargs}'") va_type = ann.pop(arg_spec.varargs) # split inputs by varargs and non-varargs @@ -203,7 +218,7 @@ def typeValidateSchema( if associated_type is not None: if not is_subset_of(associated_type, va_type): - raise TypeMismatchError( + raise CheckFailedError( f"Input type of {i.label} '{associated_type}' is not assignable to varargs type '{va_type}'" ) @@ -217,17 +232,62 @@ def typeValidateSchema( if total is not None: total_type = union_types(total) if total_type != va_type: - raise TypeMismatchError( + raise CheckFailedError( f"Varargs type '{va_type}' should be equal to the union of all arguments '{total_type}'" ) if len(ann) != len(inputs): - raise TypeMismatchError( + raise CheckFailedError( f"Number of inputs and arguments don't match: {len(ann)=} != {len(inputs)=}" ) for (a_name, a_type), i in zip(ann.items(), inputs): associated_type = i.associated_type if associated_type is not None and a_type != associated_type: - raise TypeMismatchError( + raise CheckFailedError( f"Expected type of {i.label} ({a_name}) to be '{associated_type}' but found '{a_type}'" ) + + +def check_naming_conventions( + wrapped_func: Callable, + node_type: NodeType, + name: str, + fix: bool, +): + expected_name = ( + name.lower() + .replace(" (iterator)", "") + .replace(" ", "_") + .replace("-", "_") + .replace("(", "") + .replace(")", "") + .replace("&", "and") + ) + + if node_type == "iteratorHelper": + expected_name = "iterator_helper_" + expected_name + + func_name = wrapped_func.__name__ + file_path = pathlib.Path(inspect.getfile(wrapped_func)) + file_name = file_path.stem + + # check function name + if func_name != expected_name + "_node": + if not fix: + raise CheckFailedError( + f"Function name is '{func_name}', but it should be '{expected_name}_node'" + ) + + fixed_code = file_path.read_text(encoding="utf-8").replace( + f"def {func_name}(", f"def {expected_name}_node(" + ) + file_path.write_text(fixed_code, encoding="utf-8") + + # check file name + if node_type != "iteratorHelper" and file_name != expected_name: + if not fix: + raise CheckFailedError( + f"File name is '{file_name}.py', but it should be '{expected_name}.py'" + ) + + os.rename(file_path, file_path.with_name(expected_name + ".py")) diff --git a/package.json b/package.json index 2cd85e6aa..e6eea3d2f 100644 --- a/package.json +++ b/package.json @@ -7,7 +7,7 @@ "scripts": { "start": "electron-forge start -- --devtools", "frontend": "electron-forge start -- --remote-backend=http://127.0.0.1:8000 --devtools", - "dev": "concurrently \"cd backend/src && cross-env TYPE_CHECK_LEVEL=warn nodemon ./run.py 8000\" \"electron-forge start -- --remote-backend=http://127.0.0.1:8000 --refresh --devtools\"", + "dev": "concurrently \"cd backend/src && cross-env CHECK_LEVEL=fix nodemon ./run.py 8000\" \"electron-forge start -- --remote-backend=http://127.0.0.1:8000 --refresh --devtools\"", "debug": "concurrently \"npm run debug:py\" \"electron-forge start -- --remote-backend=http://127.0.0.1:8000 --refresh --devtools\"", "debug:py": "cd backend/src && nodemon --exec \"python -m debugpy --listen 5678\" ./run.py 8000", "package": "cross-env NODE_ENV=production electron-forge package", From 66a53f8b735da4c269a3e05c690f3e58de1ceed7 Mon Sep 17 00:00:00 2001 From: RunDevelopment Date: Sun, 23 Jul 2023 15:17:37 +0200 Subject: [PATCH 2/3] Invalid name --- .../chaiNNer_standard/image_utility/modification/shift.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/src/packages/chaiNNer_standard/image_utility/modification/shift.py b/backend/src/packages/chaiNNer_standard/image_utility/modification/shift.py index 59480dc75..9a02c8303 100644 --- a/backend/src/packages/chaiNNer_standard/image_utility/modification/shift.py +++ b/backend/src/packages/chaiNNer_standard/image_utility/modification/shift.py @@ -32,7 +32,7 @@ ) ], ) -def shift_node( +def shift_amount_node( img: np.ndarray, amount_x: int, amount_y: int, From ba482ca2b326d14ceb39eaf023139474fc72fdc5 Mon Sep 17 00:00:00 2001 From: RunDevelopment Date: Sun, 23 Jul 2023 15:19:54 +0200 Subject: [PATCH 3/3] Fixed node name --- .../chaiNNer_standard/image_utility/modification/shift.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/src/packages/chaiNNer_standard/image_utility/modification/shift.py b/backend/src/packages/chaiNNer_standard/image_utility/modification/shift.py index 9a02c8303..59480dc75 100644 --- a/backend/src/packages/chaiNNer_standard/image_utility/modification/shift.py +++ b/backend/src/packages/chaiNNer_standard/image_utility/modification/shift.py @@ -32,7 +32,7 @@ ) ], ) -def shift_amount_node( +def shift_node( img: np.ndarray, amount_x: int, amount_y: int,