diff --git a/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/OnnxInferenceModel.kt b/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/OnnxInferenceModel.kt index aea5d3d40..36ff99f23 100644 --- a/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/OnnxInferenceModel.kt +++ b/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/OnnxInferenceModel.kt @@ -280,6 +280,8 @@ public open class OnnxInferenceModel private constructor(private val modelSource dataType: OnnxJavaType, shape: LongArray ): OnnxTensor { + checkTensorMatchesInputShape(data, shape) + val inputTensor = when (dataType) { OnnxJavaType.FLOAT -> OnnxTensor.createTensor(this, FloatBuffer.wrap(data), shape) OnnxJavaType.DOUBLE -> OnnxTensor.createTensor( @@ -320,6 +322,29 @@ public open class OnnxInferenceModel private constructor(private val modelSource return inputTensor } + private fun checkTensorMatchesInputShape(data: FloatArray, inputShape: LongArray) { + val numOfElements = inputShape.reduce { acc, dim -> acc * dim }.toInt() + + if (data.size == numOfElements) return + + if (inputShape.size == 4 && + inputShape[0] == 1L && + (inputShape[1] == 3L || inputShape[3] == 3L) && + (data.size * 3 == numOfElements) + ) { + throw IllegalArgumentException( + "The number of elements (N=${data.size}) in the input tensor does not match the model input shape - " + .plus("${inputShape.contentToString()}.") + .plus(" It looks like you are trying to use a 1-channel (grayscale) image as an input, but the model expects a 3-channel image.") + ) + } + + throw IllegalArgumentException( + "The number of elements (N=${data.size}) in the input tensor does not match the model input shape - " + .plus("${inputShape.contentToString()}.") + ) + } + /** * Loads model from serialized ONNX file. */