Skip to content

Commit

Permalink
Add validation of input shape (#385) (#467)
Browse files Browse the repository at this point in the history
* Check if the number of elements in the input matches the model's input shape

* Trying to guess if the user used a grayscale image when 3-channels expected
  • Loading branch information
ermolenkodev authored Oct 13, 2022
1 parent 8483b8a commit 8ba7f91
Showing 1 changed file with 25 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,8 @@ public open class OnnxInferenceModel private constructor(private val modelSource
dataType: OnnxJavaType,
shape: LongArray
): OnnxTensor {
checkTensorMatchesInputShape(data, shape)

val inputTensor = when (dataType) {
OnnxJavaType.FLOAT -> OnnxTensor.createTensor(this, FloatBuffer.wrap(data), shape)
OnnxJavaType.DOUBLE -> OnnxTensor.createTensor(
Expand Down Expand Up @@ -320,6 +322,29 @@ public open class OnnxInferenceModel private constructor(private val modelSource
return inputTensor
}

private fun checkTensorMatchesInputShape(data: FloatArray, inputShape: LongArray) {
val numOfElements = inputShape.reduce { acc, dim -> acc * dim }.toInt()

if (data.size == numOfElements) return

if (inputShape.size == 4 &&
inputShape[0] == 1L &&
(inputShape[1] == 3L || inputShape[3] == 3L) &&
(data.size * 3 == numOfElements)
) {
throw IllegalArgumentException(
"The number of elements (N=${data.size}) in the input tensor does not match the model input shape - "
.plus("${inputShape.contentToString()}.")
.plus(" It looks like you are trying to use a 1-channel (grayscale) image as an input, but the model expects a 3-channel image.")
)
}

throw IllegalArgumentException(
"The number of elements (N=${data.size}) in the input tensor does not match the model input shape - "
.plus("${inputShape.contentToString()}.")
)
}

/**
* Loads model from serialized ONNX file.
*/
Expand Down

0 comments on commit 8ba7f91

Please sign in to comment.