From e4907ea24c2ab9191c5b3104e52819bc95c5c458 Mon Sep 17 00:00:00 2001 From: Nikita Ermolenko Date: Wed, 7 Sep 2022 15:26:30 +0300 Subject: [PATCH] Fix ssdLightAPITest (#440) * Turn Coco from class to enum --- .../jetbrains/kotlinx/dl/dataset/CocoUtils.kt | 15 +++++++---- .../SSDMobileNetObjectDetectionModel.kt | 9 ++++--- .../ObjectDetectionModelBase.kt | 3 +-- .../EfficientDetObjectDetectionModel.kt | 2 +- .../SSDMobileNetV1ObjectDetectionModel.kt | 2 +- .../SSDObjectDetectionModel.kt | 26 ++++++++++++++++++- 6 files changed, 43 insertions(+), 14 deletions(-) diff --git a/dataset/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/dataset/CocoUtils.kt b/dataset/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/dataset/CocoUtils.kt index c2cd52538..c5a81fae1 100644 --- a/dataset/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/dataset/CocoUtils.kt +++ b/dataset/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/dataset/CocoUtils.kt @@ -5,10 +5,16 @@ package org.jetbrains.kotlinx.dl.dataset -public class Coco(public val version: CocoVersion, zeroIndexed: Boolean = false) { - public val labels: Map = when (version) { - CocoVersion.V2014 -> if (zeroIndexed) toZeroIndexed(cocoCategories2014) else cocoCategories2014 - CocoVersion.V2017 -> if (zeroIndexed) toZeroIndexed(cocoCategories2017) else cocoCategories2017 + +public enum class Coco { + V2014, + V2017; + + public fun labels(zeroIndexed: Boolean = false) : Map { + return when (this) { + V2014 -> if (zeroIndexed) toZeroIndexed(cocoCategories2014) else cocoCategories2014 + V2017 -> if (zeroIndexed) toZeroIndexed(cocoCategories2017) else cocoCategories2017 + } } private fun toZeroIndexed(labels: Map) : Map { @@ -20,7 +26,6 @@ public class Coco(public val version: CocoVersion, zeroIndexed: Boolean = false) } } - /** * 80 object categories in COCO dataset. * diff --git a/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/SSDMobileNetObjectDetectionModel.kt b/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/SSDMobileNetObjectDetectionModel.kt index 50493df53..9d67b3fd4 100644 --- a/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/SSDMobileNetObjectDetectionModel.kt +++ b/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/SSDMobileNetObjectDetectionModel.kt @@ -2,6 +2,7 @@ package org.jetbrains.kotlinx.dl.api.inference.onnx.objectdetection import android.graphics.Bitmap import org.jetbrains.kotlinx.dl.api.inference.InferenceModel +import org.jetbrains.kotlinx.dl.api.inference.objectdetection.DetectedObject import org.jetbrains.kotlinx.dl.api.inference.onnx.OnnxInferenceModel import org.jetbrains.kotlinx.dl.api.inference.onnx.executionproviders.ExecutionProvider.CPU import org.jetbrains.kotlinx.dl.dataset.Coco @@ -22,18 +23,18 @@ public class SSDMobileNetObjectDetectionModel(override val internalModel: OnnxIn SSDObjectDetectionModelBase(SSD_MOBILENET_METADATA), InferenceModel by internalModel { - override val classLabels: Map = Coco(V2017, zeroIndexed = true).labels + override val classLabels: Map = Coco.V2017.labels(zeroIndexed = true) private var targetRotation = 0f + override lateinit var preprocessing: Operation> + private set + public constructor (modelBytes: ByteArray) : this(OnnxInferenceModel(modelBytes)) { internalModel.initializeWith(CPU()) preprocessing = buildPreprocessingPipeline() } - override lateinit var preprocessing: Operation> - private set - public fun setTargetRotation(targetRotation: Float) { if (this.targetRotation == targetRotation) return diff --git a/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/ObjectDetectionModelBase.kt b/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/ObjectDetectionModelBase.kt index 5cc1c97dd..3427326e2 100644 --- a/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/ObjectDetectionModelBase.kt +++ b/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/ObjectDetectionModelBase.kt @@ -68,8 +68,7 @@ public abstract class EfficientDetObjectDetectionModelBase : ObjectDetectionM /** * Base class for object detection model based on SSD architecture. */ -public abstract class SSDObjectDetectionModelBase(private val metadata: SSDModelMetadata) : ObjectDetectionModelBase() { - +public abstract class SSDObjectDetectionModelBase(protected val metadata: SSDModelMetadata) : ObjectDetectionModelBase() { override fun convert(output: Map): List { val boxes = (output[metadata.outputBoxesName] as Array>)[0] val classIndices = (output[metadata.outputClassesName] as Array)[0] diff --git a/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/EfficientDetObjectDetectionModel.kt b/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/EfficientDetObjectDetectionModel.kt index d0c93a933..76f22d520 100644 --- a/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/EfficientDetObjectDetectionModel.kt +++ b/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/EfficientDetObjectDetectionModel.kt @@ -46,7 +46,7 @@ public class EfficientDetObjectDetectionModel(override val internalModel: OnnxIn // model is quite sensitive for this .convert { colorMode = ColorMode.RGB } .toFloatArray { } - override val classLabels: Map = Coco(V2017).labels + override val classLabels: Map = Coco.V2017.labels(zeroIndexed = false) /** * Constructs the object detection model from a given path. diff --git a/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/SSDMobileNetV1ObjectDetectionModel.kt b/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/SSDMobileNetV1ObjectDetectionModel.kt index c380f2113..74fe0f2de 100644 --- a/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/SSDMobileNetV1ObjectDetectionModel.kt +++ b/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/SSDMobileNetV1ObjectDetectionModel.kt @@ -55,7 +55,7 @@ public class SSDMobileNetV1ObjectDetectionModel(override val internalModel: Onnx .toFloatArray { } .call(ONNXModels.ObjectDetection.SSDMobileNetV1.preprocessor) - override val classLabels: Map = Coco(V2017).labels + override val classLabels: Map = Coco.V2017.labels() /** * Constructs the object detection model from a given path. diff --git a/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/SSDObjectDetectionModel.kt b/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/SSDObjectDetectionModel.kt index 929d38615..500f8c298 100644 --- a/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/SSDObjectDetectionModel.kt +++ b/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/SSDObjectDetectionModel.kt @@ -58,7 +58,7 @@ public class SSDObjectDetectionModel(override val internalModel: OnnxInferenceMo .toFloatArray { } .call(ONNXModels.ObjectDetection.SSD.preprocessor) - override val classLabels: Map = Coco(CocoVersion.V2014).labels + override val classLabels: Map = Coco.V2014.labels() /** * Constructs the object detection model from a given path. @@ -80,6 +80,30 @@ public class SSDObjectDetectionModel(override val internalModel: OnnxInferenceMo return detectObjects(ImageConverter.toBufferedImage(imageFile), topK) } + // TODO remove code duplication due to different type of class labels array + override fun convert(output: Map): List { + val boxes = (output[metadata.outputBoxesName] as Array>)[0] + val classIndices = (output[metadata.outputClassesName] as Array)[0] + val probabilities = (output[metadata.outputScoresName] as Array)[0] + val numberOfFoundObjects = boxes.size + + val foundObjects = mutableListOf() + for (i in 0 until numberOfFoundObjects) { + val detectedObject = DetectedObject( + classLabel = if (classIndices[i].toInt() in classLabels.keys) classLabels[classIndices[i].toInt()]!! else "Unknown", + probability = probabilities[i], + // left, bot, right, top + xMin = boxes[i][metadata.xMinIdx], + yMin = boxes[i][metadata.yMinIdx], + xMax = boxes[i][metadata.xMinIdx + 2], + yMax = boxes[i][metadata.yMinIdx + 2] + ) + foundObjects.add(detectedObject) + } + return foundObjects + } + + override fun copy( copiedModelName: String?, saveOptimizerState: Boolean,