Skip to content

Commit

Permalink
Add SSD model for android (Kotlin#440)
Browse files Browse the repository at this point in the history
* Add onnxruntime-mobile dependency for androidMain

* Reduc code duplication for SSD models code on JVM and Android platforms

* Add support for .inferUsing API for OnnxHighLevelModel (Kotlin#434)

* Add support for zero indexed COCO labels
  • Loading branch information
ermolenkodev committed Sep 7, 2022
1 parent 6ad3ebc commit b21b786
Show file tree
Hide file tree
Showing 12 changed files with 161 additions and 86 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,23 @@
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
*/

package org.jetbrains.kotlinx.dl.dataset.handler
package org.jetbrains.kotlinx.dl.dataset

public class Coco(public val version: CocoVersion, zeroIndexed: Boolean = false) {
public val labels: Map<Int, String> = when (version) {
CocoVersion.V2014 -> if (zeroIndexed) toZeroIndexed(cocoCategories2014) else cocoCategories2014
CocoVersion.V2017 -> if (zeroIndexed) toZeroIndexed(cocoCategories2017) else cocoCategories2017
}

private fun toZeroIndexed(labels: Map<Int, String>) : Map<Int, String> {
val zeroIndexedLabels = mutableMapOf<Int, String>()
labels.forEach { (key, value) ->
zeroIndexedLabels[key - 1] = value
}
return zeroIndexedLabels
}
}


/**
* 80 object categories in COCO dataset.
Expand All @@ -14,7 +30,7 @@ package org.jetbrains.kotlinx.dl.dataset.handler
* @see <a href="https://cocodataset.org/#home">
* COCO dataset</a>
*/
public val cocoCategoriesForSSD: Map<Int, String> = mapOf(
public val cocoCategories2014: Map<Int, String> = mapOf(
1 to "person",
2 to "bicycle",
3 to "car",
Expand Down Expand Up @@ -104,7 +120,7 @@ public val cocoCategoriesForSSD: Map<Int, String> = mapOf(
* @see <a href="https://cocodataset.org/#home">
* COCO dataset</a>
*/
public val cocoCategories: Map<Int, String> = mapOf(
public val cocoCategories2017: Map<Int, String> = mapOf(
1 to "person",
2 to "bicycle",
3 to "car",
Expand Down Expand Up @@ -186,3 +202,9 @@ public val cocoCategories: Map<Int, String> = mapOf(
89 to "hair drier",
90 to "toothbrush"
)


public enum class CocoVersion {
V2014,
V2017
}
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ fun runImageRecognitionPrediction(
}

return if (executionProviders.isNotEmpty()) {
model.inferAndCloseUsing(executionProviders) { inference(it) }
model.inferAndCloseUsing(*executionProviders.toTypedArray()) { inference(it) }
} else {
model.use { inference(it) }
}
Expand Down
3 changes: 3 additions & 0 deletions onnx/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ kotlin {
}
}
androidMain {
dependencies {
api 'com.microsoft.onnxruntime:onnxruntime-mobile:latest.release'
}
}
}
explicitApiWarning()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package org.jetbrains.kotlinx.dl.api.inference.onnx.objectdetection

import android.graphics.Bitmap
import org.jetbrains.kotlinx.dl.api.inference.InferenceModel
import org.jetbrains.kotlinx.dl.api.inference.onnx.OnnxInferenceModel
import org.jetbrains.kotlinx.dl.api.inference.onnx.executionproviders.ExecutionProvider.CPU
import org.jetbrains.kotlinx.dl.dataset.Coco
import org.jetbrains.kotlinx.dl.dataset.CocoVersion.V2017
import org.jetbrains.kotlinx.dl.dataset.preprocessing.*
import org.jetbrains.kotlinx.dl.dataset.shape.TensorShape


private val SSD_MOBILENET_METADATA = SSDModelMetadata(
"TFLite_Detection_PostProcess",
"TFLite_Detection_PostProcess:1",
"TFLite_Detection_PostProcess:2",
0, 1
)


public class SSDMobileNetObjectDetectionModel(override val internalModel: OnnxInferenceModel) :
SSDObjectDetectionModelBase<Bitmap>(SSD_MOBILENET_METADATA),
InferenceModel by internalModel {

override val classLabels: Map<Int, String> = Coco(V2017, zeroIndexed = true).labels

private var targetRotation = 0f

public constructor (modelBytes: ByteArray) : this(OnnxInferenceModel(modelBytes)) {
internalModel.initializeWith(CPU())
preprocessing = buildPreprocessingPipeline()
}

override lateinit var preprocessing: Operation<Bitmap, Pair<FloatArray, TensorShape>>
private set

public fun setTargetRotation(targetRotation: Float) {
if (this.targetRotation == targetRotation) return

this.targetRotation = targetRotation
preprocessing = buildPreprocessingPipeline()
}

private fun buildPreprocessingPipeline(): Operation<Bitmap, Pair<FloatArray, TensorShape>> {
return pipeline<Bitmap>()
.resize {
outputHeight = inputDimensions[0].toInt()
outputWidth = inputDimensions[1].toInt()
}
.rotate { degrees = targetRotation }
.toFloatArray { layout = TensorLayout.NHWC }
}

override fun copy(
copiedModelName: String?,
saveOptimizerState: Boolean,
copyWeights: Boolean
): SSDMobileNetObjectDetectionModel {
return SSDMobileNetObjectDetectionModel(internalModel.copy(copiedModelName, saveOptimizerState, copyWeights))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package org.jetbrains.kotlinx.dl.api.inference.onnx

import org.jetbrains.kotlinx.dl.api.inference.onnx.executionproviders.ExecutionProvider
import org.jetbrains.kotlinx.dl.api.inference.onnx.executionproviders.ExecutionProvider.CPU

public interface ExecutionProviderCompatible {
public fun initializeWith(vararg executionProviders: ExecutionProvider = arrayOf(CPU(true)))
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.jetbrains.kotlinx.dl.api.inference.onnx

import org.jetbrains.kotlinx.dl.api.inference.onnx.executionproviders.ExecutionProvider
import org.jetbrains.kotlinx.dl.dataset.preprocessing.Operation
import org.jetbrains.kotlinx.dl.dataset.shape.TensorShape

Expand All @@ -14,7 +15,7 @@ import org.jetbrains.kotlinx.dl.dataset.shape.TensorShape
* @param [I] input type
* @param [R] output type
*/
public interface OnnxHighLevelModel<I, R> {
public interface OnnxHighLevelModel<I, R> : AutoCloseable, ExecutionProviderCompatible {
/**
* Model used to make predictions.
*/
Expand All @@ -38,4 +39,8 @@ public interface OnnxHighLevelModel<I, R> {
val output = internalModel.predictRaw(preprocessedInput.first)
return convert(output)
}
}

override fun initializeWith(vararg executionProviders: ExecutionProvider) {
internalModel.initializeWith(*executionProviders)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ private const val RESHAPE_MISSED_MESSAGE = "Model input shape is not defined. Ca
*
* @since 0.3
*/
public open class OnnxInferenceModel private constructor(private val modelSource: ModelSource) : InferenceModel {
public open class OnnxInferenceModel private constructor(private val modelSource: ModelSource) : InferenceModel, ExecutionProviderCompatible {
/**
* The host object for the onnx-runtime system. Can create [session] which encapsulate
* specific models.
Expand Down Expand Up @@ -104,7 +104,7 @@ public open class OnnxInferenceModel private constructor(private val modelSource
*
* @param executionProviders list of execution providers to use.
*/
public fun initializeWith(vararg executionProviders: ExecutionProvider = arrayOf(CPU(true))) {
public override fun initializeWith(vararg executionProviders: ExecutionProvider) {
val uniqueProviders = collectProviders(executionProviders)

if (::executionProvidersInUse.isInitialized && uniqueProviders == executionProvidersInUse) {
Expand Down Expand Up @@ -169,11 +169,13 @@ public open class OnnxInferenceModel private constructor(private val modelSource
0 -> {
uniqueProviders.add(CPU(true))
}

1 -> {
val cpu = uniqueProviders.first { it is CPU }
uniqueProviders.remove(cpu)
uniqueProviders.add(cpu)
}

else -> throw IllegalArgumentException("Unable to use CPU(useArena = true) and CPU(useArena = false) at the same time!")
}
return uniqueProviders
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,34 +7,26 @@ import org.jetbrains.kotlinx.dl.api.inference.onnx.executionproviders.ExecutionP
* Convenience extension functions for inference of ONNX models using different execution providers.
*/

public inline fun <R> OnnxInferenceModel.inferAndCloseUsing(
public inline fun <reified M : AutoCloseable, R> M.inferAndCloseUsing(
vararg providers: ExecutionProvider,
block: (OnnxInferenceModel) -> R
block: (M) -> R
): R {
this.initializeWith(*providers)
return this.use(block)
}
when (this) {
is ExecutionProviderCompatible -> this.initializeWith(*providers)
else -> throw IllegalArgumentException("Unsupported model type: ${M::class.simpleName}")
}

public inline fun <R> OnnxInferenceModel.inferAndCloseUsing(
providers: List<ExecutionProvider>,
block: (OnnxInferenceModel) -> R
): R {
this.initializeWith(*providers.toTypedArray())
return this.use(block)
}

public inline fun <R> OnnxInferenceModel.inferUsing(
public inline fun <reified M, R> M.inferUsing(
vararg providers: ExecutionProvider,
block: (OnnxInferenceModel) -> R
block: (M) -> R
): R {
this.initializeWith(*providers)
return this.run(block)
}
when (this) {
is ExecutionProviderCompatible -> this.initializeWith(*providers)
else -> throw IllegalArgumentException("Unsupported model type: ${M::class.simpleName}")
}

public inline fun <R> OnnxInferenceModel.inferUsing(
providers: List<ExecutionProvider>,
block: (OnnxInferenceModel) -> R
): R {
this.initializeWith(*providers.toTypedArray())
return this.run(block)
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,71 +65,38 @@ public abstract class EfficientDetObjectDetectionModelBase<I> : ObjectDetectionM
}
}

/**
* Base class for object detection model based on SSD-MobilNet architecture.
*/
public abstract class SSDMobileNetObjectDetectionModelBase<I> : ObjectDetectionModelBase<I>() {

override fun convert(output: Map<String, Any>): List<DetectedObject> {
val foundObjects = mutableListOf<DetectedObject>()
val boxes = (output[OUTPUT_BOXES] as Array<Array<FloatArray>>)[0]
val classIndices = (output[OUTPUT_CLASSES] as Array<FloatArray>)[0]
val probabilities = (output[OUTPUT_SCORES] as Array<FloatArray>)[0]
val numberOfFoundObjects = (output[OUTPUT_NUMBER_OF_DETECTIONS] as FloatArray)[0].toInt()

for (i in 0 until numberOfFoundObjects) {
val detectedObject = DetectedObject(
classLabel = classLabels[classIndices[i].toInt()]!!,
probability = probabilities[i],
// top, left, bottom, right
yMin = boxes[i][0],
xMin = boxes[i][1],
yMax = boxes[i][2],
xMax = boxes[i][3]
)
foundObjects.add(detectedObject)
}
return foundObjects
}

private companion object {
private const val OUTPUT_BOXES = "detection_boxes:0"
private const val OUTPUT_CLASSES = "detection_classes:0"
private const val OUTPUT_SCORES = "detection_scores:0"
private const val OUTPUT_NUMBER_OF_DETECTIONS = "num_detections:0"
}
}

/**
* Base class for object detection model based on SSD architecture.
*/
public abstract class SSDObjectDetectionModelBase<I> : ObjectDetectionModelBase<I>() {
public abstract class SSDObjectDetectionModelBase<I>(private val metadata: SSDModelMetadata) : ObjectDetectionModelBase<I>() {

override fun convert(output: Map<String, Any>): List<DetectedObject> {
val boxes = (output[OUTPUT_BOXES] as Array<Array<FloatArray>>)[0]
val classIndices = (output[OUTPUT_LABELS] as Array<LongArray>)[0]
val probabilities = (output[OUTPUT_SCORES] as Array<FloatArray>)[0]
val boxes = (output[metadata.outputBoxesName] as Array<Array<FloatArray>>)[0]
val classIndices = (output[metadata.outputClassesName] as Array<FloatArray>)[0]
val probabilities = (output[metadata.outputScoresName] as Array<FloatArray>)[0]
val numberOfFoundObjects = boxes.size

val foundObjects = mutableListOf<DetectedObject>()
for (i in 0 until numberOfFoundObjects) {
val detectedObject = DetectedObject(
classLabel = classLabels[classIndices[i].toInt()]!!,
classLabel = if (classIndices[i].toInt() in classLabels.keys) classLabels[classIndices[i].toInt()]!! else "Unknown",
probability = probabilities[i],
// left, bot, right, top
xMin = boxes[i][0],
yMin = boxes[i][1],
xMax = boxes[i][2],
yMax = boxes[i][3]
xMin = boxes[i][metadata.xMinIdx],
yMin = boxes[i][metadata.yMinIdx],
xMax = boxes[i][metadata.xMinIdx + 2],
yMax = boxes[i][metadata.yMinIdx + 2]
)
foundObjects.add(detectedObject)
}
return foundObjects
}
}

private companion object {
private const val OUTPUT_BOXES = "bboxes"
private const val OUTPUT_LABELS = "labels"
private const val OUTPUT_SCORES = "scores"
}
}
public data class SSDModelMetadata(
public val outputBoxesName: String,
public val outputClassesName: String,
public val outputScoresName: String,
public val yMinIdx: Int,
public val xMinIdx: Int
)
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ import org.jetbrains.kotlinx.dl.api.inference.InferenceModel
import org.jetbrains.kotlinx.dl.api.inference.objectdetection.DetectedObject
import org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels
import org.jetbrains.kotlinx.dl.api.inference.onnx.OnnxInferenceModel
import org.jetbrains.kotlinx.dl.dataset.handler.cocoCategories
import org.jetbrains.kotlinx.dl.dataset.Coco
import org.jetbrains.kotlinx.dl.dataset.CocoVersion.V2017
import org.jetbrains.kotlinx.dl.dataset.image.ColorMode
import org.jetbrains.kotlinx.dl.dataset.image.ImageConverter
import org.jetbrains.kotlinx.dl.dataset.preprocessing.Operation
Expand Down Expand Up @@ -45,7 +46,7 @@ public class EfficientDetObjectDetectionModel(override val internalModel: OnnxIn
// model is quite sensitive for this
.convert { colorMode = ColorMode.RGB }
.toFloatArray { }
override val classLabels: Map<Int, String> = cocoCategories
override val classLabels: Map<Int, String> = Coco(V2017).labels

/**
* Constructs the object detection model from a given path.
Expand Down
Loading

0 comments on commit b21b786

Please sign in to comment.