From b21b78632805056f2bcd06db4a3860be76cc2885 Mon Sep 17 00:00:00 2001 From: Nikita Ermolenko Date: Wed, 7 Sep 2022 14:55:23 +0300 Subject: [PATCH] Add SSD model for android (#440) * Add onnxruntime-mobile dependency for androidMain * Reduc code duplication for SSD models code on JVM and Android platforms * Add support for .inferUsing API for OnnxHighLevelModel (#434) * Add support for zero indexed COCO labels --- .../kotlinx/dl/dataset}/CocoUtils.kt | 28 +++++++- .../examples/onnx/cv/predictionRunner.kt | 2 +- onnx/build.gradle | 3 + .../SSDMobileNetObjectDetectionModel.kt | 61 +++++++++++++++++ .../onnx/ExecutionProviderCompatible.kt | 8 +++ .../api/inference/onnx/OnnxHighLevelModel.kt | 9 ++- .../api/inference/onnx/OnnxInferenceModel.kt | 6 +- .../inference/onnx/OnnxInferenceModelEx.kt | 32 ++++----- .../ObjectDetectionModelBase.kt | 67 +++++-------------- .../EfficientDetObjectDetectionModel.kt | 5 +- .../SSDMobileNetV1ObjectDetectionModel.kt | 16 ++++- .../SSDObjectDetectionModel.kt | 10 ++- 12 files changed, 161 insertions(+), 86 deletions(-) rename {api/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/dataset/handler => dataset/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/dataset}/CocoUtils.kt (82%) create mode 100644 onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/SSDMobileNetObjectDetectionModel.kt create mode 100644 onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/ExecutionProviderCompatible.kt diff --git a/api/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/dataset/handler/CocoUtils.kt b/dataset/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/dataset/CocoUtils.kt similarity index 82% rename from api/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/dataset/handler/CocoUtils.kt rename to dataset/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/dataset/CocoUtils.kt index 49876d398..c2cd52538 100644 --- a/api/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/dataset/handler/CocoUtils.kt +++ b/dataset/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/dataset/CocoUtils.kt @@ -3,7 +3,23 @@ * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. */ -package org.jetbrains.kotlinx.dl.dataset.handler +package org.jetbrains.kotlinx.dl.dataset + +public class Coco(public val version: CocoVersion, zeroIndexed: Boolean = false) { + public val labels: Map = when (version) { + CocoVersion.V2014 -> if (zeroIndexed) toZeroIndexed(cocoCategories2014) else cocoCategories2014 + CocoVersion.V2017 -> if (zeroIndexed) toZeroIndexed(cocoCategories2017) else cocoCategories2017 + } + + private fun toZeroIndexed(labels: Map) : Map { + val zeroIndexedLabels = mutableMapOf() + labels.forEach { (key, value) -> + zeroIndexedLabels[key - 1] = value + } + return zeroIndexedLabels + } +} + /** * 80 object categories in COCO dataset. @@ -14,7 +30,7 @@ package org.jetbrains.kotlinx.dl.dataset.handler * @see * COCO dataset */ -public val cocoCategoriesForSSD: Map = mapOf( +public val cocoCategories2014: Map = mapOf( 1 to "person", 2 to "bicycle", 3 to "car", @@ -104,7 +120,7 @@ public val cocoCategoriesForSSD: Map = mapOf( * @see * COCO dataset */ -public val cocoCategories: Map = mapOf( +public val cocoCategories2017: Map = mapOf( 1 to "person", 2 to "bicycle", 3 to "car", @@ -186,3 +202,9 @@ public val cocoCategories: Map = mapOf( 89 to "hair drier", 90 to "toothbrush" ) + + +public enum class CocoVersion { + V2014, + V2017 +} diff --git a/examples/src/main/kotlin/examples/onnx/cv/predictionRunner.kt b/examples/src/main/kotlin/examples/onnx/cv/predictionRunner.kt index a8e7d529b..33e8f6832 100644 --- a/examples/src/main/kotlin/examples/onnx/cv/predictionRunner.kt +++ b/examples/src/main/kotlin/examples/onnx/cv/predictionRunner.kt @@ -55,7 +55,7 @@ fun runImageRecognitionPrediction( } return if (executionProviders.isNotEmpty()) { - model.inferAndCloseUsing(executionProviders) { inference(it) } + model.inferAndCloseUsing(*executionProviders.toTypedArray()) { inference(it) } } else { model.use { inference(it) } } diff --git a/onnx/build.gradle b/onnx/build.gradle index 14464d06e..51152182c 100644 --- a/onnx/build.gradle +++ b/onnx/build.gradle @@ -39,6 +39,9 @@ kotlin { } } androidMain { + dependencies { + api 'com.microsoft.onnxruntime:onnxruntime-mobile:latest.release' + } } } explicitApiWarning() diff --git a/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/SSDMobileNetObjectDetectionModel.kt b/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/SSDMobileNetObjectDetectionModel.kt new file mode 100644 index 000000000..50493df53 --- /dev/null +++ b/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/SSDMobileNetObjectDetectionModel.kt @@ -0,0 +1,61 @@ +package org.jetbrains.kotlinx.dl.api.inference.onnx.objectdetection + +import android.graphics.Bitmap +import org.jetbrains.kotlinx.dl.api.inference.InferenceModel +import org.jetbrains.kotlinx.dl.api.inference.onnx.OnnxInferenceModel +import org.jetbrains.kotlinx.dl.api.inference.onnx.executionproviders.ExecutionProvider.CPU +import org.jetbrains.kotlinx.dl.dataset.Coco +import org.jetbrains.kotlinx.dl.dataset.CocoVersion.V2017 +import org.jetbrains.kotlinx.dl.dataset.preprocessing.* +import org.jetbrains.kotlinx.dl.dataset.shape.TensorShape + + +private val SSD_MOBILENET_METADATA = SSDModelMetadata( + "TFLite_Detection_PostProcess", + "TFLite_Detection_PostProcess:1", + "TFLite_Detection_PostProcess:2", + 0, 1 +) + + +public class SSDMobileNetObjectDetectionModel(override val internalModel: OnnxInferenceModel) : + SSDObjectDetectionModelBase(SSD_MOBILENET_METADATA), + InferenceModel by internalModel { + + override val classLabels: Map = Coco(V2017, zeroIndexed = true).labels + + private var targetRotation = 0f + + public constructor (modelBytes: ByteArray) : this(OnnxInferenceModel(modelBytes)) { + internalModel.initializeWith(CPU()) + preprocessing = buildPreprocessingPipeline() + } + + override lateinit var preprocessing: Operation> + private set + + public fun setTargetRotation(targetRotation: Float) { + if (this.targetRotation == targetRotation) return + + this.targetRotation = targetRotation + preprocessing = buildPreprocessingPipeline() + } + + private fun buildPreprocessingPipeline(): Operation> { + return pipeline() + .resize { + outputHeight = inputDimensions[0].toInt() + outputWidth = inputDimensions[1].toInt() + } + .rotate { degrees = targetRotation } + .toFloatArray { layout = TensorLayout.NHWC } + } + + override fun copy( + copiedModelName: String?, + saveOptimizerState: Boolean, + copyWeights: Boolean + ): SSDMobileNetObjectDetectionModel { + return SSDMobileNetObjectDetectionModel(internalModel.copy(copiedModelName, saveOptimizerState, copyWeights)) + } +} diff --git a/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/ExecutionProviderCompatible.kt b/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/ExecutionProviderCompatible.kt new file mode 100644 index 000000000..3c398555c --- /dev/null +++ b/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/ExecutionProviderCompatible.kt @@ -0,0 +1,8 @@ +package org.jetbrains.kotlinx.dl.api.inference.onnx + +import org.jetbrains.kotlinx.dl.api.inference.onnx.executionproviders.ExecutionProvider +import org.jetbrains.kotlinx.dl.api.inference.onnx.executionproviders.ExecutionProvider.CPU + +public interface ExecutionProviderCompatible { + public fun initializeWith(vararg executionProviders: ExecutionProvider = arrayOf(CPU(true))) +} diff --git a/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/OnnxHighLevelModel.kt b/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/OnnxHighLevelModel.kt index d24347449..476b35772 100644 --- a/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/OnnxHighLevelModel.kt +++ b/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/OnnxHighLevelModel.kt @@ -5,6 +5,7 @@ package org.jetbrains.kotlinx.dl.api.inference.onnx +import org.jetbrains.kotlinx.dl.api.inference.onnx.executionproviders.ExecutionProvider import org.jetbrains.kotlinx.dl.dataset.preprocessing.Operation import org.jetbrains.kotlinx.dl.dataset.shape.TensorShape @@ -14,7 +15,7 @@ import org.jetbrains.kotlinx.dl.dataset.shape.TensorShape * @param [I] input type * @param [R] output type */ -public interface OnnxHighLevelModel { +public interface OnnxHighLevelModel : AutoCloseable, ExecutionProviderCompatible { /** * Model used to make predictions. */ @@ -38,4 +39,8 @@ public interface OnnxHighLevelModel { val output = internalModel.predictRaw(preprocessedInput.first) return convert(output) } -} \ No newline at end of file + + override fun initializeWith(vararg executionProviders: ExecutionProvider) { + internalModel.initializeWith(*executionProviders) + } +} diff --git a/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/OnnxInferenceModel.kt b/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/OnnxInferenceModel.kt index 515671613..4a2698c91 100644 --- a/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/OnnxInferenceModel.kt +++ b/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/OnnxInferenceModel.kt @@ -23,7 +23,7 @@ private const val RESHAPE_MISSED_MESSAGE = "Model input shape is not defined. Ca * * @since 0.3 */ -public open class OnnxInferenceModel private constructor(private val modelSource: ModelSource) : InferenceModel { +public open class OnnxInferenceModel private constructor(private val modelSource: ModelSource) : InferenceModel, ExecutionProviderCompatible { /** * The host object for the onnx-runtime system. Can create [session] which encapsulate * specific models. @@ -104,7 +104,7 @@ public open class OnnxInferenceModel private constructor(private val modelSource * * @param executionProviders list of execution providers to use. */ - public fun initializeWith(vararg executionProviders: ExecutionProvider = arrayOf(CPU(true))) { + public override fun initializeWith(vararg executionProviders: ExecutionProvider) { val uniqueProviders = collectProviders(executionProviders) if (::executionProvidersInUse.isInitialized && uniqueProviders == executionProvidersInUse) { @@ -169,11 +169,13 @@ public open class OnnxInferenceModel private constructor(private val modelSource 0 -> { uniqueProviders.add(CPU(true)) } + 1 -> { val cpu = uniqueProviders.first { it is CPU } uniqueProviders.remove(cpu) uniqueProviders.add(cpu) } + else -> throw IllegalArgumentException("Unable to use CPU(useArena = true) and CPU(useArena = false) at the same time!") } return uniqueProviders diff --git a/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/OnnxInferenceModelEx.kt b/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/OnnxInferenceModelEx.kt index 60dd23afa..7829b8b87 100644 --- a/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/OnnxInferenceModelEx.kt +++ b/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/OnnxInferenceModelEx.kt @@ -7,34 +7,26 @@ import org.jetbrains.kotlinx.dl.api.inference.onnx.executionproviders.ExecutionP * Convenience extension functions for inference of ONNX models using different execution providers. */ -public inline fun OnnxInferenceModel.inferAndCloseUsing( +public inline fun M.inferAndCloseUsing( vararg providers: ExecutionProvider, - block: (OnnxInferenceModel) -> R + block: (M) -> R ): R { - this.initializeWith(*providers) - return this.use(block) -} + when (this) { + is ExecutionProviderCompatible -> this.initializeWith(*providers) + else -> throw IllegalArgumentException("Unsupported model type: ${M::class.simpleName}") + } -public inline fun OnnxInferenceModel.inferAndCloseUsing( - providers: List, - block: (OnnxInferenceModel) -> R -): R { - this.initializeWith(*providers.toTypedArray()) return this.use(block) } -public inline fun OnnxInferenceModel.inferUsing( +public inline fun M.inferUsing( vararg providers: ExecutionProvider, - block: (OnnxInferenceModel) -> R + block: (M) -> R ): R { - this.initializeWith(*providers) - return this.run(block) -} + when (this) { + is ExecutionProviderCompatible -> this.initializeWith(*providers) + else -> throw IllegalArgumentException("Unsupported model type: ${M::class.simpleName}") + } -public inline fun OnnxInferenceModel.inferUsing( - providers: List, - block: (OnnxInferenceModel) -> R -): R { - this.initializeWith(*providers.toTypedArray()) return this.run(block) } diff --git a/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/ObjectDetectionModelBase.kt b/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/ObjectDetectionModelBase.kt index b09a54e52..5cc1c97dd 100644 --- a/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/ObjectDetectionModelBase.kt +++ b/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/ObjectDetectionModelBase.kt @@ -65,71 +65,38 @@ public abstract class EfficientDetObjectDetectionModelBase : ObjectDetectionM } } -/** - * Base class for object detection model based on SSD-MobilNet architecture. - */ -public abstract class SSDMobileNetObjectDetectionModelBase : ObjectDetectionModelBase() { - - override fun convert(output: Map): List { - val foundObjects = mutableListOf() - val boxes = (output[OUTPUT_BOXES] as Array>)[0] - val classIndices = (output[OUTPUT_CLASSES] as Array)[0] - val probabilities = (output[OUTPUT_SCORES] as Array)[0] - val numberOfFoundObjects = (output[OUTPUT_NUMBER_OF_DETECTIONS] as FloatArray)[0].toInt() - - for (i in 0 until numberOfFoundObjects) { - val detectedObject = DetectedObject( - classLabel = classLabels[classIndices[i].toInt()]!!, - probability = probabilities[i], - // top, left, bottom, right - yMin = boxes[i][0], - xMin = boxes[i][1], - yMax = boxes[i][2], - xMax = boxes[i][3] - ) - foundObjects.add(detectedObject) - } - return foundObjects - } - - private companion object { - private const val OUTPUT_BOXES = "detection_boxes:0" - private const val OUTPUT_CLASSES = "detection_classes:0" - private const val OUTPUT_SCORES = "detection_scores:0" - private const val OUTPUT_NUMBER_OF_DETECTIONS = "num_detections:0" - } -} - /** * Base class for object detection model based on SSD architecture. */ -public abstract class SSDObjectDetectionModelBase : ObjectDetectionModelBase() { +public abstract class SSDObjectDetectionModelBase(private val metadata: SSDModelMetadata) : ObjectDetectionModelBase() { override fun convert(output: Map): List { - val boxes = (output[OUTPUT_BOXES] as Array>)[0] - val classIndices = (output[OUTPUT_LABELS] as Array)[0] - val probabilities = (output[OUTPUT_SCORES] as Array)[0] + val boxes = (output[metadata.outputBoxesName] as Array>)[0] + val classIndices = (output[metadata.outputClassesName] as Array)[0] + val probabilities = (output[metadata.outputScoresName] as Array)[0] val numberOfFoundObjects = boxes.size val foundObjects = mutableListOf() for (i in 0 until numberOfFoundObjects) { val detectedObject = DetectedObject( - classLabel = classLabels[classIndices[i].toInt()]!!, + classLabel = if (classIndices[i].toInt() in classLabels.keys) classLabels[classIndices[i].toInt()]!! else "Unknown", probability = probabilities[i], // left, bot, right, top - xMin = boxes[i][0], - yMin = boxes[i][1], - xMax = boxes[i][2], - yMax = boxes[i][3] + xMin = boxes[i][metadata.xMinIdx], + yMin = boxes[i][metadata.yMinIdx], + xMax = boxes[i][metadata.xMinIdx + 2], + yMax = boxes[i][metadata.yMinIdx + 2] ) foundObjects.add(detectedObject) } return foundObjects } +} - private companion object { - private const val OUTPUT_BOXES = "bboxes" - private const val OUTPUT_LABELS = "labels" - private const val OUTPUT_SCORES = "scores" - } -} \ No newline at end of file +public data class SSDModelMetadata( + public val outputBoxesName: String, + public val outputClassesName: String, + public val outputScoresName: String, + public val yMinIdx: Int, + public val xMinIdx: Int +) diff --git a/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/EfficientDetObjectDetectionModel.kt b/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/EfficientDetObjectDetectionModel.kt index ab5b8d6fe..d0c93a933 100644 --- a/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/EfficientDetObjectDetectionModel.kt +++ b/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/EfficientDetObjectDetectionModel.kt @@ -9,7 +9,8 @@ import org.jetbrains.kotlinx.dl.api.inference.InferenceModel import org.jetbrains.kotlinx.dl.api.inference.objectdetection.DetectedObject import org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels import org.jetbrains.kotlinx.dl.api.inference.onnx.OnnxInferenceModel -import org.jetbrains.kotlinx.dl.dataset.handler.cocoCategories +import org.jetbrains.kotlinx.dl.dataset.Coco +import org.jetbrains.kotlinx.dl.dataset.CocoVersion.V2017 import org.jetbrains.kotlinx.dl.dataset.image.ColorMode import org.jetbrains.kotlinx.dl.dataset.image.ImageConverter import org.jetbrains.kotlinx.dl.dataset.preprocessing.Operation @@ -45,7 +46,7 @@ public class EfficientDetObjectDetectionModel(override val internalModel: OnnxIn // model is quite sensitive for this .convert { colorMode = ColorMode.RGB } .toFloatArray { } - override val classLabels: Map = cocoCategories + override val classLabels: Map = Coco(V2017).labels /** * Constructs the object detection model from a given path. diff --git a/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/SSDMobileNetV1ObjectDetectionModel.kt b/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/SSDMobileNetV1ObjectDetectionModel.kt index 967dc53bc..c380f2113 100644 --- a/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/SSDMobileNetV1ObjectDetectionModel.kt +++ b/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/SSDMobileNetV1ObjectDetectionModel.kt @@ -9,7 +9,8 @@ import org.jetbrains.kotlinx.dl.api.inference.InferenceModel import org.jetbrains.kotlinx.dl.api.inference.objectdetection.DetectedObject import org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels import org.jetbrains.kotlinx.dl.api.inference.onnx.OnnxInferenceModel -import org.jetbrains.kotlinx.dl.dataset.handler.cocoCategories +import org.jetbrains.kotlinx.dl.dataset.Coco +import org.jetbrains.kotlinx.dl.dataset.CocoVersion.V2017 import org.jetbrains.kotlinx.dl.dataset.image.ColorMode import org.jetbrains.kotlinx.dl.dataset.image.ImageConverter import org.jetbrains.kotlinx.dl.dataset.preprocessing.Operation @@ -23,6 +24,14 @@ import java.awt.image.BufferedImage import java.io.File import java.io.IOException + +private val SSD_MOBILENET_METADATA = SSDModelMetadata( + "detection_boxes:0", + "detection_classes:0", + "detection_scores:0", + 0, 1 +) + /** * Special model class for detection objects on images * with built-in preprocessing and post-processing. @@ -34,7 +43,7 @@ import java.io.IOException * @since 0.4 */ public class SSDMobileNetV1ObjectDetectionModel(override val internalModel: OnnxInferenceModel) : - SSDMobileNetObjectDetectionModelBase(), InferenceModel by internalModel { + SSDObjectDetectionModelBase(SSD_MOBILENET_METADATA), InferenceModel by internalModel { override val preprocessing: Operation> get() = pipeline() @@ -45,7 +54,8 @@ public class SSDMobileNetV1ObjectDetectionModel(override val internalModel: Onnx .convert { colorMode = ColorMode.RGB } .toFloatArray { } .call(ONNXModels.ObjectDetection.SSDMobileNetV1.preprocessor) - override val classLabels: Map = cocoCategories + + override val classLabels: Map = Coco(V2017).labels /** * Constructs the object detection model from a given path. diff --git a/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/SSDObjectDetectionModel.kt b/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/SSDObjectDetectionModel.kt index 49269ef82..929d38615 100644 --- a/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/SSDObjectDetectionModel.kt +++ b/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/SSDObjectDetectionModel.kt @@ -9,7 +9,8 @@ import org.jetbrains.kotlinx.dl.api.inference.InferenceModel import org.jetbrains.kotlinx.dl.api.inference.objectdetection.DetectedObject import org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels import org.jetbrains.kotlinx.dl.api.inference.onnx.OnnxInferenceModel -import org.jetbrains.kotlinx.dl.dataset.handler.cocoCategoriesForSSD +import org.jetbrains.kotlinx.dl.dataset.Coco +import org.jetbrains.kotlinx.dl.dataset.CocoVersion import org.jetbrains.kotlinx.dl.dataset.image.ColorMode import org.jetbrains.kotlinx.dl.dataset.image.ImageConverter import org.jetbrains.kotlinx.dl.dataset.preprocessing.Operation @@ -25,6 +26,8 @@ import java.io.IOException private const val INPUT_SIZE = 1200 +private val SSD_RESNET_METADATA = SSDModelMetadata("bboxes", "labels", "scores", 1, 0) + /** * Special model class for detection objects on images * with built-in preprocessing and post-processing. @@ -43,7 +46,7 @@ private const val INPUT_SIZE = 1200 * @since 0.3 */ public class SSDObjectDetectionModel(override val internalModel: OnnxInferenceModel) : - SSDObjectDetectionModelBase(), InferenceModel by internalModel { + SSDObjectDetectionModelBase(SSD_RESNET_METADATA), InferenceModel by internalModel { override val preprocessing: Operation> get() = pipeline() @@ -54,7 +57,8 @@ public class SSDObjectDetectionModel(override val internalModel: OnnxInferenceMo .convert { colorMode = ColorMode.RGB } .toFloatArray { } .call(ONNXModels.ObjectDetection.SSD.preprocessor) - override val classLabels: Map = cocoCategoriesForSSD + + override val classLabels: Map = Coco(CocoVersion.V2014).labels /** * Constructs the object detection model from a given path.