Skip to content

Commit

Permalink
Add support for SSD model on android (#446)
Browse files Browse the repository at this point in the history
* Add SSD model for android (#440)

* Add onnxruntime-mobile dependency for androidMain

* Reduc code duplication for SSD models code on JVM and Android platforms

* Add support for .inferUsing API for OnnxHighLevelModel (#434)

* Add support for zero indexed COCO labels

* Fix ssdLightAPITest (#440)

* Turn Coco from class to enum

* Fix version of onnxruntime dependency for an Android (#440)

* Cleanup (#440)

* Add docs for ExecutionProviderCompatible interface

* Make ExecutionProviderCompatible extends AutoCloseable
  and simplify inferUsing functions

* Remove unnecessary implementation of InferenceModel interface
  from SSD model on Android
  • Loading branch information
ermolenkodev authored Sep 9, 2022
1 parent 55c38d8 commit 9ff47d7
Show file tree
Hide file tree
Showing 12 changed files with 185 additions and 87 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,28 @@
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
*/

package org.jetbrains.kotlinx.dl.dataset.handler
package org.jetbrains.kotlinx.dl.dataset


public enum class Coco {
V2014,
V2017;

public fun labels(zeroIndexed: Boolean = false) : Map<Int, String> {
return when (this) {
V2014 -> if (zeroIndexed) toZeroIndexed(cocoCategories2014) else cocoCategories2014
V2017 -> if (zeroIndexed) toZeroIndexed(cocoCategories2017) else cocoCategories2017
}
}

private fun toZeroIndexed(labels: Map<Int, String>) : Map<Int, String> {
val zeroIndexedLabels = mutableMapOf<Int, String>()
labels.forEach { (key, value) ->
zeroIndexedLabels[key - 1] = value
}
return zeroIndexedLabels
}
}

/**
* 80 object categories in COCO dataset.
Expand All @@ -14,7 +35,7 @@ package org.jetbrains.kotlinx.dl.dataset.handler
* @see <a href="https://cocodataset.org/#home">
* COCO dataset</a>
*/
public val cocoCategoriesForSSD: Map<Int, String> = mapOf(
public val cocoCategories2014: Map<Int, String> = mapOf(
1 to "person",
2 to "bicycle",
3 to "car",
Expand Down Expand Up @@ -104,7 +125,7 @@ public val cocoCategoriesForSSD: Map<Int, String> = mapOf(
* @see <a href="https://cocodataset.org/#home">
* COCO dataset</a>
*/
public val cocoCategories: Map<Int, String> = mapOf(
public val cocoCategories2017: Map<Int, String> = mapOf(
1 to "person",
2 to "bicycle",
3 to "car",
Expand Down Expand Up @@ -186,3 +207,9 @@ public val cocoCategories: Map<Int, String> = mapOf(
89 to "hair drier",
90 to "toothbrush"
)


public enum class CocoVersion {
V2014,
V2017
}
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ fun runImageRecognitionPrediction(
}

return if (executionProviders.isNotEmpty()) {
model.inferAndCloseUsing(executionProviders) { inference(it) }
model.inferAndCloseUsing(*executionProviders.toTypedArray()) { inference(it) }
} else {
model.use { inference(it) }
}
Expand Down
3 changes: 3 additions & 0 deletions onnx/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ kotlin {
}
}
androidMain {
dependencies {
api 'com.microsoft.onnxruntime:onnxruntime-mobile:1.11.0'
}
}
}
explicitApiWarning()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package org.jetbrains.kotlinx.dl.api.inference.onnx.objectdetection

import android.graphics.Bitmap
import org.jetbrains.kotlinx.dl.api.inference.InferenceModel
import org.jetbrains.kotlinx.dl.api.inference.objectdetection.DetectedObject
import org.jetbrains.kotlinx.dl.api.inference.onnx.OnnxInferenceModel
import org.jetbrains.kotlinx.dl.api.inference.onnx.executionproviders.ExecutionProvider.CPU
import org.jetbrains.kotlinx.dl.dataset.Coco
import org.jetbrains.kotlinx.dl.dataset.CocoVersion.V2017
import org.jetbrains.kotlinx.dl.dataset.preprocessing.*
import org.jetbrains.kotlinx.dl.dataset.shape.TensorShape


private val SSD_MOBILENET_METADATA = SSDModelMetadata(
"TFLite_Detection_PostProcess",
"TFLite_Detection_PostProcess:1",
"TFLite_Detection_PostProcess:2",
0, 1
)


public class SSDMobileNetObjectDetectionModel(override val internalModel: OnnxInferenceModel) :
SSDObjectDetectionModelBase<Bitmap>(SSD_MOBILENET_METADATA) {

override val classLabels: Map<Int, String> = Coco.V2017.labels(zeroIndexed = true)

private var targetRotation = 0f

override lateinit var preprocessing: Operation<Bitmap, Pair<FloatArray, TensorShape>>
private set

public constructor (modelBytes: ByteArray) : this(OnnxInferenceModel(modelBytes)) {
internalModel.initializeWith(CPU())
preprocessing = buildPreprocessingPipeline()
}

public fun setTargetRotation(targetRotation: Float) {
if (this.targetRotation == targetRotation) return

this.targetRotation = targetRotation
preprocessing = buildPreprocessingPipeline()
}

private fun buildPreprocessingPipeline(): Operation<Bitmap, Pair<FloatArray, TensorShape>> {
return pipeline<Bitmap>()
.resize {
outputHeight = internalModel.inputDimensions[0].toInt()
outputWidth = internalModel.inputDimensions[1].toInt()
}
.rotate { degrees = targetRotation }
.toFloatArray { layout = TensorLayout.NHWC }
}

override fun close() {
internalModel.close()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package org.jetbrains.kotlinx.dl.api.inference.onnx

import org.jetbrains.kotlinx.dl.api.inference.onnx.executionproviders.ExecutionProvider
import org.jetbrains.kotlinx.dl.api.inference.onnx.executionproviders.ExecutionProvider.CPU

/**
* Interface for a different kinds of ONNX models which support different execution providers.
*/
public interface ExecutionProviderCompatible : AutoCloseable {
/**
* Initialize the model with the specified executions providers.
*/
public fun initializeWith(vararg executionProviders: ExecutionProvider = arrayOf(CPU(true)))
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.jetbrains.kotlinx.dl.api.inference.onnx

import org.jetbrains.kotlinx.dl.api.inference.onnx.executionproviders.ExecutionProvider
import org.jetbrains.kotlinx.dl.dataset.preprocessing.Operation
import org.jetbrains.kotlinx.dl.dataset.shape.TensorShape

Expand All @@ -14,7 +15,7 @@ import org.jetbrains.kotlinx.dl.dataset.shape.TensorShape
* @param [I] input type
* @param [R] output type
*/
public interface OnnxHighLevelModel<I, R> {
public interface OnnxHighLevelModel<I, R> : ExecutionProviderCompatible {
/**
* Model used to make predictions.
*/
Expand All @@ -38,4 +39,8 @@ public interface OnnxHighLevelModel<I, R> {
val output = internalModel.predictRaw(preprocessedInput.first)
return convert(output)
}
}

override fun initializeWith(vararg executionProviders: ExecutionProvider) {
internalModel.initializeWith(*executionProviders)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ private const val RESHAPE_MISSED_MESSAGE = "Model input shape is not defined. Ca
*
* @since 0.3
*/
public open class OnnxInferenceModel private constructor(private val modelSource: ModelSource) : InferenceModel {
public open class OnnxInferenceModel private constructor(private val modelSource: ModelSource) : InferenceModel,
ExecutionProviderCompatible {
/**
* The host object for the onnx-runtime system. Can create [session] which encapsulate
* specific models.
Expand Down Expand Up @@ -104,7 +105,7 @@ public open class OnnxInferenceModel private constructor(private val modelSource
*
* @param executionProviders list of execution providers to use.
*/
public fun initializeWith(vararg executionProviders: ExecutionProvider = arrayOf(CPU(true))) {
public override fun initializeWith(vararg executionProviders: ExecutionProvider) {
val uniqueProviders = collectProviders(executionProviders)

if (::executionProvidersInUse.isInitialized && uniqueProviders == executionProvidersInUse) {
Expand Down Expand Up @@ -169,11 +170,13 @@ public open class OnnxInferenceModel private constructor(private val modelSource
0 -> {
uniqueProviders.add(CPU(true))
}

1 -> {
val cpu = uniqueProviders.first { it is CPU }
uniqueProviders.remove(cpu)
uniqueProviders.add(cpu)
}

else -> throw IllegalArgumentException("Unable to use CPU(useArena = true) and CPU(useArena = false) at the same time!")
}
return uniqueProviders
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,34 +7,18 @@ import org.jetbrains.kotlinx.dl.api.inference.onnx.executionproviders.ExecutionP
* Convenience extension functions for inference of ONNX models using different execution providers.
*/

public inline fun <R> OnnxInferenceModel.inferAndCloseUsing(
public inline fun <M : ExecutionProviderCompatible, R> M.inferAndCloseUsing(
vararg providers: ExecutionProvider,
block: (OnnxInferenceModel) -> R
block: (M) -> R
): R {
this.initializeWith(*providers)
return this.use(block)
}

public inline fun <R> OnnxInferenceModel.inferAndCloseUsing(
providers: List<ExecutionProvider>,
block: (OnnxInferenceModel) -> R
): R {
this.initializeWith(*providers.toTypedArray())
return this.use(block)
}

public inline fun <R> OnnxInferenceModel.inferUsing(
public inline fun <M : ExecutionProviderCompatible, R> M.inferUsing(
vararg providers: ExecutionProvider,
block: (OnnxInferenceModel) -> R
block: (M) -> R
): R {
this.initializeWith(*providers)
return this.run(block)
}

public inline fun <R> OnnxInferenceModel.inferUsing(
providers: List<ExecutionProvider>,
block: (OnnxInferenceModel) -> R
): R {
this.initializeWith(*providers.toTypedArray())
return this.run(block)
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,71 +65,37 @@ public abstract class EfficientDetObjectDetectionModelBase<I> : ObjectDetectionM
}
}

/**
* Base class for object detection model based on SSD-MobilNet architecture.
*/
public abstract class SSDMobileNetObjectDetectionModelBase<I> : ObjectDetectionModelBase<I>() {

override fun convert(output: Map<String, Any>): List<DetectedObject> {
val foundObjects = mutableListOf<DetectedObject>()
val boxes = (output[OUTPUT_BOXES] as Array<Array<FloatArray>>)[0]
val classIndices = (output[OUTPUT_CLASSES] as Array<FloatArray>)[0]
val probabilities = (output[OUTPUT_SCORES] as Array<FloatArray>)[0]
val numberOfFoundObjects = (output[OUTPUT_NUMBER_OF_DETECTIONS] as FloatArray)[0].toInt()

for (i in 0 until numberOfFoundObjects) {
val detectedObject = DetectedObject(
classLabel = classLabels[classIndices[i].toInt()]!!,
probability = probabilities[i],
// top, left, bottom, right
yMin = boxes[i][0],
xMin = boxes[i][1],
yMax = boxes[i][2],
xMax = boxes[i][3]
)
foundObjects.add(detectedObject)
}
return foundObjects
}

private companion object {
private const val OUTPUT_BOXES = "detection_boxes:0"
private const val OUTPUT_CLASSES = "detection_classes:0"
private const val OUTPUT_SCORES = "detection_scores:0"
private const val OUTPUT_NUMBER_OF_DETECTIONS = "num_detections:0"
}
}

/**
* Base class for object detection model based on SSD architecture.
*/
public abstract class SSDObjectDetectionModelBase<I> : ObjectDetectionModelBase<I>() {

public abstract class SSDObjectDetectionModelBase<I>(protected val metadata: SSDModelMetadata) : ObjectDetectionModelBase<I>() {
override fun convert(output: Map<String, Any>): List<DetectedObject> {
val boxes = (output[OUTPUT_BOXES] as Array<Array<FloatArray>>)[0]
val classIndices = (output[OUTPUT_LABELS] as Array<LongArray>)[0]
val probabilities = (output[OUTPUT_SCORES] as Array<FloatArray>)[0]
val boxes = (output[metadata.outputBoxesName] as Array<Array<FloatArray>>)[0]
val classIndices = (output[metadata.outputClassesName] as Array<FloatArray>)[0]
val probabilities = (output[metadata.outputScoresName] as Array<FloatArray>)[0]
val numberOfFoundObjects = boxes.size

val foundObjects = mutableListOf<DetectedObject>()
for (i in 0 until numberOfFoundObjects) {
val detectedObject = DetectedObject(
classLabel = classLabels[classIndices[i].toInt()]!!,
classLabel = classLabels[classIndices[i].toInt()] ?: "Unknown",
probability = probabilities[i],
// left, bot, right, top
xMin = boxes[i][0],
yMin = boxes[i][1],
xMax = boxes[i][2],
yMax = boxes[i][3]
xMin = boxes[i][metadata.xMinIdx],
yMin = boxes[i][metadata.yMinIdx],
xMax = boxes[i][metadata.xMinIdx + 2],
yMax = boxes[i][metadata.yMinIdx + 2]
)
foundObjects.add(detectedObject)
}
return foundObjects
}
}

private companion object {
private const val OUTPUT_BOXES = "bboxes"
private const val OUTPUT_LABELS = "labels"
private const val OUTPUT_SCORES = "scores"
}
}
public data class SSDModelMetadata(
public val outputBoxesName: String,
public val outputClassesName: String,
public val outputScoresName: String,
public val yMinIdx: Int,
public val xMinIdx: Int
)
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ import org.jetbrains.kotlinx.dl.api.inference.InferenceModel
import org.jetbrains.kotlinx.dl.api.inference.objectdetection.DetectedObject
import org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels
import org.jetbrains.kotlinx.dl.api.inference.onnx.OnnxInferenceModel
import org.jetbrains.kotlinx.dl.dataset.handler.cocoCategories
import org.jetbrains.kotlinx.dl.dataset.Coco
import org.jetbrains.kotlinx.dl.dataset.CocoVersion.V2017
import org.jetbrains.kotlinx.dl.dataset.image.ColorMode
import org.jetbrains.kotlinx.dl.dataset.image.ImageConverter
import org.jetbrains.kotlinx.dl.dataset.preprocessing.Operation
Expand Down Expand Up @@ -45,7 +46,7 @@ public class EfficientDetObjectDetectionModel(override val internalModel: OnnxIn
// model is quite sensitive for this
.convert { colorMode = ColorMode.RGB }
.toFloatArray { }
override val classLabels: Map<Int, String> = cocoCategories
override val classLabels: Map<Int, String> = Coco.V2017.labels(zeroIndexed = false)

/**
* Constructs the object detection model from a given path.
Expand Down
Loading

0 comments on commit 9ff47d7

Please sign in to comment.