Skip to content

Commit

Permalink
Fix ssdLightAPITest (Kotlin#440)
Browse files Browse the repository at this point in the history
  • Loading branch information
ermolenkodev committed Sep 7, 2022
1 parent b21b786 commit 67af5c6
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package org.jetbrains.kotlinx.dl.api.inference.onnx.objectdetection

import android.graphics.Bitmap
import org.jetbrains.kotlinx.dl.api.inference.InferenceModel
import org.jetbrains.kotlinx.dl.api.inference.objectdetection.DetectedObject
import org.jetbrains.kotlinx.dl.api.inference.onnx.OnnxInferenceModel
import org.jetbrains.kotlinx.dl.api.inference.onnx.executionproviders.ExecutionProvider.CPU
import org.jetbrains.kotlinx.dl.dataset.Coco
Expand All @@ -26,13 +27,36 @@ public class SSDMobileNetObjectDetectionModel(override val internalModel: OnnxIn

private var targetRotation = 0f

override lateinit var preprocessing: Operation<Bitmap, Pair<FloatArray, TensorShape>>
private set

public constructor (modelBytes: ByteArray) : this(OnnxInferenceModel(modelBytes)) {
internalModel.initializeWith(CPU())
preprocessing = buildPreprocessingPipeline()
}

override lateinit var preprocessing: Operation<Bitmap, Pair<FloatArray, TensorShape>>
private set
// TODO remove code duplication due to different type of class labels array
override fun convert(output: Map<String, Any>): List<DetectedObject> {
val boxes = (output[metadata.outputBoxesName] as Array<Array<FloatArray>>)[0]
val classIndices = (output[metadata.outputClassesName] as Array<FloatArray>)[0]
val probabilities = (output[metadata.outputScoresName] as Array<FloatArray>)[0]
val numberOfFoundObjects = boxes.size

val foundObjects = mutableListOf<DetectedObject>()
for (i in 0 until numberOfFoundObjects) {
val detectedObject = DetectedObject(
classLabel = if (classIndices[i].toInt() in classLabels.keys) classLabels[classIndices[i].toInt()]!! else "Unknown",
probability = probabilities[i],
// left, bot, right, top
xMin = boxes[i][metadata.xMinIdx],
yMin = boxes[i][metadata.yMinIdx],
xMax = boxes[i][metadata.xMinIdx + 2],
yMax = boxes[i][metadata.yMinIdx + 2]
)
foundObjects.add(detectedObject)
}
return foundObjects
}

public fun setTargetRotation(targetRotation: Float) {
if (this.targetRotation == targetRotation) return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ public abstract class EfficientDetObjectDetectionModelBase<I> : ObjectDetectionM
/**
* Base class for object detection model based on SSD architecture.
*/
public abstract class SSDObjectDetectionModelBase<I>(private val metadata: SSDModelMetadata) : ObjectDetectionModelBase<I>() {
public abstract class SSDObjectDetectionModelBase<I>(protected val metadata: SSDModelMetadata) : ObjectDetectionModelBase<I>() {

override fun convert(output: Map<String, Any>): List<DetectedObject> {
val boxes = (output[metadata.outputBoxesName] as Array<Array<FloatArray>>)[0]
val classIndices = (output[metadata.outputClassesName] as Array<FloatArray>)[0]
val classIndices = (output[metadata.outputClassesName] as Array<LongArray>)[0]
val probabilities = (output[metadata.outputScoresName] as Array<FloatArray>)[0]
val numberOfFoundObjects = boxes.size

Expand Down

0 comments on commit 67af5c6

Please sign in to comment.