diff --git a/backend/src/api.py b/backend/src/api.py index 7d528b29f..98649c061 100644 --- a/backend/src/api.py +++ b/backend/src/api.py @@ -95,7 +95,6 @@ class NodeData: side_effects: bool deprecated: bool - default_nodes: List[DefaultNode] | None # For iterators only features: List[FeatureId] run: RunFn @@ -126,7 +125,6 @@ def register( node_type: NodeType = "regularNode", side_effects: bool = False, deprecated: bool = False, - default_nodes: List[DefaultNode] | None = None, decorators: List[Callable] | None = None, see_also: List[str] | str | None = None, features: List[FeatureId] | FeatureId | None = None, @@ -197,7 +195,6 @@ def inner_wrapper(wrapped_func: T) -> T: outputs=p_outputs, side_effects=side_effects, deprecated=deprecated, - default_nodes=default_nodes, features=features, run=wrapped_func, ) diff --git a/backend/src/process.py b/backend/src/process.py index 73376e016..9cbc05bc4 100644 --- a/backend/src/process.py +++ b/backend/src/process.py @@ -5,9 +5,8 @@ import gc import time import uuid -from collections.abc import Awaitable from concurrent.futures import ThreadPoolExecutor -from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, TypeVar +from typing import Dict, Iterable, List, Optional, Set from sanic.log import logger @@ -19,6 +18,7 @@ from events import Event, EventQueue, InputsDict from nodes.base_output import BaseOutput from progress_controller import Aborted, ProgressController +from util import timed_supplier Output = List[object] @@ -144,9 +144,6 @@ def compute_broadcast(output: Output, node_outputs: Iterable[BaseOutput]): return data, types -T = TypeVar("T") - - class NodeExecutionError(Exception): def __init__( self, @@ -161,23 +158,6 @@ def __init__( self.inputs: InputsDict = inputs -def timed_supplier(supplier: Callable[[], T]) -> Callable[[], Tuple[T, float]]: - def wrapper(): - start = time.time() - result = supplier() - duration = time.time() - start - return result, duration - - return wrapper - - -async def timed_supplier_async(supplier: Callable[[], Awaitable[T]]) -> Tuple[T, float]: - start = time.time() - result = await supplier() - duration = time.time() - start - return result, duration - - class Executor: """ Class for executing chaiNNer's processing logic @@ -412,7 +392,7 @@ def __get_upstream_nodes(self, node: NodeId) -> Set[NodeId]: upstream_nodes.extend(self.__get_upstream_nodes(upstream_node)) return set(upstream_nodes) - async def __process_nodes(self, nodes: List[NodeId]): + async def __process_nodes(self): await self.progress.suspend() iterator_node_set = set() @@ -448,7 +428,7 @@ async def __process_nodes(self, nodes: List[NodeId]): before_iteration_time = time.time() # Now run each of the iterators - for iterator_node in nodes: + for iterator_node in self.__get_iterator_nodes(): # Get all downstream nodes of the iterator # This excludes any nodes that are downstream of a collector, as well as collectors themselves downstream_nodes = [ @@ -576,7 +556,7 @@ async def __process_nodes(self, nodes: List[NodeId]): async def run(self): logger.debug(f"Running executor {self.execution_id}") try: - await self.__process_nodes(self.__get_iterator_nodes()) + await self.__process_nodes() finally: gc.collect() diff --git a/backend/src/server.py b/backend/src/server.py index 775a7ca4d..35b57177c 100644 --- a/backend/src/server.py +++ b/backend/src/server.py @@ -33,14 +33,7 @@ JsonExecutionOptions, set_execution_options, ) -from process import ( - Executor, - NodeExecutionError, - Output, - compute_broadcast, - run_node, - timed_supplier, -) +from process import Executor, NodeExecutionError, Output, compute_broadcast, run_node from progress_controller import Aborted from response import ( alreadyRunningResponse, @@ -50,6 +43,7 @@ ) from server_config import ServerConfig from system import is_arm_mac +from util import timed_supplier class AppContext: @@ -132,7 +126,6 @@ async def nodes(_request: Request): "nodeType": node.type, "hasSideEffects": node.side_effects, "deprecated": node.deprecated, - "defaultNodes": node.default_nodes, "features": node.features, } node_list.append(node_dict) diff --git a/backend/src/util.py b/backend/src/util.py new file mode 100644 index 000000000..4c7d3a9aa --- /dev/null +++ b/backend/src/util.py @@ -0,0 +1,14 @@ +import time +from typing import Callable, Tuple, TypeVar + +T = TypeVar("T") + + +def timed_supplier(supplier: Callable[[], T]) -> Callable[[], Tuple[T, float]]: + def wrapper(): + start = time.time() + result = supplier() + duration = time.time() - start + return result, duration + + return wrapper diff --git a/src/common/common-types.ts b/src/common/common-types.ts index 57d442f96..38fd9fdec 100644 --- a/src/common/common-types.ts +++ b/src/common/common-types.ts @@ -253,7 +253,6 @@ export interface NodeSchema { readonly inputs: readonly Input[]; readonly outputs: readonly Output[]; readonly groupLayout: readonly (InputId | Group)[]; - readonly defaultNodes?: readonly DefaultNode[] | null; readonly schemaId: SchemaId; readonly hasSideEffects: boolean; readonly deprecated: boolean; diff --git a/src/renderer/components/NodeDocumentation/NodeDocs.tsx b/src/renderer/components/NodeDocumentation/NodeDocs.tsx index eb47feca7..16a40889b 100644 --- a/src/renderer/components/NodeDocumentation/NodeDocs.tsx +++ b/src/renderer/components/NodeDocumentation/NodeDocs.tsx @@ -376,16 +376,13 @@ interface NodeDocsProps { schema: NodeSchema; } export const NodeDocs = memo(({ schema }: NodeDocsProps) => { - const { schemata, functionDefinitions, categories } = useContext(BackendContext); + const { functionDefinitions, categories } = useContext(BackendContext); const selectedAccentColor = getCategoryAccentColor(categories, schema.category); const [isLargerThan1200] = useMediaQuery('(min-width: 1200px)'); - const nodeDocsToShow = [ - schema, - ...(schema.defaultNodes?.map((n) => schemata.get(n.schemaId)) ?? []), - ]; + const nodeFunctionDefinition = functionDefinitions.get(schema.schemaId); return ( { textAlign="left" w="full" > - {nodeDocsToShow.map((nodeSchema) => { - const nodeFunctionDefinition = functionDefinitions.get( - nodeSchema.schemaId - ); - return ( - + +
+ - -
- - - -
- - ); - })} +
+
+
diff --git a/src/renderer/components/NodeDocumentation/NodeDocumentationModal.tsx b/src/renderer/components/NodeDocumentation/NodeDocumentationModal.tsx index 4f22f9bf7..3b4c51b01 100644 --- a/src/renderer/components/NodeDocumentation/NodeDocumentationModal.tsx +++ b/src/renderer/components/NodeDocumentation/NodeDocumentationModal.tsx @@ -59,16 +59,6 @@ const NodeDocumentationModal = memo(() => { const selectedSchema = schemata.get(selectedSchemaId); // search - const helperNodeMapping = useMemo(() => { - const mapping = new Map(); - for (const schema of schemata.schemata) { - for (const helper of schema.defaultNodes ?? []) { - mapping.set(helper.schemaId, schema.schemaId); - } - } - return mapping; - }, [schemata]); - const searchIndex = useMemo(() => createSearchIndex(schemata.schemata), [schemata.schemata]); const [searchQuery, setSearchQuery] = useState(''); const { searchScores, searchTerms } = useMemo(() => { @@ -100,16 +90,14 @@ const NodeDocumentationModal = memo(() => { useEffect(() => { if (searchScores && searchScores.size > 0) { const highestScore = Math.max(...searchScores.values()); - let highestScoreSchemaId = [...searchScores.entries()].find( + const highestScoreSchemaId = [...searchScores.entries()].find( ([, score]) => score === highestScore )?.[0]; if (highestScoreSchemaId) { - highestScoreSchemaId = - helperNodeMapping.get(highestScoreSchemaId) ?? highestScoreSchemaId; openNodeDocumentation(highestScoreSchemaId); } } - }, [searchScores, helperNodeMapping, openNodeDocumentation]); + }, [searchScores, openNodeDocumentation]); return (