diff --git a/backend/src/api.py b/backend/src/api.py
index 7d528b29f..98649c061 100644
--- a/backend/src/api.py
+++ b/backend/src/api.py
@@ -95,7 +95,6 @@ class NodeData:
side_effects: bool
deprecated: bool
- default_nodes: List[DefaultNode] | None # For iterators only
features: List[FeatureId]
run: RunFn
@@ -126,7 +125,6 @@ def register(
node_type: NodeType = "regularNode",
side_effects: bool = False,
deprecated: bool = False,
- default_nodes: List[DefaultNode] | None = None,
decorators: List[Callable] | None = None,
see_also: List[str] | str | None = None,
features: List[FeatureId] | FeatureId | None = None,
@@ -197,7 +195,6 @@ def inner_wrapper(wrapped_func: T) -> T:
outputs=p_outputs,
side_effects=side_effects,
deprecated=deprecated,
- default_nodes=default_nodes,
features=features,
run=wrapped_func,
)
diff --git a/backend/src/process.py b/backend/src/process.py
index 73376e016..9cbc05bc4 100644
--- a/backend/src/process.py
+++ b/backend/src/process.py
@@ -5,9 +5,8 @@
import gc
import time
import uuid
-from collections.abc import Awaitable
from concurrent.futures import ThreadPoolExecutor
-from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, TypeVar
+from typing import Dict, Iterable, List, Optional, Set
from sanic.log import logger
@@ -19,6 +18,7 @@
from events import Event, EventQueue, InputsDict
from nodes.base_output import BaseOutput
from progress_controller import Aborted, ProgressController
+from util import timed_supplier
Output = List[object]
@@ -144,9 +144,6 @@ def compute_broadcast(output: Output, node_outputs: Iterable[BaseOutput]):
return data, types
-T = TypeVar("T")
-
-
class NodeExecutionError(Exception):
def __init__(
self,
@@ -161,23 +158,6 @@ def __init__(
self.inputs: InputsDict = inputs
-def timed_supplier(supplier: Callable[[], T]) -> Callable[[], Tuple[T, float]]:
- def wrapper():
- start = time.time()
- result = supplier()
- duration = time.time() - start
- return result, duration
-
- return wrapper
-
-
-async def timed_supplier_async(supplier: Callable[[], Awaitable[T]]) -> Tuple[T, float]:
- start = time.time()
- result = await supplier()
- duration = time.time() - start
- return result, duration
-
-
class Executor:
"""
Class for executing chaiNNer's processing logic
@@ -412,7 +392,7 @@ def __get_upstream_nodes(self, node: NodeId) -> Set[NodeId]:
upstream_nodes.extend(self.__get_upstream_nodes(upstream_node))
return set(upstream_nodes)
- async def __process_nodes(self, nodes: List[NodeId]):
+ async def __process_nodes(self):
await self.progress.suspend()
iterator_node_set = set()
@@ -448,7 +428,7 @@ async def __process_nodes(self, nodes: List[NodeId]):
before_iteration_time = time.time()
# Now run each of the iterators
- for iterator_node in nodes:
+ for iterator_node in self.__get_iterator_nodes():
# Get all downstream nodes of the iterator
# This excludes any nodes that are downstream of a collector, as well as collectors themselves
downstream_nodes = [
@@ -576,7 +556,7 @@ async def __process_nodes(self, nodes: List[NodeId]):
async def run(self):
logger.debug(f"Running executor {self.execution_id}")
try:
- await self.__process_nodes(self.__get_iterator_nodes())
+ await self.__process_nodes()
finally:
gc.collect()
diff --git a/backend/src/server.py b/backend/src/server.py
index 775a7ca4d..35b57177c 100644
--- a/backend/src/server.py
+++ b/backend/src/server.py
@@ -33,14 +33,7 @@
JsonExecutionOptions,
set_execution_options,
)
-from process import (
- Executor,
- NodeExecutionError,
- Output,
- compute_broadcast,
- run_node,
- timed_supplier,
-)
+from process import Executor, NodeExecutionError, Output, compute_broadcast, run_node
from progress_controller import Aborted
from response import (
alreadyRunningResponse,
@@ -50,6 +43,7 @@
)
from server_config import ServerConfig
from system import is_arm_mac
+from util import timed_supplier
class AppContext:
@@ -132,7 +126,6 @@ async def nodes(_request: Request):
"nodeType": node.type,
"hasSideEffects": node.side_effects,
"deprecated": node.deprecated,
- "defaultNodes": node.default_nodes,
"features": node.features,
}
node_list.append(node_dict)
diff --git a/backend/src/util.py b/backend/src/util.py
new file mode 100644
index 000000000..4c7d3a9aa
--- /dev/null
+++ b/backend/src/util.py
@@ -0,0 +1,14 @@
+import time
+from typing import Callable, Tuple, TypeVar
+
+T = TypeVar("T")
+
+
+def timed_supplier(supplier: Callable[[], T]) -> Callable[[], Tuple[T, float]]:
+ def wrapper():
+ start = time.time()
+ result = supplier()
+ duration = time.time() - start
+ return result, duration
+
+ return wrapper
diff --git a/src/common/common-types.ts b/src/common/common-types.ts
index 57d442f96..38fd9fdec 100644
--- a/src/common/common-types.ts
+++ b/src/common/common-types.ts
@@ -253,7 +253,6 @@ export interface NodeSchema {
readonly inputs: readonly Input[];
readonly outputs: readonly Output[];
readonly groupLayout: readonly (InputId | Group)[];
- readonly defaultNodes?: readonly DefaultNode[] | null;
readonly schemaId: SchemaId;
readonly hasSideEffects: boolean;
readonly deprecated: boolean;
diff --git a/src/renderer/components/NodeDocumentation/NodeDocs.tsx b/src/renderer/components/NodeDocumentation/NodeDocs.tsx
index eb47feca7..16a40889b 100644
--- a/src/renderer/components/NodeDocumentation/NodeDocs.tsx
+++ b/src/renderer/components/NodeDocumentation/NodeDocs.tsx
@@ -376,16 +376,13 @@ interface NodeDocsProps {
schema: NodeSchema;
}
export const NodeDocs = memo(({ schema }: NodeDocsProps) => {
- const { schemata, functionDefinitions, categories } = useContext(BackendContext);
+ const { functionDefinitions, categories } = useContext(BackendContext);
const selectedAccentColor = getCategoryAccentColor(categories, schema.category);
const [isLargerThan1200] = useMediaQuery('(min-width: 1200px)');
- const nodeDocsToShow = [
- schema,
- ...(schema.defaultNodes?.map((n) => schemata.get(n.schemaId)) ?? []),
- ];
+ const nodeFunctionDefinition = functionDefinitions.get(schema.schemaId);
return (
{
textAlign="left"
w="full"
>
- {nodeDocsToShow.map((nodeSchema) => {
- const nodeFunctionDefinition = functionDefinitions.get(
- nodeSchema.schemaId
- );
- return (
-
+
+
+
-
-
-
-
-
-
-
- );
- })}
+
+
+
diff --git a/src/renderer/components/NodeDocumentation/NodeDocumentationModal.tsx b/src/renderer/components/NodeDocumentation/NodeDocumentationModal.tsx
index 4f22f9bf7..3b4c51b01 100644
--- a/src/renderer/components/NodeDocumentation/NodeDocumentationModal.tsx
+++ b/src/renderer/components/NodeDocumentation/NodeDocumentationModal.tsx
@@ -59,16 +59,6 @@ const NodeDocumentationModal = memo(() => {
const selectedSchema = schemata.get(selectedSchemaId);
// search
- const helperNodeMapping = useMemo(() => {
- const mapping = new Map();
- for (const schema of schemata.schemata) {
- for (const helper of schema.defaultNodes ?? []) {
- mapping.set(helper.schemaId, schema.schemaId);
- }
- }
- return mapping;
- }, [schemata]);
-
const searchIndex = useMemo(() => createSearchIndex(schemata.schemata), [schemata.schemata]);
const [searchQuery, setSearchQuery] = useState('');
const { searchScores, searchTerms } = useMemo(() => {
@@ -100,16 +90,14 @@ const NodeDocumentationModal = memo(() => {
useEffect(() => {
if (searchScores && searchScores.size > 0) {
const highestScore = Math.max(...searchScores.values());
- let highestScoreSchemaId = [...searchScores.entries()].find(
+ const highestScoreSchemaId = [...searchScores.entries()].find(
([, score]) => score === highestScore
)?.[0];
if (highestScoreSchemaId) {
- highestScoreSchemaId =
- helperNodeMapping.get(highestScoreSchemaId) ?? highestScoreSchemaId;
openNodeDocumentation(highestScoreSchemaId);
}
}
- }, [searchScores, helperNodeMapping, openNodeDocumentation]);
+ }, [searchScores, openNodeDocumentation]);
return (