diff --git a/dataset/build.gradle b/dataset/build.gradle index 5590aa9ec..c5bb614f4 100644 --- a/dataset/build.gradle +++ b/dataset/build.gradle @@ -30,7 +30,11 @@ kotlin { implementation project(":api") } } - androidMain {} + androidMain { + dependencies { + api 'androidx.camera:camera-core:1.0.0-rc03' + } + } } explicitApiWarning() } diff --git a/dataset/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/dataset/preprocessing/Preprocessing.kt b/dataset/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/dataset/preprocessing/Preprocessing.kt index 786d297cd..75fbe79e8 100644 --- a/dataset/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/dataset/preprocessing/Preprocessing.kt +++ b/dataset/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/dataset/preprocessing/Preprocessing.kt @@ -6,9 +6,11 @@ package org.jetbrains.kotlinx.dl.dataset.preprocessing import android.graphics.Bitmap +import androidx.camera.core.ImageProxy import org.jetbrains.kotlinx.dl.dataset.shape.TensorShape import org.jetbrains.kotlinx.dl.dataset.preprocessing.bitmap.Resize import org.jetbrains.kotlinx.dl.dataset.preprocessing.bitmap.Rotate +import org.jetbrains.kotlinx.dl.dataset.preprocessing.imageproxy.ConvertToBitmap /** * The data preprocessing pipeline presented as Kotlin DSL on receivers. @@ -28,3 +30,11 @@ public fun Operation.resize(block: Resize.() -> Unit): Operation< public fun Operation.rotate(block: Rotate.() -> Unit): Operation { return PreprocessingPipeline(this, Rotate().apply(block)) } + +/** + * Applies [ConvertToBitmap] operation to convert [ImageProxy] to [Bitmap]. + * Also appropriate rotation is applied, to match the target rotation of an image. + */ +public fun Operation.convertToBitmap(block: ConvertToBitmap.() -> Unit): Operation { + return PreprocessingPipeline(this, ConvertToBitmap().apply(block)) +} diff --git a/dataset/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/dataset/preprocessing/imageproxy/ConvertToBitmap.kt b/dataset/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/dataset/preprocessing/imageproxy/ConvertToBitmap.kt new file mode 100644 index 000000000..1d4ff27cf --- /dev/null +++ b/dataset/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/dataset/preprocessing/imageproxy/ConvertToBitmap.kt @@ -0,0 +1,39 @@ +/* + * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. + */ + +package org.jetbrains.kotlinx.dl.dataset.preprocessing.imageproxy + +import android.graphics.Bitmap +import androidx.camera.core.ImageProxy +import org.jetbrains.kotlinx.dl.dataset.shape.TensorShape +import org.jetbrains.kotlinx.dl.dataset.preprocessing.Operation +import org.jetbrains.kotlinx.dl.dataset.preprocessing.bitmap.Rotate + +/** + * Conversion of CameraX [ImageProxy] to [Bitmap]. + * Decoding YUV_420_888 image to RGB bitmap. + * Also appropriate rotation is applied, to match the target rotation of an image. + */ +public class ConvertToBitmap : Operation { + override fun apply(input: ImageProxy): Bitmap { + val bitmap = input.toBitmap() + check(bitmap != null) { "Something went wrong during image proxy to bitmap conversion" } + + val targetRotation = input.imageInfo.rotationDegrees.toFloat() + + return Rotate(targetRotation).apply(bitmap) + } + + /** + * It's not possible to know the shape of the output, because of the rotation operation which depends on the input. + */ + override fun getOutputShape(inputShape: TensorShape): TensorShape { + return when (inputShape.rank()) { + 2 -> TensorShape(-1, -1) + 3 -> TensorShape(-1, -1, inputShape[2]) + else -> throw IllegalArgumentException("Input shape must expected to be 2D or 3D") + } + } +} diff --git a/dataset/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/dataset/preprocessing/imageproxy/imageUtlls.kt b/dataset/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/dataset/preprocessing/imageproxy/imageUtlls.kt new file mode 100644 index 000000000..e4f2a8273 --- /dev/null +++ b/dataset/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/dataset/preprocessing/imageproxy/imageUtlls.kt @@ -0,0 +1,153 @@ +package org.jetbrains.kotlinx.dl.dataset.preprocessing.imageproxy + +import android.graphics.Bitmap +import android.graphics.YuvImage +import android.graphics.ImageFormat +import android.graphics.Rect +import android.graphics.BitmapFactory +import androidx.camera.core.ImageProxy +import java.io.ByteArrayOutputStream + +/** + * Converts an [ImageProxy] to a [Bitmap]. + * Currently only supports [ImageFormat.YUV_420_888]. + */ +public fun ImageProxy.toBitmap(): Bitmap? { + val nv21 = yuv420888ToNv21(this) + val yuvImage = YuvImage(nv21, ImageFormat.NV21, width, height, null) + return yuvImage.toBitmap() +} + +private fun YuvImage.toBitmap(): Bitmap? { + val out = ByteArrayOutputStream() + if (!compressToJpeg(Rect(0, 0, width, height), 100, out)) + return null + val imageBytes: ByteArray = out.toByteArray() + return BitmapFactory.decodeByteArray(imageBytes, 0, imageBytes.size) +} + +private fun yuv420888ToNv21(image: ImageProxy): ByteArray { + val pixelCount = image.cropRect.width() * image.cropRect.height() + val pixelSizeBits = ImageFormat.getBitsPerPixel(ImageFormat.YUV_420_888) + val outputBuffer = ByteArray(pixelCount * pixelSizeBits / 8) + imageToByteBuffer(image, outputBuffer, pixelCount) + return outputBuffer +} + +/** + * Decoding of YUV_420_888 image to NV21 byte representation. + */ +public fun imageToByteBuffer(image: ImageProxy, outputBuffer: ByteArray, pixelCount: Int) { + assert(image.format == ImageFormat.YUV_420_888) + + val imageCrop = image.cropRect + val imagePlanes = image.planes + + imagePlanes.forEachIndexed { planeIndex, plane -> + // How many values are read in input for each output value written + // Only the Y plane has a value for every pixel, U and V have half the resolution i.e. + // + // Y Plane U Plane V Plane + // =============== ======= ======= + // Y Y Y Y Y Y Y Y U U U U V V V V + // Y Y Y Y Y Y Y Y U U U U V V V V + // Y Y Y Y Y Y Y Y U U U U V V V V + // Y Y Y Y Y Y Y Y U U U U V V V V + // Y Y Y Y Y Y Y Y + // Y Y Y Y Y Y Y Y + // Y Y Y Y Y Y Y Y + val outputStride: Int + + // The index in the output buffer the next value will be written at + // For Y it's zero, for U and V we start at the end of Y and interleave them i.e. + // + // First chunk Second chunk + // =============== =============== + // Y Y Y Y Y Y Y Y V U V U V U V U + // Y Y Y Y Y Y Y Y V U V U V U V U + // Y Y Y Y Y Y Y Y V U V U V U V U + // Y Y Y Y Y Y Y Y V U V U V U V U + // Y Y Y Y Y Y Y Y + // Y Y Y Y Y Y Y Y + // Y Y Y Y Y Y Y Y + var outputOffset: Int + + when (planeIndex) { + 0 -> { + outputStride = 1 + outputOffset = 0 + } + 1 -> { + outputStride = 2 + // For NV21 format, U is in odd-numbered indices + outputOffset = pixelCount + 1 + } + 2 -> { + outputStride = 2 + // For NV21 format, V is in even-numbered indices + outputOffset = pixelCount + } + else -> { + // Image contains more than 3 planes, something strange is going on + return@forEachIndexed + } + } + + val planeBuffer = plane.buffer + val rowStride = plane.rowStride + val pixelStride = plane.pixelStride + + // We have to divide the width and height by two if it's not the Y plane + val planeCrop = if (planeIndex == 0) { + imageCrop + } else { + Rect( + imageCrop.left / 2, + imageCrop.top / 2, + imageCrop.right / 2, + imageCrop.bottom / 2 + ) + } + + val planeWidth = planeCrop.width() + val planeHeight = planeCrop.height() + + // Intermediate buffer used to store the bytes of each row + val rowBuffer = ByteArray(plane.rowStride) + + // Size of each row in bytes + val rowLength = if (pixelStride == 1 && outputStride == 1) { + planeWidth + } else { + // Take into account that the stride may include data from pixels other than this + // particular plane and row, and that could be between pixels and not after every + // pixel: + // + // |---- Pixel stride ----| Row ends here --> | + // | Pixel 1 | Other Data | Pixel 2 | Other Data | ... | Pixel N | + // + // We need to get (N-1) * (pixel stride bytes) per row + 1 byte for the last pixel + (planeWidth - 1) * pixelStride + 1 + } + + for (row in 0 until planeHeight) { + // Move buffer position to the beginning of this row + planeBuffer.position( + (row + planeCrop.top) * rowStride + planeCrop.left * pixelStride) + + if (pixelStride == 1 && outputStride == 1) { + // When there is a single stride value for pixel and output, we can just copy + // the entire row in a single step + planeBuffer.get(outputBuffer, outputOffset, rowLength) + outputOffset += rowLength + } else { + // When either pixel or output have a stride > 1 we must copy pixel by pixel + planeBuffer.get(rowBuffer, 0, rowLength) + for (col in 0 until planeWidth) { + outputBuffer[outputOffset] = rowBuffer[col * pixelStride] + outputOffset += outputStride + } + } + } + } +} diff --git a/onnx/build.gradle b/onnx/build.gradle index 0e13119b1..c401aa273 100644 --- a/onnx/build.gradle +++ b/onnx/build.gradle @@ -41,6 +41,7 @@ kotlin { androidMain { dependencies { api 'com.microsoft.onnxruntime:onnxruntime-mobile:1.11.0' + api 'androidx.camera:camera-core:1.0.0-rc03' } } } diff --git a/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/CameraXCompatibleModel.kt b/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/CameraXCompatibleModel.kt deleted file mode 100644 index 713afcd34..000000000 --- a/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/CameraXCompatibleModel.kt +++ /dev/null @@ -1,12 +0,0 @@ -package org.jetbrains.kotlinx.dl.api.inference.onnx - -/** - * Interface represents models which can be used with CameraX API, i.e. support setting of target image rotation. - */ -public interface CameraXCompatibleModel { - /** - * Target image rotation. - * @see [ImageInfo](https://developer.android.com/reference/androidx/camera/core/ImageInfo) - */ - public var targetRotation: Float -} diff --git a/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/ONNXModels.kt b/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/ONNXModels.kt index 8ae492770..50936d36d 100644 --- a/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/ONNXModels.kt +++ b/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/ONNXModels.kt @@ -1,8 +1,11 @@ package org.jetbrains.kotlinx.dl.api.inference.onnx; +import android.graphics.Bitmap +import androidx.camera.core.ImageProxy import org.jetbrains.kotlinx.dl.api.inference.InferenceModel import org.jetbrains.kotlinx.dl.api.inference.imagerecognition.InputType import org.jetbrains.kotlinx.dl.api.inference.keras.loaders.ModelHub +import org.jetbrains.kotlinx.dl.api.inference.onnx.camerax.CameraXCompatibleModelType import org.jetbrains.kotlinx.dl.api.inference.onnx.classification.ImageRecognitionModel import org.jetbrains.kotlinx.dl.api.inference.onnx.objectdetection.SSDLikeModel import org.jetbrains.kotlinx.dl.api.inference.onnx.objectdetection.SSDLikeModelMetadata @@ -21,10 +24,8 @@ public object ONNXModels { override val modelRelativePath: String, override val channelsFirst: Boolean, override val inputColorMode: ColorMode = ColorMode.RGB, - ) : OnnxModelType { - override fun pretrainedModel(modelHub: ModelHub): ImageRecognitionModel { - return ImageRecognitionModel(modelHub.loadModel(this) as OnnxInferenceModel, this) - } + ) : CameraXCompatibleModelType> { + protected open val classLabels: Map = Imagenet.V1k.labels() /** * Image classification model based on EfficientNet-Lite architecture. @@ -75,13 +76,46 @@ public object ONNXModels { channelsLast = !channelsFirst } - override fun pretrainedModel(modelHub: ModelHub): ImageRecognitionModel { - return ImageRecognitionModel( - modelHub.loadModel(this), - this, - Imagenet.V1001.labels() - ) - } + override val classLabels: Map = Imagenet.V1001.labels() + } + + override fun withImageProxyInput(modelHub: ModelHub): ImageRecognitionModel { + val internalModel = modelHub.loadModel(this) as OnnxInferenceModel + + val (width, height) = if (channelsFirst) + Pair(internalModel.inputDimensions[1], internalModel.inputDimensions[2]) + else + Pair(internalModel.inputDimensions[0], internalModel.inputDimensions[1]) + + val preprocessing = pipeline() + .convertToBitmap { } + .resize { + outputHeight = height.toInt() + outputWidth = width.toInt() + } + .toFloatArray { layout = if (channelsFirst) TensorLayout.NCHW else TensorLayout.NHWC } + .call(preprocessor) + + return ImageRecognitionModel(internalModel, preprocessing, classLabels) + } + + override fun withBitmapInput(modelHub: ModelHub): ImageRecognitionModel { + val internalModel = modelHub.loadModel(this) as OnnxInferenceModel + + val (width, height) = if (channelsFirst) + Pair(internalModel.inputDimensions[1], internalModel.inputDimensions[2]) + else + Pair(internalModel.inputDimensions[0], internalModel.inputDimensions[1]) + + val preprocessing = pipeline() + .resize { + outputHeight = height.toInt() + outputWidth = width.toInt() + } + .toFloatArray { layout = if (channelsFirst) TensorLayout.NCHW else TensorLayout.NHWC } + .call(preprocessor) + + return ImageRecognitionModel(internalModel, preprocessing, classLabels) } } @@ -90,7 +124,7 @@ public object ONNXModels { override val modelRelativePath: String, override val channelsFirst: Boolean = true, override val inputColorMode: ColorMode = ColorMode.RGB - ) : OnnxModelType { + ) : CameraXCompatibleModelType { /** * This model is a convolutional neural network model that runs on RGB images and predicts human joint locations of a single person. * (edges are available in [org.jetbrains.kotlinx.dl.api.inference.onnx.posedetection.edgeKeyPointsPairs] @@ -112,9 +146,20 @@ public object ONNXModels { * TensorFlow Model Hub with the MoveNetLighting model converted to ONNX. */ public object MoveNetSinglePoseLighting : - PoseDetection("movenet_singlepose_lighting_13") { - override fun pretrainedModel(modelHub: ModelHub): SinglePoseDetectionModel { - return SinglePoseDetectionModel(modelHub.loadModel(this)) + PoseDetection>("movenet_singlepose_lighting_13") { + + override fun withImageProxyInput(modelHub: ModelHub): SinglePoseDetectionModel { + val internalModel = modelHub.loadModel(this) + return SinglePoseDetectionModel( + internalModel, MoveNetSinglePoseThunder.imageProxyPreprocessing(internalModel.inputDimensions) + ) + } + + override fun withBitmapInput(modelHub: ModelHub): SinglePoseDetectionModel { + val internalModel = modelHub.loadModel(this) + return SinglePoseDetectionModel( + internalModel, MoveNetSinglePoseThunder.bitmapPreprocessing(internalModel.inputDimensions) + ) } } @@ -139,11 +184,37 @@ public object ONNXModels { * TensorFlow Model Hub with the MoveNetLighting model converted to ONNX. */ public object MoveNetSinglePoseThunder : - PoseDetection("movenet_thunder") { - override fun pretrainedModel(modelHub: ModelHub): SinglePoseDetectionModel { - return SinglePoseDetectionModel(modelHub.loadModel(this)) + PoseDetection>("movenet_thunder") { + + override fun withImageProxyInput(modelHub: ModelHub): SinglePoseDetectionModel { + val internalModel = modelHub.loadModel(this) + return SinglePoseDetectionModel(internalModel, imageProxyPreprocessing(internalModel.inputDimensions)) + } + + override fun withBitmapInput(modelHub: ModelHub): SinglePoseDetectionModel { + val internalModel = modelHub.loadModel(this) + return SinglePoseDetectionModel(internalModel, bitmapPreprocessing(internalModel.inputDimensions)) } } + + protected fun bitmapPreprocessing(inputDimensions: LongArray) : Operation> { + return pipeline() + .resize { + outputHeight = inputDimensions[0].toInt() + outputWidth = inputDimensions[1].toInt() + } + .toFloatArray { layout = TensorLayout.NHWC } + } + + protected fun imageProxyPreprocessing(inputDimensions: LongArray) : Operation> { + return pipeline() + .convertToBitmap {} + .resize { + outputHeight = inputDimensions[0].toInt() + outputWidth = inputDimensions[1].toInt() + } + .toFloatArray { layout = TensorLayout.NHWC } + } } /** Object detection models and preprocessing. */ @@ -151,7 +222,7 @@ public object ONNXModels { override val modelRelativePath: String, override val channelsFirst: Boolean = true, override val inputColorMode: ColorMode = ColorMode.RGB - ) : OnnxModelType { + ) : CameraXCompatibleModelType { /** * This model is a real-time neural network for object detection that detects 90 different classes * (labels are available in [org.jetbrains.kotlinx.dl.dataset.Coco.V2017]). @@ -178,7 +249,7 @@ public object ONNXModels { * Detailed description of SSD model and its pre- and postprocessing in onnx/models repository. */ public object SSDMobileNetV1 : - ObjectDetection("ssd_mobilenet_v1") { + ObjectDetection>("ssd_mobilenet_v1") { private val METADATA = SSDLikeModelMetadata( "TFLite_Detection_PostProcess", @@ -187,8 +258,22 @@ public object ONNXModels { 0, 1 ) - override fun pretrainedModel(modelHub: ModelHub): SSDLikeModel { - return SSDLikeModel(modelHub.loadModel(this), METADATA) + override fun withImageProxyInput(modelHub: ModelHub): SSDLikeModel { + val internalModel = modelHub.loadModel(this) + return SSDLikeModel( + internalModel, + imageProxyPreprocessing(internalModel.inputDimensions), + METADATA + ) + } + + override fun withBitmapInput(modelHub: ModelHub): SSDLikeModel { + val internalModel = modelHub.loadModel(this) + return SSDLikeModel( + internalModel, + bitmapPreprocessing(internalModel.inputDimensions), + METADATA + ) } } @@ -216,7 +301,7 @@ public object ONNXModels { * Tutorial which shows how to covert the EfficientDet models to ONNX using tf2onnx. */ public object EfficientDetLite0 : - ObjectDetection("efficientdet_lite0") { + ObjectDetection>("efficientdet_lite0") { private val METADATA = SSDLikeModelMetadata( "StatefulPartitionedCall:3", @@ -225,9 +310,42 @@ public object ONNXModels { 0, 1 ) - override fun pretrainedModel(modelHub: ModelHub): SSDLikeModel { - return SSDLikeModel(modelHub.loadModel(this), METADATA) + override fun withImageProxyInput(modelHub: ModelHub): SSDLikeModel { + val internalModel = modelHub.loadModel(this) + return SSDLikeModel( + internalModel, + imageProxyPreprocessing(internalModel.inputDimensions), + METADATA + ) + } + + override fun withBitmapInput(modelHub: ModelHub): SSDLikeModel { + val internalModel = modelHub.loadModel(this) + return SSDLikeModel( + internalModel, + bitmapPreprocessing(internalModel.inputDimensions), + METADATA + ) } } + + protected fun bitmapPreprocessing(inputDimensions: LongArray) : Operation> { + return pipeline() + .resize { + outputHeight = inputDimensions[0].toInt() + outputWidth = inputDimensions[1].toInt() + } + .toFloatArray { layout = TensorLayout.NHWC } + } + + protected fun imageProxyPreprocessing(inputDimensions: LongArray) : Operation> { + return pipeline() + .convertToBitmap {} + .resize { + outputHeight = inputDimensions[0].toInt() + outputWidth = inputDimensions[1].toInt() + } + .toFloatArray { layout = TensorLayout.NHWC } + } } } diff --git a/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/camerax/CameraXCompatibleModelType.kt b/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/camerax/CameraXCompatibleModelType.kt new file mode 100644 index 000000000..c6f7bc49c --- /dev/null +++ b/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/camerax/CameraXCompatibleModelType.kt @@ -0,0 +1,32 @@ +package org.jetbrains.kotlinx.dl.api.inference.onnx.camerax + +import org.jetbrains.kotlinx.dl.api.inference.InferenceModel +import org.jetbrains.kotlinx.dl.api.inference.keras.loaders.ModelHub +import org.jetbrains.kotlinx.dl.api.inference.onnx.OnnxModelType +import android.graphics.Bitmap +import androidx.camera.core.ImageProxy + +/** + * Interface represents model type which can be instantiated CameraX [ImageProxy] input or [Bitmap] input. + */ +public interface CameraXCompatibleModelType : OnnxModelType { + /** + * Loading model using [ModelHub] api and returns [InferenceModel] instance which accepts [ImageProxy] input. + * + * @param modelHub [ModelHub] instance. + */ + public fun withImageProxyInput(modelHub: ModelHub) : U + + /** + * Loading model using [ModelHub] api and returns [InferenceModel] instance which accepts [Bitmap] input. + * + * @param modelHub [ModelHub] instance. + */ + public fun withBitmapInput(modelHub: ModelHub) : U + + /** + * The method is the same as [withBitmapInput] and should not be used. + * Please use [withBitmapInput] as it is more clear. + */ + override fun pretrainedModel(modelHub: ModelHub): U = withBitmapInput(modelHub) +} diff --git a/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/classification/ImageRecognitionModel.kt b/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/classification/ImageRecognitionModel.kt index 2dfcbd4bb..a4a3ff2f2 100644 --- a/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/classification/ImageRecognitionModel.kt +++ b/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/classification/ImageRecognitionModel.kt @@ -1,45 +1,21 @@ package org.jetbrains.kotlinx.dl.api.inference.onnx.classification -import android.graphics.Bitmap -import org.jetbrains.kotlinx.dl.api.inference.InferenceModel import org.jetbrains.kotlinx.dl.api.inference.imagerecognition.ImageRecognitionModelBase -import org.jetbrains.kotlinx.dl.api.inference.keras.loaders.ModelType -import org.jetbrains.kotlinx.dl.api.inference.onnx.CameraXCompatibleModel import org.jetbrains.kotlinx.dl.api.inference.onnx.ExecutionProviderCompatible import org.jetbrains.kotlinx.dl.api.inference.onnx.OnnxInferenceModel import org.jetbrains.kotlinx.dl.api.inference.onnx.executionproviders.ExecutionProvider import org.jetbrains.kotlinx.dl.dataset.Imagenet -import org.jetbrains.kotlinx.dl.dataset.imagenetLabels import org.jetbrains.kotlinx.dl.dataset.preprocessing.* import org.jetbrains.kotlinx.dl.dataset.shape.TensorShape /** * The light-weight API for Classification task with one of the Model Hub models. */ -public open class ImageRecognitionModel( +public open class ImageRecognitionModel( internalModel: OnnxInferenceModel, - private val modelType: ModelType, + override val preprocessing: Operation>, override val classLabels: Map = Imagenet.V1k.labels() -) : ImageRecognitionModelBase(internalModel), ExecutionProviderCompatible, CameraXCompatibleModel { - override var targetRotation: Float = 0f - - override val preprocessing: Operation> - get() { - val (width, height) = if (modelType.channelsFirst) - Pair(internalModel.inputDimensions[1], internalModel.inputDimensions[2]) - else - Pair(internalModel.inputDimensions[0], internalModel.inputDimensions[1]) - - return pipeline() - .resize { - outputHeight = height.toInt() - outputWidth = width.toInt() - } - .rotate { degrees = targetRotation } - .toFloatArray { layout = if (modelType.channelsFirst) TensorLayout.NCHW else TensorLayout.NHWC } - .call(modelType.preprocessor) - } - +) : ImageRecognitionModelBase(internalModel), ExecutionProviderCompatible { override fun initializeWith(vararg executionProviders: ExecutionProvider) { (internalModel as OnnxInferenceModel).initializeWith(*executionProviders) } diff --git a/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/SSDLikeModel.kt b/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/SSDLikeModel.kt index 50581e920..5eba7fc9a 100644 --- a/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/SSDLikeModel.kt +++ b/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/SSDLikeModel.kt @@ -1,8 +1,6 @@ package org.jetbrains.kotlinx.dl.api.inference.onnx.objectdetection -import android.graphics.Bitmap import org.jetbrains.kotlinx.dl.api.inference.InferenceModel -import org.jetbrains.kotlinx.dl.api.inference.onnx.CameraXCompatibleModel import org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels import org.jetbrains.kotlinx.dl.api.inference.onnx.OnnxInferenceModel import org.jetbrains.kotlinx.dl.dataset.Coco @@ -19,22 +17,15 @@ import org.jetbrains.kotlinx.dl.dataset.shape.TensorShape * * @since 0.5 */ -public class SSDLikeModel(override val internalModel: OnnxInferenceModel, metadata: SSDLikeModelMetadata) : - SSDLikeModelBase(metadata), CameraXCompatibleModel, InferenceModel by internalModel { +public class SSDLikeModel( + override val internalModel: OnnxInferenceModel, + override val preprocessing: Operation>, + metadata: SSDLikeModelMetadata +) : + SSDLikeModelBase(metadata),InferenceModel by internalModel { override val classLabels: Map = Coco.V2017.labels(zeroIndexed = true) - override var targetRotation: Float = 0f - - override val preprocessing: Operation> - get() = pipeline() - .resize { - outputHeight = internalModel.inputDimensions[0].toInt() - outputWidth = internalModel.inputDimensions[1].toInt() - } - .rotate { degrees = targetRotation } - .toFloatArray { layout = TensorLayout.NHWC } - override fun close() { internalModel.close() } diff --git a/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/posedetection/SinglePoseDetectionModel.kt b/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/posedetection/SinglePoseDetectionModel.kt index 05982e668..af9f5bc50 100644 --- a/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/posedetection/SinglePoseDetectionModel.kt +++ b/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/posedetection/SinglePoseDetectionModel.kt @@ -5,9 +5,7 @@ package org.jetbrains.kotlinx.dl.api.inference.onnx.posedetection -import android.graphics.Bitmap import org.jetbrains.kotlinx.dl.api.inference.InferenceModel -import org.jetbrains.kotlinx.dl.api.inference.onnx.CameraXCompatibleModel import org.jetbrains.kotlinx.dl.api.inference.onnx.OnnxInferenceModel import org.jetbrains.kotlinx.dl.api.inference.onnx.executionproviders.ExecutionProvider import org.jetbrains.kotlinx.dl.dataset.preprocessing.* @@ -23,24 +21,18 @@ import org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels * * @param internalModel model used to make predictions */ -public class SinglePoseDetectionModel(override val internalModel: OnnxInferenceModel) : - SinglePoseDetectionModelBase(), InferenceModel by internalModel, CameraXCompatibleModel { - override val preprocessing: Operation> - get() = pipeline() - .resize { - outputHeight = internalModel.inputDimensions[0].toInt() - outputWidth = internalModel.inputDimensions[1].toInt() - } - .rotate { degrees = targetRotation } - .toFloatArray { layout = TensorLayout.NHWC } - - override var targetRotation: Float = 0f - +public class SinglePoseDetectionModel( + override val internalModel: OnnxInferenceModel, + override val preprocessing: Operation> +) : SinglePoseDetectionModelBase(), InferenceModel by internalModel { /** * Constructs the pose detection model from a model bytes. * @param [modelBytes] */ - public constructor (modelBytes: ByteArray) : this(OnnxInferenceModel(modelBytes)) { + public constructor ( + modelBytes: ByteArray, + preprocessing: Operation> + ) : this(OnnxInferenceModel(modelBytes), preprocessing) { internalModel.initializeWith(ExecutionProvider.CPU()) }