diff --git a/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/OnnxInferenceModel.kt b/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/OnnxInferenceModel.kt
index aea5d3d40..36ff99f23 100644
--- a/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/OnnxInferenceModel.kt
+++ b/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/OnnxInferenceModel.kt
@@ -280,6 +280,8 @@ public open class OnnxInferenceModel private constructor(private val modelSource
                                                 dataType: OnnxJavaType,
                                                 shape: LongArray
         ): OnnxTensor {
+            checkTensorMatchesInputShape(data, shape)
+
             val inputTensor = when (dataType) {
                 OnnxJavaType.FLOAT -> OnnxTensor.createTensor(this, FloatBuffer.wrap(data), shape)
                 OnnxJavaType.DOUBLE -> OnnxTensor.createTensor(
@@ -320,6 +322,29 @@ public open class OnnxInferenceModel private constructor(private val modelSource
             return inputTensor
         }
 
+        private fun checkTensorMatchesInputShape(data: FloatArray, inputShape: LongArray) {
+            val numOfElements = inputShape.reduce { acc, dim -> acc * dim }.toInt()
+
+            if (data.size == numOfElements) return
+
+            if (inputShape.size == 4 &&
+                inputShape[0] == 1L &&
+                (inputShape[1] == 3L || inputShape[3] == 3L) &&
+                (data.size * 3 == numOfElements)
+            ) {
+                throw IllegalArgumentException(
+                    "The number of elements (N=${data.size}) in the input tensor does not match the model input shape - "
+                        .plus("${inputShape.contentToString()}.")
+                        .plus(" It looks like you are trying to use a 1-channel (grayscale) image as an input, but the model expects a 3-channel image.")
+                )
+            }
+
+            throw IllegalArgumentException(
+                "The number of elements (N=${data.size}) in the input tensor does not match the model input shape - "
+                    .plus("${inputShape.contentToString()}.")
+            )
+        }
+
         /**
          * Loads model from serialized ONNX file.
          */