diff --git a/backend/src/navi.py b/backend/src/navi.py index ccf5d6c69..6cd625c21 100644 --- a/backend/src/navi.py +++ b/backend/src/navi.py @@ -150,6 +150,10 @@ def intersect(*items: ExpressionJson) -> ExpressionJson: return {"type": "intersection", "items": list(items)} +def intersect_with_error(*items: ExpressionJson) -> ExpressionJson: + return union(intersect(*items), *[intersect("Error", item) for item in items]) + + def named(name: str, fields: dict[str, ExpressionJson] | None = None) -> ExpressionJson: return {"type": "named", "name": name, "fields": fields} diff --git a/backend/src/nodes/properties/outputs/file_outputs.py b/backend/src/nodes/properties/outputs/file_outputs.py index 8ac9e6dde..f91601548 100644 --- a/backend/src/nodes/properties/outputs/file_outputs.py +++ b/backend/src/nodes/properties/outputs/file_outputs.py @@ -18,7 +18,7 @@ def __init__( if of_input is None else f"splitFilePath(Input{of_input}.path).dir" ) - directory_type = navi.intersect(directory_type, output_type) + directory_type = navi.intersect_with_error(directory_type, output_type) super().__init__(directory_type, label, associated_type=str) def get_broadcast_type(self, value: str): diff --git a/backend/src/nodes/properties/outputs/generic_outputs.py b/backend/src/nodes/properties/outputs/generic_outputs.py index 794c86441..f4e53feb3 100644 --- a/backend/src/nodes/properties/outputs/generic_outputs.py +++ b/backend/src/nodes/properties/outputs/generic_outputs.py @@ -17,7 +17,7 @@ def __init__( output_type: navi.ExpressionJson = "number", ): super().__init__( - navi.intersect("number", output_type), + navi.intersect_with_error("number", output_type), label, associated_type=Union[int, float], ) @@ -36,7 +36,7 @@ def __init__( label: str, output_type: navi.ExpressionJson = "string", ): - super().__init__(navi.intersect("string", output_type), label) + super().__init__(navi.intersect_with_error("string", output_type), label) def get_broadcast_type(self, value: str): return navi.literal(value) @@ -73,7 +73,9 @@ def __init__( channels: int | None = None, ): super().__init__( - output_type=navi.intersect(color_type, navi.Color(channels=channels)), + output_type=navi.intersect_with_error( + color_type, navi.Color(channels=channels) + ), label=label, kind="generic", ) diff --git a/backend/src/nodes/properties/outputs/numpy_outputs.py b/backend/src/nodes/properties/outputs/numpy_outputs.py index cb840226d..51124cecb 100644 --- a/backend/src/nodes/properties/outputs/numpy_outputs.py +++ b/backend/src/nodes/properties/outputs/numpy_outputs.py @@ -53,7 +53,7 @@ def __init__( assume_normalized: bool = False, ): super().__init__( - navi.intersect(image_type, navi.Image(channels=channels)), + navi.intersect_with_error(image_type, navi.Image(channels=channels)), label, kind=kind, has_handle=has_handle, diff --git a/backend/src/packages/chaiNNer_pytorch/pytorch/processing/guided_upscale.py b/backend/src/packages/chaiNNer_pytorch/pytorch/processing/guided_upscale.py index d9f16efb5..d4e8c47af 100644 --- a/backend/src/packages/chaiNNer_pytorch/pytorch/processing/guided_upscale.py +++ b/backend/src/packages/chaiNNer_pytorch/pytorch/processing/guided_upscale.py @@ -48,22 +48,24 @@ let source = Input0; let guide = Input1; - let valid = bool::and( - // guide image must be larger than source image - guide.width > source.width, + let kScale = bool::and( // guide image's size must be `k * source.size` for `k>1` guide.width / source.width == int, guide.width / source.width == guide.height / source.height ); - Image { - width: guide.width, - height: guide.height, - channels: source.channels, - } & if valid { any } else { never } + if guide.width <= source.width { + error("The guide image must be larger than the source image.") + } else if bool::not(kScale) { + error("The size of the guide image must be an integer multiple of the size of the source image (e.g. 2x, 3x, 4x, ...).") + } else { + Image { + width: guide.width, + height: guide.height, + channels: source.channels, + } + } """ - ).with_never_reason( - "The guide image must be larger than the source image, and the size of the guide image must be an integer multiple of the size of the source image (e.g. 2x, 3x, 4x, ...)." ), ], ) diff --git a/backend/src/packages/chaiNNer_standard/image_adjustment/adjustments/stretch_contrast.py b/backend/src/packages/chaiNNer_standard/image_adjustment/adjustments/stretch_contrast.py index d175b4a46..14e5ae3b8 100644 --- a/backend/src/packages/chaiNNer_standard/image_adjustment/adjustments/stretch_contrast.py +++ b/backend/src/packages/chaiNNer_standard/image_adjustment/adjustments/stretch_contrast.py @@ -52,14 +52,18 @@ class StretchMode(Enum): outputs=[ ImageOutput( image_type=""" - let valid: bool = match Input1 { + let minMaxRangeValid: bool = match Input1 { StretchMode::Manual => Input4 < Input5, _ => true, }; - if valid { Input0 } else { never } + if minMaxRangeValid { + Input0 + } else { + error("Minimum must be less than Maximum.") + } """, - ).with_never_reason("Minimum must be less than the Maximum."), + ), ], ) def stretch_contrast_node( diff --git a/backend/src/packages/chaiNNer_standard/image_channel/all/combine_rgba.py b/backend/src/packages/chaiNNer_standard/image_channel/all/combine_rgba.py index 8ad36b541..6d8995a31 100644 --- a/backend/src/packages/chaiNNer_standard/image_channel/all/combine_rgba.py +++ b/backend/src/packages/chaiNNer_standard/image_channel/all/combine_rgba.py @@ -34,23 +34,24 @@ outputs=[ ImageOutput( image_type=""" - let anyImages = bool::or(Input0 == Image, Input1 == Image, Input2 == Image, Input3 == Image); + def isImage(i: any) = match i { Image => true, _ => false }; + let anyImages = bool::or(isImage(Input0), isImage(Input1), isImage(Input2), isImage(Input3)); - def getWidth(i: any) = match i { Image => i.width, _ => Image.width }; - def getHeight(i: any) = match i { Image => i.height, _ => Image.height }; + if bool::not(anyImages) { + error("At least one channel must be an image.") + } else { + def getWidth(i: any) = match i { Image => i.width, _ => Image.width }; + def getHeight(i: any) = match i { Image => i.height, _ => Image.height }; - let valid = if anyImages { any } else { never }; - - valid & Image { - width: getWidth(Input0) & getWidth(Input1) & getWidth(Input2) & getWidth(Input3), - height: getHeight(Input0) & getHeight(Input1) & getHeight(Input2) & getHeight(Input3), + Image { + width: getWidth(Input0) & getWidth(Input1) & getWidth(Input2) & getWidth(Input3), + height: getHeight(Input0) & getHeight(Input1) & getHeight(Input2) & getHeight(Input3), + } } """, channels=4, assume_normalized=True, - ).with_never_reason( - "All input channels must have the same size, and at least one input channel must be an image." - ) + ).with_never_reason("All input channels must have the same size.") ], ) def combine_rgba_node( diff --git a/backend/src/packages/chaiNNer_standard/image_channel/misc/alpha_matting.py b/backend/src/packages/chaiNNer_standard/image_channel/misc/alpha_matting.py index 00d590f91..67f935dcd 100644 --- a/backend/src/packages/chaiNNer_standard/image_channel/misc/alpha_matting.py +++ b/backend/src/packages/chaiNNer_standard/image_channel/misc/alpha_matting.py @@ -47,22 +47,15 @@ let fg = Input2; let bg = Input3; - let valid = bool::and( - fg > bg, - image.width == trimap.width, - image.height == trimap.height, - ); - - if valid { - Image { width: image.width, height: image.height } + if fg <= bg { + error("The foreground threshold must be greater than the background threshold.") + } else if bool::or(image.width != trimap.width, image.height != trimap.height) { + error("The image and trimap must have the same size.") } else { - never + Image { width: image.width, height: image.height } } """, channels=4, - ).with_never_reason( - "The image and trimap must have the same size," - " and the foreground threshold must be greater than the background threshold." ), ], ) diff --git a/backend/src/packages/chaiNNer_standard/image_channel/transparency/merge_transparency.py b/backend/src/packages/chaiNNer_standard/image_channel/transparency/merge_transparency.py index 2727b80f5..1bae0b71a 100644 --- a/backend/src/packages/chaiNNer_standard/image_channel/transparency/merge_transparency.py +++ b/backend/src/packages/chaiNNer_standard/image_channel/transparency/merge_transparency.py @@ -22,23 +22,24 @@ outputs=[ ImageOutput( image_type=""" - let anyImages = bool::or(Input0 == Image, Input1 == Image); + def isImage(i: any) = match i { Image => true, _ => false }; + let anyImages = bool::or(isImage(Input0), isImage(Input1)); - def getWidth(i: any) = match i { Image => i.width, _ => Image.width }; - def getHeight(i: any) = match i { Image => i.height, _ => Image.height }; + if bool::not(anyImages) { + error("At least one input must be an image.") + } else { + def getWidth(i: any) = match i { Image => i.width, _ => Image.width }; + def getHeight(i: any) = match i { Image => i.height, _ => Image.height }; - let valid = if anyImages { any } else { never }; - - valid & Image { - width: getWidth(Input0) & getWidth(Input1), - height: getHeight(Input0) & getHeight(Input1), + Image { + width: getWidth(Input0) & getWidth(Input1), + height: getHeight(Input0) & getHeight(Input1), + } } """, channels=4, assume_normalized=True, - ).with_never_reason( - "RGB and Alpha must have the same size, and at least one must be an image." - ) + ).with_never_reason("RGB and Alpha must have the same size.") ], ) def merge_transparency_node( diff --git a/backend/src/packages/chaiNNer_standard/image_filter/quantize/quantize_to_reference.py b/backend/src/packages/chaiNNer_standard/image_filter/quantize/quantize_to_reference.py index 38194ee3b..070e7b72b 100644 --- a/backend/src/packages/chaiNNer_standard/image_filter/quantize/quantize_to_reference.py +++ b/backend/src/packages/chaiNNer_standard/image_filter/quantize/quantize_to_reference.py @@ -74,22 +74,21 @@ def quantize_image(image: np.ndarray, palette: np.ndarray): outputs=[ ImageOutput( image_type=""" - let valid = bool::and( - Input0.width >= Input1.width, - number::mod(Input0.width, Input1.width) == 0, - Input0.height >= Input1.height, - number::mod(Input0.height, Input1.height) == 0, - Input0.channels == Input1.channels, - ); - - Image { - width: max(Input0.width, Input1.width), - height: max(Input0.height, Input1.height), - channels: Input0.channels, - } & if valid { any } else { never }""", + if Input0.channels != Input1.channels { + error("The target image and reference image must have the same number of channels.") + } else if bool::or(Input0.width < Input1.width, Input0.height < Input1.height) { + error("The target image must be larger than the reference image.") + } else if bool::or(number::mod(Input0.width, Input1.width) != 0, number::mod(Input0.height, Input1.height) != 0) { + error("The size of the target image must be an integer multiple of the size of the reference image (e.g. 2x, 3x, 4x, 8x).") + } else { + Image { + width: max(Input0.width, Input1.width), + height: max(Input0.height, Input1.height), + channels: Input0.channels, + } + } + """, assume_normalized=True, - ).with_never_reason( - "Target image must be larger than reference image in both dimensions, must have dimensions that are a multiple of each other, and must have the same number of channels." ) ], ) diff --git a/package-lock.json b/package-lock.json index 59c3c1d0b..83e9ffeff 100644 --- a/package-lock.json +++ b/package-lock.json @@ -11,7 +11,7 @@ "license": "GPLv3", "dependencies": { "@babel/plugin-transform-react-jsx": "^7.17.12", - "@chainner/navi": "^0.6.2", + "@chainner/navi": "^0.7.1", "@chakra-ui/icons": "^2.1.1", "@chakra-ui/react": "^2.8.2", "@emotion/react": "^11.9.0", @@ -964,9 +964,9 @@ "dev": true }, "node_modules/@chainner/navi": { - "version": "0.6.2", - "resolved": "https://registry.npmjs.org/@chainner/navi/-/navi-0.6.2.tgz", - "integrity": "sha512-udqqqOjRrBs7AZ6+aW8oOVl2SPqmH1jhF8tHubyeV8QUUeKS0meMqbmbEtlXmwtc4QDN85/UQjzmaIjNR+NYig==", + "version": "0.7.1", + "resolved": "https://registry.npmjs.org/@chainner/navi/-/navi-0.7.1.tgz", + "integrity": "sha512-YQZngMQ28U6I5NONuKufUQPtWO6pPgztQ5Y6oDP+75oaH2aPDCI5A6AKVSH8fAgfCl/IOeff81RlhvfYvryvUQ==", "engines": { "node": ">=16.0.0" } @@ -25821,9 +25821,9 @@ "dev": true }, "@chainner/navi": { - "version": "0.6.2", - "resolved": "https://registry.npmjs.org/@chainner/navi/-/navi-0.6.2.tgz", - "integrity": "sha512-udqqqOjRrBs7AZ6+aW8oOVl2SPqmH1jhF8tHubyeV8QUUeKS0meMqbmbEtlXmwtc4QDN85/UQjzmaIjNR+NYig==" + "version": "0.7.1", + "resolved": "https://registry.npmjs.org/@chainner/navi/-/navi-0.7.1.tgz", + "integrity": "sha512-YQZngMQ28U6I5NONuKufUQPtWO6pPgztQ5Y6oDP+75oaH2aPDCI5A6AKVSH8fAgfCl/IOeff81RlhvfYvryvUQ==" }, "@chakra-ui/accordion": { "version": "2.3.1", diff --git a/package.json b/package.json index e715e32f8..2b17c7694 100644 --- a/package.json +++ b/package.json @@ -108,7 +108,7 @@ }, "dependencies": { "@babel/plugin-transform-react-jsx": "^7.17.12", - "@chainner/navi": "^0.6.2", + "@chainner/navi": "^0.7.1", "@chakra-ui/icons": "^2.1.1", "@chakra-ui/react": "^2.8.2", "@emotion/react": "^11.9.0", diff --git a/src/common/nodes/checkNodeValidity.ts b/src/common/nodes/checkNodeValidity.ts index 2b0359e9f..3d73efd19 100644 --- a/src/common/nodes/checkNodeValidity.ts +++ b/src/common/nodes/checkNodeValidity.ts @@ -84,12 +84,8 @@ export const checkNodeValidity = ({ } // eslint-disable-next-line no-unreachable-loop - for (const { outputId } of functionInstance.outputErrors) { - const output = schema.outputs.find((o) => o.id === outputId)!; - - return invalid( - `Some inputs are incompatible with each other. ${output.neverReason ?? ''}` - ); + for (const { message } of functionInstance.outputErrors) { + return invalid(`Some inputs are incompatible with each other. ${message ?? ''}`); } } diff --git a/src/common/types/chainner-scope.ts b/src/common/types/chainner-scope.ts index 69cb1e47b..7c9aabe51 100644 --- a/src/common/types/chainner-scope.ts +++ b/src/common/types/chainner-scope.ts @@ -23,6 +23,11 @@ import { const code = ` struct null; +struct Error { message: string } +def error(message: invStrSet("")): Error { + Error { message: message } +} + struct Seed; struct Directory { path: string } diff --git a/src/common/types/function.ts b/src/common/types/function.ts index dbfc7e320..53ff1c4b7 100644 --- a/src/common/types/function.ts +++ b/src/common/types/function.ts @@ -1,5 +1,6 @@ import { Expression, + NamedExpression, NeverType, NonNeverType, NumberType, @@ -17,7 +18,8 @@ import { without, } from '@chainner/navi'; import { Input, InputId, InputSchemaValue, NodeSchema, Output, OutputId } from '../common-types'; -import { EMPTY_MAP, lazyKeyed, topologicalSort } from '../util'; +import { EMPTY_MAP, lazy, lazyKeyed, topologicalSort } from '../util'; +import { getChainnerScope } from './chainner-scope'; import { fromJson } from './json'; const getConversionScope = lazyKeyed((parentScope: Scope) => { @@ -118,6 +120,48 @@ const evaluateInputs = ( return { ordered, defaults }; }; +const getErrorType = lazy(() => { + const scope = getChainnerScope(); + const errorType = evaluate(new NamedExpression('Error'), scope); + if (errorType.underlying !== 'struct' || errorType.type !== 'instance') { + throw new Error('Error type is not a struct'); + } + return errorType; +}); + +const splitOutputTypeAndError = ( + definition: FunctionDefinition, + type: Type +): [Type, string | undefined] => { + const errorType = getErrorType(); + const error = intersect(type, errorType); + if (error.type === 'never') { + // no error + return [type, undefined]; + } + + const pureType = without(type, errorType); + + // get the error message + if (error.underlying !== 'struct' || error.type !== 'instance') { + throw new Error('Error type is not a struct'); + } + + const messageType = error.getField('message')!; + const messageItems = messageType.underlying === 'union' ? messageType.items : [messageType]; + const messages: string[] = []; + for (const item of messageItems) { + if (item.underlying === 'string' && item.type === 'literal') { + messages.push(item.value); + } + } + if (messages.length === 0) { + messages.push('Unknown error'); + } + + return [pureType, messages.join(' ')]; +}; + interface OutputInfo { expression: Expression; inputRefs: Set; @@ -459,6 +503,7 @@ export interface FunctionInputAssignmentError { } export interface FunctionOutputError { outputId: OutputId; + message: string | undefined; } export class FunctionInstance { readonly definition: FunctionDefinition; @@ -563,8 +608,18 @@ export class FunctionInstance { let type: Type; if (definition.outputGenerics.has(id)) { type = evaluate(definition.outputExpressions.get(id)!, scope); + if (type.type === 'never') { - outputErrors.push({ outputId: id }); + const message = + definition.schema.outputs.find((o) => o.id === id)?.neverReason ?? + undefined; + outputErrors.push({ outputId: id, message }); + } else { + let message; + [type, message] = splitOutputTypeAndError(definition, type); + if (type.type === 'never') { + outputErrors.push({ outputId: id, message }); + } } } else { type = definition.outputDefaults.get(id)!; diff --git a/src/common/types/json.ts b/src/common/types/json.ts index 21caa07bd..8de34cf87 100644 --- a/src/common/types/json.ts +++ b/src/common/types/json.ts @@ -136,6 +136,7 @@ export const fromJson = (e: ExpressionJson): Expression => { } return new StructExpression( e.name, + undefined, Object.entries(e.fields ?? {}).map( ([name, type]) => new StructExpressionField(name, fromJson(type)) )