From 10be87ed082cbd8101930d1a6672f66e63f297fe Mon Sep 17 00:00:00 2001 From: Nikita Ermolenko Date: Thu, 29 Sep 2022 18:25:48 +0300 Subject: [PATCH] Refactoring of extension functions with ImageProxy input (#454) * Introduce doWithRotation extension function to reduce code duplication * Move extension functions to the base classes (e.g. SinglePoseDetectionModelBase instead SinglePoseDetectionModel) Co-authored-by: Julia Beliaeva --- .../inference/onnx/CameraXCompatibleModel.kt | 13 +++++++++ .../classification/ImageRecognitionModel.kt | 28 +++++++++++-------- .../onnx/objectdetection/SSDLikeModel.kt | 14 ++++++---- .../posedetection/SinglePoseDetectionModel.kt | 13 +++++---- 4 files changed, 47 insertions(+), 21 deletions(-) diff --git a/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/CameraXCompatibleModel.kt b/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/CameraXCompatibleModel.kt index 6b16a77b0..18962b209 100644 --- a/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/CameraXCompatibleModel.kt +++ b/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/CameraXCompatibleModel.kt @@ -10,3 +10,16 @@ public interface CameraXCompatibleModel { */ public var targetRotation: Int } + +/** + * Convenience function to execute arbitrary code with a preliminary updated target rotation. + * After the code is executed, the target rotation is restored to its original value. + * + * @param rotation target rotation to be set for the duration of the code execution + * @param function arbitrary code to be executed + */ +public fun CameraXCompatibleModel.doWithRotation(rotation: Int, function: () -> R): R { + val currentRotation = targetRotation + targetRotation = rotation + return function().apply { targetRotation = currentRotation } +} diff --git a/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/classification/ImageRecognitionModel.kt b/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/classification/ImageRecognitionModel.kt index 7ebb9cb9e..f904571d5 100644 --- a/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/classification/ImageRecognitionModel.kt +++ b/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/classification/ImageRecognitionModel.kt @@ -6,7 +6,9 @@ import org.jetbrains.kotlinx.dl.api.inference.imagerecognition.ImageRecognitionM import org.jetbrains.kotlinx.dl.api.inference.onnx.CameraXCompatibleModel import org.jetbrains.kotlinx.dl.api.inference.onnx.ExecutionProviderCompatible import org.jetbrains.kotlinx.dl.api.inference.onnx.OnnxInferenceModel +import org.jetbrains.kotlinx.dl.api.inference.onnx.doWithRotation import org.jetbrains.kotlinx.dl.api.inference.onnx.executionproviders.ExecutionProvider +import org.jetbrains.kotlinx.dl.api.inference.onnx.posedetection.detectPose import org.jetbrains.kotlinx.dl.dataset.Imagenet import org.jetbrains.kotlinx.dl.dataset.preprocessing.* import org.jetbrains.kotlinx.dl.dataset.preprocessing.camerax.toBitmap @@ -54,11 +56,13 @@ public open class ImageRecognitionModel( * * @return The label of the recognized object with the highest probability. */ -public fun ImageRecognitionModel.predictObject(imageProxy: ImageProxy): String { - val currentRotation = targetRotation - targetRotation = imageProxy.imageInfo.rotationDegrees - return predictObject(imageProxy.toBitmap()).also { targetRotation = currentRotation } -} +public fun ImageRecognitionModelBase.predictObject(imageProxy: ImageProxy): String = + when (this) { + is CameraXCompatibleModel -> { + doWithRotation(imageProxy.imageInfo.rotationDegrees) { predictObject(imageProxy.toBitmap()) } + } + else -> predictObject(imageProxy.toBitmap(applyRotation = true)) + } /** * Predicts [topK] objects for the given [imageProxy]. @@ -70,11 +74,13 @@ public fun ImageRecognitionModel.predictObject(imageProxy: ImageProxy): String { * * @return The list of pairs sorted from the most probable to the lowest probable. */ -public fun ImageRecognitionModel.predictTopKObjects( +public fun ImageRecognitionModelBase.predictTopKObjects( imageProxy: ImageProxy, topK: Int = 5 -): List> { - val currentRotation = targetRotation - targetRotation = imageProxy.imageInfo.rotationDegrees - return predictTopKObjects(imageProxy.toBitmap(), topK).also { targetRotation = currentRotation } -} +): List> = + when (this) { + is CameraXCompatibleModel -> { + doWithRotation(imageProxy.imageInfo.rotationDegrees) { predictTopKObjects(imageProxy.toBitmap(), topK) } + } + else -> predictTopKObjects(imageProxy.toBitmap(applyRotation = true), topK) + } diff --git a/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/SSDLikeModel.kt b/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/SSDLikeModel.kt index a5efba075..b65b24379 100644 --- a/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/SSDLikeModel.kt +++ b/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/SSDLikeModel.kt @@ -7,6 +7,8 @@ import org.jetbrains.kotlinx.dl.api.inference.objectdetection.DetectedObject import org.jetbrains.kotlinx.dl.api.inference.onnx.CameraXCompatibleModel import org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels import org.jetbrains.kotlinx.dl.api.inference.onnx.OnnxInferenceModel +import org.jetbrains.kotlinx.dl.api.inference.onnx.doWithRotation +import org.jetbrains.kotlinx.dl.api.inference.onnx.posedetection.detectPose import org.jetbrains.kotlinx.dl.dataset.Coco import org.jetbrains.kotlinx.dl.dataset.preprocessing.* import org.jetbrains.kotlinx.dl.dataset.preprocessing.camerax.toBitmap @@ -52,8 +54,10 @@ public class SSDLikeModel(override val internalModel: OnnxInferenceModel, metada * @param [topK] The number of the detected objects with the highest score to be returned. * @return List of [DetectedObject] sorted by score. */ -public fun SSDLikeModel.detectObjects(imageProxy: ImageProxy, topK: Int = 3): List { - val currentRotation = targetRotation - targetRotation = imageProxy.imageInfo.rotationDegrees - return detectObjects(imageProxy.toBitmap(), topK).also { targetRotation = currentRotation } -} +public fun ObjectDetectionModelBase.detectObjects(imageProxy: ImageProxy, topK: Int = 3): List = + when (this) { + is CameraXCompatibleModel -> { + doWithRotation(imageProxy.imageInfo.rotationDegrees) { detectObjects(imageProxy.toBitmap(), topK) } + } + else -> detectObjects(imageProxy.toBitmap(applyRotation = true), topK) + } diff --git a/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/posedetection/SinglePoseDetectionModel.kt b/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/posedetection/SinglePoseDetectionModel.kt index 39d56d14c..ebb546496 100644 --- a/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/posedetection/SinglePoseDetectionModel.kt +++ b/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/posedetection/SinglePoseDetectionModel.kt @@ -14,6 +14,7 @@ import org.jetbrains.kotlinx.dl.api.inference.onnx.executionproviders.ExecutionP import org.jetbrains.kotlinx.dl.dataset.preprocessing.* import org.jetbrains.kotlinx.dl.dataset.shape.TensorShape import org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels +import org.jetbrains.kotlinx.dl.api.inference.onnx.doWithRotation import org.jetbrains.kotlinx.dl.api.inference.posedetection.DetectedPose import org.jetbrains.kotlinx.dl.dataset.preprocessing.camerax.toBitmap @@ -59,8 +60,10 @@ public class SinglePoseDetectionModel(override val internalModel: OnnxInferenceM * * @param [imageProxy] input image. */ -public fun SinglePoseDetectionModel.detectPose(imageProxy: ImageProxy): DetectedPose { - val currentRotation = targetRotation - targetRotation = imageProxy.imageInfo.rotationDegrees - return detectPose(imageProxy.toBitmap()).also { targetRotation = currentRotation } -} +public fun SinglePoseDetectionModelBase.detectPose(imageProxy: ImageProxy): DetectedPose = + when (this) { + is CameraXCompatibleModel -> { + doWithRotation(imageProxy.imageInfo.rotationDegrees) { detectPose(imageProxy.toBitmap()) } + } + else -> detectPose(imageProxy.toBitmap(applyRotation = true)) + }