Skip to content

Commit

Permalink
More convenient API for inference of ModelHub models with ImageProxy …
Browse files Browse the repository at this point in the history
…input (Kotlin#454)

* Extend the ModelHub's models with an API that accepts ImageProxy as an input.

* Now user not required to set targetRotation property by hand

The rotation operation is quite expensive, so it's beneficial to make a rotation after resizing, although it makes API less intuitive.
  • Loading branch information
ermolenkodev committed Sep 29, 2022
1 parent edfe8c4 commit 0d80a31
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ public class Rotate(
public var degrees: Float = 0.0f,
) : Operation<Bitmap, Bitmap> {
override fun apply(input: Bitmap): Bitmap {
if (degrees == 0f) return input

val matrix = Matrix().apply { postRotate(degrees) }
return Bitmap.createBitmap(input, 0, 0, input.width, input.height, matrix, true)
}
Expand Down
3 changes: 2 additions & 1 deletion onnx/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ kotlin {
androidMain {
dependencies {
api 'com.microsoft.onnxruntime:onnxruntime-mobile:1.12.1'
api 'androidx.camera:camera-core:1.0.0-rc03'
}
}
}
Expand All @@ -58,4 +59,4 @@ android {
sourceCompatibility JavaVersion.VERSION_1_8
targetCompatibility JavaVersion.VERSION_1_8
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ public interface CameraXCompatibleModel {
* Target image rotation.
* @see [ImageInfo](https://developer.android.com/reference/androidx/camera/core/ImageInfo)
*/
public var targetRotation: Float
public var targetRotation: Int
}
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
package org.jetbrains.kotlinx.dl.api.inference.onnx.classification

import android.graphics.Bitmap
import androidx.camera.core.ImageProxy
import org.jetbrains.kotlinx.dl.api.inference.imagerecognition.ImageRecognitionModelBase
import org.jetbrains.kotlinx.dl.api.inference.onnx.CameraXCompatibleModel
import org.jetbrains.kotlinx.dl.api.inference.onnx.ExecutionProviderCompatible
import org.jetbrains.kotlinx.dl.api.inference.onnx.OnnxInferenceModel
import org.jetbrains.kotlinx.dl.api.inference.onnx.executionproviders.ExecutionProvider
import org.jetbrains.kotlinx.dl.dataset.Imagenet
import org.jetbrains.kotlinx.dl.dataset.preprocessing.*
import org.jetbrains.kotlinx.dl.dataset.preprocessing.camerax.toBitmap
import org.jetbrains.kotlinx.dl.dataset.shape.TensorShape

/**
Expand All @@ -19,7 +21,7 @@ public open class ImageRecognitionModel(
private val preprocessor: Operation<Pair<FloatArray, TensorShape>, Pair<FloatArray, TensorShape>> = Identity(),
override val classLabels: Map<Int, String> = Imagenet.V1k.labels()
) : ImageRecognitionModelBase<Bitmap>(internalModel), ExecutionProviderCompatible, CameraXCompatibleModel {
override var targetRotation: Float = 0f
override var targetRotation: Int = 0

override val preprocessing: Operation<Bitmap, Pair<FloatArray, TensorShape>>
get() {
Expand All @@ -33,7 +35,7 @@ public open class ImageRecognitionModel(
outputHeight = height.toInt()
outputWidth = width.toInt()
}
.rotate { degrees = targetRotation }
.rotate { degrees = targetRotation.toFloat() }
.toFloatArray { layout = if (channelsFirst) TensorLayout.NCHW else TensorLayout.NHWC }
.call(preprocessor)
}
Expand All @@ -42,3 +44,37 @@ public open class ImageRecognitionModel(
(internalModel as OnnxInferenceModel).initializeWith(*executionProviders)
}
}

/**
* Predicts object for the given [imageProxy].
* Internal preprocessing is updated to rotate image to match target orientation.
* After prediction, internal preprocessing is restored to the original state.
*
* @param [imageProxy] Input image.
*
* @return The label of the recognized object with the highest probability.
*/
public fun ImageRecognitionModel.predictObject(imageProxy: ImageProxy): String {
val currentRotation = targetRotation
targetRotation = imageProxy.imageInfo.rotationDegrees
return predictObject(imageProxy.toBitmap()).also { targetRotation = currentRotation }
}

/**
* Predicts [topK] objects for the given [imageProxy].
* Internal preprocessing is updated to rotate image to match target orientation.
* After prediction, internal preprocessing is restored to the original state.
*
* @param [imageProxy] Input image.
* @param [topK] Number of top ranked predictions to return
*
* @return The list of pairs <label, probability> sorted from the most probable to the lowest probable.
*/
public fun ImageRecognitionModel.predictTopKObjects(
imageProxy: ImageProxy,
topK: Int = 5
): List<Pair<String, Float>> {
val currentRotation = targetRotation
targetRotation = imageProxy.imageInfo.rotationDegrees
return predictTopKObjects(imageProxy.toBitmap(), topK).also { targetRotation = currentRotation }
}
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
package org.jetbrains.kotlinx.dl.api.inference.onnx.objectdetection

import android.graphics.Bitmap
import androidx.camera.core.ImageProxy
import org.jetbrains.kotlinx.dl.api.inference.InferenceModel
import org.jetbrains.kotlinx.dl.api.inference.objectdetection.DetectedObject
import org.jetbrains.kotlinx.dl.api.inference.onnx.CameraXCompatibleModel
import org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels
import org.jetbrains.kotlinx.dl.api.inference.onnx.OnnxInferenceModel
import org.jetbrains.kotlinx.dl.dataset.Coco
import org.jetbrains.kotlinx.dl.dataset.preprocessing.*
import org.jetbrains.kotlinx.dl.dataset.preprocessing.camerax.toBitmap
import org.jetbrains.kotlinx.dl.dataset.shape.TensorShape

/**
Expand All @@ -24,18 +27,33 @@ public class SSDLikeModel(override val internalModel: OnnxInferenceModel, metada

override val classLabels: Map<Int, String> = Coco.V2017.labels(zeroIndexed = true)

override var targetRotation: Float = 0f
override var targetRotation: Int = 0

override val preprocessing: Operation<Bitmap, Pair<FloatArray, TensorShape>>
get() = pipeline<Bitmap>()
.resize {
outputHeight = internalModel.inputDimensions[0].toInt()
outputWidth = internalModel.inputDimensions[1].toInt()
}
.rotate { degrees = targetRotation }
.rotate { degrees = targetRotation.toFloat() }
.toFloatArray { layout = TensorLayout.NHWC }

override fun close() {
internalModel.close()
}
}

/**
* Returns the detected object for the given image sorted by the score.
* Internal preprocessing is updated to rotate image to match target orientation.
* After prediction, internal preprocessing is restored to the original state.
*
* @param [imageProxy] Input image.
* @param [topK] The number of the detected objects with the highest score to be returned.
* @return List of [DetectedObject] sorted by score.
*/
public fun SSDLikeModel.detectObjects(imageProxy: ImageProxy, topK: Int = 3): List<DetectedObject> {
val currentRotation = targetRotation
targetRotation = imageProxy.imageInfo.rotationDegrees
return detectObjects(imageProxy.toBitmap(), topK).also { targetRotation = currentRotation }
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@
package org.jetbrains.kotlinx.dl.api.inference.onnx.posedetection

import android.graphics.Bitmap
import androidx.camera.core.ImageProxy
import org.jetbrains.kotlinx.dl.api.inference.InferenceModel
import org.jetbrains.kotlinx.dl.api.inference.onnx.CameraXCompatibleModel
import org.jetbrains.kotlinx.dl.api.inference.onnx.OnnxInferenceModel
import org.jetbrains.kotlinx.dl.api.inference.onnx.executionproviders.ExecutionProvider
import org.jetbrains.kotlinx.dl.dataset.preprocessing.*
import org.jetbrains.kotlinx.dl.dataset.shape.TensorShape
import org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels
import org.jetbrains.kotlinx.dl.api.inference.posedetection.DetectedPose
import org.jetbrains.kotlinx.dl.dataset.preprocessing.camerax.toBitmap


/**
Expand All @@ -31,10 +34,10 @@ public class SinglePoseDetectionModel(override val internalModel: OnnxInferenceM
outputHeight = internalModel.inputDimensions[0].toInt()
outputWidth = internalModel.inputDimensions[1].toInt()
}
.rotate { degrees = targetRotation }
.rotate { degrees = targetRotation.toFloat() }
.toFloatArray { layout = TensorLayout.NHWC }

override var targetRotation: Float = 0f
override var targetRotation: Int = 0

/**
* Constructs the pose detection model from a model bytes.
Expand All @@ -48,3 +51,16 @@ public class SinglePoseDetectionModel(override val internalModel: OnnxInferenceM
internalModel.close()
}
}

/**
* Detects a pose for the given [imageProxy].
* Internal preprocessing is updated to rotate image to match target orientation.
* After prediction, internal preprocessing is restored to the original state.
*
* @param [imageProxy] input image.
*/
public fun SinglePoseDetectionModel.detectPose(imageProxy: ImageProxy): DetectedPose {
val currentRotation = targetRotation
targetRotation = imageProxy.imageInfo.rotationDegrees
return detectPose(imageProxy.toBitmap()).also { targetRotation = currentRotation }
}

0 comments on commit 0d80a31

Please sign in to comment.