diff --git a/backend/src/api/api.py b/backend/src/api/api.py index 2ede1f3bdf..5277c4bd00 100644 --- a/backend/src/api/api.py +++ b/backend/src/api/api.py @@ -24,7 +24,7 @@ check_naming_conventions, check_schema_types, ) -from .node_data import IteratorInputInfo, IteratorOutputInfo, NodeData +from .node_data import IteratorInputInfo, IteratorOutputInfo, KeyInfo, NodeData from .output import BaseOutput from .settings import Setting from .types import FeatureId, InputId, NodeId, NodeKind, OutputId, RunFn @@ -113,6 +113,7 @@ def register( iterator_inputs: list[IteratorInputInfo] | IteratorInputInfo | None = None, iterator_outputs: list[IteratorOutputInfo] | IteratorOutputInfo | None = None, node_context: bool = False, + key_info: KeyInfo | None = None, ): if not isinstance(description, str): description = "\n\n".join(description) @@ -181,6 +182,7 @@ def inner_wrapper(wrapped_func: T) -> T: outputs=p_outputs, iterator_inputs=iterator_inputs, iterator_outputs=iterator_outputs, + key_info=key_info, side_effects=side_effects, deprecated=deprecated, node_context=node_context, diff --git a/backend/src/api/node_data.py b/backend/src/api/node_data.py index 3644e2af9d..a05cb58786 100644 --- a/backend/src/api/node_data.py +++ b/backend/src/api/node_data.py @@ -1,6 +1,7 @@ from __future__ import annotations from dataclasses import dataclass +from typing import Any import navi @@ -50,6 +51,22 @@ def to_dict(self): } +class KeyInfo: + def __init__(self, data: dict[str, Any]) -> None: + self._data = data + + @staticmethod + def enum(enum_input: InputId | int) -> KeyInfo: + return KeyInfo({"kind": "enum", "enum": enum_input}) + + @staticmethod + def type(expression: navi.ExpressionJson) -> KeyInfo: + return KeyInfo({"kind": "type", "expression": expression}) + + def to_dict(self): + return self._data + + @dataclass(frozen=True) class NodeData: schema_id: str @@ -66,6 +83,8 @@ class NodeData: iterator_inputs: list[IteratorInputInfo] iterator_outputs: list[IteratorOutputInfo] + key_info: KeyInfo | None + side_effects: bool deprecated: bool node_context: bool diff --git a/backend/src/packages/chaiNNer_pytorch/pytorch/processing/upscale_image.py b/backend/src/packages/chaiNNer_pytorch/pytorch/processing/upscale_image.py index 40bccca28a..67ae666f02 100644 --- a/backend/src/packages/chaiNNer_pytorch/pytorch/processing/upscale_image.py +++ b/backend/src/packages/chaiNNer_pytorch/pytorch/processing/upscale_image.py @@ -6,7 +6,7 @@ from sanic.log import logger from spandrel import ImageModelDescriptor, ModelTiling -from api import NodeContext +from api import KeyInfo, NodeContext from nodes.groups import Condition, if_group from nodes.impl.pytorch.auto_split import pytorch_auto_split from nodes.impl.upscale.auto_split_tiles import ( @@ -215,6 +215,23 @@ def estimate(): assume_normalized=True, # pytorch_auto_split already does clipping internally ) ], + key_info=KeyInfo.type( + """ + let model = Input0; + let useCustomScale = Input4; + let customScale = Input5; + + let singleUpscale = convenientUpscale(model, img); + + let scale = if bool::and(useCustomScale, model.scale >= 2, model.inputChannels == model.outputChannels) { + customScale + } else { + model.scale + }; + + string::concat(toString(scale), "x") + """ + ), node_context=True, ) def upscale_image_node( diff --git a/backend/src/packages/chaiNNer_standard/image/io/save_image.py b/backend/src/packages/chaiNNer_standard/image/io/save_image.py index 1f941d386a..efcac0b8f3 100644 --- a/backend/src/packages/chaiNNer_standard/image/io/save_image.py +++ b/backend/src/packages/chaiNNer_standard/image/io/save_image.py @@ -10,6 +10,7 @@ from PIL import Image from sanic.log import logger +from api import KeyInfo from nodes.groups import Condition, if_enum_group, if_group from nodes.impl.dds.format import ( BC7_FORMATS, @@ -211,6 +212,7 @@ class TiffColorDepth(Enum): ), ], outputs=[], + key_info=KeyInfo.enum(4), side_effects=True, limited_to_8bpc="Image will be saved with 8 bits/channel by default. Some formats support higher bit depths.", ) diff --git a/backend/src/packages/chaiNNer_standard/image_dimension/border/pad.py b/backend/src/packages/chaiNNer_standard/image_dimension/border/pad.py index 2a6f742967..94d2b94c85 100644 --- a/backend/src/packages/chaiNNer_standard/image_dimension/border/pad.py +++ b/backend/src/packages/chaiNNer_standard/image_dimension/border/pad.py @@ -4,6 +4,7 @@ import numpy as np +from api import KeyInfo from nodes.groups import if_enum_group from nodes.impl.color.color import Color from nodes.impl.image_utils import BorderType, create_border @@ -95,6 +96,7 @@ class BorderMode(Enum): assume_normalized=True, ) ], + key_info=KeyInfo.enum(3), ) def pad_node( img: np.ndarray, diff --git a/backend/src/packages/chaiNNer_standard/image_dimension/crop/crop.py b/backend/src/packages/chaiNNer_standard/image_dimension/crop/crop.py index 3f0ea6d1dd..53d4430e72 100644 --- a/backend/src/packages/chaiNNer_standard/image_dimension/crop/crop.py +++ b/backend/src/packages/chaiNNer_standard/image_dimension/crop/crop.py @@ -4,6 +4,7 @@ import numpy as np +from api import KeyInfo from nodes.groups import if_enum_group from nodes.properties.inputs import EnumInput, ImageInput, NumberInput from nodes.properties.outputs import ImageOutput @@ -77,6 +78,7 @@ class CropMode(Enum): "The cropped area would result in an image with no width or no height." ) ], + key_info=KeyInfo.enum(1), ) def crop_node( img: np.ndarray, diff --git a/backend/src/packages/chaiNNer_standard/image_dimension/resize/resize.py b/backend/src/packages/chaiNNer_standard/image_dimension/resize/resize.py index 40f54a8923..a2b7428d58 100644 --- a/backend/src/packages/chaiNNer_standard/image_dimension/resize/resize.py +++ b/backend/src/packages/chaiNNer_standard/image_dimension/resize/resize.py @@ -4,6 +4,7 @@ import numpy as np +from api import KeyInfo from nodes.groups import Condition, if_enum_group, if_group from nodes.impl.resize import ResizeFilter, resize from nodes.properties.inputs import ( @@ -89,6 +90,20 @@ class ImageResizeMode(Enum): assume_normalized=True, ) ], + key_info=KeyInfo.type( + """ + let mode = Input1; + + let scale = Input2; + let width = Input3; + let height = Input4; + + match mode { + ImageResizeMode::Percentage => string::concat(toString(scale), "%"), + ImageResizeMode::Absolute => string::concat(toString(width), "x", toString(height)), + } + """ + ), ) def resize_node( img: np.ndarray, diff --git a/backend/src/packages/chaiNNer_standard/utility/color/color_from.py b/backend/src/packages/chaiNNer_standard/utility/color/color_from.py index 2e5d57b880..72a739e056 100644 --- a/backend/src/packages/chaiNNer_standard/utility/color/color_from.py +++ b/backend/src/packages/chaiNNer_standard/utility/color/color_from.py @@ -2,6 +2,7 @@ from enum import Enum +from api import KeyInfo from nodes.groups import if_enum_group from nodes.impl.color.color import Color from nodes.properties.inputs import EnumInput, SliderInput @@ -22,7 +23,9 @@ class ColorType(Enum): description="Create a new color value from individual channels.", icon="MdColorLens", inputs=[ - EnumInput(ColorType, "Color Type", ColorType.RGBA, preferred_style="tabs"), + EnumInput( + ColorType, "Color Type", ColorType.RGBA, preferred_style="tabs" + ).with_id(0), if_enum_group(0, ColorType.GRAY)( SliderInput( "Luma", @@ -96,6 +99,7 @@ class ColorType(Enum): """ ) ], + key_info=KeyInfo.enum(0), ) def color_from_node( color_type: ColorType, diff --git a/backend/src/packages/chaiNNer_standard/utility/value/number.py b/backend/src/packages/chaiNNer_standard/utility/value/number.py index c90663a2ae..25e913dcc7 100644 --- a/backend/src/packages/chaiNNer_standard/utility/value/number.py +++ b/backend/src/packages/chaiNNer_standard/utility/value/number.py @@ -1,5 +1,6 @@ from __future__ import annotations +from api import KeyInfo from nodes.properties.inputs import NumberInput from nodes.properties.outputs import NumberOutput @@ -24,6 +25,7 @@ outputs=[ NumberOutput("Number", output_type="Input0").suggest(), ], + key_info=KeyInfo.type("""toString(Input0)"""), ) def number_node(number: float) -> float: return number diff --git a/backend/src/packages/chaiNNer_standard/utility/value/percent.py b/backend/src/packages/chaiNNer_standard/utility/value/percent.py index 2210851f17..3d59d81bf6 100644 --- a/backend/src/packages/chaiNNer_standard/utility/value/percent.py +++ b/backend/src/packages/chaiNNer_standard/utility/value/percent.py @@ -1,5 +1,6 @@ from __future__ import annotations +from api import KeyInfo from nodes.properties.inputs import SliderInput from nodes.properties.outputs import NumberOutput @@ -25,6 +26,7 @@ outputs=[ NumberOutput("Percent", output_type="Input0"), ], + key_info=KeyInfo.type("""string::concat(toString(Input0), "%")"""), ) def percent_node(number: int) -> int: return number diff --git a/backend/src/packages/chaiNNer_standard/utility/value/switch.py b/backend/src/packages/chaiNNer_standard/utility/value/switch.py index 293cf39eba..c22079e897 100644 --- a/backend/src/packages/chaiNNer_standard/utility/value/switch.py +++ b/backend/src/packages/chaiNNer_standard/utility/value/switch.py @@ -2,6 +2,7 @@ from enum import Enum +from api import KeyInfo from nodes.groups import optional_list_group from nodes.properties.inputs import AnyInput, EnumInput from nodes.properties.outputs import BaseOutput @@ -29,7 +30,7 @@ class ValueIndex(Enum): description="Allows you to pass in multiple inputs and then change which one passes through to the output.", icon="BsShuffle", inputs=[ - EnumInput(ValueIndex), + EnumInput(ValueIndex).with_id(0), AnyInput(label="Value A"), AnyInput(label="Value B"), optional_list_group( @@ -64,6 +65,7 @@ class ValueIndex(Enum): label="Value", ).with_never_reason("The selected value should have a connection.") ], + key_info=KeyInfo.enum(0), see_also=["chainner:utility:pass_through"], ) def switch_node(selection: ValueIndex, *args: object | None) -> object: diff --git a/backend/src/server.py b/backend/src/server.py index ac23232d7a..07c806c437 100644 --- a/backend/src/server.py +++ b/backend/src/server.py @@ -120,6 +120,7 @@ async def nodes(_request: Request): ], "iteratorInputs": [x.to_dict() for x in node.iterator_inputs], "iteratorOutputs": [x.to_dict() for x in node.iterator_outputs], + "keyInfo": node.key_info.to_dict() if node.key_info else None, "description": node.description, "seeAlso": node.see_also, "icon": node.icon, diff --git a/src/common/common-types.ts b/src/common/common-types.ts index 11761e3c86..b647db99c2 100644 --- a/src/common/common-types.ts +++ b/src/common/common-types.ts @@ -286,6 +286,16 @@ export interface IteratorOutputInfo { readonly lengthType: ExpressionJson; } +export type KeyInfo = EnumKeyInfo | TypeKeyInfo; +export interface EnumKeyInfo { + readonly kind: 'enum'; + readonly enum: InputId; +} +export interface TypeKeyInfo { + readonly kind: 'type'; + readonly expression: ExpressionJson; +} + export interface NodeSchema { readonly name: string; readonly category: CategoryId; @@ -299,6 +309,7 @@ export interface NodeSchema { readonly groupLayout: readonly (InputId | Group)[]; readonly iteratorInputs: readonly IteratorInputInfo[]; readonly iteratorOutputs: readonly IteratorOutputInfo[]; + readonly keyInfo?: KeyInfo | null; readonly schemaId: SchemaId; readonly hasSideEffects: boolean; readonly deprecated: boolean; diff --git a/src/common/nodes/keyInfo.ts b/src/common/nodes/keyInfo.ts new file mode 100644 index 0000000000..be018adc43 --- /dev/null +++ b/src/common/nodes/keyInfo.ts @@ -0,0 +1,92 @@ +import { + ParameterDefinition, + Scope, + ScopeBuilder, + StringType, + evaluate, + isStringLiteral, + isSubsetOf, +} from '@chainner/navi'; +import { InputData, KeyInfo, NodeSchema, OfKind } from '../common-types'; +import { + FunctionDefinition, + FunctionInstance, + getInputParamName, + getOutputParamName, +} from '../types/function'; +import { fromJson } from '../types/json'; +import { lazyKeyed } from '../util'; + +const getKeyInfoScopeTemplate = lazyKeyed((definition: FunctionDefinition): Scope => { + const builder = new ScopeBuilder('key info', definition.scope); + + // assign inputs and outputs + definition.inputDefaults.forEach((input, inputId) => { + builder.add(new ParameterDefinition(getInputParamName(inputId), input)); + }); + definition.outputDefaults.forEach((output, outputId) => { + builder.add(new ParameterDefinition(getOutputParamName(outputId), output)); + }); + + return builder.createScope(); +}); +const getKeyInfoScope = (instance: FunctionInstance): Scope => { + const scope = getKeyInfoScopeTemplate(instance.definition); + + // assign inputs and outputs + instance.inputs.forEach((input, inputId) => { + scope.assignParameter(getInputParamName(inputId), input); + }); + instance.outputs.forEach((output, outputId) => { + scope.assignParameter(getOutputParamName(outputId), output); + }); + + return scope; +}; + +const accessors: { + [kind in KeyInfo['kind']]: ( + info: OfKind, + node: NodeSchema, + inputData: InputData, + types: FunctionInstance | undefined + ) => string | undefined; +} = { + enum: (info, node, inputData) => { + const input = node.inputs.find((i) => i.id === info.enum); + if (!input) throw new Error(`Input ${info.enum} not found`); + if (input.kind !== 'dropdown') throw new Error(`Input ${info.enum} is not a dropdown`); + + const value = inputData[input.id]; + const option = input.options.find((o) => o.value === value); + return option?.option; + }, + type: (info, node, inputData, types) => { + if (!types) return undefined; + + const expression = fromJson(info.expression); + const scope = getKeyInfoScope(types); + const result = evaluate(expression, scope); + + if (isStringLiteral(result)) return result.value; + + // check that the expression actually evaluates to a string + if (!isSubsetOf(result, StringType.instance)) { + throw new Error( + `Key info expression must evaluate to a string, but got ${result.toString()}` + ); + } + + return undefined; + }, +}; + +export const getKeyInfo = ( + node: NodeSchema, + inputData: InputData, + types: FunctionInstance | undefined +): string | undefined => { + const { keyInfo } = node; + if (!keyInfo) return undefined; + return accessors[keyInfo.kind](keyInfo as never, node, inputData, types); +}; diff --git a/src/common/types/function.ts b/src/common/types/function.ts index 3f86f87bee..2357a7fe94 100644 --- a/src/common/types/function.ts +++ b/src/common/types/function.ts @@ -49,8 +49,8 @@ const getParamRefs =

( return refs; }; -const getInputParamName = (inputId: InputId) => `Input${inputId}` as const; -const getOutputParamName = (outputId: OutputId) => `Output${outputId}` as const; +export const getInputParamName = (inputId: InputId) => `Input${inputId}` as const; +export const getOutputParamName = (outputId: OutputId) => `Output${outputId}` as const; interface InputInfo { expression: Expression; diff --git a/src/renderer/components/node/NodeHeader.tsx b/src/renderer/components/node/NodeHeader.tsx index 2c9f62c1ec..e55f678e58 100644 --- a/src/renderer/components/node/NodeHeader.tsx +++ b/src/renderer/components/node/NodeHeader.tsx @@ -1,8 +1,11 @@ import { ChevronDownIcon, ChevronRightIcon } from '@chakra-ui/icons'; import { Box, Center, HStack, Heading, IconButton, Spacer, Text, VStack } from '@chakra-ui/react'; -import { memo } from 'react'; +import { memo, useEffect, useMemo } from 'react'; import ReactTimeAgo from 'react-time-ago'; +import { useContext } from 'use-context-selector'; +import { getKeyInfo } from '../../../common/nodes/keyInfo'; import { Validity } from '../../../common/Validity'; +import { AlertBoxContext, AlertType } from '../../contexts/AlertBoxContext'; import { NodeProgress } from '../../contexts/ExecutionContext'; import { interpolateColor } from '../../helpers/colorTools'; import { NodeState } from '../../helpers/nodeState'; @@ -72,6 +75,52 @@ const IteratorProcess = memo(({ nodeProgress, progressColor }: IteratorProcessPr ); }); +interface KeyInfoLabelProps { + nodeState: NodeState; +} + +const KeyInfoLabel = memo(({ nodeState }: KeyInfoLabelProps) => { + const { sendAlert } = useContext(AlertBoxContext); + + const { schema, inputData, type } = nodeState; + const [info, error] = useMemo((): [string | undefined, unknown] => { + try { + return [getKeyInfo(schema, inputData, type.instance), undefined]; + } catch (e) { + return [undefined, e]; + } + }, [schema, inputData, type.instance]); + + useEffect(() => { + if (error) { + sendAlert({ + type: AlertType.ERROR, + title: 'Implementation Error', + message: `Unable to determine key info for node ${schema.name} (${ + schema.schemaId + }) due to an error in the implementation of the key info:\n\n${String(error)}`, + }); + } + }, [schema, error, sendAlert]); + + // eslint-disable-next-line react/jsx-no-useless-fragment + if (!info) return <>; + + return ( + + {info} + + ); +}); + interface NodeHeaderProps { nodeState: NodeState; accentColor: string; @@ -150,22 +199,27 @@ export const NodeHeader = memo( />

- - {nodeState.schema.name} - + + + {nodeState.schema.name} + + {isCollapsed && nodeState.schema.keyInfo && ( + + )} +