Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added naming convention check #1969

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions .github/workflows/test-backend.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,17 @@ jobs:
env:
TYPE_CHECK_LEVEL: 'error'

backend-name-check-test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: '3.9'
- run: python ./backend/src/run.py --close-after-start
env:
NAME_CHECK_LEVEL: 'error'

backend-bootstrap:
runs-on: ubuntu-latest
strategy:
Expand Down
61 changes: 38 additions & 23 deletions backend/src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,17 @@

from base_types import InputId, OutputId
from custom_types import NodeType, RunFn
from node_check import (
NAME_CHECK_LEVEL,
TYPE_CHECK_LEVEL,
CheckFailedError,
CheckLevel,
check_naming_conventions,
check_schema_types,
)
from nodes.base_input import BaseInput
from nodes.base_output import BaseOutput
from nodes.group import Group, GroupId, NestedGroup, NestedIdGroup
from type_checking import (
TypeCheckLevel,
TypeMismatchError,
get_type_check_level,
typeValidateSchema,
)

KB = 1024**1
MB = 1024**2
Expand Down Expand Up @@ -121,22 +123,35 @@ def register(
if isinstance(see_also, str):
see_also = [see_also]

def run_check(level: CheckLevel, run: Callable[[bool], None]):
if level == CheckLevel.NONE:
return

try:
run(level == CheckLevel.FIX)
except CheckFailedError as e:
full_error_message = f"Error in {schema_id}: {e}"
if level == CheckLevel.ERROR:
# pylint: disable=raise-missing-from
raise CheckFailedError(full_error_message)
logger.warning(full_error_message)

def inner_wrapper(wrapped_func: T) -> T:
p_inputs, group_layout = _process_inputs(inputs)
p_outputs = _process_outputs(outputs)

TYPE_CHECK_LEVEL = get_type_check_level()

if TYPE_CHECK_LEVEL != TypeCheckLevel.NONE:
try:
typeValidateSchema(wrapped_func, node_type, p_inputs, p_outputs)
except TypeMismatchError as e:
full_error_message = f"Error in {schema_id}: {e}"
if TYPE_CHECK_LEVEL == TypeCheckLevel.WARN:
logger.warning(full_error_message)
elif TYPE_CHECK_LEVEL == TypeCheckLevel.ERROR:
# pylint: disable=raise-missing-from
raise TypeMismatchError(full_error_message)
run_check(
TYPE_CHECK_LEVEL,
lambda _: check_schema_types(
wrapped_func, node_type, p_inputs, p_outputs
),
)
run_check(
NAME_CHECK_LEVEL,
lambda fix: check_naming_conventions(
wrapped_func, node_type, name, fix
),
)

if decorators is not None:
for decorator in decorators:
Expand Down Expand Up @@ -268,7 +283,7 @@ def add(self, package: Package) -> Package:

def load_nodes(self, current_file: str):
import_errors: List[ImportError] = []
type_errors: List[TypeMismatchError] = []
failed_checks: List[CheckFailedError] = []

for package in list(self.packages.values()):
for file_path in _iter_py_files(os.path.dirname(package.where)):
Expand All @@ -285,12 +300,12 @@ def load_nodes(self, current_file: str):
logger.warning(f"Failed to load {module} ({file_path}): {e}")
except ValueError as e:
logger.warning(f"Failed to load {module} ({file_path}): {e}")
except TypeMismatchError as e:
except CheckFailedError as e:
logger.error(e)
type_errors.append(e)
failed_checks.append(e)

if len(type_errors) > 0:
raise RuntimeError(f"Type errors occurred in {len(type_errors)} node(s)")
if len(failed_checks) > 0:
raise RuntimeError(f"Checks failed in {len(failed_checks)} node(s)")

logger.info(import_errors)
self._refresh_nodes()
Expand Down
114 changes: 87 additions & 27 deletions backend/src/type_checking.py → backend/src/node_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import ast
import inspect
import os
import pathlib
from enum import Enum
from typing import Any, Callable, Dict, List, NewType, Set, Union, cast, get_args

Expand All @@ -13,28 +14,42 @@
_Ty = NewType("_Ty", object)


class TypeMismatchError(Exception):
class CheckFailedError(Exception):
pass


# Enum for type check level
class TypeCheckLevel(Enum):
class CheckLevel(Enum):
NONE = "none"
WARN = "warn"
FIX = "fix"
ERROR = "error"

@staticmethod
def parse(s: str) -> CheckLevel:
s = s.strip().lower()
if s == CheckLevel.NONE.value:
return CheckLevel.NONE
elif s == CheckLevel.WARN.value:
return CheckLevel.WARN
elif s == CheckLevel.FIX.value:
return CheckLevel.FIX
elif s == CheckLevel.ERROR.value:
return CheckLevel.ERROR
else:
raise ValueError(f"Invalid check level: {s}")


def _get_check_level(name: str, default: CheckLevel) -> CheckLevel:
try:
s = os.environ.get(name, default.value)
return CheckLevel.parse(s)
except:
return default

# If it's stupid but it works, it's not stupid
def get_type_check_level() -> TypeCheckLevel:
type_check_level = os.environ.get("TYPE_CHECK_LEVEL", TypeCheckLevel.NONE.value)
if type_check_level.lower() == TypeCheckLevel.NONE.value:
return TypeCheckLevel.NONE
elif type_check_level.lower() == TypeCheckLevel.WARN.value:
return TypeCheckLevel.WARN
elif type_check_level.lower() == TypeCheckLevel.ERROR.value:
return TypeCheckLevel.ERROR
else:
return TypeCheckLevel.NONE

CHECK_LEVEL = _get_check_level("CHECK_LEVEL", CheckLevel.NONE)
NAME_CHECK_LEVEL = _get_check_level("NAME_CHECK_LEVEL", CHECK_LEVEL)
TYPE_CHECK_LEVEL = _get_check_level("TYPE_CHECK_LEVEL", CHECK_LEVEL)


class TypeTransformer(ast.NodeTransformer):
Expand Down Expand Up @@ -120,39 +135,39 @@ def get_type_annotations(fn: Callable) -> Dict[str, _Ty]:
def validate_return_type(return_type: _Ty, outputs: list[BaseOutput]):
if len(outputs) == 0:
if return_type is not None: # type: ignore
raise TypeMismatchError(
raise CheckFailedError(
f"Return type should be 'None' because there are no outputs"
)
elif len(outputs) == 1:
o = outputs[0]
if o.associated_type is not None and not is_subset_of(
return_type, o.associated_type
):
raise TypeMismatchError(
raise CheckFailedError(
f"Return type '{return_type}' must be a subset of '{o.associated_type}'"
)
else:
if not str(return_type).startswith("typing.Tuple["):
raise TypeMismatchError(
raise CheckFailedError(
f"Return type '{return_type}' must be a tuple because there are multiple outputs"
)

return_args = get_args(return_type)
if len(return_args) != len(outputs):
raise TypeMismatchError(
raise CheckFailedError(
f"Return type '{return_type}' must have the same number of arguments as there are outputs"
)

for o, return_arg in zip(outputs, return_args):
if o.associated_type is not None and not is_subset_of(
return_arg, o.associated_type
):
raise TypeMismatchError(
raise CheckFailedError(
f"Return type of {o.label} '{return_arg}' must be a subset of '{o.associated_type}'"
)


def typeValidateSchema(
def check_schema_types(
wrapped_func: Callable,
node_type: NodeType,
inputs: list[BaseInput],
Expand All @@ -173,7 +188,7 @@ def typeValidateSchema(
arg_spec = inspect.getfullargspec(wrapped_func)
for arg in arg_spec.args:
if not arg in ann:
raise TypeMismatchError(f"Missing type annotation for '{arg}'")
raise CheckFailedError(f"Missing type annotation for '{arg}'")

if node_type == "iteratorHelper":
# iterator helpers have inputs that do not describe the arguments of the function, so we can't check them
Expand All @@ -184,13 +199,13 @@ def typeValidateSchema(
context = [*ann.keys()][-1]
context_type = ann.pop(context)
if str(context_type) != "<class 'process.IteratorContext'>":
raise TypeMismatchError(
raise CheckFailedError(
f"Last argument of an iterator must be an IteratorContext, not '{context_type}'"
)

if arg_spec.varargs is not None:
if not arg_spec.varargs in ann:
raise TypeMismatchError(f"Missing type annotation for '{arg_spec.varargs}'")
raise CheckFailedError(f"Missing type annotation for '{arg_spec.varargs}'")
va_type = ann.pop(arg_spec.varargs)

# split inputs by varargs and non-varargs
Expand All @@ -203,7 +218,7 @@ def typeValidateSchema(

if associated_type is not None:
if not is_subset_of(associated_type, va_type):
raise TypeMismatchError(
raise CheckFailedError(
f"Input type of {i.label} '{associated_type}' is not assignable to varargs type '{va_type}'"
)

Expand All @@ -217,17 +232,62 @@ def typeValidateSchema(
if total is not None:
total_type = union_types(total)
if total_type != va_type:
raise TypeMismatchError(
raise CheckFailedError(
f"Varargs type '{va_type}' should be equal to the union of all arguments '{total_type}'"
)

if len(ann) != len(inputs):
raise TypeMismatchError(
raise CheckFailedError(
f"Number of inputs and arguments don't match: {len(ann)=} != {len(inputs)=}"
)
for (a_name, a_type), i in zip(ann.items(), inputs):
associated_type = i.associated_type
if associated_type is not None and a_type != associated_type:
raise TypeMismatchError(
raise CheckFailedError(
f"Expected type of {i.label} ({a_name}) to be '{associated_type}' but found '{a_type}'"
)


def check_naming_conventions(
wrapped_func: Callable,
node_type: NodeType,
name: str,
fix: bool,
):
expected_name = (
name.lower()
.replace(" (iterator)", "")
.replace(" ", "_")
.replace("-", "_")
.replace("(", "")
.replace(")", "")
.replace("&", "and")
)

if node_type == "iteratorHelper":
expected_name = "iterator_helper_" + expected_name

func_name = wrapped_func.__name__
file_path = pathlib.Path(inspect.getfile(wrapped_func))
file_name = file_path.stem

# check function name
if func_name != expected_name + "_node":
if not fix:
raise CheckFailedError(
f"Function name is '{func_name}', but it should be '{expected_name}_node'"
)

fixed_code = file_path.read_text(encoding="utf-8").replace(
f"def {func_name}(", f"def {expected_name}_node("
)
file_path.write_text(fixed_code, encoding="utf-8")

# check file name
if node_type != "iteratorHelper" and file_name != expected_name:
if not fix:
raise CheckFailedError(
f"File name is '{file_name}.py', but it should be '{expected_name}.py'"
)

os.rename(file_path, file_path.with_name(expected_name + ".py"))
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"scripts": {
"start": "electron-forge start -- --devtools",
"frontend": "electron-forge start -- --remote-backend=http://127.0.0.1:8000 --devtools",
"dev": "concurrently \"cd backend/src && cross-env TYPE_CHECK_LEVEL=warn nodemon ./run.py 8000\" \"electron-forge start -- --remote-backend=http://127.0.0.1:8000 --refresh --devtools\"",
"dev": "concurrently \"cd backend/src && cross-env CHECK_LEVEL=fix nodemon ./run.py 8000\" \"electron-forge start -- --remote-backend=http://127.0.0.1:8000 --refresh --devtools\"",
"debug": "concurrently \"npm run debug:py\" \"electron-forge start -- --remote-backend=http://127.0.0.1:8000 --refresh --devtools\"",
"debug:py": "cd backend/src && nodemon --exec \"python -m debugpy --listen 5678\" ./run.py 8000",
"package": "cross-env NODE_ENV=production electron-forge package",
Expand Down